[MLton] improved overloading for SML
Stephen Weeks
MLton@mlton.org
Tue, 18 Oct 2005 17:27:50 -0700
There is an aspect of the latest Num implementation that I did not
like. There were calls to the "bug" function, which raises an
exception, inside the implementation of Num. These corresponded to
code that should be unreachable, where the unreachability was
guaranteed by the invariants enforced by the phantom types and
signature constraints.
This mail has an implementation of Num that fixes the problem. All of
the potentially exception-raising code has been moved to the generic
TypeCase2 module, where it can be proved correct (i.e. that it doesn't
raise any exceptions) once and for all. Furthermore, TypeCase2 has
been improved to be more general, by using varargs fold to allow
typecase over any number of arguments (of the same type). This is not
just a syntactic convenience -- it is essential to implement operators
like + that require both arguments to be the same type. With typecase
only operating on a single argument, the code for + had to account for
the possibility of getting, say, a real and an int. Now, with the
code below, the implementation of plus looks like
fun n1 + n2 =
typeCase`n1`n2 $
(fn (i1 & i2, I) => I (Int.+ (i1, i2)),
fn (r1 & r2, R) => R (Real.+ (r1, r2)))
The type system guarantees that + gets either two reals or two
integers.
The complete code is below. It makes use of the Fold01N module that I
just explained in an earlier mail.
----------------------------------------------------------------------
datatype ('a, 'b) product = & of 'a * 'b
infix 4 &
fun $ (a, f) = f a
fun const c _ = c
fun curry h x y = h (x, y)
fun id x = x
fun ignore _ = ()
fun pass x f = f x
structure Fold =
struct
type ('a, 'b, 'c, 'd) step = 'a * ('b -> 'c) -> 'd
type ('a, 'b, 'c, 'd) t = ('a, 'b, 'c, 'd) step -> 'd
type ('a1, 'a2, 'b, 'c, 'd) step0 =
('a1, 'b, 'c, ('a2, 'b, 'c, 'd) t) step
type ('a1, 'a2, 'a3, 'b, 'c, 'd) step1 =
('a2, 'b, 'c, 'a1 -> ('a3, 'b, 'c, 'd) t) step
val fold = pass
fun step0 h (a1, f) = fold (h a1, f)
fun step1 h $ x = step0 (curry h x) $
end
structure Fold01N =
struct
type ('a, 'b, 'c, 'd, 'e) t =
((unit -> unit) * ('c -> 'c), (unit -> 'a) * 'd, 'b, 'e) Fold.t
type ('a, 'b, 'c, 'z1, 'z2, 'z3, 'z4, 'z5) step1 =
('z1,
'z2 * ('z1 -> 'a),
(unit -> 'a) * ('b -> 'c),
'z3, 'z4, 'z5) Fold.step1
val fold: ('a -> 'b) -> ('a, 'b, 'c, 'd, 'e) t =
fn finish =>
Fold.fold ((ignore, id), fn (p, _) => finish (p ()))
val step1
: ('a * 'b -> 'c) -> ('a, 'b, 'c, 'z1, 'z2, 'z3, 'z4, 'z5) step1 =
fn combine =>
Fold.step1 (fn (x, (_, f)) =>
(fn () => f x, fn x' => combine (f x, x')))
end
signature TYPE_CASE2 =
sig
type u1
type u2
type ('a, 'b1, 'b2) t
val from1: 'b1 -> (u1, 'b1, 'b2) t
val from2: 'b2 -> (u2, 'b1, 'b2) t
val to1: (u1, 'b1, 'b2) t -> 'b1
val to2: (u2, 'b1, 'b2) t -> 'b2
val typeCase:
(('a, 'c1, 'c2) t,
( ('c1 * ('b1 -> ('a, 'b1, 'b2) t) -> 'e)
* ('c2 * ('b2 -> ('a, 'b1, 'b2) t) -> 'e)) -> 'e,
'z1, 'z2, 'z3) Fold01N.t
val ` :
(('a, 'c1, 'c2) t,
('a, 'b1, 'b2) t,
('a, ('c1, 'b1) product, ('c2, 'b2) product) t,
'z1, 'z2, 'z3, 'z4, 'z5) Fold01N.step1
end
structure TypeCase2:> TYPE_CASE2 =
struct
(* Invariant: Values are always of the form X1 (U1, _) or X2 (U2, _) *)
datatype u1 = U1
datatype u2 = U2
datatype ('a, 'b1, 'b2) t =
X1 of 'a * 'b1
| X2 of 'a * 'b2
fun from1 b = X1 (U1, b)
fun from2 b = X2 (U2, b)
fun bug () = raise Fail "bug"
val to1 = fn X1 (_, b) => b | _ => bug ()
val to2 = fn X2 (_, b) => b | _ => bug ()
fun typeCase $ =
Fold01N.fold
(fn p => fn (f1, f2) =>
let
fun call (f, P, (a, p)) = f (p, fn x => P (a, x))
in
case p of
X1 p => call (f1, X1, p)
| X2 p => call (f2, X2, p)
end) $
fun ` $ =
Fold01N.step1
(fn (X1 (a, p), X1 (_, x)) => X1 (a, p & x)
| (X2 (a, p), X2 (_, x)) => X2 (a, p & x)
| _ => bug ()) $
end
signature NUM =
sig
type 'a t = ('a, int, real) TypeCase2.t
type i = TypeCase2.u1 t
type r = TypeCase2.u2 t
val < : 'a t * 'a t -> bool
val <= : 'a t * 'a t -> bool
val > : 'a t * 'a t -> bool
val >= : 'a t * 'a t -> bool
val ~ : 'a t -> 'a t
val + : 'a t * 'a t -> 'a t
val - : 'a t * 'a t -> 'a t
val * : 'a t * 'a t -> 'a t
val / : 'a t * 'b t -> r
val abs: 'a t -> 'a t
val div: i * i -> i
val e: r
val fromInt: int -> i
val fromReal: real -> r
val max: 'a t * 'a t -> 'a t
val min: 'a t * 'a t -> 'a t
val mod: i * i -> i
val pi: r
val real: 'a t -> r
val round: 'a t -> i
val sqrt: 'a t -> r
val toInt: i -> int
val toReal: r -> real
val toString: 'a t -> string
end
structure Num:> NUM =
struct
open TypeCase2
type 'a t = ('a, int, real) t
type i = u1 t
type r = u2 t
val fromInt: int -> i = from1
val fromReal: real -> r = from2
val toInt: i -> int = to1
val toReal: r -> real = to2
val e = fromReal Real.Math.e
val pi = fromReal Real.Math.pi
fun unary (fi, fr) n = typeCase`n $ (fi o #1, fr o #1)
val toString = fn $ => unary (Int.toString, Real.toString) $
fun real n = fromReal (unary (Real.fromInt, fn r => r) n)
fun sqrt n = fromReal (Real.Math.sqrt (toReal (real n)))
val round = fn $ => unary (fromInt, fromInt o Real.round) $
local
fun make (fi, fr) n =
let
fun wrap f (x, X) = X (f x)
in
typeCase`n $ (wrap fi, wrap fr)
end
in
val abs = fn $ => make (Int.abs, Real.abs) $
val ~ = fn $ => make (Int.~, Real.~) $
end
local
fun make (fi, fr) (n1, n2) =
typeCase`n1`n2 $
(fn (i1 & i2, _) => fi (i1, i2),
fn (r1 & r2, _) => fr (r1, r2))
in
val op < = fn $ => make (Int.<, Real.<) $
val op <= = fn $ => make (Int.<=, Real.<=) $
val op > = fn $ => make (Int.>, Real.>) $
val op >= = fn $ => make (Int.>=, Real.>=) $
end
local
fun make (fi, fr) (n1, n2) =
typeCase`n1`n2 $
(fn (i1 & i2, I) => I (fi (i1, i2)),
fn (r1 & r2, R) => R (fr (r1, r2)))
in
val op + = fn $ => make (Int.+, Real.+) $
val op - = fn $ => make (Int.-, Real.-) $
val op * = fn $ => make (Int.*, Real.* ) $
val max = fn $ => make (Int.max, Real.max) $
val min = fn $ => make (Int.min, Real.min) $
end
fun a / b = fromReal (Real./ (toReal (real a), toReal (real b)))
local
fun make f (n1, n2) = fromInt (f (toInt n1, toInt n2))
in
val op div = fn $ => make Int.div $
val op mod = fn $ => make Int.mod $
end
end
functor Test (Num: NUM) =
struct
open Num
val i = fromInt
val r = fromReal
fun p n = print (concat [toString n, "\n"])
val () = p (i 1 + i 2)
val () = p (r 1.5 + r 2.5)
val () = p (round ((i 1 + i 2) / r 3.5))
fun double x = x + x
val () = p (double (i 1))
val () = p (double (r 1.5 + pi))
end
structure Z = Test (Num)