[MLton] improved overloading for SML
Stephen Weeks
sweeks@sweeks.com
Mon, 17 Oct 2005 17:16:24 -0700
One problem with the implementation of Num that I sent earlier is that
one must look at all of the code inside the Num structure to convince
oneself of the invariant that a value of type "x t" is of type x
underneath (where x is either real or int). Below is an improved
implementation that uses a functor to isolate the necessary reasoning
to a few lines of reusable code. The trick is to define a typecase
construct that can be used to implement all of the overloaded
functions. One need only convince oneself that typecase maintains the
invariant in order to believe the code works.
Using a generic typecase functor also isolates the trick needed to
make MLton do the right thing, making it easier to create new
overloaded types and to extend existing ones. It also makes clearer
the expressive power of this technique.
Too bad we don't have fold at the functor level, or there would be a
way to work around having to define a family of TypeCase<N> functors
for creating overloadable N-element type families.
----------------------------------------------------------------------
functor TypeCase2 (type x1
type x2):
sig
type 'a t
val from1: x1 -> x1 t
val from2: x2 -> x2 t
val typeCase:
'a t
* (x1 * (x1 -> 'a t) -> 'b)
* (x2 * (x2 -> 'a t) -> 'b)
-> 'b
end =
struct
datatype x = X1 of x1 | X2 of x2
datatype 'a t = T of 'a option * x
local
fun make X x = T (NONE, X x)
in
val from1: x1 -> x1 t = make X1
val from2: x2 -> x2 t = make X2
end
fun typeCase (T (a, x), f1, f2) =
let
fun call (f, X, x) = f (x, fn x => T (a, X x))
in
case x of
X1 x => call (f1, X1, x)
| X2 x => call (f2, X2, x)
end
end
signature NUM =
sig
type 'a t
val < : 'a t * 'a t -> bool
val <= : 'a t * 'a t -> bool
val > : 'a t * 'a t -> bool
val >= : 'a t * 'a t -> bool
val ~ : 'a t -> 'a t
val + : 'a t * 'a t -> 'a t
val - : 'a t * 'a t -> 'a t
val * : 'a t * 'a t -> 'a t
val / : 'a t * 'b t -> real t
val abs: 'a t -> 'a t
val div: int t * int t -> int t
val e: real t
val fromInt: int -> int t
val fromReal: real -> real t
val max: 'a t * 'a t -> 'a t
val min: 'a t * 'a t -> 'a t
val mod: int t * int t -> int t
val pi: real t
val real: 'a t -> real t
val round: 'a t -> int t
val sqrt: 'a t -> real t
val toString: 'a t -> string
val typeCase:
'a t
* (int * (int -> 'a t) -> 'b)
* (real * (real -> 'a t) -> 'b)
-> 'b
end
structure Num:> NUM =
struct
structure Z = TypeCase2 (type x1 = int
type x2 = real)
open Z
val fromInt = from1
val fromReal = from2
val e = fromReal Real.Math.e
val pi = fromReal Real.Math.pi
fun unary (fi, fr) n = typeCase (n, fi o #1, fr o #1)
val toString = fn $ => unary (Int.toString, Real.toString) $
val toReal = fn $ => unary (Real.fromInt, fn r => r) $
fun sqrt n = fromReal (Real.Math.sqrt (toReal n))
val real = fn $ => (fromReal o toReal) $
val round = fn $ => unary (fromInt, fromInt o Real.round) $
local
fun make (fi, fr) n =
let
fun wrap f (x, X) = X (f x)
in
typeCase (n, wrap fi, wrap fr)
end
in
val abs = fn $ => make (Int.abs, Real.abs) $
val ~ = fn $ => make (Int.~, Real.~) $
end
fun bug _ = raise Fail "bug"
local
fun make (fi, fr) (n1, n2) =
typeCase
(n1,
fn (i1, _) => typeCase (n2, fn (i2, _) => fi (i1, i2), bug),
fn (r1, _) => typeCase (n2, bug, fn (r2, _) => fr (r1, r2)))
in
val op < = fn $ => make (Int.<, Real.<) $
val op <= = fn $ => make (Int.<=, Real.<=) $
val op > = fn $ => make (Int.>, Real.>) $
val op >= = fn $ => make (Int.>=, Real.>=) $
end
local
fun make (fi, fr) (n1, n2) =
typeCase
(n1,
fn (i1, I) => typeCase (n2, fn (i2, _) => I (fi (i1, i2)), bug),
fn (r1, R) => typeCase (n2, bug, fn (r2, _) => R (fr (r1, r2))))
in
val op + = fn $ => make (Int.+, Real.+) $
val op - = fn $ => make (Int.-, Real.-) $
val op * = fn $ => make (Int.*, Real.* ) $
val max = fn $ => make (Int.max, Real.max) $
val min = fn $ => make (Int.min, Real.min) $
end
fun a / b = fromReal (Real./ (toReal a, toReal b))
local
fun make f (n1, n2) =
typeCase
(n1,
fn (i1, I) => typeCase (n2, fn (i2, _) => I (f (i1, i2)), bug),
bug)
in
val op div = fn $ => make Int.div $
val op mod = fn $ => make Int.mod $
end
end
functor Test (Num: NUM) =
struct
open Num
val i = fromInt
val r = fromReal
fun p n = print (concat [toString n, "\n"])
val () = p (i 1 + i 2)
val () = p (r 1.5 + r 2.5)
val () = p (round ((i 1 + i 2) / r 3.5))
fun double x = x + x
val () = p (double (i 1))
val () = p (double (r 1.5 + pi))
end
structure Z = Test (Num)