[MLton] user-level elimination of array bounds checks

Stephen Weeks MLton@mlton.org
Thu, 22 Jan 2004 15:17:05 -0800


I was thinking about how to use the SML type system to express
elimination of array bounds checks and came up with the following.
The idea is to make the bounds check more explicit and separate it
from the operations of sub and update.  In short, suppose we had an
array signature like this.

------------------------------------------------------------
signature ARRAY =
   sig
      structure Elt:
	 sig
	    type 'a t
	 
	    val get: 'a t -> 'a
	    val set: 'a t * 'a -> unit
	 end

      type 'a t
	 
      val elt: 'a t * int -> 'a Elt.t option
   end
------------------------------------------------------------

The idea is that "elt (a, i)" does the bounds check "0 <= i < length
a" and if i is in range, then returns "SOME e".  Then, at a later
time, you can use "get e" instead of "sub (a, i)" and "set (e, x)"
instead of "update (a, i, x)".  Of course, the whole point is that
Elt.get and Elt.set do *not* perform bounds checks, since the type
system will guarantee that they are only passed elements where the
index is in range.

Now, we can do things with elements instead of array indices to avoid
bounds checks.  For example, a swap function that switches two
elements in an array would normally make four bounds checks (assuming
no optimization, but probably only two with MLton).  But we can swap
two elements without any bounds checks and can use this to eliminate
two of the bounds checks from the usual array swap.

------------------------------------------------------------
functor Swap (A: ARRAY) =
   struct
      open A

      structure Elt =
	 struct
	    open Elt
	       
	    fun swap (e, e') =
	       let
		  val x = get e
		  val _ = set (e, get e')
		  val _ = set (e', x)
	       in
		  ()
	       end

	    fun force eo =
	       case eo of
		  NONE => raise Subscript
		| SOME e => e
	 end

      val force = Elt.force
	 
      fun swap (a, i, i') = Elt.swap (force (elt (a, i)), force (elt (a, i')))
   end
------------------------------------------------------------

Let's flesh out the ARRAY signature a bit.

------------------------------------------------------------
signature ARRAY =
   sig
      type 'a t

      structure Elt:
	 sig
	    type 'a t

	    val equals: 'a t * 'a t -> bool
	    val force: 'a t option -> 'a t (* may raise Subscript *)
	    val get: 'a t -> 'a
	    val index: 'a t -> int
	    val next: 'a t -> 'a t option
	    val prev: 'a t -> 'a t option
	    val set: 'a t * 'a -> unit
	    val swap: 'a t * 'a t -> unit
	 end

      val elt: 'a t * int -> 'a Elt.t option
      val first: 'a t -> 'a Elt.t option
      val last: 'a t -> 'a Elt.t option
   end
------------------------------------------------------------

Now, we can define the usual Array.modify function, and implement it
without any bounds checks.

------------------------------------------------------------
functor Modify (A: ARRAY) =
   struct
      open A
	 
      fun modify (a: 'a t, f) =
	 let
	    fun loop e =
	       (Elt.set (e, f (Elt.get e))
		; (case Elt.next e of
		      NONE => ()
		    | SOME e => loop e))
	 in
	    case elt (a, 0) of
	       NONE => ()
	     | SOME e => loop e
	 end
   end      
------------------------------------------------------------

Compare this to the usual implementation, which takes two bounds
checks per loop iteration.

------------------------------------------------------------
fun modify (a, f) =
   let
      val n = Array.length a
      fun loop i =
	 if i = n
	    then ()
	 else Array.update (a, i, f (Array.sub (a, i)))
   in
      loop 0
   end
------------------------------------------------------------

There are two reasons we were able to eliminate all the bounds checks.
We can replace what would usually be the test on the loop variable (i
= n) with the call to Elt.next to both perform the test and to produce
an element.  Next, we can use that element to get the value.  Finally,
once we have our hands on an element to get (sub) the value, we can
also set (update) the same element without another bounds check.

Getting and then setting the same element is a very common idiom, and
although MLton's optimizer will certainly eliminate the second
(redundant) bounds check, it's nice to be able to do it in source
code.

This approach doesn't only eliminate redundant tests that MLton's
optimizer would already get.  It can be used to eliminate *all* the
bounds checks from insertion sort.  Here's the code.

------------------------------------------------------------
functor InsertionSort (A: ARRAY) =
   struct
      open A
	 
      fun insertionSort (a: 'a t, op <= : 'a * 'a -> bool): unit =
	 let
	    fun loop (i: 'a Elt.t): unit =
	       let
		  val t = Elt.get i
		  fun sift (j: 'a Elt.t): 'a Elt.t =
		     case Elt.prev j of
			NONE => j
		      | SOME j' =>
			   let
			      val z = Elt.get j'
			   in
			      if t <= z
				 then (Elt.set (j, z); sift j')
			      else j
			   end
		  val _ = Elt.set (sift i, t)
	       in
		  case Elt.next i of
		     NONE => ()
		   | SOME i => loop i
	       end
	 in
	    case elt (a, 1) of
	       NONE => ()
	     | SOME e => loop e
	 end
   end
------------------------------------------------------------

This is patterned off the code currently in the MLton library

	lib/mlton/basic/insertion-sort.sml

That code does two bounds checks per iteration of sift (corresponding
to the get and the set in sift), as well as two per iteration of the
outer loop.  Redundant test elimination would not get any of these,
since the redundancy is across calls and returns of functions.

Insertion sort is ideal for this approach.  How about quick sort?
Here's the code, again based on the MLton library.

------------------------------------------------------------
functor QuickSort (structure A: ARRAY
		   val randInt: int * int -> int) =
   struct
      structure A = InsertionSort (A)
      open A

      val force = Elt.force

      fun quickSort (a: 'a t, op <= : 'a * 'a -> bool): unit =
	 let
	    val cutoff = 20
	    fun qsort (l: 'a Elt.t, u: 'a Elt.t): unit =
	       if Elt.index u - Elt.index l > cutoff
		  then
		     let
			val _ =
			   Elt.swap (l, force (elt (a, randInt (Elt.index l,
								Elt.index u))))
			val t = Elt.get l
			fun loop (i: 'a Elt.t, m: 'a Elt.t): 'a Elt.t =
			   let
			      val m =
				 if Elt.get i <= t
				    then
				       let
					  val m = force (Elt.next m)
					  val _ = Elt.swap (m, i)
				       in
					  m
				       end
				 else m
			   in
			      if Elt.equals (i, u)
				 then m
			      else loop (force (Elt.next i), m)
			   end
			val m = loop (force (Elt.next l), l)
			val _ = Elt.swap (l, m)
			val _ = qsort (l, force (Elt.prev m))
			val _ = qsort (force (Elt.next m), u)
		     in ()
		     end
	       else ()
	    val _ = qsort (force (first a), force (last a))
	    val _ = insertionSort (a, op <=)
	 in
	    ()
	 end
   end
------------------------------------------------------------

Things aren't quite so nice as with insertion sort.  But we still have
eliminated some bounds checks.  For exampe, we only do one test for
the swap of the random partition element, instead of two (or four with
no optimization).  That test is hard to eliminate, since it depends on
randInt returning an element in range.

Within the main partitioning loop, we only do one test for each swap
instead of the usual two (or four :-).  The intuition for why we are
able to eliminate the test here is that the use of elements has
captured the invariant that the arguments to loop are always valid
array indices.  This lets us reuse the bounds check when we bump the
index (force (Elt.next m)) in later iterations of the loop.

That's all my examples for now.  As to practicality, I'm not sure.
Here's how I'd implement ARRAY.

------------------------------------------------------------
structure Array:> ARRAY =
   struct
      type 'a t = 'a Array.array
	 
      structure Elt =
	 struct
	    datatype 'a t = T of 'a Array.array * int
	 end
      
      fun elt (a, i) =
	 if 0 <= i andalso i < Array.length a
	    then SOME (Elt.T (a, i))
	 else NONE

      structure Elt =
	 struct
	    open Elt

	    fun equals (T (_, i), T (_, i')) = i = i'

	    fun get (T (a, i)) = Array.sub (a, i)

	    fun set (T (a, i), x) = Array.update (a, i, x)

	    fun index (T (_, i)) = i

	    fun next (T (a, i)) = elt (a, i + 1)

	    fun prev (T (a, i)) = elt (a, i - 1)

	    fun force (eo: 'a t option): 'a t =
	       case eo of
		  NONE => raise Subscript
		| SOME e => e

	    fun swap (e, e') =
	       let
		  val x = get e
		  val _ = set (e, get e')
		  val _ = set (e', x)
	       in
		  ()
	       end
	 end

      fun first a = elt (a, 0)

      fun last a = elt (a, Array.length a - 1)

      val force = Elt.force
	 
      fun swap (a, i, i') = Elt.swap (force (elt (a, i)), force (elt (a, i')))
   end
------------------------------------------------------------

One unanswered question is whether MLton's optimizer will be able to
eliminate all the allocation of Elt.t's.  If not, then this approach
probably isn't worth it, since allocation is more expensive than a
bounds check.  If so, then this might be worthwhile, especially if we
can figure out how to push the approach to eliminate even more checks.