RosettaCodeData/Task/Strassens-algorithm/Java/strassens-algorithm.java

282 lines
9.0 KiB
Java

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
class Matrix {
public List<List<Double>> data;
public int rows;
public int cols;
public Matrix(List<List<Double>> data) {
this.data = data;
rows = data.size();
cols = (rows > 0) ? data.get(0).size() : 0;
}
public int getRows() {
return rows;
}
public int getCols() {
return cols;
}
public void validateDimensions(Matrix other) {
if (getRows() != other.getRows() || getCols() != other.getCols()) {
throw new RuntimeException("Matrices must have the same dimensions.");
}
}
public void validateMultiplication(Matrix other) {
if (getCols() != other.getRows()) {
throw new RuntimeException("Cannot multiply these matrices.");
}
}
public void validateSquarePowerOfTwo() {
if (getRows() != getCols()) {
throw new RuntimeException("Matrix must be square.");
}
if (getRows() == 0 || (getRows() & (getRows() - 1)) != 0) {
throw new RuntimeException("Size of matrix must be a power of two.");
}
}
public Matrix add(Matrix other) {
validateDimensions(other);
List<List<Double>> resultData = new ArrayList<>();
for (int i = 0; i < rows; ++i) {
List<Double> row = new ArrayList<>();
for (int j = 0; j < cols; ++j) {
row.add(data.get(i).get(j) + other.data.get(i).get(j));
}
resultData.add(row);
}
return new Matrix(resultData);
}
public Matrix subtract(Matrix other) {
validateDimensions(other);
List<List<Double>> resultData = new ArrayList<>();
for (int i = 0; i < rows; ++i) {
List<Double> row = new ArrayList<>();
for (int j = 0; j < cols; ++j) {
row.add(data.get(i).get(j) - other.data.get(i).get(j));
}
resultData.add(row);
}
return new Matrix(resultData);
}
public Matrix multiply(Matrix other) {
validateMultiplication(other);
List<List<Double>> resultData = new ArrayList<>();
for (int i = 0; i < rows; ++i) {
List<Double> row = new ArrayList<>();
for (int j = 0; j < other.cols; ++j) {
double sum = 0.0;
for (int k = 0; k < other.rows; ++k) {
sum += data.get(i).get(k) * other.data.get(k).get(j);
}
row.add(sum);
}
resultData.add(row);
}
return new Matrix(resultData);
}
@Override
public String toString() {
StringBuilder sb = new StringBuilder();
for (List<Double> row : data) {
sb.append("[");
for (int i = 0; i < row.size(); ++i) {
sb.append(row.get(i));
if (i < row.size() - 1) {
sb.append(", ");
}
}
sb.append("]\n");
}
return sb.toString();
}
public String toStringWithPrecision(int p) {
StringBuilder sb = new StringBuilder();
double pow = Math.pow(10.0, p);
for (List<Double> row : data) {
sb.append("[");
for (int i = 0; i < row.size(); ++i) {
double r = Math.round(row.get(i) * pow) / pow;
String formatted = String.format("%." + p + "f", r);
if (formatted.equals("-0" + (p > 0 ? "." + "0".repeat(p) : ""))) {
formatted = "0" + (p > 0 ? "." + "0".repeat(p) : "");
}
sb.append(formatted);
if (i < row.size() - 1) {
sb.append(", ");
}
}
sb.append("]\n");
}
return sb.toString();
}
private static int[][] getParams(int r, int c) {
return new int[][] {
{0, r, 0, c, 0, 0},
{0, r, c, 2 * c, 0, c},
{r, 2 * r, 0, c, r, 0},
{r, 2 * r, c, 2 * c, r, c}
};
}
public Matrix[] toQuarters() {
int r = getRows() / 2;
int c = getCols() / 2;
int[][] p = getParams(r, c);
Matrix[] quarters = new Matrix[4];
for (int k = 0; k < 4; ++k) {
List<List<Double>> qData = new ArrayList<>();
for (int i = 0; i < r; i++) {
List<Double> row = new ArrayList<>();
for (int j = 0; j < c; j++) {
row.add(0.0);
}
qData.add(row);
}
for (int i = p[k][0]; i < p[k][1]; ++i) {
for (int j = p[k][2]; j < p[k][3]; ++j) {
qData.get(i - p[k][4]).set(j - p[k][5], data.get(i).get(j));
}
}
quarters[k] = new Matrix(qData);
}
return quarters;
}
public static Matrix fromQuarters(Matrix[] q) {
int r = q[0].getRows();
int c = q[0].getCols();
int[][] p = getParams(r, c);
int rows = r * 2;
int cols = c * 2;
List<List<Double>> mData = new ArrayList<>();
for (int i = 0; i < rows; i++) {
List<Double> row = new ArrayList<>();
for (int j = 0; j < cols; j++) {
row.add(0.0);
}
mData.add(row);
}
for (int k = 0; k < 4; ++k) {
for (int i = p[k][0]; i < p[k][1]; ++i) {
for (int j = p[k][2]; j < p[k][3]; ++j) {
mData.get(i).set(j, q[k].data.get(i - p[k][4]).get(j - p[k][5]));
}
}
}
return new Matrix(mData);
}
public Matrix strassen(Matrix other) {
validateSquarePowerOfTwo();
other.validateSquarePowerOfTwo();
if (getRows() != other.getRows() || getCols() != other.getCols()) {
throw new RuntimeException("Matrices must be square and of equal size for Strassen multiplication.");
}
if (getRows() == 1) {
return this.multiply(other);
}
Matrix[] qa = toQuarters();
Matrix[] qb = other.toQuarters();
Matrix p1 = qa[1].subtract(qa[3]).strassen(qb[2].add(qb[3]));
Matrix p2 = qa[0].add(qa[3]).strassen(qb[0].add(qb[3]));
Matrix p3 = qa[0].subtract(qa[2]).strassen(qb[0].add(qb[1]));
Matrix p4 = qa[0].add(qa[1]).strassen(qb[3]);
Matrix p5 = qa[0].strassen(qb[1].subtract(qb[3]));
Matrix p6 = qa[3].strassen(qb[2].subtract(qb[0]));
Matrix p7 = qa[2].add(qa[3]).strassen(qb[0]);
Matrix[] q = new Matrix[4];
q[0] = p1.add(p2).subtract(p4).add(p6);
q[1] = p4.add(p5);
q[2] = p6.add(p7);
q[3] = p2.subtract(p3).add(p5).subtract(p7);
return fromQuarters(q);
}
}
public class Main{
public static void main(String[] args) {
List<List<Double>> aData = new ArrayList<>();
aData.add(Arrays.asList(1.0, 2.0));
aData.add(Arrays.asList(3.0, 4.0));
Matrix a = new Matrix(aData);
List<List<Double>> bData = new ArrayList<>();
bData.add(Arrays.asList(5.0, 6.0));
bData.add(Arrays.asList(7.0, 8.0));
Matrix b = new Matrix(bData);
List<List<Double>> cData = new ArrayList<>();
cData.add(Arrays.asList(1.0, 1.0, 1.0, 1.0));
cData.add(Arrays.asList(2.0, 4.0, 8.0, 16.0));
cData.add(Arrays.asList(3.0, 9.0, 27.0, 81.0));
cData.add(Arrays.asList(4.0, 16.0, 64.0, 256.0));
Matrix c = new Matrix(cData);
List<List<Double>> dData = new ArrayList<>();
dData.add(Arrays.asList(4.0, -3.0, 4.0 / 3.0, -1.0 / 4.0));
dData.add(Arrays.asList(-13.0 / 3.0, 19.0 / 4.0, -7.0 / 3.0, 11.0 / 24.0));
dData.add(Arrays.asList(3.0 / 2.0, -2.0, 7.0 / 6.0, -1.0 / 4.0));
dData.add(Arrays.asList(-1.0 / 6.0, 1.0 / 4.0, -1.0 / 6.0, 1.0 / 24.0));
Matrix d = new Matrix(dData);
List<List<Double>> eData = new ArrayList<>();
eData.add(Arrays.asList(1.0, 2.0, 3.0, 4.0));
eData.add(Arrays.asList(5.0, 6.0, 7.0, 8.0));
eData.add(Arrays.asList(9.0, 10.0, 11.0, 12.0));
eData.add(Arrays.asList(13.0, 14.0, 15.0, 16.0));
Matrix e = new Matrix(eData);
List<List<Double>> fData = new ArrayList<>();
fData.add(Arrays.asList(1.0, 0.0, 0.0, 0.0));
fData.add(Arrays.asList(0.0, 1.0, 0.0, 0.0));
fData.add(Arrays.asList(0.0, 0.0, 1.0, 0.0));
fData.add(Arrays.asList(0.0, 0.0, 0.0, 1.0));
Matrix f = new Matrix(fData);
System.out.println("Using 'normal' matrix multiplication:");
System.out.println(" a * b = " + a.multiply(b));
System.out.println(" c * d = " + c.multiply(d).toStringWithPrecision(6));
System.out.println(" e * f = " + e.multiply(f));
System.out.println("\nUsing 'Strassen' matrix multiplication:");
System.out.println(" a * b = " + a.strassen(b));
System.out.println(" c * d = " + c.strassen(d).toStringWithPrecision(6));
System.out.println(" e * f = " + e.strassen(f));
}
}