166 lines
4.6 KiB
Go
166 lines
4.6 KiB
Go
package main
|
|
|
|
import (
|
|
"fmt"
|
|
"log"
|
|
"math"
|
|
)
|
|
|
|
type Matrix [][]float64
|
|
|
|
func (m Matrix) rows() int { return len(m) }
|
|
func (m Matrix) cols() int { return len(m[0]) }
|
|
|
|
func (m Matrix) add(m2 Matrix) Matrix {
|
|
if m.rows() != m2.rows() || m.cols() != m2.cols() {
|
|
log.Fatal("Matrices must have the same dimensions.")
|
|
}
|
|
c := make(Matrix, m.rows())
|
|
for i := 0; i < m.rows(); i++ {
|
|
c[i] = make([]float64, m.cols())
|
|
for j := 0; j < m.cols(); j++ {
|
|
c[i][j] = m[i][j] + m2[i][j]
|
|
}
|
|
}
|
|
return c
|
|
}
|
|
|
|
func (m Matrix) sub(m2 Matrix) Matrix {
|
|
if m.rows() != m2.rows() || m.cols() != m2.cols() {
|
|
log.Fatal("Matrices must have the same dimensions.")
|
|
}
|
|
c := make(Matrix, m.rows())
|
|
for i := 0; i < m.rows(); i++ {
|
|
c[i] = make([]float64, m.cols())
|
|
for j := 0; j < m.cols(); j++ {
|
|
c[i][j] = m[i][j] - m2[i][j]
|
|
}
|
|
}
|
|
return c
|
|
}
|
|
|
|
func (m Matrix) mul(m2 Matrix) Matrix {
|
|
if m.cols() != m2.rows() {
|
|
log.Fatal("Cannot multiply these matrices.")
|
|
}
|
|
c := make(Matrix, m.rows())
|
|
for i := 0; i < m.rows(); i++ {
|
|
c[i] = make([]float64, m2.cols())
|
|
for j := 0; j < m2.cols(); j++ {
|
|
for k := 0; k < m2.rows(); k++ {
|
|
c[i][j] += m[i][k] * m2[k][j]
|
|
}
|
|
}
|
|
}
|
|
return c
|
|
}
|
|
|
|
func (m Matrix) toString(p int) string {
|
|
s := make([]string, m.rows())
|
|
pow := math.Pow(10, float64(p))
|
|
for i := 0; i < m.rows(); i++ {
|
|
t := make([]string, m.cols())
|
|
for j := 0; j < m.cols(); j++ {
|
|
r := math.Round(m[i][j]*pow) / pow
|
|
t[j] = fmt.Sprintf("%g", r)
|
|
if t[j] == "-0" {
|
|
t[j] = "0"
|
|
}
|
|
}
|
|
s[i] = fmt.Sprintf("%v", t)
|
|
}
|
|
return fmt.Sprintf("%v", s)
|
|
}
|
|
|
|
func params(r, c int) [4][6]int {
|
|
return [4][6]int{
|
|
{0, r, 0, c, 0, 0},
|
|
{0, r, c, 2 * c, 0, c},
|
|
{r, 2 * r, 0, c, r, 0},
|
|
{r, 2 * r, c, 2 * c, r, c},
|
|
}
|
|
}
|
|
|
|
func toQuarters(m Matrix) [4]Matrix {
|
|
r := m.rows() / 2
|
|
c := m.cols() / 2
|
|
p := params(r, c)
|
|
var quarters [4]Matrix
|
|
for k := 0; k < 4; k++ {
|
|
q := make(Matrix, r)
|
|
for i := p[k][0]; i < p[k][1]; i++ {
|
|
q[i-p[k][4]] = make([]float64, c)
|
|
for j := p[k][2]; j < p[k][3]; j++ {
|
|
q[i-p[k][4]][j-p[k][5]] = m[i][j]
|
|
}
|
|
}
|
|
quarters[k] = q
|
|
}
|
|
return quarters
|
|
}
|
|
|
|
func fromQuarters(q [4]Matrix) Matrix {
|
|
r := q[0].rows()
|
|
c := q[0].cols()
|
|
p := params(r, c)
|
|
r *= 2
|
|
c *= 2
|
|
m := make(Matrix, r)
|
|
for i := 0; i < c; i++ {
|
|
m[i] = make([]float64, c)
|
|
}
|
|
for k := 0; k < 4; k++ {
|
|
for i := p[k][0]; i < p[k][1]; i++ {
|
|
for j := p[k][2]; j < p[k][3]; j++ {
|
|
m[i][j] = q[k][i-p[k][4]][j-p[k][5]]
|
|
}
|
|
}
|
|
}
|
|
return m
|
|
}
|
|
|
|
func strassen(a, b Matrix) Matrix {
|
|
if a.rows() != a.cols() || b.rows() != b.cols() || a.rows() != b.rows() {
|
|
log.Fatal("Matrices must be square and of equal size.")
|
|
}
|
|
if a.rows() == 0 || (a.rows()&(a.rows()-1)) != 0 {
|
|
log.Fatal("Size of matrices must be a power of two.")
|
|
}
|
|
if a.rows() == 1 {
|
|
return a.mul(b)
|
|
}
|
|
qa := toQuarters(a)
|
|
qb := toQuarters(b)
|
|
p1 := strassen(qa[1].sub(qa[3]), qb[2].add(qb[3]))
|
|
p2 := strassen(qa[0].add(qa[3]), qb[0].add(qb[3]))
|
|
p3 := strassen(qa[0].sub(qa[2]), qb[0].add(qb[1]))
|
|
p4 := strassen(qa[0].add(qa[1]), qb[3])
|
|
p5 := strassen(qa[0], qb[1].sub(qb[3]))
|
|
p6 := strassen(qa[3], qb[2].sub(qb[0]))
|
|
p7 := strassen(qa[2].add(qa[3]), qb[0])
|
|
var q [4]Matrix
|
|
q[0] = p1.add(p2).sub(p4).add(p6)
|
|
q[1] = p4.add(p5)
|
|
q[2] = p6.add(p7)
|
|
q[3] = p2.sub(p3).add(p5).sub(p7)
|
|
return fromQuarters(q)
|
|
}
|
|
|
|
func main() {
|
|
a := Matrix{{1, 2}, {3, 4}}
|
|
b := Matrix{{5, 6}, {7, 8}}
|
|
c := Matrix{{1, 1, 1, 1}, {2, 4, 8, 16}, {3, 9, 27, 81}, {4, 16, 64, 256}}
|
|
d := Matrix{{4, -3, 4.0 / 3, -1.0 / 4}, {-13.0 / 3, 19.0 / 4, -7.0 / 3, 11.0 / 24},
|
|
{3.0 / 2, -2, 7.0 / 6, -1.0 / 4}, {-1.0 / 6, 1.0 / 4, -1.0 / 6, 1.0 / 24}}
|
|
e := Matrix{{1, 2, 3, 4}, {5, 6, 7, 8}, {9, 10, 11, 12}, {13, 14, 15, 16}}
|
|
f := Matrix{{1, 0, 0, 0}, {0, 1, 0, 0}, {0, 0, 1, 0}, {0, 0, 0, 1}}
|
|
fmt.Println("Using 'normal' matrix multiplication:")
|
|
fmt.Printf(" a * b = %v\n", a.mul(b))
|
|
fmt.Printf(" c * d = %v\n", c.mul(d).toString(6))
|
|
fmt.Printf(" e * f = %v\n", e.mul(f))
|
|
fmt.Println("\nUsing 'Strassen' matrix multiplication:")
|
|
fmt.Printf(" a * b = %v\n", strassen(a, b))
|
|
fmt.Printf(" c * d = %v\n", strassen(c, d).toString(6))
|
|
fmt.Printf(" e * f = %v\n", strassen(e, f))
|
|
}
|