Map as a Recursion Scheme in OCaml

2018-03-18

Let us explore a simple recursion scheme in OCaml. To create motivation for it, we will write a few simple compiler passes for a toy language.

You might think—oh, crickets! again these functional programmers with their compilers! gimme some real problems!

First, compilers are the single most researched application of software, so there is existing terminology which can be quickly used to build up a realistic example.

Second, and more importantly, it looks like there are more and more applications of compiler construction methods to other fields, for example, financial instruments. It even may be that every software problem is a compiler problem.

Let’s say we have the following informally specified language:

(e)
| ()
| true | false
| 0 | 1 | 2 |
| id
| e / e
| e; e
| let id = e in e
| if e then e else e

We can represent it straightforwardly with this type:

module Syntax = struct
  type t =
    | Unit
    | Boolean of bool
    | Number of int
    | Id of string
    | Divide of t * t
    | Sequence of t * t
    | Let of {id: string; value: t; body: t}
    | If of {conditional: t; consequence: t; alternative: t}
end


1. Primitive recursion

Now, let’s say you want to eliminate dead code by creating a compiler pass for the following transformations:

if true  then x else y  ⇒  x
if false then x else y  ⇒  y

And here is how you can do it using primitive recursion:

module Dead_code_elimination = struct
  let rec pass = function
    | Unit | Boolean _ | Number _ | Id _ as t ->
        t
    | Divide (left, right) ->
        Divide (pass left, pass right)
    | Sequence (left, right) ->
        Sequence (pass left, pass right)
    | Let {id; value; body} ->
        Let {id; value=pass value; body=pass body}
| If {conditional=Boolean true; consequence; _} -> pass consequence | If {conditional=Boolean false; alternative; _} -> pass alternative
| If {conditional; consequence; alternative} -> let conditional = pass conditional in let consequence = pass consequence in let alternative = pass alternative in If {conditional; consequence; alternative} end


The highlighted area represents the actual transformation, while the rest is boilerplate that makes sure that the transformation is applied recursively.

2. Factored recursion

This pattern of recursion can be captured by a map function that applies a function f recursively to the data structure:

let map f = function
  | Unit | Boolean _ | Number _ | Id _ as t ->
      t
  | Divide (left, right) ->
      Divide (f left, f right)
  | Sequence (left, right) ->
      Sequence (f left, f right)
  | Let {id; value; body} ->
      Let {id; value=f value; body=f body}
  | If {conditional; consequence; alternative} ->
      let conditional = f conditional in
      let consequence = f consequence in
      let alternative = f alternative in
      If {conditional; consequence; alternative}

Now we can rewrite our compiler pass to focus on the actual transformation and to delegate the recursive descent to map:

module Dead_code_elimination = struct
  let rec pass = function
    | If {conditional=Boolean true; consequence; _} ->
        pass consequence
    | If {conditional=Boolean false; alternative; _} ->
        pass alternative
    | other -> map pass other
end


Now the pass is focused on the transformation. To sum up:

Caveat: our map implementation above is peculiar. Instead of the regular signature:

val map : ('a -> 'b) -> 'a t -> 'b t

It has a monomorphic signature, where 'a, 'b, 'a t, 'b t are the same thing:

val map : (Syntax.t -> Syntax.t) -> Syntax.t -> Syntax.t

Thus, this map will only help us factor out recursion for passes of the form Syntax.t -> Syntax.t. We’ll see what we can do about other kinds of passes further.

3. Recursion for free

So far the deal was that we could write map once and then use it in several passes. However, if your compiler has an intermediate representation with hundreds of different nodes (like many do), it might be tedious to write. Worse, what if you have many intermediate representations?

Fortunately, there is ppx_deriving code generation framework which can generate, among many others, a map implementation for us, similar to Haskell’s deriving Functor.

However, there’s a caveat. Similar to deriving Functor in Haskell, deriving a map implementation using ppx_deriving requires a type with a single type parameter: 'a t. We will need to rewrite our Syntax.t type to use a type parameter instead of it being defined self-referentially. However, we’ll be able to reclaim our monomorphic map in no time. See here:

module Syntax = struct
  module Open = struct
    type 'a t = ❶
      | Unit
      | Boolean of bool
      | Number of int
      | Id of string
      | Divide of 'a * 'a
      | Sequence of 'a * 'a
      | Let of {id: string; value: 'a; body: 'a}
      | If of {conditional: 'a; consequence: 'a; alternative: 'a}
      [@@deriving map] ❷
  end

  type t = t Open.t ❸

  let map: (t -> t) -> t -> t = Open.map ❹
end




First, we called our new polymorphic type 'a Syntax.Open.t ❶. We’ll refer to it as an “open” type. Second, we used the deriving framework to get map for free ❷. We regained our monomorphic type by making a recursive type definition ❸. We’ll refer to this type as a “closed” type.

Unlike in Haskell, there is no need for a fix-point type and for the wrapping and unwrapping that is associated with it. OCaml supports such recursive type definitions using -rectypes compiler flag. The resulting closed type Syntax.t is indistinguishable from our original Syntax.t, for all intents and purposes.

Interestingly enough, -rectype flag is not necessary for tying the recursive knot when used together with polymorphic variants. And there might be more reasons to use polymorphic variants for syntax trees and intermediate representations.

We can also regain our monomorphic map ❹ function, if necessary, by constraining Open.map with a signature.

Now we can write the same short version of the dead code elimination pass without writing the map function ourselves.

4. Monadic passes

So far we were only able to automate recursion for passes of form Syntax.t -> Syntax.t. What about passes that return an option? A result type to signal errors? And how about passes that need to maintain a lexical environment?

Say, we want a pass of form Syntax.t -> (Syntax.t, 'error) result that checks for literal division by zero.

Let’s write it using primitive recursion first. We’ll use the result monad to compose our pass recursively. Normally you would use a library like Base.Result for that, but let’s spell it out here:

module Result = struct
  let return x = Ok x

  let (>>=) = function
    | Ok ok -> fun f -> f ok
    | error -> fun _ -> error
end
And then we implement the pass itself:
module Check_literal_division_by_zero = struct
  let rec pass = function
    | Unit | Boolean _ | Number _ | Id _ as t ->
        return t
| Divide (_, Number 0) -> Error `Literal_division_by_zero
| Divide (left, right) -> pass left >>= fun left -> pass right >>= fun right -> return (Divide (left, right)) | Sequence (left, right) -> pass left >>= fun left -> pass right >>= fun right -> return (Sequence (left, right)) | Let {id; value; body} -> pass value >>= fun value -> pass body >>= fun body -> return (Let {id; value; body}) | If {conditional; consequence; alternative} -> pass conditional >>= fun conditional -> pass consequence >>= fun consequence -> pass alternative >>= fun alternative -> return (If {conditional; consequence; alternative}) end


As previously, the highlighted area shows the code that implements the transformation, and the rest is boilerplate that implements recursion.

We have used a polymorphic variant `Literal_division_by_zero for our error value. To learn why this might be a good idea read about Composable Error Handling in OCaml.

Like we did before, we can factor the boilerplate out and, as a result, we get map_result, a function that maps Syntax.t -> (Syntax.t, 'error) result:

open Result

let rec map_result f = function
  | Unit | Boolean _ | Number _ | Id _ as t ->
      return t
  | Divide (left, right) ->
      f left >>= fun left ->
      f right >>= fun right ->
      return (Divide (left, right))
  | Sequence (left, right) ->
      f left >>= fun left ->
      f right >>= fun right ->
      return (Sequence (left, right))
  | Let {id; value; body} ->
      f value >>= fun value ->
      f body >>= fun body ->
      return (Let {id; value; body})
  | If {conditional; consequence; alternative} ->
      f conditional >>= fun conditional ->
      f consequence >>= fun consequence ->
      f alternative >>= fun alternative ->
      return (If {conditional; consequence; alternative})

We can convince ourselves that it works by rewriting the pass with recursion delegated to map_result:

module Check_literal_division_by_zero = struct
  let rec pass = function
    | Divide (_, Number 0) ->
        Error `Literal_division_by_zero
    | other -> map_result pass other
end


Much better now!

Looking at map_result implementation, we can quickly discover that it has nothing specific to the result type. It only uses return and bind. So, instead, we can make a “map generator” function which is parametrized over return and bind to get a mapper for any monad:

let map_monad ~return ~bind:(>>=) f = function
  | Unit | Boolean _ | Number _ | Id _ as t ->
      return t
  | Divide (left, right) ->
      f left >>= fun left ->
      f right >>= fun right ->
      return (Divide (left, right))
  | Sequence (left, right) ->
      f left >>= fun left ->
      f right >>= fun right ->
      return (Sequence (left, right))
  | Let {id; value; body} ->
      f value >>= fun value ->
      f body >>= fun body ->
      return (Let {id; value; body})
  | If {conditional; consequence; alternative} ->
      f conditional >>= fun conditional ->
      f consequence >>= fun consequence ->
      f alternative >>= fun alternative ->
      return (If {conditional; consequence; alternative})

One can write a ppx_deriving plugin to automatically derive this function, just like map.

What can we do with it? We can use it to factor out recursion from any pass of the form Syntax.t -> Syntax.t m where m is a monad type.

We can generate map_result from result monad to implement our literal division checker:

let map_result =
  map_monad ~return:Result.return ~bind:Result.(>>=)

module Check_literal_division_by_zero = struct
  let rec pass = function
    | Divide (_, Number 0) ->
        Error `Literal_division_by_zero
    | other -> map_result pass other
end


We can generate map_option with option monad to get Syntax.t -> Syntax.t option passes.

We can pass identity monad to get our original map function and implement the dead code elimination pass:

let map =
  map_monad ~return:Identity.return ~bind:Identity.(>>=)

module Dead_code_elimination = struct
  let rec pass = function
    | If {conditional=Boolean true; consequence; _} ->
        pass consequence
    | If {conditional=Boolean false; alternative; _} ->
        pass alternative
    | other -> map pass other
end


5. State monad pass

Some passes need to maintain a symbol table for lexical analysis.

Consider a pass that finds all variables which are undefined according to the rules of lexical scope. When recurring, it needs to pass down the list of variables available in scope, and pass up the list of undefined variables that it found.

The following type can be used for this purpose:

module Environment = struct
  type t = {defined: string list; undefined: string list}

  let initial = {defined=[]; undefined=[]}
end


Since map_monad works for any monad, we could define a state monad for Environment.t and use it to factor out the recursive pattern.

When using a monad library in OCaml we can obtain a state monad from a type with no effort, and get a rich set of functions to work with it:

module Monad = StateMonad (Environment)

However, for illustrative purpose let’s write a state monad for Environment.t manually while keeping in mind that we can get most of the code below for free:

module Monad = struct
  type 'a t = Environment.t -> 'a * Environment.t

  let return a env = a, env

  let (>>=) t callback env =
    let a, env = t env in
    callback a env

  let with_defined name t {defined; undefined} = ❶
    let a, env = t {undefined; defined=name :: defined} in
    a, {env with defined}

  let check_id name env = ❷
    if List.mem name env.defined then
      (), env
    else
      (), {env with undefined=name :: env.undefined}

  let undefined t = (snd (t initial)).undefined ❸
end


We have also written three functions that are specific to our pass. Number ❶ is with_defined which given an identifier adds it to the defined list to pass this information down. Number ❷ is check_id. Given an identifier it checks if the identifier belongs to the defined list, and if it does not—it adds it to the list of undefined variables. Finally ❸ we create a function to extract the list of undefined variables from a monadic value.

Now, we have everything in place to write our pass that checks for undefined variables.

First, we generate map_environment that maps Syntax.t -> Sytax.t Environment.Monad.t:

let map_environment =
  map_monad ~return:Monad.return ~bind:Monad.(>>=)

And now, the pass itself:

module Collect_undefined_variables = struct
  let rec pass = function
    | Let {id; value; body} ->
        pass value >>= fun value ->
        with_defined id (pass body) >>= fun body ->
        return (Let {id; value; body})
    | Id id as t ->
        check_id id >>= fun () ->
        return t
    | other -> map_environment pass other
end


When our pass reaches a let-binding, it uses with_defined to pass down the information about the bound identifier to the body of the binding. If we had support for let rec we would also use with_defined for the value branch.

When we reach an identifier, we check that it is in scope using check_id function, and if it is not, check_id will add it to the undefined list.

We delegate the recursion to map_environment.

Let us test this pass. Consider the following program in our toy language:

x;
let x = () in
let a = () in
a;
let b = a in
let y = y in
(let z = () in ());
a; b; z

The three variables highlighted are used outside of the lexical scope where they are defined.

This program can be represented as the following syntax tree:

let tree =
  let (%) left right = Sequence (left, right) in
  Id "x" %
    Let {id="x"; value=Unit; body=
      Let {id="a"; value=Unit; body=
        Id "a" %
          Let {id="b"; value=Id "a"; body=
            Let {id="y"; value=Id "y"; body=
              Let {id="z"; value=Unit; body=Unit} %
                Id "a" % Id "b" % Id "z"}}}}

We write a test that applies our pass to the tree and extracts the list of undefined variables from the resulting value:

let () =
  let pass = Collect_undefined_variables.pass in
  assert (undefined (pass tree) = ["z"; "y"; "x"])

And confirm that they are z, y, and x. If our syntax tree had location information we could easily collect the precise locations of the undefined variables.

Summary

Code

You can find the code from this article in a GitHub gist along with more tests and examples that you can play with.

Limitations and further steps

The approach of using a map function as a recursion scheme works well when your pass works on a subset of variants and ignores the rest. However, it does not offer anything for the cases when a pass needs to touch every variant, which is common. For these cases there are other recursive schemes.

Not all passes map over the same syntax tree or intermediate representation type. Many useful passes work by converting one representation to a different one. In a future post we’ll explore what can be done for passes of form a -> b m, for some a and b.

Resources

Adventures in Uncertainty is a blog about recursion schemes in Haskell.

Follow me on Twitter