RosettaCodeData/Task/Strassens-algorithm/Jq/strassens-algorithm.jq

114 lines
3.7 KiB
Plaintext

### Generic utilities
def sigma(stream): reduce stream as $x (0; . + $x);
# Hilbert-Schmidt norm:
def frobenius:
sigma( flatten[] | .*. ) | sqrt;
### Matrices
# Create an m x n matrix
def matrix(m; n; init):
if m == 0 then []
elif m == 1 then [[range(0;n) | init]]
elif m > 0 then
matrix(1;n;init) as $row
| [range(0;m) | $row ]
else error("matrix\(m);_;_) invalid")
end;
# A and B must be (multi-dimensional) vectors of the same shape
def vector_add($A;$B):
if ($A|type) == "array"
then reduce range(0; $A|length) as $i ([];
. + [vector_add($A[$i]; $B[$i])] )
else $A + $B
end;
def vector_negate:
if type == "array"
then map(vector_negate)
else - .
end;
def vector_subtract($A;$B):
vector_add($A; $B|vector_negate);
# A should be m by n; and B n by p
# Pre-allocating the resultant matrix results in a very small net speedup.
def multiply($A; $B):
($A|length) as $m
| ($A[0]|length) as $n
| ($B[0]|length) as $p
| reduce range(0; $m) as $i
(matrix($m; $p; 0); # initialize to avoid resizing
reduce range(0;$p) as $j (.;
.[$i][$j] = reduce range(0;$n) as $k
(0;
# Cij = innerproduct of row i, column j
. + $A[$i][$k] * $B[$k][$j] ))) ;
def submatrix($m1; $m2; $n1; $n2):
.[$m1:$m2] | map( .[$n1:$n2]);
def submatrix($A; $m1; $m2; $n1; $n2):
$A | submatrix($m1; $m2; $n1; $n2);
def rowwise_extend($A;$B):
reduce range(0; $A|length) as $i ([]; . + [$A[$i] + $B[$i]]);
### Strassen multiplication of n*n square matrices where n is a power of 2.
def Strassen($A; $B):
($A[0]|length) as $n
| if $n == 1 then multiply($A; $B)
else
submatrix($A; 0; $n/2; 0; $n/2) as $A11
| submatrix($A; 0; $n/2; $n/2; $n) as $A12
| submatrix($A; $n/2; $n; 0; $n/2) as $A21
| submatrix($A; $n/2; $n; $n/2; $n) as $A22
| submatrix($B; 0; $n/2; 0; $n/2) as $B11
| submatrix($B; 0; $n/2; $n/2; $n) as $B12
| submatrix($B; $n/2; $n; 0; $n/2) as $B21
| submatrix($B; $n/2; $n; $n/2+0; $n) as $B22
| Strassen( vector_subtract($A12; $A22); vector_add($B21; $B22)) as $P1
| Strassen( vector_add($A11; $A22); vector_add($B11; $B22)) as $P2
| Strassen( vector_subtract($A11; $A21); vector_add($B11; $B12)) as $P3
| Strassen( vector_add($A11; $A12); $B22) as $P4
| Strassen( $A11; vector_subtract($B12; $B22)) as $P5
| Strassen( $A22; vector_subtract($B21; $B11) ) as $P6
| Strassen( vector_add($A21; $A22); $B11) as $P7
| vector_add(vector_subtract(vector_add($P1; $P2); $P4); $P6) as $C11
| vector_add($P4; $P5) as $C12
| vector_add($P6; $P7) as $C21
| vector_add(vector_subtract($P2; $P3); vector_subtract($P5;$P7)) as $C22
# [C11 C12; C21 C22]
| rowwise_extend($C11; $C12) + rowwise_extend($C21; $C22)
end
;
# ## Examples
def A: [[1, 2], [3, 4]];
def B: [[5, 6], [7, 8]];
def C: [[1, 1, 1, 1],
[2, 4, 8, 16],
[3, 9, 27, 81],
[4, 16, 64, 256]];
def D: [[4, -3, 4/3, -1/4],
[-13/3, 19/4, -7/3, 11/24],
[3/2, -2, 7/6, -1/4],
[-1/6, 1/4, -1/6, 1/24]];
def E: [[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12], [13, 14, 15, 16]];
def F: [[1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [ 0, 0, 0, 1]];
def r: (2|sqrt)/2;
def R: [[r, r], [-r, r]];
"A*B == Strassen(A;B): \((multiply(A; B) == Strassen(A; B)))",
"Frobenius norm for C*D - Strassen(C;D): \(vector_subtract(multiply(C; D); Strassen(C; D)) | frobenius)",
"E*F == Strassen(E;F): \(multiply(E; F) == Strassen( E; F))",
"R*R == Strassen(R;R): \(multiply(R ; R) == Strassen(R; R))"