[MLton-user] Stream Fusion

Vesa Karvonen vesa.a.j.k at gmail.com
Sun Apr 29 06:37:07 PDT 2007

The core of the technique for stream fusion presented in the paper

  Stream Fusion: From List to Streams to Nothing at All
  Duncan Coutts, Roman Leshchinskiy, and Don Stewart

translates fairly straightforwardly to SML and yields a stream
implementation with some attractive properties.  The rest of this post
assumes that the reader has read the above paper.

As an example, using a translation of the stream fusion technique to SML,
the code:

val sum =
    foldl op +
          (map Math.sqrt
               (unfoldr (fn i => if i <= n then SOME (i, i+1.0) else NONE) 1.0))

is compiled by MLton to a simple loop (x86 code):

	fldL (0+((16*1)+(0*1)))(%ebp)
	fcomp %st(2)
	fnstsw %ax
	testw $0x500,%ax
	jnz L_3595
	fldL (globalReal64+((0*8)+(0*1)))
	fadd %st(2), %st
	fxch %st(2)
	faddp %st, %st(1)
	jmp L_1049

The basic idea of the technique is to represent streams as unfolds

  data Stream a = forall s. Unlifted s =>
                            Stream !(s -> Step a s)  -- ^ a stepper function
                                   !s                -- ^ an initial state

  data Step a s = Yield a !s
                | Skip    !s
                | Done

where the stepper function is non-recursive.  Recursion is only used in
the implementation of stream consumers (folds).  This roughly means that
(all) stream folds can be compiled to simple, non-nested loops.

One way to translate the above to SML is to use exceptions (open datatype)
to replace the existential type:

   datatype 'a step =
     | GIVE of 'a * Exn.t
     | SKIP of Exn.t

   datatype 'a t = T of Exn.t * (Exn.t -> 'a step)

Perhaps surprisingly, MLton can sometimes completely eliminate the
manipulation of exceptions as can be seen from the earlier example.  See
the end of this message for the implementation (using the extended basis).

Compared to traditional implementations of lazy sequences, such as

  datatype 'a t = T of Unit.t -> ('a * 'a t) Option.t

the stream fusion technique has the advantage that it basically always
results in a single loop, which can often be optimized further, while the
traditional techniques result in multiple nested loops, which tend to be
much more challenging to optimize further.  On my laptop, the earlier
example using the stream fusion implementation runs over twice as fast as
the same example using the above traditional sequence implementation.
Compared to the loop combinators on the http://mlton.org/ForLoops page,
the stream fusion technique can express zip.

AFAICT, a drawback of the technique, that I did not see mentioned in the
paper, is a potential for space leaks.  Consider a stream expression of
the form:

  (map f s1) @ s2

Due to the non-recursive implementation of @ (append) the space taken by f
may not be reclaimed until after the entire stream has been discarded
(modulo exceedingly clever compiler optimizations).  This concern should
also apply to the Haskell version even with the "unstream . streamFn
. stream" wrapping (which, AFAICT, effectively memoizes streams) used to
implement list functions in terms of streams.

-Vesa Karvonen

infixr :::

structure Stream :> sig
   type 'a t

   (** == Eliminating Streams == *)

   val get : 'a t -> ('a * 'a t) Option.t

   val app : 'a Effect.t -> 'a t Effect.t
   val foldl : ('a * 'b -> 'b) -> 'b -> 'a t -> 'b

   (** == Introducing Streams == *)

   val empty : 'a t
   val ::: : 'a * 'a t -> 'a t

   val singleton : 'a -> 'a t

   val @ : 'a t BinOp.t

   val unfoldr : ('s -> ('a * 's) Option.t) -> 's -> 'a t
   val iterate : 'a UnOp.t -> 'a -> 'a t

   (** == Combining Streams == *)

   val zip : 'a t -> 'b t -> ('a * 'b) t

   (** == Manipulating Streams == *)

   val map : ('a -> 'b) -> 'a t -> 'b t
   val filter : 'a UnPr.t -> 'a t UnOp.t
   val concatMap : ('a -> 'b t) -> 'a t  -> 'b t

   (** == Monad Interface == *)

   structure Monad : MONADP where type 'a monad = 'a t

   (** == Conversions == *)

   val fromArray : 'a Array.t -> 'a t
   val fromList : 'a List.t -> 'a t
   val fromString : String.t -> Char.t t
   val fromVector : 'a Vector.t -> 'a t

   val toList : 'a t -> 'a List.t
end = struct
   datatype 'a step =
     | GIVE of 'a * Exn.t
     | SKIP of Exn.t

   datatype 'a t = T of Exn.t * (Exn.t -> 'a step)

   (** == Eliminating Streams == *)

   fun get (T (s, g)) =
       recur s (fn lp =>
          fn s =>
             case g s
              of DONE        => NONE
               | GIVE (a, s) => SOME (a, T (s, g))
               | SKIP s      => lp s)

   fun foldl f b (T (s, g)) =
       recur (s, b) (fn lp =>
          fn (s, b) =>
             case g s
              of DONE        => b
               | GIVE (a, s) => lp (s, f (a, b))
               | SKIP s      => lp (s, b))

   fun app ef = foldl (ef o #1) ()

   (** == Introducing Streams == *)

   val empty = T (Empty,
                  fn Empty => DONE
                   | _     => fail "impossible")

   fun singleton x = let
      exception Some of 'a
      exception None
      T (Some x,
         fn Some x => GIVE (x, None)
          | None   => DONE
          | _      => fail "impossible")

   fun (T (sa, ga)) @ (T (sb, gb)) = let
      exception Fst of Exn.t
      exception Snd of Exn.t
      T (Fst sa,
         fn Fst sa => (case ga sa
                        of DONE         => SKIP (Snd sb)
                         | SKIP sa      => SKIP (Fst sa)
                         | GIVE (a, sa) => GIVE (a, Fst sa))
          | Snd sb => (case gb sb
                        of DONE         => DONE
                         | SKIP sb      => SKIP (Snd sb)
                         | GIVE (b, sb) => GIVE (b, Snd sb))
          | _      => fail "impossible")

   fun x ::: xs = singleton x @ xs

   fun unfoldr f s = let
      exception State of 's
      T (State s,
         fn State s => (case f s
                         of NONE        => DONE
                          | SOME (a, s) => GIVE (a, State s))
          | _       => fail "impossible")

   fun iterate f = unfoldr (fn x => SOME (x, f x)) (* XXX too strict *)

   (** == Combining Streams == *)

   fun zip (T (sa, ga)) (T (sb, gb)) = let
      exception SS of Exn.t * Exn.t
      exception SAS of Exn.t * 'a * Exn.t
      T (SS (sa, sb),
         fn SS (sa, sb)     => (case ga sa
                                 of DONE         => DONE
                                  | SKIP sa      => SKIP (SS (sa, sb))
                                  | GIVE (a, sa) => SKIP (SAS (sa, a, sb)))
          | SAS (sa, a, sb) => (case gb sb
                                 of DONE         => DONE
                                  | SKIP sb      => SKIP (SAS (sa, a, sb))
                                  | GIVE (b, sb) => GIVE ((a, b), SS (sa, sb)))
          | _               => fail "impossible")

   (** == Manipulating Streams == *)

   fun map a2b (T (s, g)) =
       T (s,
          fn s =>
             case g s
              of DONE        => DONE
               | GIVE (a, s) => GIVE (a2b a, s)
               | SKIP s      => SKIP s)

   fun filter p (T (s, g)) =
       T (s,
          fn s =>
             case g s
              of GIVE (a, s) => if p a then GIVE (a, s) else SKIP s
               | otherwise   => otherwise)

   fun concatMap a2bM (T (sa, ga)) = let
      exception SA of Exn.t
      exception SB of Exn.t * (Exn.t -> 'b step) * Exn.t
      T (SA sa,
         fn SA sa           => (case ga sa
                                 of DONE         => DONE
                                  | SKIP sa      => SKIP (SA sa)
                                  | GIVE (a, sa) => case a2bM a of T (sb, gb)
                                                     => SKIP (SB (sb, gb, sa)))
          | SB (sb, gb, sa) => (case gb sb
                                 of DONE         => SKIP (SA sa)
                                  | SKIP sb      => SKIP (SB (sb, gb, sa))
                                  | GIVE (b, sb) => GIVE (b, SB (sb, gb, sa)))
          | _               => fail "impossible")

   (** == Monad Interface == *)

   structure Monad =
         (type 'a monad = 'a t
          val zero = empty
          val return = singleton
          fun aM >>= a2bM = concatMap a2bM aM
          val op <|> = op @)

   (** == Conversions == *)

   fun fromArray a = unfoldr ArraySlice.getItem (ArraySlice.full a)
   fun fromList l = unfoldr List.getItem l
   val fromString = unfoldr Substring.getc o Substring.full
   fun fromVector v = unfoldr VectorSlice.getItem (VectorSlice.full v)

   fun toList t = List.unfoldr get t

(* Example *)

val n = valOf (Real.fromString (hd (CommandLine.arguments ())))

open Stream

fun pr r = print (Real.toString r ^ "\n")

val sum =
    foldl op +
          (map Math.sqrt
               (unfoldr (fn i => if i <= n then SOME (i, i+1.0) else NONE) 1.0))

val () = pr sum

More information about the MLton-user mailing list