[MLton-user] Stream Fusion
Vesa Karvonen
vesa.a.j.k at gmail.com
Sun Apr 29 06:37:07 PDT 2007
The core of the technique for stream fusion presented in the paper
Stream Fusion: From List to Streams to Nothing at All
Duncan Coutts, Roman Leshchinskiy, and Don Stewart
http://mlton.org/References#CouttsEtAl07
translates fairly straightforwardly to SML and yields a stream
implementation with some attractive properties. The rest of this post
assumes that the reader has read the above paper.
As an example, using a translation of the stream fusion technique to SML,
the code:
val sum =
foldl op +
0.0
(map Math.sqrt
(unfoldr (fn i => if i <= n then SOME (i, i+1.0) else NONE) 1.0))
is compiled by MLton to a simple loop (x86 code):
L_1049:
fldL (0+((16*1)+(0*1)))(%ebp)
fcomp %st(2)
fnstsw %ax
testw $0x500,%ax
jnz L_3595
L_1933:
fldL (globalReal64+((0*8)+(0*1)))
fadd %st(2), %st
fxch %st(2)
fsqrt
faddp %st, %st(1)
jmp L_1049
The basic idea of the technique is to represent streams as unfolds
data Stream a = forall s. Unlifted s =>
Stream !(s -> Step a s) -- ^ a stepper function
!s -- ^ an initial state
data Step a s = Yield a !s
| Skip !s
| Done
where the stepper function is non-recursive. Recursion is only used in
the implementation of stream consumers (folds). This roughly means that
(all) stream folds can be compiled to simple, non-nested loops.
One way to translate the above to SML is to use exceptions (open datatype)
to replace the existential type:
datatype 'a step =
DONE
| GIVE of 'a * Exn.t
| SKIP of Exn.t
datatype 'a t = T of Exn.t * (Exn.t -> 'a step)
Perhaps surprisingly, MLton can sometimes completely eliminate the
manipulation of exceptions as can be seen from the earlier example. See
the end of this message for the implementation (using the extended basis).
Compared to traditional implementations of lazy sequences, such as
datatype 'a t = T of Unit.t -> ('a * 'a t) Option.t
the stream fusion technique has the advantage that it basically always
results in a single loop, which can often be optimized further, while the
traditional techniques result in multiple nested loops, which tend to be
much more challenging to optimize further. On my laptop, the earlier
example using the stream fusion implementation runs over twice as fast as
the same example using the above traditional sequence implementation.
Compared to the loop combinators on the http://mlton.org/ForLoops page,
the stream fusion technique can express zip.
AFAICT, a drawback of the technique, that I did not see mentioned in the
paper, is a potential for space leaks. Consider a stream expression of
the form:
(map f s1) @ s2
Due to the non-recursive implementation of @ (append) the space taken by f
may not be reclaimed until after the entire stream has been discarded
(modulo exceedingly clever compiler optimizations). This concern should
also apply to the Haskell version even with the "unstream . streamFn
. stream" wrapping (which, AFAICT, effectively memoizes streams) used to
implement list functions in terms of streams.
-Vesa Karvonen
infixr :::
structure Stream :> sig
type 'a t
(** == Eliminating Streams == *)
val get : 'a t -> ('a * 'a t) Option.t
val app : 'a Effect.t -> 'a t Effect.t
val foldl : ('a * 'b -> 'b) -> 'b -> 'a t -> 'b
(** == Introducing Streams == *)
val empty : 'a t
val ::: : 'a * 'a t -> 'a t
val singleton : 'a -> 'a t
val @ : 'a t BinOp.t
val unfoldr : ('s -> ('a * 's) Option.t) -> 's -> 'a t
val iterate : 'a UnOp.t -> 'a -> 'a t
(** == Combining Streams == *)
val zip : 'a t -> 'b t -> ('a * 'b) t
(** == Manipulating Streams == *)
val map : ('a -> 'b) -> 'a t -> 'b t
val filter : 'a UnPr.t -> 'a t UnOp.t
val concatMap : ('a -> 'b t) -> 'a t -> 'b t
(** == Monad Interface == *)
structure Monad : MONADP where type 'a monad = 'a t
(** == Conversions == *)
val fromArray : 'a Array.t -> 'a t
val fromList : 'a List.t -> 'a t
val fromString : String.t -> Char.t t
val fromVector : 'a Vector.t -> 'a t
val toList : 'a t -> 'a List.t
end = struct
datatype 'a step =
DONE
| GIVE of 'a * Exn.t
| SKIP of Exn.t
datatype 'a t = T of Exn.t * (Exn.t -> 'a step)
(** == Eliminating Streams == *)
fun get (T (s, g)) =
recur s (fn lp =>
fn s =>
case g s
of DONE => NONE
| GIVE (a, s) => SOME (a, T (s, g))
| SKIP s => lp s)
fun foldl f b (T (s, g)) =
recur (s, b) (fn lp =>
fn (s, b) =>
case g s
of DONE => b
| GIVE (a, s) => lp (s, f (a, b))
| SKIP s => lp (s, b))
fun app ef = foldl (ef o #1) ()
(** == Introducing Streams == *)
val empty = T (Empty,
fn Empty => DONE
| _ => fail "impossible")
fun singleton x = let
exception Some of 'a
exception None
in
T (Some x,
fn Some x => GIVE (x, None)
| None => DONE
| _ => fail "impossible")
end
fun (T (sa, ga)) @ (T (sb, gb)) = let
exception Fst of Exn.t
exception Snd of Exn.t
in
T (Fst sa,
fn Fst sa => (case ga sa
of DONE => SKIP (Snd sb)
| SKIP sa => SKIP (Fst sa)
| GIVE (a, sa) => GIVE (a, Fst sa))
| Snd sb => (case gb sb
of DONE => DONE
| SKIP sb => SKIP (Snd sb)
| GIVE (b, sb) => GIVE (b, Snd sb))
| _ => fail "impossible")
end
fun x ::: xs = singleton x @ xs
fun unfoldr f s = let
exception State of 's
in
T (State s,
fn State s => (case f s
of NONE => DONE
| SOME (a, s) => GIVE (a, State s))
| _ => fail "impossible")
end
fun iterate f = unfoldr (fn x => SOME (x, f x)) (* XXX too strict *)
(** == Combining Streams == *)
fun zip (T (sa, ga)) (T (sb, gb)) = let
exception SS of Exn.t * Exn.t
exception SAS of Exn.t * 'a * Exn.t
in
T (SS (sa, sb),
fn SS (sa, sb) => (case ga sa
of DONE => DONE
| SKIP sa => SKIP (SS (sa, sb))
| GIVE (a, sa) => SKIP (SAS (sa, a, sb)))
| SAS (sa, a, sb) => (case gb sb
of DONE => DONE
| SKIP sb => SKIP (SAS (sa, a, sb))
| GIVE (b, sb) => GIVE ((a, b), SS (sa, sb)))
| _ => fail "impossible")
end
(** == Manipulating Streams == *)
fun map a2b (T (s, g)) =
T (s,
fn s =>
case g s
of DONE => DONE
| GIVE (a, s) => GIVE (a2b a, s)
| SKIP s => SKIP s)
fun filter p (T (s, g)) =
T (s,
fn s =>
case g s
of GIVE (a, s) => if p a then GIVE (a, s) else SKIP s
| otherwise => otherwise)
fun concatMap a2bM (T (sa, ga)) = let
exception SA of Exn.t
exception SB of Exn.t * (Exn.t -> 'b step) * Exn.t
in
T (SA sa,
fn SA sa => (case ga sa
of DONE => DONE
| SKIP sa => SKIP (SA sa)
| GIVE (a, sa) => case a2bM a of T (sb, gb)
=> SKIP (SB (sb, gb, sa)))
| SB (sb, gb, sa) => (case gb sb
of DONE => SKIP (SA sa)
| SKIP sb => SKIP (SB (sb, gb, sa))
| GIVE (b, sb) => GIVE (b, SB (sb, gb, sa)))
| _ => fail "impossible")
end
(** == Monad Interface == *)
structure Monad =
MkMonadP
(type 'a monad = 'a t
val zero = empty
val return = singleton
fun aM >>= a2bM = concatMap a2bM aM
val op <|> = op @)
(** == Conversions == *)
fun fromArray a = unfoldr ArraySlice.getItem (ArraySlice.full a)
fun fromList l = unfoldr List.getItem l
val fromString = unfoldr Substring.getc o Substring.full
fun fromVector v = unfoldr VectorSlice.getItem (VectorSlice.full v)
fun toList t = List.unfoldr get t
end
(* Example *)
val n = valOf (Real.fromString (hd (CommandLine.arguments ())))
open Stream
fun pr r = print (Real.toString r ^ "\n")
val sum =
foldl op +
0.0
(map Math.sqrt
(unfoldr (fn i => if i <= n then SOME (i, i+1.0) else NONE) 1.0))
val () = pr sum
More information about the MLton-user
mailing list