matrix.mlton

Stephen Weeks MLton@sourcelight.com
Tue, 10 Jul 2001 18:27:47 -0700


Here is a new, faster version of matrix.mlton.


(* Translated from matrix.ocaml. *)

fun incr r = r := !r + 1
fun for (start, stop, f) =
   let
      fun loop i =
	 if i > stop
	    then ()
	 else (f i; loop (i + 1))
   in
      loop start
   end

structure Array2 =
   struct
      datatype 'a t = T of 'a array array

      fun sub (T a, r, c) = Array.sub (Array.sub (a, r), c)
      fun subr (T a, r) =
	 let val a = Array.sub (a, r)
	 in fn c => Array.sub (a, c)
	 end
      fun update (T a, r, c, x) = Array.update (Array.sub (a, r), c, x)
      fun array (r, c, x) =
	 T (Array.tabulate (r, fn _ => Array.array (c, x)))
   end
val sub = Array2.sub
val update = Array2.update
   
val size = 30

fun mkmatrix (rows, cols) =
   let
      val count = ref 1
      val last_col = cols - 1
      val m = Array2.array (rows, cols, 0)
   in
      for (0, rows - 1, fn i =>
	   for (0, last_col, fn j =>
		(update (m, i, j, !count)
		 ; incr count)));
      m
   end

fun mmult (rows, cols, m1, m2, m3) =
   let
      val last_col = cols - 1
      val last_row = rows - 1
   in
      for (0, last_row, fn i =>
	   for (0, last_col, fn j =>
		update (m3, i, j,
			let
			   val m1i = Array2.subr (m1, i)
			   fun loop (k, sum) =
			      if k < 0
				 then sum
			      else loop (k - 1,
					 sum + m1i k * sub (m2, k, j))
			in loop (last_row, 0)
			end)))
   end

fun atoi s = case Int.fromString s of SOME num => num | NONE => 0;
fun printl [] = print "\n" | printl(h::t) = ( print h ; printl t );

fun main (name, args) =
  let
     val n = atoi (hd (args @ ["1"]))
     val m1 = mkmatrix (size, size)
     val m2 = mkmatrix (size, size)
     val m3 = Array2.array (size, size, 0)
     val _ = for (1, n - 1, fn _ => mmult (size, size, m1, m2, m3))
     val _ = mmult (size, size, m1, m2, m3)
  in
     printl [Int.toString (sub (m3, 0, 0)),
	     " ",
	     Int.toString (sub (m3, 2, 3)),
	     " ",
	     Int.toString (sub (m3, 3, 2)),
	     " ",
	     Int.toString (sub (m3, 4, 4))];
     OS.Process.success
  end

val _ = main( CommandLine.name(), CommandLine.arguments() )