[MLton] improved overloading for SML
Stephen Weeks
sweeks@sweeks.com
Sat, 22 Oct 2005 14:16:14 -0700
Here's a slightly different approach to doing typeCase. This approach
doesn't require any fold technology. Rather, it uses a generally
applicable trick for exposing dataflow information down the branch of
a conditional via a "witness" value. Here's a snippet of the
signature.
structure Equiv:
sig
type ('a, 'b) t
end
type u1
type u2
type ('a, 'b1, 'b2) t (* 'a is u1 or u2 *)
val cast: ('a, 'b1, 'b2) t * ('a, 'c) Equiv.t -> ('c, 'b1, 'b2) t
val to1: (u1, 'b1, 'b2) t -> 'b1
val typeCase:
('a, 'b1, 'b2) t
* (('a, u1) Equiv.t -> 'c)
* (('a, u2) Equiv.t -> 'c)
-> 'c
Here, "('a, 'b1, 'b2) t" is either a value of type 'b1 or a value of
type 'b2. The "'a" specifies which type the value is. If the 'a is
u1 (a type constant) then the value is of type 'b1 and if 'a is u2
then the value is of type 'b2. In "typeCase (v, f1, f2)", if v is of
type 'b1, we proceed down the f1 branch, supplying f1 with a witness
(of type ('a, u1) Equiv.t) that 'a is the same as u1. Similarly for
the other branch. Within the first branch, one can use the witness w
to extract the value of type 'b1 with "to1 (cast (v, w))". Underneath
the signature cast is just the identity function (which is easy to
prove), but outside, the witnesses force us to always do type-safe
cast.
The complete code is below. I've also shown the changes require to
implement Num using this new approach.
One reason I like this approach is the very simple reasoning needed to
convince oneself that Fail is never raised.
BTW, all the Equiv stuff could be completely phantom, except that I
want to force MLton to duplicate code.
--------------------------------------------------------------------------------
fun const c _ = c
fun id x = x
signature TYPE_CASE2 =
sig
structure Equiv:
sig
type ('a, 'b) t
val reflexive: ('a, 'a) t
val symmetric: ('a, 'b) t -> ('b, 'a) t
val transitive: ('a, 'b) t * ('b, 'c) t -> ('a, 'c) t
end
type u1
type u2
type ('a, 'b1, 'b2) t (* 'a is u1 or u2 *)
val cast: ('a, 'b1, 'b2) t * ('a, 'c) Equiv.t -> ('c, 'b1, 'b2) t
val from1: 'b1 -> (u1, 'b1, 'b2) t
val from2: 'b2 -> (u2, 'b1, 'b2) t
val to1: (u1, 'b1, 'b2) t -> 'b1
val to2: (u2, 'b1, 'b2) t -> 'b2
val typeCase:
('a, 'b1, 'b2) t
* (('a, u1) Equiv.t -> 'c)
* (('a, u2) Equiv.t -> 'c)
-> 'c
end
structure TypeCase2:> TYPE_CASE2 =
struct
datatype u1 = U1
datatype u2 = U2
datatype ('b1, 'b2) x = X1 of 'b1 | X2 of 'b2
datatype ('a, 'b1, 'b2) t = T of 'a * ('b1, 'b2) x
structure Equiv =
struct
datatype ('a, 'b) t = T of ('a -> 'b) * ('b -> 'a)
val reflexive = T (id, id)
fun symmetric (T (f, g)) = T (g, f)
fun transitive (T (f, g), T (f', g')) = T (f' o f, g o g')
end
fun cast (T (ty, x), Equiv.T (f, _)) = T (f ty, x)
fun from1 x1 = T (U1, X1 x1)
fun from2 x2 = T (U2, X2 x2)
fun bug () = raise Fail "bug"
val to1 = fn T (_, X1 x1) => x1 | _ => bug ()
val to2 = fn T (_, X2 x2) => x2 | _ => bug ()
fun typeCase (T (u, x), f1, f2) =
let
fun one (f, t) = f (Equiv.T (const t, const u))
in
case x of
X1 _ => one (f1, U1)
| X2 _ => one (f2, U2)
end
end
signature NUM =
sig
type 'a t = ('a, int, real) TypeCase2.t
type i = TypeCase2.u1 t
type r = TypeCase2.u2 t
val < : 'a t * 'a t -> bool
val <= : 'a t * 'a t -> bool
val > : 'a t * 'a t -> bool
val >= : 'a t * 'a t -> bool
val ~ : 'a t -> 'a t
val + : 'a t * 'a t -> 'a t
val - : 'a t * 'a t -> 'a t
val * : 'a t * 'a t -> 'a t
val / : 'a t * 'b t -> r
val abs: 'a t -> 'a t
val div: i * i -> i
val e: r
val fromInt: int -> i
val fromReal: real -> r
val max: 'a t * 'a t -> 'a t
val min: 'a t * 'a t -> 'a t
val mod: i * i -> i
val pi: r
val real: 'a t -> r
val round: 'a t -> i
val sqrt: 'a t -> r
val toInt: i -> int
val toReal: r -> real
val toString: 'a t -> string
end
structure Num:> NUM =
struct
open TypeCase2
type 'a t = ('a, int, real) t
type i = u1 t
type r = u2 t
val fromInt: int -> i = from1
val fromReal: real -> r = from2
val toInt: i -> int = to1
val toReal: r -> real = to2
val e = fromReal Real.Math.e
val pi = fromReal Real.Math.pi
fun unary (fi, fr) n =
let
fun one (f, to) e n = f (to (cast (n, e)))
in
typeCase (n, one (fi, toInt), one (fr, toReal)) n
end
val toString = fn $ => unary (Int.toString, Real.toString) $
fun real n = fromReal (unary (Real.fromInt, fn r => r) n)
fun sqrt n = fromReal (Real.Math.sqrt (toReal (real n)))
val round = fn $ => unary (fromInt, fromInt o Real.round) $
local
fun make (fi, fr) n =
let
fun one (f, from, to) e =
cast (from (f (to (cast (n, e)))), Equiv.symmetric e)
in
typeCase (n,
one (fi, fromInt, toInt),
one (fr, fromReal, toReal))
end
in
val abs = fn $ => make (Int.abs, Real.abs) $
val ~ = fn $ => make (Int.~, Real.~) $
end
local
fun make (fi, fr) (n1, n2) =
let
fun one (f, from, to) e =
f (to (cast (n1, e)), to (cast (n2, e)))
in
typeCase (n1,
one (fi, fromInt, toInt),
one (fr, fromReal, toReal))
end
in
val op < = fn $ => make (Int.<, Real.<) $
val op <= = fn $ => make (Int.<=, Real.<=) $
val op > = fn $ => make (Int.>, Real.>) $
val op >= = fn $ => make (Int.>=, Real.>=) $
end
local
fun make (fi, fr) (n1, n2) =
let
fun one (f, from, to) e =
cast (from (f (to (cast (n1, e)), to (cast (n2, e)))),
Equiv.symmetric e)
in
typeCase (n1,
one (fi, fromInt, toInt),
one (fr, fromReal, toReal))
end
in
val op + = fn $ => make (Int.+, Real.+) $
val op - = fn $ => make (Int.-, Real.-) $
val op * = fn $ => make (Int.*, Real.* ) $
val max = fn $ => make (Int.max, Real.max) $
val min = fn $ => make (Int.min, Real.min) $
end
fun a / b = fromReal (Real./ (toReal (real a), toReal (real b)))
local
fun make f (n1, n2) = fromInt (f (toInt n1, toInt n2))
in
val op div = fn $ => make Int.div $
val op mod = fn $ => make Int.mod $
end
end
functor Test (Num: NUM) =
struct
open Num
val i = fromInt
val r = fromReal
fun p n = print (concat [toString n, "\n"])
val () = p (i 1 + i 2)
val () = p (r 1.5 + r 2.5)
val () = p (round ((i 1 + i 2) / r 3.5))
fun double x = x + x
val () = p (double (i 1))
val () = p (double (r 1.5 + pi))
end
structure Z = Test (Num)