new snapshot
Stephen Weeks
MLton@sourcelight.com
Thu, 16 Nov 2000 14:18:37 -0800 (PST)
>
> > I just put a new snapshot at http://www.star-lab.com/sweeks/src.tgz
> > This is the aforementioned stable snapshot.
>
> Do I win a prize for finding the first bug? ;)
>
> Try the regression suite with -DMLton_detectOverflow.
> I variously get
> mlton: toList(jump L_0) of unknown handler stack
> and
> /tmp/fileR7UCrO.c: In function `Chunk1':
> /tmp/fileR7UCrO.c:953: `L_198' undeclared (first use in this function)
> compile errors.
Oops. I never tested with overflow checking. There was at least one bug in the
CPS shrinker that forgot to eliminate the useless label when an overflow
checking primitive was simplified away. Here is a new cps/shrink.fun. With it,
I get through FuhMishra, abstype, and array, but die on array2 in the backend
with
Int_addCheck
mlton: applyPrim: getDst
compilation of array2 failed
I assume that's because my version doesn't have valid overflow checking code.
I'm still looking into the second bug, which I think is unrelated to the first.
--------------------------------------------------------------------------------
(* Copyright (C) 1997-1999 NEC Research Institute.
* Please see the file LICENSE for license information.
*)
functor Shrink(S: SHRINK_STRUCTS): SHRINK =
struct
open S
open Dec PrimExp Transfer
type int = Int.t
structure Position =
struct
datatype t =
Formal of int
| Free of Var.t
fun layout(p: t) =
case p of
Formal i => Int.layout i
| Free x => Var.layout x
val equals =
fn (Formal i, Formal i') => i = i'
| (Free x, Free x') => Var.equals(x, x')
| _ => false
end
structure Positions =
struct
structure PS = MonoList(Position)
open PS
fun usesFormal(ps: t): bool =
List.exists(ps, fn Position.Formal _ => true | _ => false)
end
structure JumpInfo =
struct
(* IsLiftableCase iff f is of the form
* fun f(x) =
* let fun L1
* ...
* fun Ln
* in case x of ...
* end
* where x only occurs once.
* In this case, f will be rewritten as
* fun L1
* ...
* fun Ln
* fun f(x) = case x of ...
* The Li's will be called directly at places where the constructor is
* known.
* Where the constructor is not known, f will be called.
*)
(* !numOccurrences is the number of calls to the jump from outside its
* body.
*)
datatype t = T of {meaning: meaning,
name: Jump.t,
numArgs: int,
numOccurrences: int ref}
and meaning =
Code of {body: Exp.t,
formals: (Var.t * Type.t) list,
isLiftableCase: {cases: t Cases.t,
default: t option} option ref,
isRecursive: bool ref}
| Jump of {dst: t, args: Positions.t}
| Raise of Positions.t
| Return of Positions.t
fun layout(T{meaning, name, numOccurrences, ...}) =
let open Layout
in record[("name", Jump.layout name),
("numOccurrences", Int.layout(!numOccurrences)),
("meaning",
case meaning of
Code{isRecursive, ...} =>
seq[str "Code",
record[("isRecursive", Bool.layout(!isRecursive))]]
| Jump{dst, args} =>
seq[str "Jump",
tuple[layout dst, Positions.layout args]]
| Raise ps => seq[str "Raise ", Positions.layout ps]
| Return ps => seq[str "Return ", Positions.layout ps])]
end
fun usesFormal(T{meaning, ...}) =
case meaning of
Code _ => true
| Jump{args, ...} => Positions.usesFormal args
| Raise xs => Positions.usesFormal xs
| Return xs => Positions.usesFormal xs
local fun make s (T r) = s r
in
val name = make #name
end
fun inc(T{numOccurrences = r, ...}, n: int): unit = r := n + !r
fun equals(T{meaning = m, ...}, T{meaning = m', ...}) =
case (m, m') of
(Code{isRecursive = i, ...}, Code{isRecursive = i', ...}) =>
i = i'
| (Code{isRecursive = i, ...},
Jump{dst = T{meaning = Code{isRecursive = i', ...}, ...},
args = [], ...}) => i = i'
| (Jump{dst = T{meaning = Code{isRecursive = i', ...}, ...},
args = [], ...},
Code{isRecursive = i, ...}) => i = i'
| (Jump{dst = d, args = a}, Jump{dst = d', args = a'}) =>
equals(d, d') andalso Positions.equals(a, a')
| (Raise ps, Raise ps') => Positions.equals(ps, ps')
| (Return ps, Return ps') => Positions.equals(ps, ps')
| _ => false
fun jump(j as T{meaning, ...}, ps: Positions.t): meaning =
let
fun extract(ps': Positions.t): Positions.t =
let val a = Array.fromList ps
in List.map(ps', let open Position
in fn Free x => Free x
| Formal i => Array.sub(a, i)
end)
end
in case meaning of
Code _ => Jump{dst = j, args = ps}
| Jump{dst, args} => Jump{dst = dst, args = extract args}
| Raise ps => Raise(extract ps)
| Return ps => Return(extract ps)
end
fun isTail(T{numArgs, meaning, ...}): bool =
case meaning of
Return ps =>
numArgs = List.length ps
andalso List.foralli(ps,
fn (i, Position.Formal i') => i = i'
| _ => false)
| _ => false
end
structure VarInfo =
struct
datatype t = T of {var: Var.t,
numOccurrences: int ref,
value: value option ref}
and value =
Con of {con: Con.t, args: t list}
| Const of Const.t
| Tuple of t list
fun equals (T {var = x, ...}, T {var = y, ...}) = Var.equals(x, y)
fun layout(T{var, numOccurrences, value}) =
let open Layout
in record[("var", Var.layout var),
("numOccurrences", Int.layout(!numOccurrences)),
("value", Option.layout layoutValue (!value))]
end
and layoutValue v =
let open Layout
in case v of
Con{con, args} => seq[Con.layout con,
tuple(List.map(args, layout))]
| Const c => Const.layout c
| Tuple vis => tuple(List.map(vis, layout))
end
fun new(x: Var.t) = T{var = x,
numOccurrences = ref 0,
value = ref NONE}
fun var(T{var, ...}): Var.t = var
fun numOccurrences(T{numOccurrences = r, ...}) = r
fun value(T{value, ...}): value option = !value
end
structure Value =
struct
datatype t = datatype VarInfo.value
val layout = VarInfo.layoutValue
fun fromBool(b: bool): t =
Con{con = if b then Con.truee else Con.falsee,
args = []}
fun toPrimExp(v: t): PrimExp.t =
case v of
Con{con, args} =>
PrimExp.ConApp{con = con,
args = List.map(args, VarInfo.var)}
| Const c => PrimExp.Const c
| Tuple xs => PrimExp.Tuple(List.map(xs, VarInfo.var))
end
val traceApply =
Trace.trace("Prim.apply",
fn (p, args, _: VarInfo.t * VarInfo.t -> bool) =>
let open Layout
in seq[Prim.layout p,
List.layout (Prim.ApplyArg.layout
(Var.layout o VarInfo.var)) args]
end,
Prim.ApplyResult.layout (Var.layout o VarInfo.var))
val traceSimplifyBind =
Trace.trace2("Shrink.simplifyBind", layoutBind,
Layout.ignore: (unit -> Exp.t) -> Layout.t,
Layout.ignore: Exp.t -> Layout.t)
val traceSimplifyExp = Trace.trace("simplifyExp", Exp.layout, Exp.layout)
val traceSimplifyTransfer =
Trace.trace("Shrink.simplifyTransfer", Transfer.layout, Exp.layout)
val traceJump =
Trace.trace("jump",
fn (dst, args) => Layout.tuple[JumpInfo.layout dst,
List.layout VarInfo.layout args],
Exp.layout)
val traceDeleteExp = Trace.trace("deleteExp", Exp.layout, Unit.layout)
fun shrinkExp globals =
let
(* varInfo can't be getSetOnce because of setReplacement. *)
val {get = varInfo: Var.t -> VarInfo.t, set = setVarInfo} =
Property.getSet(Var.plist, Property.initFun VarInfo.new)
val varInfo = Trace.trace("varInfo", Var.layout, VarInfo.layout) varInfo
val setVarInfo =
Trace.trace2("setVarInfo", Var.layout, VarInfo.layout, Unit.layout)
setVarInfo
fun varInfos xs = List.map(xs, varInfo)
val {get = jumpInfo: Jump.t -> JumpInfo.t, set = setJumpInfo} =
Property.getSetOnce(Jump.plist, Property.initRaise("info", Jump.layout))
val jumpInfo =
Trace.trace("jumpInfo", Jump.layout, JumpInfo.layout) jumpInfo
fun simplifyVar(x: Var.t) = VarInfo.var(varInfo x)
val simplifyVar =
Trace.trace("simplifyVar", Var.layout, Var.layout) simplifyVar
fun simplifyVars xs = List.map(xs, simplifyVar)
fun incNumOccurrences(r: int ref, n: int): unit =
let val new = n + !r
val _ = Assert.assert("incNumOccurrences", fn () => new >= 0)
in r := new
end
fun incVarInfo(x: VarInfo.t, n: int): unit =
incNumOccurrences(VarInfo.numOccurrences x, n)
fun incVar(x: Var.t, n: int): unit = incVarInfo(varInfo x, n)
val incVar =
Trace.trace2("incVar", Var.layout, Int.layout, Unit.layout) incVar
fun deleteVarInfo i = incVarInfo(i, ~1)
fun deleteVarInfos is = List.map(is, deleteVarInfo)
fun deleteVar x = incVar(x, ~1)
val deleteVar = Trace.trace("deleteVar", Var.layout, Unit.layout) deleteVar
fun deletePosition(p: Position.t): unit =
case p of
Position.Free x => deleteVar x
| _ => ()
fun numVarOccurrences(x: Var.t): int =
!(VarInfo.numOccurrences(varInfo x))
fun setReplacement(x: Var.t, i: VarInfo.t): unit =
let val VarInfo.T{numOccurrences = r, ...} = varInfo x
in incVarInfo(i, !r)
; setVarInfo(x, i)
end
val setReplacement =
Trace.trace2("setReplacement", Var.layout, VarInfo.layout, Unit.layout)
setReplacement
fun addVar(x: Var.t): unit = incVar(x, 1)
fun addVarInfo(x: VarInfo.t): unit = incVarInfo(x, 1)
fun addPositions(ps: Positions.t): unit =
List.foreach(ps,
fn Position.Free x => addVar x
| _ => ())
fun addJumpInfo j = JumpInfo.inc(j, 1)
fun addJump j = addJumpInfo(jumpInfo j)
fun addJumpMeaning(m: JumpInfo.meaning) =
let datatype z = datatype JumpInfo.meaning
in case m of
Code _ => ()
| Jump{dst, args} => (addJumpInfo dst
; addPositions args)
| Raise ps => addPositions ps
| Return ps => addPositions ps
end
val _ =
List.foreach
(globals, fn {var, exp, ty} =>
let
fun construct v =
setVarInfo(var, VarInfo.T{var = var,
numOccurrences = ref 0,
value = ref(SOME v)})
in case exp of
Var y => setVarInfo(var, varInfo y)
| Const c => construct(Value.Const c)
| Tuple xs => construct(Value.Tuple(varInfos xs))
| ConApp{con, args} =>
construct(Value.Con{con = con,
args = varInfos args})
| _ => ()
end)
in fn (exp: Exp.t, mayDelete: bool) =>
let
(* Compute occurrence counts for both variables and jumps. *)
val _ =
let
fun loopExp e =
let val {decs, transfer} = Exp.dest e
in List.foreach(decs, loopDec)
; Transfer.foreachJumpVar(transfer, addJump, addVar)
end
and loopDec d =
case d of
Bind{exp, ...} =>
PrimExp.foreachJumpVar(exp, addJump, addVar)
| Fun{name, args, body} =>
let
val {decs, transfer} = Exp.dest body
val numOccurrences = ref 0
fun set(m: JumpInfo.meaning) =
(addJumpMeaning m
; (setJumpInfo
(name,
JumpInfo.T{name = name,
numArgs = List.length args,
numOccurrences = numOccurrences,
meaning = m})))
fun normal() =
let
val isLiftableCase = ref NONE
val isRecursive = ref false
in set(JumpInfo.Code
{body = body,
formals = args,
isLiftableCase = isLiftableCase,
isRecursive = isRecursive})
; loopExp body
; if !numOccurrences > 0
then isRecursive := true
else
case (args, transfer) of
([(x, _)],
Case{test, cases, default, ...}) =>
if Var.equals(x, test)
andalso
List.forall(decs,
fn Fun _ => true
| _ => false)
andalso 1 = numVarOccurrences x
then
(isLiftableCase :=
SOME{cases =
Cases.map(cases, jumpInfo),
default =
Option.map(default,
jumpInfo)})
else ()
| _ => ()
; numOccurrences := 0
end
fun extract(actuals: Var.t list): Positions.t =
let
val {get: Var.t -> Position.t, set, destroy} =
Property.destGetSetOnce
(Var.plist, Property.initFun Position.Free)
val _ =
List.foreachi(args, fn (i, (x, _)) =>
set(x, Position.Formal i))
val ps =
List.fold(rev actuals, [], fn (x, ps) =>
(get x :: ps))
val _ = destroy()
in ps
end
fun sameAsArgs args' =
List.equals(args, args', fn ((x, _), x') =>
Var.equals(x, x'))
in case (decs, transfer) of
([], Jump{dst, args}) =>
if Jump.equals(dst, name)
then normal()
else
if sameAsArgs args
then setJumpInfo(name, jumpInfo dst)
else
set(JumpInfo.jump(jumpInfo dst,
extract args))
| ([], Raise xs) => set(JumpInfo.Raise(extract xs))
| ([], Return xs) => set(JumpInfo.Return(extract xs))
| _ => normal()
end
| HandlerPop => ()
| HandlerPush h => addJump h
in loopExp exp
end
local
val {get: Jump.t -> bool ref} =
Property.get(Jump.plist, fn _ => ref false)
in
val amInJump = ! o get
fun withinJump(j: Jump.t, f: unit -> 'a): 'a =
let val r = get j
in r := true
; f() before r := false
end
end
fun makeBody(body, isLiftableCase) =
if isSome(!isLiftableCase)
then Exp.make{decs = [],
transfer = Exp.transfer body}
else body
fun deletePositions(ps: Positions.t): unit =
List.foreach(ps,
fn Position.Free x => deleteVar x
| _ => ())
(* Deleting jumps and expressions. *)
fun deleteJumpMeaning(m: JumpInfo.meaning, j: Jump.t): unit =
let datatype z = datatype JumpInfo.meaning
in case m of
Code{body, isLiftableCase, ...} =>
withinJump(j, fn () =>
deleteExp(makeBody(body, isLiftableCase)))
| Jump{dst, args} => (deleteJumpInfo dst; deletePositions args)
| Raise ps => deletePositions ps
| Return ps => deletePositions ps
end
and deleteJump (j: Jump.t): unit = deleteJumpInfo (jumpInfo j)
and deleteJumpInfo (JumpInfo.T {meaning, name, numOccurrences, ...})
: unit =
if amInJump name
then ()
else
let
val new = !numOccurrences - 1
val _ = Assert.assert("deleteJumpInfo", fn () => new >= 0)
val _ = numOccurrences := new
in if new = 0
then deleteJumpMeaning(meaning, name)
else ()
end
and deleteExp arg : unit =
traceDeleteExp
(fn (exp: Exp.t) =>
let val {decs, transfer} = Exp.dest exp
in List.foreach(decs, deleteDec)
; Transfer.foreachJumpVar(transfer, deleteJump, deleteVar)
end) arg
and deleteDec d =
case d of
Bind{exp, ...} =>
PrimExp.foreachJumpVar(exp, deleteJump, deleteVar)
| Fun{name, body, ...} =>
let
val JumpInfo.T{meaning, name = n, numOccurrences, ...} =
jumpInfo name
in if 0 = !numOccurrences andalso Jump.equals(name, n)
then deleteJumpMeaning(meaning, name)
else ()
end
| HandlerPush h => deleteJump h
| _ => ()
(* Pre: the args counts are correct. *)
fun jump arg: Exp.t =
traceJump
(fn (info as JumpInfo.T{meaning, name, numOccurrences, ...},
args: VarInfo.t list) =>
let
fun extract(ps: Positions.t): VarInfo.t list =
let
val a = Array.fromList args
val ps =
List.map(ps, fn p =>
let
val i =
case p of
Position.Formal i => Array.sub(a, i)
| Position.Free x => varInfo x
in addVarInfo i; i
end)
in ps
end
fun rr(f: Var.t list -> Transfer.t, ps: Positions.t): Exp.t =
(deleteJumpInfo info
; Exp.fromTransfer(f(List.map(extract ps, VarInfo.var))))
in case meaning of
JumpInfo.Code{body, formals, isLiftableCase, isRecursive, ...} =>
if 1 = !numOccurrences andalso not(!isRecursive)
then (numOccurrences := 0
; List.foreach2(formals, args, fn ((x, _), i) =>
setReplacement(x, i))
; deleteVarInfos args
; simplifyExp(makeBody(body, isLiftableCase)))
else
let
fun jump(j: Jump.t, args: VarInfo.t list): Exp.t =
Exp.fromTransfer
(Jump{dst = j, args = List.map(args, VarInfo.var)})
in case (!isLiftableCase, args) of
(SOME{cases, default},
[VarInfo.T{numOccurrences,
value = ref(SOME v), ...}]) =>
let
fun doit(cases, is, args) =
let
val jump =
fn (j, args) =>
(deleteJumpInfo info
; JumpInfo.inc(j, 1)
; jump(JumpInfo.name j, args))
in case List.peek(cases, fn (i, _) => is i) of
NONE => (case default of
NONE => Exp.bug
| SOME j => jump(j, []))
| SOME(_, j) =>
(incNumOccurrences
(numOccurrences, ~1)
; List.foreach(args, addVarInfo)
; jump(j, args))
end
in case (cases, v) of
(Cases.Con cases, Value.Con{con, args}) =>
doit(cases, fn c => Con.equals(c, con), args)
| (Cases.Int cases, Value.Const c) =>
(case Const.node c of
Const.Node.Int i =>
doit(cases, fn i' => i = i', [])
| _ =>
Error.bug "strange constant for Cases.Int")
| _ => Error.bug "strange Case with constant test"
end
| _ => jump(name, args)
end
| JumpInfo.Jump{dst, args} =>
(addJumpInfo dst
; deleteJumpInfo info
; jump(dst, extract args))
| JumpInfo.Raise ps => rr(Raise, ps)
| JumpInfo.Return ps => rr(Return, ps)
end) arg
and simplifyExp arg : Exp.t =
traceSimplifyExp
(fn (e: Exp.t) =>
let val {decs, transfer} = Exp.dest e
in simplifyDecs(decs, fn () => simplifyTransfer transfer)
end) arg
and simplifyDecs(decs: Dec.t list, rest): Exp.t =
let
val rec loop =
fn [] => rest()
| d :: ds => simplifyDec(d, fn () => loop ds)
in loop decs
end
and simplifyDec(dec: Dec.t, rest: unit -> Exp.t): Exp.t =
case dec of
Bind r => simplifyBind(r, rest)
| Fun r => simplifyFun(r, rest)
| HandlerPop => Exp.prefix(rest(), dec)
| HandlerPush _ => Exp.prefix(rest(), dec)
and simplifyBind arg: Exp.t =
traceSimplifyBind
(fn ({var, ty, exp}, rest: unit -> Exp.t) =>
let
val VarInfo.T{numOccurrences, value, ...} = varInfo var
fun finish(exp: PrimExp.t, decs: Exp.t): Exp.t =
Exp.prefix(decs, Bind{var = var, ty = ty, exp = exp})
fun nonExpansive(p: PrimExp.t): Exp.t =
let
fun isUseless(): bool =
mayDelete andalso 0 = !numOccurrences
fun delete() = PrimExp.foreachVar(p, deleteVar)
in if isUseless()
then (delete(); rest())
else let val rest = rest()
in if isUseless()
then (delete(); rest)
else finish(p, rest)
end
end
fun construct(v: Value.t) =
(value := SOME v; nonExpansive(Value.toPrimExp v))
fun bindVar (x: VarInfo.t) =
(setReplacement(var, x)
; deleteVarInfo x
; rest())
fun primApp {prim, info, targs, args} =
let
val info =
let open PrimInfo
in case info of
None => None
| Overflow j => Overflow (JumpInfo.name (jumpInfo j))
end
val e =
PrimApp {prim = prim,
info = info,
targs = targs,
args = List.map (args, VarInfo.var)}
in if Prim.maySideEffect prim
then finish (e, rest())
else nonExpansive e
end
in case exp of
Const c => construct (Value.Const c)
| ConApp {con, args} =>
construct (Value.Con {con = con, args = varInfos args})
| PrimApp {prim, info, targs, args} =>
let
fun deleteInfo () =
let
datatype z = datatype PrimInfo.t
in case info of
None => ()
| Overflow j => deleteJump j
end
val args = varInfos args
fun normal () =
primApp {prim = prim,
info = info,
targs = targs,
args = args}
fun loop (args: VarInfo.t list,
ac: VarInfo.t Prim.ApplyArg.t list) =
case args of
[] =>
let
val res =
traceApply Prim.apply
(prim, rev ac, VarInfo.equals)
datatype z = datatype Prim.ApplyResult.t
in case res of
Apply (p, args) =>
primApp {prim = p, info = info,
targs = [], args = args}
| Bool b => (deleteInfo ()
; construct (Value.fromBool b))
| Const c => (deleteInfo ()
; construct (Value.Const c))
| Unknown => normal()
| Var x => (deleteInfo ()
; bindVar x)
end
| arg :: args =>
let
val arg =
case arg of
VarInfo.T{value = ref (SOME
(Value.Const c)),
...} =>
Prim.ApplyArg.Const (Const.node c)
| _ => Prim.ApplyArg.Var arg
in loop (args, arg :: ac)
end
in case Prim.name prim of
Prim.Name.FFI _ => normal()
| _ => loop (args, [])
end
| Select{tuple, offset} =>
let
val VarInfo.T{var = tuple, numOccurrences, value} =
varInfo tuple
in case !value of
SOME(Value.Tuple vs) =>
(incNumOccurrences(numOccurrences, ~1)
; setReplacement(var, List.nth(vs, offset))
; rest())
| NONE => nonExpansive(Select{tuple = tuple,
offset = offset})
| _ => Error.bug "select of non-tuple"
end
| Tuple xs => construct(Value.Tuple(varInfos xs))
| Var x => bindVar (varInfo x)
end) arg
and simplifyFun({name, args, body}, rest: unit -> Exp.t): Exp.t =
let
val JumpInfo.T{meaning, name = n, numOccurrences, ...} =
jumpInfo name
fun doit() =
let
fun isUseless() = 0 = !numOccurrences
in if isUseless()
then (deleteJumpMeaning(meaning, n); rest())
else
let val rest = rest()
in if isUseless()
then rest
else
let
fun getVars(ps: Positions.t): VarInfo.t list =
let val args = Array.fromList args
in List.map
(ps,
fn Position.Formal i =>
varInfo(#1(Array.sub(args, i)))
| Position.Free x => varInfo x)
end
fun rr(f: Var.t list -> Transfer.t, ps)
: Exp.t =
Exp.fromTransfer
(f(List.map(getVars ps, VarInfo.var)))
val body =
case meaning of
JumpInfo.Code _ =>
withinJump(name, fn () =>
simplifyExp body)
| JumpInfo.Jump{dst, args} =>
jump(dst, getVars args)
| JumpInfo.Raise ps => rr(Raise, ps)
| JumpInfo.Return ps => rr(Return, ps)
in Exp.prefix(rest, Fun{name = name, args = args,
body = body})
end
end
end
in case (Jump.equals(name, n), meaning) of
(true, JumpInfo.Code{isLiftableCase, ...}) =>
let val {decs, transfer} = Exp.dest body
in case (decs, !isLiftableCase) of
(_ :: _, SOME _) =>
simplifyDecs
(decs @ [Fun{name = name, args = args,
body = Exp.make{decs = [],
transfer = transfer}}],
rest)
| _ => doit()
end
| (true, _) => doit()
| _ => rest()
end
and simplifyTransfer arg : Exp.t =
traceSimplifyTransfer
(fn (t: Transfer.t) =>
case t of
Bug => Exp.bug
| Call{func, args, cont} =>
let
val cont =
case cont of
NONE => NONE
| SOME j =>
let val info = jumpInfo j
in if JumpInfo.isTail info
then NONE
else SOME(JumpInfo.name info)
end
in Exp.fromTransfer
(Call{func = func, args = simplifyVars args,
cont = cont})
end
| Case r => simplifyCase r
| Jump{dst, args} => jump(jumpInfo dst, varInfos args)
| Raise xs => Exp.fromTransfer(Raise(simplifyVars xs))
| Return xs => Exp.fromTransfer(Return(simplifyVars xs))
) arg
and simplifyCase(c as {cause, test, cases, default}) =
let val test = varInfo test
val cases = Cases.map(cases, jumpInfo)
val default = Option.map(default, jumpInfo)
in if Cases.isEmpty cases
then (case default of
NONE => (deleteVarInfo test; Exp.bug)
| SOME j => (deleteVarInfo test; jump(j, [])))
else
let
fun findCase(cases, is, args) =
let
val _ = deleteVarInfo test
val rec loop =
fn [] => (case default of
NONE => Exp.bug
| SOME j => jump(j, []))
| (i, j) :: cases =>
if is i
then (List.foreach
(cases, deleteJumpInfo o #2)
; Option.app(default,
deleteJumpInfo)
; List.foreach(args, addVarInfo)
; jump(j, args))
else (deleteJumpInfo j; loop cases)
in loop cases
end
fun normal() =
Exp.fromTransfer
(Case{cause = cause,
test = VarInfo.var test,
cases = Cases.map(cases, JumpInfo.name),
default = Option.map(default, JumpInfo.name)})
in case (VarInfo.value test, cases) of
(SOME(Value.Const c), Cases.Int cases) =>
(case Const.node c of
Const.Node.Int i =>
findCase(cases, fn i' => i = i', [])
| _ => Error.bug "strange constant for Cases.Int")
| (SOME(Value.Con{con, args}), Cases.Con cases) =>
findCase(cases, fn c => Con.equals(con, c), args)
| (SOME _, _) => Error.bug "strange bind for case test"
| (NONE, _) =>
let
val info as JumpInfo.T{meaning, ...} = Cases.hd cases
(* If all cases are the same, eliminate the case. *)
fun isOk(i: JumpInfo.t): bool =
not(JumpInfo.usesFormal i)
andalso JumpInfo.equals(info, i)
in if (not(JumpInfo.usesFormal info)
andalso Cases.forall(cases, isOk)
andalso Option.fold(default, true, isOk o #1))
then
let
fun getFree(ps: Positions.t)
: VarInfo.t list =
List.map
(ps,
fn Position.Formal _ =>
Error.bug "getFree"
| Position.Free x =>
let val i = varInfo x
in addVarInfo i; i
end)
fun delete() =
(Cases.foreach(cases, deleteJumpInfo)
; Option.app(default, deleteJumpInfo))
fun rr(f: Var.t list -> Transfer.t, ps)
: Exp.t =
(delete()
; (Exp.fromTransfer
(f(List.map(getFree ps,
VarInfo.var)))))
val _ = deleteVarInfo test
in case meaning of
JumpInfo.Code _ => Error.bug "Code?"
| JumpInfo.Jump{dst, args} =>
let
val _ = addJumpInfo dst
val _ = delete()
in jump(dst, getFree args)
end
| JumpInfo.Raise ps => rr(Raise, ps)
| JumpInfo.Return ps => rr(Return, ps)
end
else normal()
end
end
end
val exp = simplifyExp exp
in Exp.clear exp
; exp
end
end
val shrinkExpNoDelete = fn e => shrinkExp [] (e, false)
val traceShrinkExp = Trace.trace("shrinkExp", Exp.layout, Exp.layout)
val shrinkExp = fn globals => let val shrinkExp = shrinkExp globals
in traceShrinkExp(fn e => shrinkExp(e, true))
end
fun simplifyProgram simplifyExp
(Program.T{datatypes, globals, functions, main}) =
let
val shrinkExp = shrinkExp globals
val functions =
List.revMap
(functions, fn {name, args, body, returns} =>
{name = name, args = args,
body = shrinkExp(simplifyExp body),
returns = returns})
in Program.T{datatypes = datatypes,
globals = globals,
functions = functions,
main = main}
end
fun shrink p = simplifyProgram (fn x => x) p
end