[MLton] recursive generics in SML
Stephen Weeks
MLton@mlton.org
Sun, 23 Oct 2005 23:36:02 -0700
I've sent earlier mails showing how to do overloading in SML via
typecase, where the overloaded operators worked on a flat sum type,
like int + real. This note shows how to extend that approach to
inductively defined type families via a safe (I think) typerec
function. For example, suppose one wanted to represent the type
family
'b + 'b list + 'b list list + 'b list list list + ...
Then, this approach represents the infinite sum via the type
('a, 'b) t
For a particular value of type "(u, b) t", u is an index indicating
which summand the value is in and b indicates the base type. The
summand index is represented by a family of types
type u1
type 'a u2
For example, a "string list list" is represented by the type
(u1 u2 u2, string) t
As with the earlier non-recursive sum types, there are injections and
projections from the sum.
val from1: 'b -> (u1, 'b) t
val from2: ('a, 'b) t list -> ('a u2, 'b) t
val to1: (u1, 'b) t -> 'b
val to2: ('a u2, 'b) t -> ('a, 'b) t list
The fun part is the typeRec function, which allows one to define an
overloaded function that works over the entire sum type, via recursion
over the types. For example, one can define a generic print
function as follows.
val print: ('a, string) t -> unit =
fn x =>
typeRec`x $ (TextIO.print, fn (l, print) => List.foreach (l, print))
There is no direct recursion here. After being supplied as many
values as desired, typeRec takes two arguments. The first, in this
case TextIO.print, gives the behavior at the base type. The second
gives the behavior at higher type, and can use the function provided by
typeRec for recursive calls at the next lower type. The definition of
typeRec uses similar Fold01N technology as one of the earlier typeCase
examples.
Although this looks like polymorphic recursion on the outside, SML
doesn't have polymorphic recursion, so underneath there is a universal
type (hence we don't get the nice code duplication as with the flat
sum type). But, that is completely hidden under the interface. What
one sees from the outside is the infinite family of types, along with
a way of defining certain (apparently) polymorphically recursive
functions.
The code below also includes a couple of other examples from the
extensional polymorphism/GCaml papers: generic flatten and generic
equals. I think the other examples that recur on arrow types are
pretty easy to emulate as well.
There is one weakness of typeRec as given here; the same result type
must be returned for each member of the family. I think this
restriction is present in GCaml as well. It rules out some functions
that could be done with a more powerful typeRec (such as in TIL), for
example, reversing all the lists at every level within a value of type
('a, 'b) t. The essential missing ingredient is the ability to
construct a value whose type depends (in the same way at every level)
on the level of the value currently being processed. I thought that
perhaps the ideas about witnesses to type equivalences that I sent
earlier might help, but I couldn't figure out how to make that work
with the recursive helper provided by typeRec.
Here's the code.
----------------------------------------------------------------------
structure List =
struct
fun fold (l, b, f) = List.foldl f b l
fun foreach (l, f) = List.app f l
fun map (l, f) = List.map f l
end
datatype ('a, 'b) product = & of 'a * 'b
infix 4 &
fun $ (a, f) = f a
fun const c _ = c
fun curry h x y = h (x, y)
fun id x = x
fun ignore _ = ()
fun pass x f = f x
structure Fold =
struct
type ('a, 'b, 'c, 'd) step = 'a * ('b -> 'c) -> 'd
type ('a, 'b, 'c, 'd) t = ('a, 'b, 'c, 'd) step -> 'd
type ('a1, 'a2, 'b, 'c, 'd) step0 =
('a1, 'b, 'c, ('a2, 'b, 'c, 'd) t) step
type ('a1, 'a2, 'a3, 'b, 'c, 'd) step1 =
('a2, 'b, 'c, 'a1 -> ('a3, 'b, 'c, 'd) t) step
val fold = pass
fun step0 h (a1, f) = fold (h a1, f)
fun step1 h $ x = step0 (curry h x) $
end
structure Fold01N =
struct
type ('a, 'b, 'c, 'd, 'e, 'f) t =
((unit -> unit) * ('a -> 'b), (unit -> 'c) * 'e, 'd, 'f) Fold.t
type ('a, 'b, 'c, 'z1, 'z2, 'z3, 'z4, 'z5) step1 =
('z1,
'z2 * ('z1 -> 'a),
(unit -> 'a) * ('b -> 'c),
'z3, 'z4, 'z5) Fold.step1
val fold: ('a -> 'b) * ('c -> 'd) -> ('a, 'b, 'c, 'd, 'e, 'f) t =
fn (one, finish) =>
Fold.fold ((ignore, one), fn (p, _) => finish (p ()))
val step1
: ('a * 'b -> 'c) -> ('a, 'b, 'c, 'z1, 'z2, 'z3, 'z4, 'z5) step1 =
fn combine =>
Fold.step1 (fn (x, (_, f)) =>
(fn () => f x, fn x' => combine (f x, x')))
end
signature TYPE_REC =
sig
type u1
type 'a u2
type 'a tr
type z
type ('a, 'b) t
val from1: 'b -> (u1, 'b) t
val from2: ('a, 'b) t tr -> ('a u2, 'b) t
val to1: (u1, 'b) t -> 'b
val to2: ('a u2, 'b) t -> ('a, 'b) t tr
type ('a, 'b, 'c, 'd, 'e1, 'e2) split
val typeRec:
(('a, 'b) t,
('a, 'b) t
* ('a, 'b, (z, 'b) t, ('a, 'b) t, 'b, (z, 'b) t tr) split,
'd * ('a, 'b, 'c, 'd, 'e1, 'e2) split,
(('e1 -> 'f) * ('e2 * ('c -> 'f) -> 'f)) -> 'f,
'z1, 'z2) Fold01N.t
val ` :
('d * ('a, 'b, 'c, 'd, 'e1, 'e2) split,
('a, 'b) t,
('d, ('a, 'b) t) product
* ('a, 'b,
('c, (z, 'b) t) product,
('d, ('a, 'b) t) product,
('e1, 'b) product,
('e2, (z, 'b) t tr) product) split,
'z1, 'z2, 'z3, 'z4, 'z5) Fold01N.step1
end
functor TypeRec (type 'a t):> TYPE_REC where type 'a tr = 'a t =
struct
type 'a tr = 'a t
datatype ('b1, 'b2) s =
X1 of 'b1
| X2 of 'b2
datatype 'b t = T of ('b, 'b t tr) s
fun from1 b = T (X1 b)
fun from2 b = T (X2 b)
fun bug () = raise Fail "bug"
val to1 = fn T (X1 b) => b | _ => bug ()
val to2 = fn T (X2 b) => b | _ => bug ()
type ('a, 'b, 'c, 'd, 'e1, 'e2) split =
('c -> 'd) * ('d -> ('e1, 'e2) s)
fun typeRec $ =
Fold01N.fold
(fn x => (x, (id, fn T x => x)),
fn (p, (cast, split)) => fn (f1, f2) =>
let
fun loop p =
case split p of
X1 p => f1 p
| X2 p => f2 (p, loop o cast)
in
loop p
end)
$
fun ` $ =
Fold01N.step1
(fn ((p, (cast, split)), x) =>
(p & x,
(fn p & x => cast p & x,
fn p & T x =>
case (split p, x) of
(X1 p, X1 x) => X1 (p & x)
| (X2 p, X2 x) => X2 (p & x)
| _ => bug ())))
$
type ('a, 'b) t = 'b t
type u1 = unit
type 'a u2 = unit
type z = unit
end
structure Test =
struct
structure ListRec = TypeRec (type 'a t = 'a list)
open ListRec
val flatOnto: ('a, 'b) t -> 'b list -> 'b list =
fn x =>
typeRec`x $
(fn b => fn ac => b :: ac,
fn (l, flatOnto) => fn ac =>
List.fold (rev l, ac, fn (x, ac) => flatOnto x ac))
val flat: ('a, 'b) t -> 'b list = fn x => flatOnto x []
val equals: ('a, 'b) t * ('a, 'b) t * ('b * 'b -> bool) -> bool =
fn (x1, x2, equals) =>
typeRec`x1`x2 $
(fn y1 & y2 => equals (y1, y2),
fn (l1 & l2, equals) =>
let
val rec loop =
fn ([], []) => true
| (x1 :: l1, x2 :: l2) =>
equals (x1 & x2) andalso loop (l1, l2)
| _ => false
in
loop (l1, l2)
end)
val print: ('a, string) t -> unit =
fn x =>
typeRec`x $
(TextIO.print, fn (l, print) => List.foreach (l, print))
val S = from1
val L = from2
val x1 = S "s\n"
val x2 = L [S "hello, ", S "world\n"]
val x3 = L [L [], L [S "hello, "], L [S "world\n"]]
val x4 = L [L [], L [S "hello, "], L [S "world"]]
val () = (print x1; print x2; print x3)
val () = List.foreach (flat x3, TextIO.print)
fun test (x1, x2) =
TextIO.print (concat [Bool.toString (equals (x1, x2, op =)), "\n"])
val () = test (x1, x1)
val () = test (x2, x2)
val () = test (x3, x3)
val () = test (x3, x4)
end