slow matrix multiply
Stephen Weeks
MLton@sourcelight.com
Tue, 10 Jul 2001 18:02:49 -0700
> One thing I was surprised by was that not only loop nests was slow: so was
> matrix multiply. I didn't look at it, but it was scary that it was slow
> as well as the nested loop.
I looked into matrix.{gcc, mlton, ocaml}. Here are the running times I see.
gcc 1.29
ocaml 1.40
mlton 4.84
Here's the source and annotated assembly for the hot loop for each of the three
compilers. Ocaml and gcc do better for several reasons.
* they keep the loop index and sum in registers
* MLton completely recomputes the array offset for each subscript
* MLton does some extra stuff (cltd, shuffling, ...)
--------------------------------------------------------------------------------
MLton
--------------------------------------------------------------------------------
fun loop (k, sum) =
if k < 0
then sum
else loop (k - 1, sum + sub (m1, i, k) * sub (m2, k, j))
loop_51:
movl (200*1)(%edi),%esp # %esp = k
cmpl $0,%esp # if k < 0
jl L_229
movl %esp,%ebp # %ebp = k
decl %ebp # %ebp = k - 1
movl (188*1)(%edi),%edx # %edx = i
movl %edx,%ecx # %ecx = i
movl %ecx,%eax # %eax = i
movl $30,%ecx # %ecx = 30
cltd
imull %ecx # %eax = i * 30
addl %esp,%eax # %eax = i * 30 + k
movl %ebp,(200*1)(%edi) # store k - 1
xchgl %esp,%eax # %eax = k %esp = i * 30 + k
movl $30,%ebp # %ebp = 30
cltd
imull %ebp # %eax = k * 30
addl (192*1)(%edi),%eax # %eax = k * 30 + j
movl (144*1)(%edi),%ebp # %ebp = m1
movl %esp,%edx # %edx = i * 30 + k
movl (%ebp,%edx,4),%esp # %esp = sub (m1, i, k)
xchgl %esp,%eax # %eax = sub(m1, i, k) %esp = k * 30 + j
movl (160*1)(%edi),%ebp # %ebp = m2
movl %esp,%ecx # %ecx = k * 30 + j
cltd
imull (%ebp,%ecx,4) # %eax = sub (m1, i, k) * sub (m2, k, j)
addl %eax,(196*1)(%edi) # sum = sum + ...
jmp loop_51
--------------------------------------------------------------------------------
ocaml
--------------------------------------------------------------------------------
let rec inner_loop k v m1i m2 j =
if k < 0 then v
else inner_loop (k - 1) (v + m1i.(k) * m2.(k).(j)) m1i m2 j
.L107: # %eax = k %ebx = v %ecx = m1i %edx = m2
cmpl $1, %eax
jge .L106
movl %ebx, %eax
ret
.align 16
.L106:
movl -2(%edx, %eax, 2), %edi # %edi = m2.(k)
movl -2(%edi, %esi, 2), %ebp # %ebp = m2.(k).(j)
sarl $1, %ebp
movl -2(%ecx, %eax, 2), %edi # %eax = m1i.(k)
decl %edi
imull %ebp, %edi # m1i.(k) * m2.(k).(j)
addl %edi, %ebx # v + m1i.(k) * m2.(k).(j)
addl $-2, %eax # %eax = k - 1
jmp .L107
--------------------------------------------------------------------------------
gcc
--------------------------------------------------------------------------------
for (k=0; k<cols; k++)
val += m1[i][k] * m2[k][j];
.L125: # %ebx = i %ecx = sum %edx = k
movl 16(%esp), %edi # %edi = m1
movl (%edi,%edx,4), %eax # %eax = m1 [k]
movl (%eax,%ebx,4), %eax # %eax = m1[i][k]
movl 12(%esp), %edi # %edi = m2 [j]
imull (%edi,%edx,4), %eax # %eax = m1[i][k] * m2[k][j]
incl %edx # %edx = k + 1
addl %eax, %ecx # %ecx = sum + m1[i][k] * m2[k][j]
cmpl $30, %edx # if k < 30
jl .L125