RosettaCodeData/Task/Strassens-algorithm/JavaScript/strassens-algorithm-2.js

317 lines
12 KiB
JavaScript

class Matrix {
/** @type {number[][]} */
data;
/** @type {number} */
rows;
/** @type {number} */
cols;
/**
* @param {number[][]} data The matrix data as a 2D array.
*/
constructor(data) {
if (!Array.isArray(data) || (data.length > 0 && !Array.isArray(data[0]))) {
throw new Error("Input data must be a 2D array.");
}
// Optional: Deep copy to prevent external modifications
this.data = data.map(row => [...row]);
this.rows = data.length;
this.cols = (this.rows > 0) ? (data[0]?.length ?? 0) : 0; // Handle empty rows gracefully
// Optional: Validate that all rows have the same length
if (this.rows > 0) {
const firstRowLength = this.cols;
for (let i = 1; i < this.rows; i++) {
if (data[i].length !== firstRowLength) {
throw new Error("All rows in the matrix must have the same length.");
}
}
}
}
/** @returns {number} */
getRows() {
return this.rows;
}
/** @returns {number} */
getCols() {
return this.cols;
}
/** @param {Matrix} other */
validateDimensions(other) {
if (this.getRows() !== other.getRows() || this.getCols() !== other.getCols()) {
throw new Error("Matrices must have the same dimensions.");
}
}
/** @param {Matrix} other */
validateMultiplication(other) {
if (this.getCols() !== other.getRows()) {
throw new Error(`Cannot multiply matrices: (${this.getRows()}x${this.getCols()}) * (${other.getRows()}x${other.getCols()})`);
}
}
validateSquarePowerOfTwo() {
if (this.getRows() !== this.getCols()) {
throw new Error("Matrix must be square for this operation.");
}
const n = this.getRows();
// Check if n is 0 or not a power of two
// (n & (n - 1)) === 0 checks if n is a power of two (or 0)
if (n === 0 || (n & (n - 1)) !== 0) {
throw new Error("Size of matrix must be a power of two for Strassen.");
}
}
/**
* Adds another matrix to this matrix.
* @param {Matrix} other The matrix to add.
* @returns {Matrix} A new matrix representing the sum.
*/
add(other) {
this.validateDimensions(other);
const result_data = Array.from({ length: this.rows }, () => Array(this.cols).fill(0.0));
for (let i = 0; i < this.rows; ++i) {
for (let j = 0; j < this.cols; ++j) {
result_data[i][j] = this.data[i][j] + other.data[i][j];
}
}
return new Matrix(result_data);
}
/**
* Subtracts another matrix from this matrix.
* @param {Matrix} other The matrix to subtract.
* @returns {Matrix} A new matrix representing the difference.
*/
subtract(other) {
this.validateDimensions(other);
const result_data = Array.from({ length: this.rows }, () => Array(this.cols).fill(0.0));
for (let i = 0; i < this.rows; ++i) {
for (let j = 0; j < this.cols; ++j) {
result_data[i][j] = this.data[i][j] - other.data[i][j];
}
}
return new Matrix(result_data);
}
/**
* Multiplies this matrix by another matrix (standard algorithm).
* @param {Matrix} other The matrix to multiply by.
* @returns {Matrix} A new matrix representing the product.
*/
multiply(other) {
this.validateMultiplication(other);
const result_data = Array.from({ length: this.rows }, () => Array(other.cols).fill(0.0));
for (let i = 0; i < this.rows; ++i) {
for (let j = 0; j < other.cols; ++j) {
let sum = 0.0;
// K loops through columns of 'this' and rows of 'other'
for (let k = 0; k < this.cols; ++k) {
sum += this.data[i][k] * other.data[k][j];
}
result_data[i][j] = sum;
}
}
return new Matrix(result_data);
}
/**
* Returns a string representation of the matrix.
* @returns {string}
*/
toString() {
return this.data.map(row => `[${row.join(', ')}]`).join('\n');
}
/**
* Returns a string representation with specified precision, handling rounding and "-0".
* @param {number} p Precision (number of decimal places).
* @returns {string}
*/
toStringWithPrecision(p) {
let resultString = "";
const pow = Math.pow(10, p);
const zeroString = (0).toFixed(p);
const negZeroString = `-${zeroString}`;
for (const row of this.data) {
resultString += "[";
for (let i = 0; i < row.length; ++i) {
let val = row[i];
// Round like C++: round(val * 10^p) / 10^p
let roundedVal = Math.round(val * pow) / pow;
// Format to fixed precision
let formattedVal = roundedVal.toFixed(p);
// Handle the "-0.00..." case that toFixed might produce after rounding
if (formattedVal === negZeroString) {
formattedVal = zeroString;
}
resultString += formattedVal;
if (i < row.length - 1) {
resultString += ", ";
}
}
resultString += "]\n"; // Add newline after each row like C++ example
}
return resultString.trimEnd(); // Remove trailing newline
}
/**
* Helper function to get quadrant slicing parameters.
* @param {number} r Half rows
* @param {number} c Half columns
* @returns {number[][]} Array of [startRow, endRow, startCol, endCol, offsetRow, offsetCol]
*/
static params(r, c) {
// [startRow, endRow, startCol, endCol, resultOffsetRow, resultOffsetCol]
return [
[0, r, 0, c, 0, 0], // Top-left quadrant (0)
[0, r, c, 2 * c, 0, c], // Top-right quadrant (1)
[r, 2 * r, 0, c, r, 0], // Bottom-left quadrant (2)
[r, 2 * r, c, 2 * c, r, c] // Bottom-right quadrant (3)
];
}
/**
* Splits the matrix into four equally sized quadrants.
* Assumes matrix dimensions are even.
* @returns {Matrix[]} An array of four matrices [TopLeft, TopRight, BottomLeft, BottomRight].
*/
toQuarters() {
const r = this.getRows() / 2;
const c = this.getCols() / 2;
if (!Number.isInteger(r) || !Number.isInteger(c)) {
throw new Error("Matrix dimensions must be even for splitting into quarters.");
}
const p = Matrix.params(r, c);
const quarters = Array(4); // Will hold 4 Matrix objects
for (let k = 0; k < 4; ++k) {
const q_data = Array.from({ length: r }, () => Array(c));
const [startRow, endRow, startCol, endCol, offsetRow, offsetCol] = p[k];
for (let i = startRow; i < endRow; ++i) {
for (let j = startCol; j < endCol; ++j) {
// Adjust indices for the smaller quarter matrix
q_data[i - offsetRow][j - offsetCol] = this.data[i][j];
}
}
quarters[k] = new Matrix(q_data);
}
return quarters; // [TopLeft, TopRight, BottomLeft, BottomRight]
}
/**
* Creates a new matrix by combining four quadrant matrices.
* @param {Matrix[]} q An array of four matrices [TopLeft, TopRight, BottomLeft, BottomRight].
* @returns {Matrix} The combined matrix.
*/
static fromQuarters(q) {
if (q.length !== 4) throw new Error("Requires exactly four quadrant matrices.");
// Basic validation: Ensure quadrants have compatible dimensions
const r = q[0].getRows();
const c = q[0].getCols();
if (q[1].getRows() !== r || q[1].getCols() !== c ||
q[2].getRows() !== r || q[2].getCols() !== c ||
q[3].getRows() !== r || q[3].getCols() !== c) {
throw new Error("Quadrant matrices must have the same dimensions.");
}
const p = Matrix.params(r, c);
const rows = r * 2;
const cols = c * 2;
const m_data = Array.from({ length: rows }, () => Array(cols));
for (let k = 0; k < 4; ++k) {
const [startRow, endRow, startCol, endCol, offsetRow, offsetCol] = p[k];
for (let i = startRow; i < endRow; ++i) {
for (let j = startCol; j < endCol; ++j) {
// Adjust indices to read from the correct quadrant
m_data[i][j] = q[k].data[i - offsetRow][j - offsetCol];
}
}
}
return new Matrix(m_data);
}
/**
* Multiplies this matrix by another using Strassen's algorithm.
* Assumes both matrices are square and their size is a power of two.
* @param {Matrix} other The matrix to multiply by.
* @returns {Matrix} The resulting matrix product.
*/
strassen(other) {
this.validateSquarePowerOfTwo();
other.validateSquarePowerOfTwo();
if (this.getRows() !== other.getRows()) { // Columns already checked by validateSquarePowerOfTwo
throw new Error("Matrices must be square and of equal size for Strassen multiplication.");
}
// Base case: If the matrix is 1x1
if (this.getRows() === 1) {
// Use standard multiplication for the 1x1 case
return this.multiply(other);
}
// Split matrices into quarters
const qa = this.toQuarters(); // [a11, a12, a21, a22]
const qb = other.toQuarters(); // [b11, b12, b21, b22]
// Calculate the 7 products recursively (P1 to P7)
const p1 = (qa[1].subtract(qa[3])).strassen(qb[2].add(qb[3])); // p1 = (a12 - a22) * (b21 + b22)
const p2 = (qa[0].add(qa[3])).strassen(qb[0].add(qb[3])); // p2 = (a11 + a22) * (b11 + b22)
const p3 = (qa[0].subtract(qa[2])).strassen(qb[0].add(qb[1])); // p3 = (a11 - a21) * (b11 + b12)
const p4 = (qa[0].add(qa[1])).strassen(qb[3]); // p4 = (a11 + a12) * b22
const p5 = qa[0].strassen(qb[1].subtract(qb[3])); // p5 = a11 * (b12 - b22)
const p6 = qa[3].strassen(qb[2].subtract(qb[0])); // p6 = a22 * (b21 - b11)
const p7 = (qa[2].add(qa[3])).strassen(qb[0]); // p7 = (a21 + a22) * b11
// Calculate the result quarters (C11, C12, C21, C22)
const c11 = p1.add(p2).subtract(p4).add(p6);
const c12 = p4.add(p5);
const c21 = p6.add(p7);
const c22 = p2.subtract(p3).add(p5).subtract(p7);
// Combine the quarters into the result matrix
return Matrix.fromQuarters([c11, c12, c21, c22]);
}
}
// --- Main execution (equivalent to C++ main) ---
function main() {
const a = new Matrix([[1.0, 2.0], [3.0, 4.0]]);
const b = new Matrix([[5.0, 6.0], [7.0, 8.0]]);
const c = new Matrix([[1.0, 1.0, 1.0, 1.0], [2.0, 4.0, 8.0, 16.0], [3.0, 9.0, 27.0, 81.0], [4.0, 16.0, 64.0, 256.0]]);
const d = new Matrix([[4.0, -3.0, 4.0 / 3.0, -1.0 / 4.0], [-13.0 / 3.0, 19.0 / 4.0, -7.0 / 3.0, 11.0 / 24.0], [3.0 / 2.0, -2.0, 7.0 / 6.0, -1.0 / 4.0], [-1.0 / 6.0, 1.0 / 4.0, -1.0 / 6.0, 1.0 / 24.0]]);
const e = new Matrix([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]);
const f = new Matrix([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]); // Identity Matrix
console.log("Using 'normal' matrix multiplication:");
console.log(` a * b = \n${a.multiply(b).toString()}`);
console.log(`\n c * d = \n${c.multiply(d).toStringWithPrecision(6)}`); // Should be close to identity
console.log(`\n e * f = \n${e.multiply(f).toString()}`); // Should be e
console.log("\nUsing 'Strassen' matrix multiplication:");
try {
console.log(` a * b = \n${a.strassen(b).toString()}`);
console.log(`\n c * d = \n${c.strassen(d).toStringWithPrecision(6)}`); // Should be close to identity
console.log(`\n e * f = \n${e.strassen(f).toString()}`); // Should be e
} catch (error) {
console.error("Strassen multiplication failed:", error.message);
}
}
// Run the main function
main();