Wednesday, April 28, 2010

Lazy evaluation for tree replacement in Haskell, OCaml, and SML

(The Haskell code in this post comes from Julia Lawall's paper "Implementing circularity using partial evaluation".  Note also that this post has been significantly rewritten thanks to Jake Donham, a former colleague at both NYU and CMU, who pointed out I didn't know what I was talking about.)

An interesting application of lazy evaluation is to replace all the elements of an int tree by the smallest element without traversing the tree twice.  The pretty Haskell program that does this is:


rm t = fst p
  where p = repmin t (snd p)
        repmin (Tip n) m = (Tip m, n)
        repmin (Fork l r) m = (Fork t1 t2, min m1 m2)
          where (t1, m1) = repmin l m 
                (t2, m2) = repmin r m


We can see this is correct by looking at an invariant of repmin.  Repmin always returns the smallest element of the tree in the second element of the pair, and a tree isomorphic to the first argument filled in with the second argument.  Laziness allows us to provide the minimal value after we've traversed the tree.  Looking at the tree will then give you the correct result.  

In this case laziness provides an elegant solution which would be more cumbersome in SML or Ocaml.  I wondered how much more cumbersome.  I find it difficult to understand the space usage of my Haskell programs, and despite its incredible elegance and advanced features, I prefer Standard ML when my code has to be fast.   I implemented near identical copies of my theorem prover for propositional intuitionistic logic in Haskell and SML.  The Haskell version uses 10X the memory.  Of course, there is a high probability that this is because I'm not a very good Haskell programmer.  Yet the techniques used by experts to achieve good performance can make beautiful code ugly.  I don't think I am alone in this view.  For instance, the Haskell examples at the Programming Languages Shootout are written by experts such as one of the authors of Real World Haskell.  For instance


The programs are laden with strictness annotations and "bang patterns" which basically force the evaluation at the pattern match site.  This suggests to me that often laziness is not the ideal default mode of use.  Here at CMU we teach that laziness is a mode of use of strictness by implementing laziness using closures and references (see below).  However, the example above is beautiful and compelling.  How ugly does it become in strict languages.  Below are some experiments.  If you can do better, I'd be delighted to see your examples.

OCaml

OCaml has some syntactic support for lazy evaluation.  (Note the use of the special syntax `lazy'.)

open Lazy

type 'a tree = Tip of 'a | Fork of 'a tree * 'a tree

let snd (_, y) = y

let rec tmap f t = match t with
  | Tip a -> Tip (f a)
  | Fork (l, r) -> Fork (tmap f l, tmap f r)

let fmap f x = lazy (f (force x))

(* Without type annotations *)
let fm t = 
  let rec repmin t m = 
    match t with
    | Tip n -> (print_int n; print_newline(); (Tip m, n))
    | Fork(l, r) ->
        let (t1, m1) = repmin l m in
        let (t2, m2) = repmin r m in
          (Fork (t1, t2), min m1 m2) in
  let rec p = lazy (repmin t (fmap snd p)) in
    fst (Lazy.force p)

(* With type annotations *)

let fm : int tree -> int t tree =
 fun t ->
    let rec repmin : int tree -> int t -> int t tree * int =
      fun t m -> match t with
        | Tip n -> (print_int n; print_newline(); (Tip m, n))
        | Fork(l, r) ->
            let (t1, m1) = repmin l m in
            let (t2, m2) = repmin r m in
              (Fork (t1, t2), min m1 m2) in
    let rec p = lazy (repmin t (fmap snd p)) in
      fst (Lazy.force p)

let _ = tmap force (fm (Fork(Tip 5, Tip 1)));;

What I like about this code, especially the version with type annotations, is that it is clear that
the second element of the pair returned is strict.   The first version is closer to the Haskell implementation.  It's not so bad except for the need for a special application (`fmap') and the explicit argument to `p'.  

Standard ML

To implement laziness, we need to do it ourselves in SML.  (The SML/NJ compiler has something similar to OCaml's Lazy module, though since my primary compiler is MLton I decided to write it from scratch.)  Update.  This problem is harder than I imagined in the first post.  It seems difficult to solve this problem in SML without resorting to effects apart from the lazy evaluation effects; i.e. making a tree
whose leaves are a ref cell.  There is a rather easy solution using continuations to rebuild the tree once you know the smallest element, but this also traverses the tree twice.  The following solution doesn't work.  


signature LAZY = 
sig
  type 'a t
  val inject_val : 'a -> 'a t
  val force : 'a t -> 'a
  val make : (unit -> 'a) -> 'a t
  val fmap : ('a -> 'b) -> 'a t -> 'b t
  val fmap2 : ('a * 'b -> 'c) -> 'a t * 'b t -> 'c t
  val concat : 'a t t -> 'a t
  val fix : ('a t -> 'a t) -> 'a t
end 

structure Lazy' :> LAZY =
struct 
  datatype 'a laz = Forced of 'a
                  | Thunk of unit -> 'a
  type 'a t = 'a laz ref
  fun inject_val x = ref (Forced x)
  fun force (ref (Forced x)) = x
    | force (l as (ref (Thunk f))) = 
      let
         val x = f ()
      in 
         l := Forced x
       ; x
      end
  fun make f = ref (Thunk f) 
  fun fmap f x = ref (Thunk (fn () => f (force x)))
  fun fmap2 f (x, y) = ref (Thunk (fn () => f (force x, force y)))              
  fun concat x = ref (Thunk (fn () => force (force x)))
  fun fix f = ref (Thunk (fn () => force (f (fix f))))
end

structure L = Lazy'

datatype 'a Tree = Tip of 'a | Fork of 'a Tree * 'a Tree

fun fst (x, y) = x
fun snd (x, y) = y

val fm : int Tree -> int L.t Tree =
 fn t =>
    let
       val rec repmin : int Tree -> int L.t -> (int L.t Tree * int) =
        fn Tip n => (fn m => (print (Int.toString n ^ "\n"); (Tip m, n)))
         | Fork(l, r) => (fn m =>
           let
              val (t1, m1) = repmin l m
              val (t2, m2) = repmin r m
           in
              (Fork(t1, t2), Int.min (m1, m2))
           end)
       val rec p : unit -> int L.t Tree * int =
        fn () => repmin t (L.fmap snd (L.make p))
    in
       fst (p ())
    end

The solution, again due to Jake, is to make a more clever fixpoint operator that makes sure it only is forced once:

signature LAZY = 
sig
  ...
  val fix' : ('a t -> 'a) -> 'a t
end 


structure Lazy' :> LAZY =
struct 
  ...
  fun fix' f =
      let 
        val t = ref (Thunk (fn () => raise (Fail "")))
      in
        t := Thunk (fn () => f t)
      ; t
      end
end


val fm : int Tree -> int L.t Tree =
 fn t =>
    let
       ...
       val p : (int L.t Tree * int) L.t = L.fix' (fn p => repmin t (L.fmap snd p))
    in
      fst (L.force p)
    end




3 comments:

  1. Anonymous5:10 PM

    Here's another in MetaOCaml http://conway.rutgers.edu/~ccshan/wiki/blog/posts/Circularity/

    ReplyDelete
  2. Hi Sean,

    Your code actually traverses the tree twice (once when lazy_from_fun p is forced, and once when you call p ()). To fix the OCaml version, you could say

    let rec p = lazy (repmin t (fmap snd p)) in
    fst (Lazy.force p)

    I'm not sure how to fix the SML version, since this relies on OCaml's special support for lazy on the RHS of a let rec.

    I agree that a little bit of syntactic overhead (in OCaml) when you want laziness is well worth efficiency and predictability when you don't.

    ReplyDelete