RosettaCodeData/Task/Strassens-algorithm/JavaScript/strassens-algorithm-1.js

319 lines
10 KiB
JavaScript

/**
* Represents the dimensions of a matrix.
* @typedef {object} Shape
* @property {number} rows - Number of rows.
* @property {number} cols - Number of columns.
*/
/**
* A matrix implemented as a wrapper around a 2D array.
*/
class Matrix {
/**
* Creates a Matrix instance.
* @param {number[][]} data - A 2D array representing the matrix data.
*/
constructor(data = []) {
if (!Array.isArray(data) || (data.length > 0 && !Array.isArray(data[0]))) {
throw new Error("Matrix data must be a 2D array.");
}
// Basic check for consistent row lengths
if (data.length > 1) {
const firstLen = data[0].length;
if (!data.every(row => row.length === firstLen)) {
throw new Error("Matrix rows must have consistent lengths.");
}
}
this.data = data;
}
/**
* Gets the dimensions (shape) of the matrix.
* @returns {Shape} An object with rows and cols properties.
*/
get shape() {
const rows = this.data.length;
const cols = rows > 0 ? this.data[0].length : 0;
return { rows, cols };
}
/**
* Creates a new Matrix assembled from nested blocks of matrices.
* @param {Matrix[][]} blocks - A 2D array of Matrix objects.
* @returns {Matrix} A new Matrix assembled from the blocks.
* @static
*/
static block(blocks) {
const newMatrixData = [];
for (const hblock of blocks) {
if (!hblock || hblock.length === 0) continue;
const numRowsInBlock = hblock[0].shape.rows; // Assume consistent rows within a hblock
for (let i = 0; i < numRowsInBlock; i++) {
let newRow = [];
for (const matrix of hblock) {
if (matrix.data[i]) { // Check if row exists
newRow = newRow.concat(matrix.data[i]);
} else {
// Handle potential inconsistencies if needed, maybe throw error or fill?
console.warn("Inconsistent row count during block assembly");
}
}
newMatrixData.push(newRow);
}
}
return new Matrix(newMatrixData);
}
/**
* Performs naive matrix multiplication (dot product).
* @param {Matrix} b - The matrix to multiply with.
* @returns {Matrix} The resulting matrix product.
*/
dot(b) {
if (!(b instanceof Matrix)) {
throw new Error("Argument must be a Matrix instance.");
}
const aShape = this.shape;
const bShape = b.shape;
if (aShape.cols !== bShape.rows) {
throw new Error(`Matrices incompatible for multiplication: ${aShape.cols} cols != ${bShape.rows} rows`);
}
const resultData = [];
for (let i = 0; i < aShape.rows; i++) {
resultData[i] = [];
for (let j = 0; j < bShape.cols; j++) {
let sum = 0;
for (let k = 0; k < aShape.cols; k++) {
sum += this.data[i][k] * b.data[k][j];
}
resultData[i][j] = sum;
}
}
return new Matrix(resultData);
}
/**
* Multiplies this matrix by another matrix (using naive multiplication).
* Equivalent to Python's __matmul__.
* @param {Matrix} b - The matrix to multiply with.
* @returns {Matrix} The resulting matrix product.
*/
multiply(b) {
return this.dot(b);
}
/**
* Adds another matrix to this matrix.
* Equivalent to Python's __add__.
* @param {Matrix} b - The matrix to add.
* @returns {Matrix} The resulting matrix sum.
*/
add(b) {
if (!(b instanceof Matrix)) {
throw new Error("Argument must be a Matrix instance.");
}
const aShape = this.shape;
const bShape = b.shape;
if (aShape.rows !== bShape.rows || aShape.cols !== bShape.cols) {
throw new Error("Matrices must have the same shape for addition.");
}
const resultData = this.data.map((row, i) =>
row.map((val, j) => val + b.data[i][j])
);
return new Matrix(resultData);
}
/**
* Subtracts another matrix from this matrix.
* Equivalent to Python's __sub__.
* @param {Matrix} b - The matrix to subtract.
* @returns {Matrix} The resulting matrix difference.
*/
subtract(b) {
if (!(b instanceof Matrix)) {
throw new Error("Argument must be a Matrix instance.");
}
const aShape = this.shape;
const bShape = b.shape;
if (aShape.rows !== bShape.rows || aShape.cols !== bShape.cols) {
throw new Error("Matrices must have the same shape for subtraction.");
}
const resultData = this.data.map((row, i) =>
row.map((val, j) => val - b.data[i][j])
);
return new Matrix(resultData);
}
/**
* Helper function to slice the matrix data.
* @param {number} rowStart - Starting row index (inclusive).
* @param {number} rowEnd - Ending row index (exclusive).
* @param {number} colStart - Starting column index (inclusive).
* @param {number} colEnd - Ending column index (exclusive).
* @returns {Matrix} A new Matrix containing the sliced data.
* @private // Indicates intended internal use
*/
_slice(rowStart, rowEnd, colStart, colEnd) {
const slicedData = this.data.slice(rowStart, rowEnd)
.map(row => row.slice(colStart, colEnd));
return new Matrix(slicedData);
}
/**
* Performs matrix multiplication using Strassen's algorithm.
* Requires square matrices whose dimensions are powers of 2.
* @param {Matrix} b - The matrix to multiply with.
* @returns {Matrix} The resulting matrix product.
*/
strassen(b) {
if (!(b instanceof Matrix)) {
throw new Error("Argument must be a Matrix instance.");
}
const aShape = this.shape;
const bShape = b.shape;
if (aShape.rows !== aShape.cols) {
throw new Error("Matrix must be square for Strassen's algorithm.");
}
if (aShape.rows !== bShape.rows || aShape.cols !== bShape.cols) {
throw new Error("Matrices must have the same shape for Strassen's algorithm.");
}
// Check if dimension is a power of 2
if (aShape.rows === 0 || (aShape.rows & (aShape.rows - 1)) !== 0) {
throw new Error("Matrix dimension must be a power of 2 for Strassen's algorithm.");
}
if (aShape.rows === 1) {
return this.dot(b); // Base case
}
const n = aShape.rows;
const p = n / 2; // Partition size
// Partition matrices
const a11 = this._slice(0, p, 0, p);
const a12 = this._slice(0, p, p, n);
const a21 = this._slice(p, n, 0, p);
const a22 = this._slice(p, n, p, n);
const b11 = b._slice(0, p, 0, p);
const b12 = b._slice(0, p, p, n);
const b21 = b._slice(p, n, 0, p);
const b22 = b._slice(p, n, p, n);
// Recursive calls (Strassen's 7 multiplications)
const m1 = (a11.add(a22)).strassen(b11.add(b22));
const m2 = (a21.add(a22)).strassen(b11);
const m3 = a11.strassen(b12.subtract(b22));
const m4 = a22.strassen(b21.subtract(b11));
const m5 = (a11.add(a12)).strassen(b22);
const m6 = (a21.subtract(a11)).strassen(b11.add(b12));
const m7 = (a12.subtract(a22)).strassen(b21.add(b22));
// Combine results
const c11 = m1.add(m4).subtract(m5).add(m7);
const c12 = m3.add(m5);
const c21 = m2.add(m4);
const c22 = m1.subtract(m2).add(m3).add(m6);
// Assemble the final matrix from blocks
return Matrix.block([[c11, c12], [c21, c22]]);
}
/**
* Rounds the elements of the matrix to a specified number of decimal places.
* @param {number} [ndigits=0] - Number of decimal places to round to. If undefined or 0, rounds to the nearest integer.
* @returns {Matrix} A new Matrix with rounded elements.
*/
round(ndigits = 0) {
const factor = Math.pow(10, ndigits);
const roundFn = ndigits > 0
? (num) => Math.round((num + Number.EPSILON) * factor) / factor
: (num) => Math.round(num);
const roundedData = this.data.map(row =>
row.map(val => roundFn(val))
);
return new Matrix(roundedData);
}
/**
* Provides a string representation of the matrix.
* @returns {string} The string representation.
*/
toString() {
const rowsStr = this.data.map(row => ` [${row.join(', ')}]`);
return `Matrix([\n${rowsStr.join(',\n')}\n])`;
}
}
// --- Examples ---
function examples() {
const a = new Matrix([
[1, 2],
[3, 4],
]);
const b = new Matrix([
[5, 6],
[7, 8],
]);
const c = new Matrix([
[1, 1, 1, 1],
[2, 4, 8, 16],
[3, 9, 27, 81],
[4, 16, 64, 256],
]);
const d = new Matrix([
[4, -3, 4 / 3, -1 / 4],
[-13 / 3, 19 / 4, -7 / 3, 11 / 24],
[3 / 2, -2, 7 / 6, -1 / 4],
[-1 / 6, 1 / 4, -1 / 6, 1 / 24],
]);
const e = new Matrix([
[1, 2, 3, 4],
[5, 6, 7, 8],
[9, 10, 11, 12],
[13, 14, 15, 16],
]);
const f = new Matrix([ // Identity matrix
[1, 0, 0, 0],
[0, 1, 0, 0],
[0, 0, 1, 0],
[0, 0, 0, 1],
]);
console.log("Naive matrix multiplication:");
console.log(` a * b = ${a.multiply(b)}`); // Uses toString implicitly
console.log(` c * d = ${c.multiply(d).round(2)}`); // Round near-zero elements
console.log(` e * f = ${e.multiply(f)}`);
console.log("\nStrassen's matrix multiplication:");
console.log(` a * b = ${a.strassen(b)}`);
console.log(` c * d = ${c.strassen(d).round(2)}`); // Round near-zero elements
console.log(` e * f = ${e.strassen(f)}`);
// Example of addition/subtraction
console.log("\nAddition/Subtraction:");
const sum_ab = a.add(b);
console.log(` a + b = ${sum_ab}`);
const diff_ba = b.subtract(a);
console.log(` b - a = ${diff_ba}`);
// Example of block creation (creates a 4x4 matrix from four 2x2 matrices)
console.log("\nBlock Creation:");
const blocked = Matrix.block([[a, b], [b, a]]);
console.log(` Blocked [a,b],[b,a] = ${blocked}`);
}
// Run examples
examples();