RosettaCodeData/Task/Strassens-algorithm/C-sharp/strassens-algorithm.cs

324 lines
9.1 KiB
C#

using System;
using System.Collections.Generic;
using System.Linq;
using System.Text;
class Matrix
{
public List<List<double>> data;
public int rows;
public int cols;
public Matrix(List<List<double>> data)
{
this.data = data;
rows = data.Count;
cols = (rows > 0) ? data[0].Count : 0;
}
public int GetRows()
{
return rows;
}
public int GetCols()
{
return cols;
}
public void ValidateDimensions(Matrix other)
{
if (GetRows() != other.GetRows() || GetCols() != other.GetCols())
{
throw new InvalidOperationException("Matrices must have the same dimensions.");
}
}
public void ValidateMultiplication(Matrix other)
{
if (GetCols() != other.GetRows())
{
throw new InvalidOperationException("Cannot multiply these matrices.");
}
}
public void ValidateSquarePowerOfTwo()
{
if (GetRows() != GetCols())
{
throw new InvalidOperationException("Matrix must be square.");
}
if (GetRows() == 0 || (GetRows() & (GetRows() - 1)) != 0)
{
throw new InvalidOperationException("Size of matrix must be a power of two.");
}
}
public static Matrix operator +(Matrix a, Matrix b)
{
a.ValidateDimensions(b);
List<List<double>> resultData = new List<List<double>>();
for (int i = 0; i < a.rows; ++i)
{
List<double> row = new List<double>();
for (int j = 0; j < a.cols; ++j)
{
row.Add(a.data[i][j] + b.data[i][j]);
}
resultData.Add(row);
}
return new Matrix(resultData);
}
public static Matrix operator -(Matrix a, Matrix b)
{
a.ValidateDimensions(b);
List<List<double>> resultData = new List<List<double>>();
for (int i = 0; i < a.rows; ++i)
{
List<double> row = new List<double>();
for (int j = 0; j < a.cols; ++j)
{
row.Add(a.data[i][j] - b.data[i][j]);
}
resultData.Add(row);
}
return new Matrix(resultData);
}
public static Matrix operator *(Matrix a, Matrix b)
{
a.ValidateMultiplication(b);
List<List<double>> resultData = new List<List<double>>();
for (int i = 0; i < a.rows; ++i)
{
List<double> row = new List<double>();
for (int j = 0; j < b.cols; ++j)
{
double sum = 0.0;
for (int k = 0; k < b.rows; ++k)
{
sum += a.data[i][k] * b.data[k][j];
}
row.Add(sum);
}
resultData.Add(row);
}
return new Matrix(resultData);
}
public override string ToString()
{
StringBuilder sb = new StringBuilder();
foreach (var row in data)
{
sb.Append("[");
for (int i = 0; i < row.Count; ++i)
{
sb.Append(row[i]);
if (i < row.Count - 1)
{
sb.Append(", ");
}
}
sb.AppendLine("]");
}
return sb.ToString();
}
public string ToStringWithPrecision(int p)
{
StringBuilder sb = new StringBuilder();
double pow = Math.Pow(10.0, p);
foreach (var row in data)
{
sb.Append("[");
for (int i = 0; i < row.Count; ++i)
{
double r = Math.Round(row[i] * pow) / pow;
string formatted = r.ToString($"F{p}");
if (formatted == "-0" + (p > 0 ? "." + new string('0', p) : ""))
{
formatted = "0" + (p > 0 ? "." + new string('0', p) : "");
}
sb.Append(formatted);
if (i < row.Count - 1)
{
sb.Append(", ");
}
}
sb.AppendLine("]");
}
return sb.ToString();
}
private static int[,] GetParams(int r, int c)
{
return new int[,]
{
{0, r, 0, c, 0, 0},
{0, r, c, 2 * c, 0, c},
{r, 2 * r, 0, c, r, 0},
{r, 2 * r, c, 2 * c, r, c}
};
}
public Matrix[] ToQuarters()
{
int r = GetRows() / 2;
int c = GetCols() / 2;
int[,] p = GetParams(r, c);
Matrix[] quarters = new Matrix[4];
for (int k = 0; k < 4; ++k)
{
List<List<double>> qData = new List<List<double>>();
for (int i = 0; i < r; i++)
{
List<double> row = new List<double>();
for (int j = 0; j < c; j++)
{
row.Add(0.0);
}
qData.Add(row);
}
for (int i = p[k, 0]; i < p[k, 1]; ++i)
{
for (int j = p[k, 2]; j < p[k, 3]; ++j)
{
qData[i - p[k, 4]][j - p[k, 5]] = data[i][j];
}
}
quarters[k] = new Matrix(qData);
}
return quarters;
}
public static Matrix FromQuarters(Matrix[] q)
{
int r = q[0].GetRows();
int c = q[0].GetCols();
int[,] p = GetParams(r, c);
int rows = r * 2;
int cols = c * 2;
List<List<double>> mData = new List<List<double>>();
for (int i = 0; i < rows; i++)
{
List<double> row = new List<double>();
for (int j = 0; j < cols; j++)
{
row.Add(0.0);
}
mData.Add(row);
}
for (int k = 0; k < 4; ++k)
{
for (int i = p[k, 0]; i < p[k, 1]; ++i)
{
for (int j = p[k, 2]; j < p[k, 3]; ++j)
{
mData[i][j] = q[k].data[i - p[k, 4]][j - p[k, 5]];
}
}
}
return new Matrix(mData);
}
public Matrix Strassen(Matrix other)
{
ValidateSquarePowerOfTwo();
other.ValidateSquarePowerOfTwo();
if (GetRows() != other.GetRows() || GetCols() != other.GetCols())
{
throw new InvalidOperationException("Matrices must be square and of equal size for Strassen multiplication.");
}
if (GetRows() == 1)
{
return this * other;
}
Matrix[] qa = ToQuarters();
Matrix[] qb = other.ToQuarters();
Matrix p1 = (qa[1] - qa[3]).Strassen(qb[2] + qb[3]);
Matrix p2 = (qa[0] + qa[3]).Strassen(qb[0] + qb[3]);
Matrix p3 = (qa[0] - qa[2]).Strassen(qb[0] + qb[1]);
Matrix p4 = (qa[0] + qa[1]).Strassen(qb[3]);
Matrix p5 = qa[0].Strassen(qb[1] - qb[3]);
Matrix p6 = qa[3].Strassen(qb[2] - qb[0]);
Matrix p7 = (qa[2] + qa[3]).Strassen(qb[0]);
Matrix[] q = new Matrix[4];
q[0] = p1 + p2 - p4 + p6;
q[1] = p4 + p5;
q[2] = p6 + p7;
q[3] = p2 - p3 + p5 - p7;
return FromQuarters(q);
}
}
class Program
{
static void Main(string[] args)
{
Matrix a = new Matrix(new List<List<double>> { new List<double> { 1.0, 2.0 }, new List<double> { 3.0, 4.0 } });
Matrix b = new Matrix(new List<List<double>> { new List<double> { 5.0, 6.0 }, new List<double> { 7.0, 8.0 } });
Matrix c = new Matrix(new List<List<double>>
{
new List<double> { 1.0, 1.0, 1.0, 1.0 },
new List<double> { 2.0, 4.0, 8.0, 16.0 },
new List<double> { 3.0, 9.0, 27.0, 81.0 },
new List<double> { 4.0, 16.0, 64.0, 256.0 }
});
Matrix d = new Matrix(new List<List<double>>
{
new List<double> { 4.0, -3.0, 4.0 / 3.0, -1.0 / 4.0 },
new List<double> { -13.0 / 3.0, 19.0 / 4.0, -7.0 / 3.0, 11.0 / 24.0 },
new List<double> { 3.0 / 2.0, -2.0, 7.0 / 6.0, -1.0 / 4.0 },
new List<double> { -1.0 / 6.0, 1.0 / 4.0, -1.0 / 6.0, 1.0 / 24.0 }
});
Matrix e = new Matrix(new List<List<double>>
{
new List<double> { 1.0, 2.0, 3.0, 4.0 },
new List<double> { 5.0, 6.0, 7.0, 8.0 },
new List<double> { 9.0, 10.0, 11.0, 12.0 },
new List<double> { 13.0, 14.0, 15.0, 16.0 }
});
Matrix f = new Matrix(new List<List<double>>
{
new List<double> { 1.0, 0.0, 0.0, 0.0 },
new List<double> { 0.0, 1.0, 0.0, 0.0 },
new List<double> { 0.0, 0.0, 1.0, 0.0 },
new List<double> { 0.0, 0.0, 0.0, 1.0 }
});
Console.WriteLine("Using 'normal' matrix multiplication:");
Console.WriteLine($" a * b = {a * b}");
Console.WriteLine($" c * d = {(c * d).ToStringWithPrecision(6)}");
Console.WriteLine($" e * f = {e * f}");
Console.WriteLine("\nUsing 'Strassen' matrix multiplication:");
Console.WriteLine($" a * b = {a.Strassen(b)}");
Console.WriteLine($" c * d = {c.Strassen(d).ToStringWithPrecision(6)}");
Console.WriteLine($" e * f = {e.Strassen(f)}");
}
}