[MLton] HOL's Moscow ML implementation, and pushing MLton to emulate it

Stephen Weeks MLton@mlton.org
Mon, 11 Apr 2005 14:41:44 -0700


> Attached is SML file that sketchs how to do the interoperation
> expressed as a source to source translation of code to be linked
> into the interpter.
>
> It deals with recursive polymorphic structures.  Each coercion
> between the generic and hosted representation is constant time.

Thanks for the code Dan.  It made sense.  Here's my take on what it
does, and does not, do, followed by an attempt to improve on it.

Your code uses a universal datatype in the hosted world, while
allowing native representations in the host world.  It allows
coercions of values from the host world to the hosted world.  It
achieves constant-time coercions by introducing a variant in the
universal datatype for delayed evaluation.  It uses recursion schemes
to share code between the host and the hosted world, by allowing the
same code to deal with different representations of a (possibly
recursive) type.

Now, for some drawbacks.

Abstracting code to deal with different representations of lists has
the effect of adding an additional amount of polymorphism to all of
the code, whether it was originally polymorphic or monomorphic.  For
example, consider the Lst.sum code.  Initially, its type is

    val sum: int list -> int

After translation, it becomes Lst.sumF, and its type become
polymorphic

    val sumF: ('a -> (int, 'a) listF) -> (int, 'a) listF -> int

Of course, this polymorphism is essential to be able deal with the
different representation of host lists and hosted lists.  But its
cost, in the context of MLton, is that all the host code that deals
with the interpreter, including monomorphic code, is duplicated.  In
MLton, we're used to duplicating polymorphic code.  Duplicating
monomorphic code is a new thing, and it would be nice if it could be
avoided (or at least we had a choice so we could make a compile-time
tradeoff).

Your approach is very similar to copying the entire program and
feeding it as source to the interpreter, and then running parallel
dual worlds -- there really isn't much communication between the
worlds.  There is an advantage of your approach in that the source of
the exported functions is compiled by MLton, but they're still
dealing exclusively with the hosted universal datatype.

One thing I would like to see in a solution is that if I define a list
and a sum function in the host world

  val xs: int list
  val sum: int list -> int

and export both of these to the hosted world, and then in the hosted
world call "sum xs", it should effectively call "sum xs" in the host
world.  This will give one the freedom to experiment in the hosted
world while getting a lot of the speed and representation benefits of
the host world.  Your solution doesn't achieve this (not that I'm
claiming that you could know that I wanted it :-).  Your solution will
(lazily) coerce the int list to the universal type, and apply the
duplicate copy of sum for the hosted world, which deals only with the
universal type.

The only real communication between the worlds is the delaying trick
that allows constant-time coercion of host values to hosted values.
Unfortunately, there is *no* communication the other direction.  The
delaying trick doesn't work, because there are no hooks in the host
types to introduce delaying, and I want to avoid changing host
representations because that would adversely affect host performance.
Unfortunately, communication from hosted world to host world is
essential.  Suppose we had a variant of Lst.sum that saved its
argument in a ref cell.

  val r: int list ref = ref []
  val sumSave: int list -> int = fn l => (r := l; sum l)

Duplicating the ref cell is not an option (it is semantically
incorrect), and abstracting sumSave over the representation of lists
doesn't help.  When sumSave is called in the hosted world, the only
choice is to completely coerce a hosted list to a host list.  If the
hosted world represents int lists by embedding them piece-by-piece
into a universal datatype, then this will obviously not be a
constant-time operation.  Things are even worse if the data structure
is more complicated than a list and contains mutable components of its
own.

My conclusion is that for a complete and efficient solution, one must
embed host values directly, not piece-by-piece, in the hosted
universal datatype.  Then, coercions both ways are trivially constant
time; either add a tag or remove it.

Here's an idea for how to do it.  For each host program, build a
universal datatype specialized to that program.  For every monomorphic
type "t" that occurs in the host program, add a variant "T of t" to
the universal type.  For every polymorphic type that occurs, add a
variant to the universal datatype with all type variables replaced by
the universal datatype constructor.

With this approach, monomorphic values are trivially embeddable in the
hosted interpreter, and can be passed both ways between worlds in
constant time.  There is no requirement to duplicate monomorphic code.
Furthermore, there is no problem with mutable datastructures, since
they use a consistent representation in both worlds.  Another nice
benefit is that when a host function is applied to a host value via
the interpreter (as in my "sum xs" example above), the application
runs at full host speed once the interpreter has dispatched the call.

To embed polymorphic values, we can use type passing to support the
use of the (non-uniform) host representations.  There is a range of
possibilities that trade off speed for code duplication.  For example,
consider List.length.  We could create a single generic version that
works for any list type.  Or, we could create a version that
dispatches to a specializations of List.length at various types, which
gains speed but duplicates code.  It only makes sense to create
specialized versions for types that we have variants for in the
universal datatype.  But we do have a choice as to which of those to
create specializations.  A simple heuristic is to include specialized
functions if they already appear in the program; otherwise use the
generic version.

Below is some code that I hope clarifies what I mean.  It's based on
an imaginary program that uses lists of integers and reals, and has a
length function.  I didn't go whole hog in adding variants for every
type, but just included enough to show how to embed the different
types of lists and to do type passing.

I'm not certain that this works for all of SML, but I am hopeful.  I
think we may have been able to go beyond where Nick went in his paper
because of the use of recursion schemes, type passing, and the use of
one variant per monotype in the universal datatype.

I'm sure my explanation here is lacking, so let me know what needs
clarification.

--------------------------------------------------------------------------------

structure Type =
   struct
      (* Representation of types.  Used for type passing. *)
      datatype t =
	 Arrow of t * t
       | Int
       | List of t
       | Real
       | Tuple of t list
       | Univ
   end
      
structure Value =
   struct
      (* The universal type used in the interpreter.
       * Specialized with a variant for any monomorphic type we would like
       * to use the host representation.
       * The "Poly" variant is used to embed polymorphic functions, which
       * use type passing to support non-uniform representations.
       *)
      datatype t =
	 Arrow of t -> t
       | Int of int
       | IntList of int list
       | Pair of t * t
       | Poly of Type.t -> t
       | Real of real
       | RealList of real list
       | UnivList of t list

      (* Used to implement function application inside the interpreter. *)
      val apply: t * t -> t =
	 fn (Arrow f, a) => f a
	  | _ => raise Fail "apply"

      (* Embedding into and extracting values from the universal type. *)
      type u = t
      structure Embed:
	 sig
	    type 'a t

	    val arrow: 'a t * 'b t -> ('a -> 'b) t
	    val int: int t
	    val intList: int list t
	    val pair: 'a t * 'b t -> ('a * 'b) t
	    val real: real t
	    val realList: real list t
	    val unwrap: 'a t * u -> 'a
	    val univ: u t
	    val univList: u list t
	    val wrap: 'a t * 'a -> u
	 end =
	 struct
	       
	    datatype 'a t = T of {unwrap: u -> 'a,
				  wrap: 'a -> u}

	    val unwrap = fn (T {unwrap, ...}, u) => unwrap u

	    val wrap = fn (T {wrap, ...}, a) => wrap a

	    val univ =
	       T {unwrap = fn u => u,
		  wrap = fn u => u}

	    val int =
	       T {unwrap = fn Int x => x | _ => raise Fail "unwrapInt",
		  wrap = Int}

	    val real = 
	       T {unwrap = fn Real x => x | _ => raise Fail "unwrapReal",
		  wrap = Real}

	    val intList =
	       T {unwrap = fn IntList x => x | _ => raise Fail "unwrapIntList",
		  wrap = IntList}

	    val realList =
	       T {unwrap = fn RealList x => x | _ => raise Fail "unwrapRealList",
		  wrap = RealList}

	    val univList =
	       T {unwrap = fn UnivList x => x | _ => raise Fail "unwrapUnivList",
		  wrap = UnivList}

	    val arrow =
	       fn (T {unwrap = unwrapA, wrap = wrapA},
		   T {unwrap = unwrapB, wrap = wrapB}) =>
	       let
		  val unwrap =
		     fn Arrow f => unwrapB o f o wrapA
		      | _ => raise Fail "unwrap arrow"
		  fun wrap f = Arrow (wrapB o f o unwrapA)
	       in
		  T {unwrap = unwrap, wrap = wrap}
	       end

	    val pair =
	       fn (T {unwrap = unwrapA, wrap = wrapA},
		   T {unwrap = unwrapB, wrap = wrapB}) =>
	       let
		  val unwrap =
		     fn Pair (a, b) => (unwrapA a, unwrapB b)
		      | _ => raise Fail "unwrap pair"
		  fun wrap (a, b) = Pair (wrapA a, wrapB b)
	       in
		  T {unwrap = unwrap, wrap = wrap}
	       end
	 end
 
      (* Polymorphic nil, cons, and length functions, implemented using type
       * passing to specialize for non-uniform representations.
       * These are exposed to the interpreter by wrapping them with the
       * "Poly" constructor.
       *)
      val nill: Type.t -> t =
	 fn Type.Int => IntList []
	  | Type.Real => RealList []
	  | _ => UnivList []

      fun cons (t: Type.t): t =
	 let
	    open Embed
	    fun doit (u, ul, cons) = wrap (arrow (pair (u, ul), ul), cons)
	 in
	    case t of
	       Type.Int => doit (int, intList, op ::)
	     | Type.Real => doit (real, realList, op ::)
	     | _ => doit (univ, univList, op ::)
	 end
      
      val specializedLength: Type.t -> t =
	 let
	    open Embed
	    fun doit ul = wrap (arrow (ul, int), List.length)
	 in
	    fn Type.Int => doit intList
	     | Type.Real => doit realList
	     | _ => doit univList
	 end

      (* An abstraction to support arbitrary list representations. *)
      structure Cons =
	 struct
	    datatype ('a, 'b) t =
	       Nil
	     | Cons of 'a * 'b
	 end

      (* Generic destructor function for lists embedded in the universal type.
       * We can use this to implement pattern matching on conses in the
       * interpreter.
       *)
      val dest: Type.t -> t -> (t, t) Cons.t =
	 fn t =>
	 let
	    open Embed
	    fun doit (elt, list) u =
	       case unwrap (list, u) of
		  [] => Cons.Nil
		| x :: xs => Cons.Cons (wrap (elt, x), wrap (list, xs))
	 in
	    case t of
	       Type.Int => doit (int, intList)
	     | Type.Real => doit (real, realList)
	     | _ => doit (univ, univList)
	 end

      (* Generic length function based on the generic destructor function.
       * Unlike specializedLength, this requires no code duplication, but at
       * the cost of slower operation.
       *)
      val genericLength: Type.t -> t =
	 fn t =>
	 let
	    val dest = dest t
	 in
	    Arrow
	    (fn l =>
	     let
		fun loop l =
		   case dest l of
		      Cons.Nil => 0
		    | Cons.Cons (_, xs) => 1 + loop xs
	     in
		Int (loop l)
	     end)
	 end

      (* A mix of the generic and specialized length functions. *)
      val mixedLength: Type.t -> t =
	 fn t =>
	 case t of
	    Type.Int => specializedLength t
	  | _ => genericLength t
   end