118 lines
3.9 KiB
Julia
118 lines
3.9 KiB
Julia
"""
|
|
Strassen's matrix multiplication algorithm.
|
|
Use dynamic padding in order to reduce required auxiliary memory.
|
|
"""
|
|
function strassen(x::Matrix, y::Matrix)
|
|
# Check that the sizes of these matrices match.
|
|
(r1, c1) = size(x)
|
|
(r2, c2) = size(y)
|
|
if c1 != r2
|
|
error("Multiplying $r1 x $c1 and $r2 x $c2 matrix: dimensions do not match.")
|
|
end
|
|
|
|
# Put a matrix into the top left of a matrix of zeros.
|
|
# `rows` and `cols` are the dimensions of the output matrix.
|
|
function embed(mat, rows, cols)
|
|
# If the matrix is already of the right dimensions, don't allocate new memory.
|
|
(r, c) = size(mat)
|
|
if (r, c) == (rows, cols)
|
|
return mat
|
|
end
|
|
|
|
# Pad the matrix with zeros to be the right size.
|
|
out = zeros(Int, rows, cols)
|
|
out[1:r, 1:c] = mat
|
|
out
|
|
end
|
|
|
|
# Make sure both matrices are the same size.
|
|
# This is exclusively for simplicity:
|
|
# this algorithm can be implemented with matrices of different sizes.
|
|
r = max(r1, r2); c = max(c1, c2)
|
|
x = embed(x, r, c)
|
|
y = embed(y, r, c)
|
|
|
|
# Our recursive multiplication function.
|
|
function block_mult(a, b, rows, cols)
|
|
# For small matrices, resort to naive multiplication.
|
|
# if rows <= 128 || cols <= 128
|
|
if rows == 1 && cols == 1
|
|
# if rows == 2 && cols == 2
|
|
return a * b
|
|
end
|
|
|
|
# Apply dynamic padding.
|
|
if rows % 2 == 1 && cols % 2 == 1
|
|
a = embed(a, rows + 1, cols + 1)
|
|
b = embed(b, rows + 1, cols + 1)
|
|
elseif rows % 2 == 1
|
|
a = embed(a, rows + 1, cols)
|
|
b = embed(b, rows + 1, cols)
|
|
elseif cols % 2 == 1
|
|
a = embed(a, rows, cols + 1)
|
|
b = embed(b, rows, cols + 1)
|
|
end
|
|
|
|
half_rows = Int(size(a, 1) / 2)
|
|
half_cols = Int(size(a, 2) / 2)
|
|
|
|
# Subdivide input matrices.
|
|
a11 = a[1:half_rows, 1:half_cols]
|
|
b11 = b[1:half_rows, 1:half_cols]
|
|
|
|
a12 = a[1:half_rows, half_cols+1:size(a, 2)]
|
|
b12 = b[1:half_rows, half_cols+1:size(a, 2)]
|
|
|
|
a21 = a[half_rows+1:size(a, 1), 1:half_cols]
|
|
b21 = b[half_rows+1:size(a, 1), 1:half_cols]
|
|
|
|
a22 = a[half_rows+1:size(a, 1), half_cols+1:size(a, 2)]
|
|
b22 = b[half_rows+1:size(a, 1), half_cols+1:size(a, 2)]
|
|
|
|
# Compute intermediate values.
|
|
multip(x, y) = block_mult(x, y, half_rows, half_cols)
|
|
m1 = multip(a11 + a22, b11 + b22)
|
|
m2 = multip(a21 + a22, b11)
|
|
m3 = multip(a11, b12 - b22)
|
|
m4 = multip(a22, b21 - b11)
|
|
m5 = multip(a11 + a12, b22)
|
|
m6 = multip(a21 - a11, b11 + b12)
|
|
m7 = multip(a12 - a22, b21 + b22)
|
|
|
|
# Combine intermediate values into the output.
|
|
c11 = m1 + m4 - m5 + m7
|
|
c12 = m3 + m5
|
|
c21 = m2 + m4
|
|
c22 = m1 - m2 + m3 + m6
|
|
|
|
# Crop output to the desired size (undo dynamic padding).
|
|
out = [c11 c12; c21 c22]
|
|
out[1:rows, 1:cols]
|
|
end
|
|
|
|
block_mult(x, y, r, c)
|
|
end
|
|
|
|
const A = [[1, 2] [3, 4]]
|
|
const B = [[5, 6] [7, 8]]
|
|
const C = [[1, 1, 1, 1] [2, 4, 8, 16] [3, 9, 27, 81] [4, 16, 64, 256]]
|
|
const D = [[4, -3, 4/3, -1/4] [-13/3, 19/4, -7/3, 11/24] [3/2, -2, 7/6, -1/4] [-1/6, 1/4, -1/6, 1/24]]
|
|
const E = [[1, 2, 3, 4] [5, 6, 7, 8] [9, 10, 11, 12] [13, 14, 15, 16]]
|
|
const F = [[1, 0, 0, 0] [0, 1, 0, 0] [0, 0, 1, 0] [0, 0, 0, 1]]
|
|
|
|
""" For pretty printing: change matrix to integer if it is within 0.00000001 of an integer """
|
|
intprint(s, mat) = println(s, map(x -> Int(round(x, digits=8)), mat)')
|
|
|
|
intprint("Regular multiply: ", A' * B')
|
|
intprint("Strassen multiply: ", strassen(Matrix(A'), Matrix(B')))
|
|
intprint("Regular multiply: ", C * D)
|
|
intprint("Strassen multiply: ", strassen(C, D))
|
|
intprint("Regular multiply: ", E * F)
|
|
intprint("Strassen multiply: ", strassen(E, F))
|
|
|
|
const r = sqrt(2)/2
|
|
const R = [[r, r] [-r, r]]
|
|
|
|
intprint("Regular multiply: ", R * R)
|
|
intprint("Strassen multiply: ", strassen(R,R))
|