205 lines
6.6 KiB
Plaintext
205 lines
6.6 KiB
Plaintext
Type Matrix
|
|
As Integer rows
|
|
As Integer cols
|
|
As Double dato(Any)
|
|
End Type
|
|
|
|
Function makeMatrix(rows As Integer, cols As Integer) As Matrix
|
|
Dim As Matrix m
|
|
m.rows = rows
|
|
m.cols = cols
|
|
Redim m.dato(rows * cols - 1)
|
|
Return m
|
|
End Function
|
|
|
|
Function matrixAdd(m1 As Matrix, m2 As Matrix) As Matrix
|
|
If m1.rows <> m2.rows Or m1.cols <> m2.cols Then
|
|
Print "Matrices must have the same dimensions."
|
|
End
|
|
End If
|
|
|
|
Dim As Matrix result = makeMatrix(m1.rows, m1.cols)
|
|
For i As Integer = 0 To m1.rows * m1.cols - 1
|
|
result.dato(i) = m1.dato(i) + m2.dato(i)
|
|
Next
|
|
Return result
|
|
End Function
|
|
|
|
Function matrixSub(m1 As Matrix, m2 As Matrix) As Matrix
|
|
If m1.rows <> m2.rows Or m1.cols <> m2.cols Then
|
|
Print "Matrices must have the same dimensions."
|
|
End
|
|
End If
|
|
|
|
Dim As Matrix result = makeMatrix(m1.rows, m1.cols)
|
|
For i As Integer = 0 To m1.rows * m1.cols - 1
|
|
result.dato(i) = m1.dato(i) - m2.dato(i)
|
|
Next
|
|
Return result
|
|
End Function
|
|
|
|
Function matrixMul(m1 As Matrix, m2 As Matrix) As Matrix
|
|
If m1.cols <> m2.rows Then
|
|
Print "Cannot multiply these matrices."
|
|
End
|
|
End If
|
|
|
|
Dim As Integer i, j, k
|
|
Dim As Matrix result = makeMatrix(m1.rows, m2.cols)
|
|
For i = 0 To m1.rows - 1
|
|
For j = 0 To m2.cols - 1
|
|
Dim As Double sum = 0
|
|
For k = 0 To m1.cols - 1
|
|
sum += m1.dato(i * m1.cols + k) * m2.dato(k * m2.cols + j)
|
|
Next
|
|
result.dato(i * result.cols + j) = sum
|
|
Next
|
|
Next
|
|
Return result
|
|
End Function
|
|
|
|
Sub printMatrix(m As Matrix, precision As Integer = 6)
|
|
Dim As Integer i, j
|
|
Print "[[";
|
|
For i = 0 To m.rows - 1
|
|
If i > 0 Then Print " [";
|
|
For j = 0 To m.cols - 1
|
|
Dim As Double valor = m.dato(i * m.cols + j)
|
|
valor = Int(valor * (10 ^ precision) + 0.5) / (10 ^ precision)
|
|
If Abs(valor) < 1e-10 Then valor = 0
|
|
Print Using "&"; valor;
|
|
If j < m.cols - 1 Then Print " ";
|
|
Next
|
|
If i < m.rows - 1 Then Print "]";
|
|
Next
|
|
Print "]]"
|
|
End Sub
|
|
|
|
Function getQuarter(m As Matrix, quarter As Integer) As Matrix
|
|
Dim As Integer i, j
|
|
Dim As Integer halfRows, halfCols
|
|
halfRows = m.rows \ 2
|
|
halfCols = m.cols \ 2
|
|
Dim As Matrix result = makeMatrix(halfRows, halfCols)
|
|
|
|
Dim As Integer rowOffset, colOffset
|
|
rowOffset = (quarter \ 2) * halfRows
|
|
colOffset = (quarter Mod 2) * halfCols
|
|
|
|
For i = 0 To halfRows - 1
|
|
For j = 0 To halfCols - 1
|
|
result.dato(i * halfCols + j) = m.dato((i + rowOffset) * m.cols + j + colOffset)
|
|
Next
|
|
Next
|
|
Return result
|
|
End Function
|
|
|
|
Function combineQuarters(q1 As Matrix, q2 As Matrix, q3 As Matrix, q4 As Matrix) As Matrix
|
|
Dim As Integer i, j, n
|
|
n = q1.rows
|
|
Dim As Matrix result = makeMatrix(n * 2, n * 2)
|
|
|
|
For i = 0 To n - 1
|
|
For j = 0 To n - 1
|
|
result.dato(i * result.cols + j) = q1.dato(i * n + j)
|
|
result.dato(i * result.cols + j + n) = q2.dato(i * n + j)
|
|
result.dato((i + n) * result.cols + j) = q3.dato(i * n + j)
|
|
result.dato((i + n) * result.cols + j + n) = q4.dato(i * n + j)
|
|
Next
|
|
Next
|
|
Return result
|
|
End Function
|
|
|
|
Function strassen(a As Matrix, b As Matrix) As Matrix
|
|
If a.rows <> a.cols Or b.rows <> b.cols Or a.rows <> b.rows Then
|
|
Print "Matrices must be square and of equal size."
|
|
End
|
|
End If
|
|
|
|
If a.rows = 1 Then Return matrixMul(a, b)
|
|
|
|
Dim As Matrix a11 = getQuarter(a, 0)
|
|
Dim As Matrix a12 = getQuarter(a, 1)
|
|
Dim As Matrix a21 = getQuarter(a, 2)
|
|
Dim As Matrix a22 = getQuarter(a, 3)
|
|
Dim As Matrix b11 = getQuarter(b, 0)
|
|
Dim As Matrix b12 = getQuarter(b, 1)
|
|
Dim As Matrix b21 = getQuarter(b, 2)
|
|
Dim As Matrix b22 = getQuarter(b, 3)
|
|
|
|
Dim As Matrix p1 = strassen(matrixSub(a12, a22), matrixAdd(b21, b22))
|
|
Dim As Matrix p2 = strassen(matrixAdd(a11, a22), matrixAdd(b11, b22))
|
|
Dim As Matrix p3 = strassen(matrixSub(a11, a21), matrixAdd(b11, b12))
|
|
Dim As Matrix p4 = strassen(matrixAdd(a11, a12), b22)
|
|
Dim As Matrix p5 = strassen(a11, matrixSub(b12, b22))
|
|
Dim As Matrix p6 = strassen(a22, matrixSub(b21, b11))
|
|
Dim As Matrix p7 = strassen(matrixAdd(a21, a22), b11)
|
|
|
|
Dim As Matrix c11 = matrixAdd(matrixSub(matrixAdd(p1, p2), p4), p6)
|
|
Dim As Matrix c12 = matrixAdd(p4, p5)
|
|
Dim As Matrix c21 = matrixAdd(p6, p7)
|
|
Dim As Matrix c22 = matrixSub(matrixAdd(matrixSub(p2, p3), p5), p7)
|
|
|
|
Return combineQuarters(c11, c12, c21, c22)
|
|
End Function
|
|
|
|
Sub main()
|
|
' Matrix A (2x2)
|
|
Dim As Matrix a = makeMatrix(2, 2)
|
|
a.dato(0) = 1: a.dato(1) = 2
|
|
a.dato(2) = 3: a.dato(3) = 4
|
|
|
|
' Matrix B (2x2)
|
|
Dim As Matrix b = makeMatrix(2, 2)
|
|
b.dato(0) = 5: b.dato(1) = 6
|
|
b.dato(2) = 7: b.dato(3) = 8
|
|
|
|
' Matrix C (4x4)
|
|
Dim As Matrix c = makeMatrix(4, 4)
|
|
c.dato(0) = 1: c.dato(1) = 1: c.dato(2) = 1: c.dato(3) = 1
|
|
c.dato(4) = 2: c.dato(5) = 4: c.dato(6) = 8: c.dato(7) = 16
|
|
c.dato(8) = 3: c.dato(9) = 9: c.dato(10) = 27: c.dato(11) = 81
|
|
c.dato(12) = 4: c.dato(13) = 16: c.dato(14) = 64: c.dato(15) = 256
|
|
|
|
' Matrix D (4x4)
|
|
Dim As Matrix d = makeMatrix(4, 4)
|
|
d.dato(0) = 4: d.dato(1) = -3: d.dato(2) = 4/3: d.dato(3) = -1/4
|
|
d.dato(4) = -13/3: d.dato(5) = 19/4: d.dato(6) = -7/3: d.dato(7) = 11/24
|
|
d.dato(8) = 3/2: d.dato(9) = -2: d.dato(10) = 7/6: d.dato(11) = -1/4
|
|
d.dato(12) = -1/6: d.dato(13) = 1/4: d.dato(14) = -1/6: d.dato(15) = 1/24
|
|
|
|
' Matrix E (4x4)
|
|
Dim As Matrix e = makeMatrix(4, 4)
|
|
e.dato(0) = 1: e.dato(1) = 2: e.dato(2) = 3: e.dato(3) = 4
|
|
e.dato(4) = 5: e.dato(5) = 6: e.dato(6) = 7: e.dato(7) = 8
|
|
e.dato(8) = 9: e.dato(9) = 10: e.dato(10) = 11: e.dato(11) = 12
|
|
e.dato(12) = 13: e.dato(13) = 14: e.dato(14) = 15: e.dato(15) = 16
|
|
|
|
' Matrix F (Identity 4x4)
|
|
Dim As Matrix f = makeMatrix(4, 4)
|
|
f.dato(0) = 1: f.dato(1) = 0: f.dato(2) = 0: f.dato(3) = 0
|
|
f.dato(4) = 0: f.dato(5) = 1: f.dato(6) = 0: f.dato(7) = 0
|
|
f.dato(8) = 0: f.dato(9) = 0: f.dato(10) = 1: f.dato(11) = 0
|
|
f.dato(12) = 0: f.dato(13) = 0: f.dato(14) = 0: f.dato(15) = 1
|
|
|
|
Print "Using 'normal' matrix multiplication:"
|
|
Print " a * b = ";
|
|
printMatrix(matrixMul(a, b))
|
|
Print " c * d = ";
|
|
printMatrix(matrixMul(c, d), 6)
|
|
Print " e * f = ";
|
|
printMatrix(matrixMul(e, f))
|
|
|
|
Print !"\nUsing 'Strassen' matrix multiplication:"
|
|
Print " a * b = ";
|
|
printMatrix(strassen(a, b))
|
|
Print " c * d = ";
|
|
printMatrix(strassen(c, d), 6)
|
|
Print " e * f = ";
|
|
printMatrix(strassen(e, f))
|
|
End Sub
|
|
|
|
main()
|
|
|
|
Sleep
|