Max Hallinan

How does the continuation monad work?

I had trouble using the continuation monad until I understood how it works. Here is what I wish I knew from the beginning. The examples are written in PureScript.

What is a continuation?

A continuation is the next step in a computation. Consider this expression:

3 * (2 / 2) + 1

We start by evaluating the sub-expression 2 / 2. Let’s replace it with a placeholder for the value of that expression.

3 * ? + 1

The continuation is everything surrounding the placeholder. We can think of the continuation as a function. Calling the function completes the computation.

-- k is a common name for the continuation

k n = 3 * n + 1

k (2 / 2)

There is always a next step, another continuation, until the computation is finished. But these steps are not something that you, the programmer, can always manipulate. Constructs like if/else and try/catch give you some control over the next step. The continuation pattern gives you full control.

What is continuation-passing style?

Continuation-passing style (CPS) is contrasted with direct style. A function written in direct style takes some arguments and returns a value. Returning a value advances the computation to the next step.

add x y = x + y

Here is the last expression translated to direct style.

mult x y = x * y

add x y = x + y

div x y = x / y

result = add 1 (mult 3 (div 2 2))

A CPS function takes an additional argument, a continuation. Instead of returning a value, it calls the continuation.

add x y k = k (x + y)

Here is the last example written in continuation-passing style.

add x y k = k (x + y)

mult x y k = k (x * y)

div x y k = k (x / y)

step4 k =
  div 2 2 \x ->
    mult 3 x \y ->
      add 1 y k

Each time div, mult, and add are called, a continuation is passed forward. The continuation is called with a value that is used to do the next step of the computation.

What is the ContT type?

Let’s look at the types of the above functions. The type r is the result of calling the continuation.

add :: forall r. Int -> Int -> (Int -> r) -> r

mult :: forall r. Int -> Int -> (Int -> r) -> r

div :: forall r. Int -> Int -> (Int -> r) -> r

We can think of these functions not as returning an r but as returning a function-that-takes-a-continuation-and-returns-an-r.

add :: forall r. Int -> Int -> ((Int -> r) -> r)

mult :: forall r. Int -> Int -> ((Int -> r) -> r)

div :: forall r. Int -> Int -> ((Int -> r) -> r)

That way of thinking is the basis for the ContT type.

type ContT r m a = ContT ((a -> m r) -> m r)

ContT wraps a function that takes a continuation a -> m r and returns any result r in a functor m.

add
  :: forall r m
   . Int
  -> Int
  -> ContT r m Int
add x y = ContT \k -> k (x + y)

mult
  :: forall r m
   . Int
  -> Int
  -> ContT r m Int
mult x y = ContT \k -> k (x * y)

div
  :: forall r m
   . Int
  -> Int
  -> ContT r m Int
div x y = ContT \k -> k (x / y)

runContT
  :: forall r m a
   . ContT r m a
  -> (a -> m r)
  -> m r
runContT (ContT f) k = f k

result :: Identity Int
result =
  runContT (div 2 2) \x ->
    runContT (mult 3 x) \y ->
      runContT (add 1 y) pure

Functions written in a continuation-passing style pass the continuation around explicitly. The use of runContT above is an awkward imitation of that style. We’d rather compose each step and then run them all together.

How does ContT compose?

Each step in our example depends on the result of the previous step. The previous step passes its result to the next step through a continuation. We’re looking for a way to combine steps without having to run ContT in between.

Let’s call this operator andThen.

andThen
  :: forall r m
   . ContT r m Int
  -> (Int -> ContT r m Int)
  -> ContT r m Int

We can generalize andThen to this type.

andThen
  :: forall r m a b
   . Monad m
  => ContT r m a
  -> (a -> ContT r m b)
  -> ContT r m b

Now we can see that andThen is bind.

bind
  :: forall m a b
   . Monad m
  => m a
  -> (a -> m b)
  -> m b

Defining bind for ContT will enable us to perform steps sequentially without so much plumbing.

How is bind defined for ContT

The arguments to bind seem like they should connect but it wasn’t immediately obvious to me how they could.

Start by destructuring ContT.

bindContT (ContT m) f = ...

We have these functions in scope.

m :: (a -> m r) -> m r
f :: a -> ContT r m b

And we want to return a value of this type.

ContT r m b

If we had access to an a, it would be easy. We could just apply a to f.

a :: a
f :: a -> ContT r m b
f a :: ContT r m b

But we don’t have direct access to an a. The only place we find a as an input is the continuation of m.

m :: (a -> m r) -> m r

To gain access to a, we must supply that continuation.

bindContT (ContT m) f = m \a ->

Now we can apply a to f.

bindContT (ContT m) f = m \a -> f a

But this doesn’t compile. The types are wrong. The continuation must be a function a -> m r. And we’ve supplied a function a -> ContT r m b.

We can try to fix this problem by running the result of f a.

bindContT (ContT m) f = m \a -> runContT (f a) ?

There are two problems. First, we must supply a continuation b -> m r to runContT but we don’t have a value of that type in scope. Second, bindContT now returns m r but it should return Cont r m b.

Both problems are fixed in the same way, by wrapping everything in another continuation-passing style function.

bindContT (ContT m) f =
  ContT \k -> m \a -> runContT (f a) k

Here is what happens:

  • Calling m performs the first step, e.g. div 2 2
  • The continuation of m performs the second step, e.g. mult 3 1
  • This is all wrapped in a CPS function with a continuation k
  • The third step is performed by k, e.g. add 1 3

Ergonomic composition of CPS functions

Thus, the continuation monad enables us to perform steps sequentially without threading the continuation explicitly.

result = div 2 2 >>= mult 3 >>= add 1

or

result = do
  step1 <- div 2 2
  step2 <- mult 3 step1
  add 1 step2