167 lines
4.4 KiB
Python
167 lines
4.4 KiB
Python
"""Matrix multiplication using Strassen's algorithm. Requires Python >= 3.7."""
|
|
|
|
from __future__ import annotations
|
|
from itertools import chain
|
|
from typing import List
|
|
from typing import NamedTuple
|
|
from typing import Optional
|
|
|
|
|
|
class Shape(NamedTuple):
|
|
rows: int
|
|
cols: int
|
|
|
|
|
|
class Matrix(List):
|
|
"""A matrix implemented as a two-dimensional list."""
|
|
|
|
@classmethod
|
|
def block(cls, blocks) -> Matrix:
|
|
"""Return a new Matrix assembled from nested blocks."""
|
|
m = Matrix()
|
|
for hblock in blocks:
|
|
for row in zip(*hblock):
|
|
m.append(list(chain.from_iterable(row)))
|
|
|
|
return m
|
|
|
|
def dot(self, b: Matrix) -> Matrix:
|
|
"""Return a new Matrix that is the product of this matrix and matrix `b`.
|
|
Uses 'simple' or 'naive' matrix multiplication."""
|
|
assert self.shape.cols == b.shape.rows
|
|
m = Matrix()
|
|
for row in self:
|
|
new_row = []
|
|
for c in range(len(b[0])):
|
|
col = [b[r][c] for r in range(len(b))]
|
|
new_row.append(sum(x * y for x, y in zip(row, col)))
|
|
m.append(new_row)
|
|
return m
|
|
|
|
def __matmul__(self, b: Matrix) -> Matrix:
|
|
return self.dot(b)
|
|
|
|
def __add__(self, b: Matrix) -> Matrix:
|
|
"""Matrix addition."""
|
|
assert self.shape == b.shape
|
|
rows, cols = self.shape
|
|
return Matrix(
|
|
[[self[i][j] + b[i][j] for j in range(cols)] for i in range(rows)]
|
|
)
|
|
|
|
def __sub__(self, b: Matrix) -> Matrix:
|
|
"""Matrix subtraction."""
|
|
assert self.shape == b.shape
|
|
rows, cols = self.shape
|
|
return Matrix(
|
|
[[self[i][j] - b[i][j] for j in range(cols)] for i in range(rows)]
|
|
)
|
|
|
|
def strassen(self, b: Matrix) -> Matrix:
|
|
"""Return a new Matrix that is the product of this matrix and matrix `b`.
|
|
Uses strassen algorithm."""
|
|
rows, cols = self.shape
|
|
|
|
assert rows == cols, "matrices must be square"
|
|
assert self.shape == b.shape, "matrices must be the same shape"
|
|
assert rows and (rows & rows - 1) == 0, "shape must be a power of 2"
|
|
|
|
if rows == 1:
|
|
return self.dot(b)
|
|
|
|
p = rows // 2 # partition
|
|
|
|
a11 = Matrix([n[:p] for n in self[:p]])
|
|
a12 = Matrix([n[p:] for n in self[:p]])
|
|
a21 = Matrix([n[:p] for n in self[p:]])
|
|
a22 = Matrix([n[p:] for n in self[p:]])
|
|
|
|
b11 = Matrix([n[:p] for n in b[:p]])
|
|
b12 = Matrix([n[p:] for n in b[:p]])
|
|
b21 = Matrix([n[:p] for n in b[p:]])
|
|
b22 = Matrix([n[p:] for n in b[p:]])
|
|
|
|
m1 = (a11 + a22).strassen(b11 + b22)
|
|
m2 = (a21 + a22).strassen(b11)
|
|
m3 = a11.strassen(b12 - b22)
|
|
m4 = a22.strassen(b21 - b11)
|
|
m5 = (a11 + a12).strassen(b22)
|
|
m6 = (a21 - a11).strassen(b11 + b12)
|
|
m7 = (a12 - a22).strassen(b21 + b22)
|
|
|
|
c11 = m1 + m4 - m5 + m7
|
|
c12 = m3 + m5
|
|
c21 = m2 + m4
|
|
c22 = m1 - m2 + m3 + m6
|
|
|
|
return Matrix.block([[c11, c12], [c21, c22]])
|
|
|
|
def round(self, ndigits: Optional[int] = None) -> Matrix:
|
|
return Matrix([[round(i, ndigits) for i in row] for row in self])
|
|
|
|
@property
|
|
def shape(self) -> Shape:
|
|
cols = len(self[0]) if self else 0
|
|
return Shape(len(self), cols)
|
|
|
|
|
|
def examples():
|
|
a = Matrix(
|
|
[
|
|
[1, 2],
|
|
[3, 4],
|
|
]
|
|
)
|
|
b = Matrix(
|
|
[
|
|
[5, 6],
|
|
[7, 8],
|
|
]
|
|
)
|
|
c = Matrix(
|
|
[
|
|
[1, 1, 1, 1],
|
|
[2, 4, 8, 16],
|
|
[3, 9, 27, 81],
|
|
[4, 16, 64, 256],
|
|
]
|
|
)
|
|
d = Matrix(
|
|
[
|
|
[4, -3, 4 / 3, -1 / 4],
|
|
[-13 / 3, 19 / 4, -7 / 3, 11 / 24],
|
|
[3 / 2, -2, 7 / 6, -1 / 4],
|
|
[-1 / 6, 1 / 4, -1 / 6, 1 / 24],
|
|
]
|
|
)
|
|
e = Matrix(
|
|
[
|
|
[1, 2, 3, 4],
|
|
[5, 6, 7, 8],
|
|
[9, 10, 11, 12],
|
|
[13, 14, 15, 16],
|
|
]
|
|
)
|
|
f = Matrix(
|
|
[
|
|
[1, 0, 0, 0],
|
|
[0, 1, 0, 0],
|
|
[0, 0, 1, 0],
|
|
[0, 0, 0, 1],
|
|
]
|
|
)
|
|
|
|
print("Naive matrix multiplication:")
|
|
print(f" a * b = {a @ b}")
|
|
print(f" c * d = {(c @ d).round()}")
|
|
print(f" e * f = {e @ f}")
|
|
|
|
print("Strassen's matrix multiplication:")
|
|
print(f" a * b = {a.strassen(b)}")
|
|
print(f" c * d = {c.strassen(d).round()}")
|
|
print(f" e * f = {e.strassen(f)}")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
examples()
|