[MLton-commit] r6351

Matthew Fluet fluet at mlton.org
Thu Jan 24 12:28:32 PST 2008


Formatting
----------------------------------------------------------------------

U   mlton/trunk/mlton/ssa/poly-equal.fun

----------------------------------------------------------------------

Modified: mlton/trunk/mlton/ssa/poly-equal.fun
===================================================================
--- mlton/trunk/mlton/ssa/poly-equal.fun	2008-01-24 12:51:34 UTC (rev 6350)
+++ mlton/trunk/mlton/ssa/poly-equal.fun	2008-01-24 20:28:30 UTC (rev 6351)
@@ -1,4 +1,4 @@
-(* Copyright (C) 1999-2007 Henry Cejtin, Matthew Fluet, Suresh
+(* Copyright (C) 1999-2008 Henry Cejtin, Matthew Fluet, Suresh
  *    Jagannathan, and Stephen Weeks.
  * Copyright (C) 1997-2000 NEC Research Institute.
  *
@@ -17,7 +17,7 @@
  * This pass implements polymorphic equality.
  *
  * For each datatype tycon and vector type, it builds an equality function and
- * translates calls to = into calls to that function.
+ * translates calls to MLton_equal into calls to that function.
  *
  * Also generates calls to primitives intInfEqual and wordEqual.
  *
@@ -62,7 +62,7 @@
                 default = NONE,
                 ty = Type.bool}
 
-      fun disjoin (e1: t, e2:t ): t =
+      fun disjoin (e1: t, e2:t): t =
          casee {test = e1,
                 cases = Con (Vector.new2 ({con = Con.truee,
                                            args = Vector.new0 (),
@@ -72,11 +72,20 @@
                                            body = e2})),
                 default = NONE,
                 ty = Type.bool}
+
+      fun wordEqual (e1: t, e2: t, s): t =
+         primApp {prim = Prim.wordEqual s,
+                  targs = Vector.new0 (),
+                  args = Vector.new2 (e1, e2),
+                  ty = Type.bool}
    end
 
 fun polyEqual (Program.T {datatypes, globals, functions, main}) =
    let
       val shrink = shrinkFunction {globals = globals}
+      val {get = varInfo: Var.t -> {isConst: bool},
+           set = setVarInfo, ...} =
+         Property.getSetOnce (Var.plist, Property.initConst {isConst = false})
       val {get = tyconInfo: Tycon.t -> {isEnum: bool,
                                         cons: {con: Con.t,
                                                args: Type.t vector} vector},
@@ -85,9 +94,6 @@
          (Tycon.plist, Property.initRaise ("PolyEqual.info", Tycon.layout))
       val isEnum = #isEnum o tyconInfo
       val tyconCons = #cons o tyconInfo
-      val {get = varInfo: Var.t -> {isConst: bool},
-           set = setVarInfo, ...} =
-         Property.getSetOnce (Var.plist, Property.initConst {isConst = false})
       val _ =
          Vector.foreach
          (datatypes, fn Datatype.T {tycon, cons} =>
@@ -101,10 +107,11 @@
          Property.getSet (Tycon.plist, Property.initConst NONE)
       val {get = getVectorEqualFunc: Type.t -> Func.t option, 
            set = setVectorEqualFunc,
-           destroy = destroyType} =
+           destroy = destroyVectorEqualFunc} =
          Property.destGetSet (Type.plist, Property.initConst NONE)
       val returns = SOME (Vector.new1 Type.bool)
       val seqIndexWordSize = WordSize.seqIndex ()
+      val seqIndexTy = Type.word seqIndexWordSize
       fun newFunction z =
          List.push (newFunctions,
                     Function.profile (shrink (Function.new z),
@@ -158,16 +165,16 @@
                                  default = if 1 = Vector.length cons
                                               then NONE
                                            else SOME Dexp.falsee,
-                                              cases =
-                                              Dexp.Con
-                                              (Vector.new1
-                                               {con = con,
-                                                args = ys,
-                                                body =
-                                                Vector.fold2
-                                                (xs, ys, Dexp.truee,
-                                                 fn ((x, ty), (y, _), de) =>
-                                                 Dexp.conjoin (de, equal (x, y, ty)))})}}
+                                 cases = 
+                                 Dexp.Con
+                                 (Vector.new1
+                                  {con = con,
+                                   args = ys,
+                                   body =
+                                   Vector.fold2
+                                   (xs, ys, Dexp.truee,
+                                    fn ((x, ty), (y, _), de) =>
+                                    Dexp.conjoin (de, equal (x, y, ty)))})}}
                             end))})
                   val (start, blocks) = Dexp.linearize (body, Handler.Caller)
                   val blocks = Vector.fromList blocks
@@ -195,30 +202,38 @@
                   val loop = Func.newString "vectorEqualLoop"
                   val vty = Type.vector ty
                   local
-                     val v1 = (Var.newNoname (), vty)
-                     val v2 = (Var.newNoname (), vty)
-                     val args = Vector.new2 (v1, v2)
-                     val dv1 = Dexp.var v1
-                     val dv2 = Dexp.var v2
+                     val vec1 = (Var.newNoname (), vty)
+                     val vec2 = (Var.newNoname (), vty)
+                     val args = Vector.new2 (vec1, vec2)
+                     val dvec1 = Dexp.var vec1
+                     val dvec2 = Dexp.var vec2
+                     val len1 = (Var.newNoname (), seqIndexTy)
+                     val dlen1 = Dexp.var len1
+                     val len2 = (Var.newNoname (), seqIndexTy)
+                     val dlen2 = Dexp.var len2
+
                      val body =
                         let
-                          fun length x =
-                             Dexp.primApp {prim = Prim.vectorLength,
-                                           targs = Vector.new1 ty,
-                                           args = Vector.new1 x,
-                                           ty = Type.word seqIndexWordSize}
+                           fun length dvec =
+                              Dexp.primApp {prim = Prim.vectorLength,
+                                            targs = Vector.new1 ty,
+                                            args = Vector.new1 dvec,
+                                            ty = Type.word seqIndexWordSize}
                         in
                            Dexp.disjoin
-                           (Dexp.eq (Dexp.var v1, Dexp.var v2, vty),
-                            Dexp.conjoin
-                            (Dexp.eq (length dv1, length dv2, 
-                                      Type.word seqIndexWordSize),
-                             Dexp.call
-                             {func = loop,
-                              args = (Vector.new4 
-                                      (Dexp.word (WordX.zero seqIndexWordSize),
-                                       length dv1, dv1, dv2)),
-                              ty = Type.bool}))
+                           (Dexp.eq (dvec1, dvec2, vty),
+                            Dexp.lett
+                            {decs = [{var = #1 len1, exp = length dvec1},
+                                     {var = #1 len2, exp = length dvec2}],
+                             body =
+                             Dexp.conjoin
+                             (Dexp.wordEqual (dlen1, dlen2, seqIndexWordSize),
+                              Dexp.call
+                              {func = loop,
+                               args = (Vector.new4 
+                                       (dvec1, dvec2, dlen1,
+                                        Dexp.word (WordX.zero seqIndexWordSize))),
+                              ty = Type.bool})})
                         end
                      val (start, blocks) = Dexp.linearize (body, Handler.Caller)
                      val blocks = Vector.fromList blocks
@@ -233,33 +248,34 @@
                                      start = start}
                   end
                   local
-                     val i = (Var.newNoname (), Type.word seqIndexWordSize)
-                     val len = (Var.newNoname (), Type.word seqIndexWordSize)
-                     val v1 = (Var.newNoname (), vty)
-                     val v2 = (Var.newNoname (), vty)
-                     val args = Vector.new4 (i, len, v1, v2)
+                     val vec1 = (Var.newNoname (), vty)
+                     val vec2 = (Var.newNoname (), vty)
+                     val len = (Var.newNoname (), seqIndexTy)
+                     val i = (Var.newNoname (), seqIndexTy)
+                     val args = Vector.new4 (vec1, vec2, len, i)
+                     val dvec1 = Dexp.var vec1
+                     val dvec2 = Dexp.var vec2
+                     val dlen = Dexp.var len
                      val di = Dexp.var i
-                     val dlen = Dexp.var len
-                     val dv1 = Dexp.var v1
-                     val dv2 = Dexp.var v2
                      val body =
                         let
-                           fun sub (v, i) =
+                           fun sub (dvec, di) =
                               Dexp.primApp {prim = Prim.vectorSub,
                                             targs = Vector.new1 ty,
-                                            args = Vector.new2 (v, i),
+                                            args = Vector.new2 (dvec, di),
                                             ty = ty}
                            val args =
                               Vector.new4 
-                              (Dexp.add
+                              (dvec1, dvec2, dlen, 
+                               Dexp.add
                                (di, Dexp.word (WordX.one seqIndexWordSize), 
-                                seqIndexWordSize),
-                               dlen, dv1, dv2)
+                                seqIndexWordSize))
                         in
                            Dexp.disjoin 
-                           (Dexp.eq (di, dlen, Type.word seqIndexWordSize),
+                           (Dexp.wordEqual
+                            (di, dlen, seqIndexWordSize),
                             Dexp.conjoin
-                            (equalExp (sub (dv1, di), sub (dv2, di), ty),
+                            (equalExp (sub (dvec1, di), sub (dvec2, di), ty),
                              Dexp.call {args = args,
                                         func = loop,
                                         ty = Type.bool}))
@@ -280,17 +296,17 @@
                   name
                end
       and equalExp (e1: Dexp.t, e2: Dexp.t, ty: Type.t): Dexp.t =
-         Dexp.name (e1, fn x1 => Dexp.name (e2, fn x2 => equal (x1, x2, ty)))
+         Dexp.name (e1, fn x1 => 
+         Dexp.name (e2, fn x2 => equal (x1, x2, ty)))
       and equal (x1: Var.t, x2: Var.t, ty: Type.t): Dexp.t =
          let
             val dx1 = Dexp.var (x1, ty)
             val dx2 = Dexp.var (x2, ty)
-            fun primWithArgs (p, targs, dx1, dx2) =
+            fun prim (p, targs) =
                Dexp.primApp {prim = p,
                              targs = targs, 
                              args = Vector.new2 (dx1, dx2),
                              ty = Type.bool}
-            fun prim (p, targs) = primWithArgs (p, targs, dx1, dx2)
             fun eq () = prim (Prim.eq, Vector.new1 ty)
             fun hasConstArg () = #isConst (varInfo x1) orelse #isConst (varInfo x2)
          in
@@ -303,9 +319,10 @@
                   else Dexp.call {func = equalFunc tycon,
                                   args = Vector.new2 (dx1, dx2),
                                   ty = Type.bool}
-             | Type.IntInf => if hasConstArg ()
-                                 then eq ()
-                              else prim (Prim.intInfEqual, Vector.new0 ())
+             | Type.IntInf => 
+                  if hasConstArg ()
+                     then eq ()
+                  else prim (Prim.intInfEqual, Vector.new0 ())
              | Type.Real rs =>
                   let
                      val ws = WordSize.fromBits (RealSize.bits rs)
@@ -316,8 +333,7 @@
                          args = Vector.new1 dx,
                          ty = Type.word ws}
                   in
-                     primWithArgs (Prim.wordEqual ws, Vector.new0 (),
-                                   toWord dx1, toWord dx2)
+                     Dexp.wordEqual (toWord dx1, toWord dx2, ws)
                   end
              | Type.Ref _ => eq ()
              | Type.Thread => eq ()
@@ -330,8 +346,8 @@
                            then Dexp.truee
                         else let
                                 val ty = Vector.sub (tys, i)
-                                fun select tuple =
-                                   Dexp.select {tuple = tuple,
+                                fun select dx =
+                                   Dexp.select {tuple = dx,
                                                 offset = i,
                                                 ty = ty}
                              in
@@ -513,7 +529,7 @@
                     globals = globals,
                     functions = (!newFunctions) @ functions,
                     main = main}
-      val _ = destroyType ()
+      val _ = destroyVectorEqualFunc ()
       val _ = Program.clearTop program
    in
       program




More information about the MLton-commit mailing list