317 lines
12 KiB
JavaScript
317 lines
12 KiB
JavaScript
class Matrix {
|
|
/** @type {number[][]} */
|
|
data;
|
|
/** @type {number} */
|
|
rows;
|
|
/** @type {number} */
|
|
cols;
|
|
|
|
/**
|
|
* @param {number[][]} data The matrix data as a 2D array.
|
|
*/
|
|
constructor(data) {
|
|
if (!Array.isArray(data) || (data.length > 0 && !Array.isArray(data[0]))) {
|
|
throw new Error("Input data must be a 2D array.");
|
|
}
|
|
// Optional: Deep copy to prevent external modifications
|
|
this.data = data.map(row => [...row]);
|
|
this.rows = data.length;
|
|
this.cols = (this.rows > 0) ? (data[0]?.length ?? 0) : 0; // Handle empty rows gracefully
|
|
|
|
// Optional: Validate that all rows have the same length
|
|
if (this.rows > 0) {
|
|
const firstRowLength = this.cols;
|
|
for (let i = 1; i < this.rows; i++) {
|
|
if (data[i].length !== firstRowLength) {
|
|
throw new Error("All rows in the matrix must have the same length.");
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/** @returns {number} */
|
|
getRows() {
|
|
return this.rows;
|
|
}
|
|
|
|
/** @returns {number} */
|
|
getCols() {
|
|
return this.cols;
|
|
}
|
|
|
|
/** @param {Matrix} other */
|
|
validateDimensions(other) {
|
|
if (this.getRows() !== other.getRows() || this.getCols() !== other.getCols()) {
|
|
throw new Error("Matrices must have the same dimensions.");
|
|
}
|
|
}
|
|
|
|
/** @param {Matrix} other */
|
|
validateMultiplication(other) {
|
|
if (this.getCols() !== other.getRows()) {
|
|
throw new Error(`Cannot multiply matrices: (${this.getRows()}x${this.getCols()}) * (${other.getRows()}x${other.getCols()})`);
|
|
}
|
|
}
|
|
|
|
validateSquarePowerOfTwo() {
|
|
if (this.getRows() !== this.getCols()) {
|
|
throw new Error("Matrix must be square for this operation.");
|
|
}
|
|
const n = this.getRows();
|
|
// Check if n is 0 or not a power of two
|
|
// (n & (n - 1)) === 0 checks if n is a power of two (or 0)
|
|
if (n === 0 || (n & (n - 1)) !== 0) {
|
|
throw new Error("Size of matrix must be a power of two for Strassen.");
|
|
}
|
|
}
|
|
|
|
/**
|
|
* Adds another matrix to this matrix.
|
|
* @param {Matrix} other The matrix to add.
|
|
* @returns {Matrix} A new matrix representing the sum.
|
|
*/
|
|
add(other) {
|
|
this.validateDimensions(other);
|
|
|
|
const result_data = Array.from({ length: this.rows }, () => Array(this.cols).fill(0.0));
|
|
for (let i = 0; i < this.rows; ++i) {
|
|
for (let j = 0; j < this.cols; ++j) {
|
|
result_data[i][j] = this.data[i][j] + other.data[i][j];
|
|
}
|
|
}
|
|
|
|
return new Matrix(result_data);
|
|
}
|
|
|
|
/**
|
|
* Subtracts another matrix from this matrix.
|
|
* @param {Matrix} other The matrix to subtract.
|
|
* @returns {Matrix} A new matrix representing the difference.
|
|
*/
|
|
subtract(other) {
|
|
this.validateDimensions(other);
|
|
|
|
const result_data = Array.from({ length: this.rows }, () => Array(this.cols).fill(0.0));
|
|
for (let i = 0; i < this.rows; ++i) {
|
|
for (let j = 0; j < this.cols; ++j) {
|
|
result_data[i][j] = this.data[i][j] - other.data[i][j];
|
|
}
|
|
}
|
|
|
|
return new Matrix(result_data);
|
|
}
|
|
|
|
/**
|
|
* Multiplies this matrix by another matrix (standard algorithm).
|
|
* @param {Matrix} other The matrix to multiply by.
|
|
* @returns {Matrix} A new matrix representing the product.
|
|
*/
|
|
multiply(other) {
|
|
this.validateMultiplication(other);
|
|
|
|
const result_data = Array.from({ length: this.rows }, () => Array(other.cols).fill(0.0));
|
|
for (let i = 0; i < this.rows; ++i) {
|
|
for (let j = 0; j < other.cols; ++j) {
|
|
let sum = 0.0;
|
|
// K loops through columns of 'this' and rows of 'other'
|
|
for (let k = 0; k < this.cols; ++k) {
|
|
sum += this.data[i][k] * other.data[k][j];
|
|
}
|
|
result_data[i][j] = sum;
|
|
}
|
|
}
|
|
|
|
return new Matrix(result_data);
|
|
}
|
|
|
|
/**
|
|
* Returns a string representation of the matrix.
|
|
* @returns {string}
|
|
*/
|
|
toString() {
|
|
return this.data.map(row => `[${row.join(', ')}]`).join('\n');
|
|
}
|
|
|
|
/**
|
|
* Returns a string representation with specified precision, handling rounding and "-0".
|
|
* @param {number} p Precision (number of decimal places).
|
|
* @returns {string}
|
|
*/
|
|
toStringWithPrecision(p) {
|
|
let resultString = "";
|
|
const pow = Math.pow(10, p);
|
|
const zeroString = (0).toFixed(p);
|
|
const negZeroString = `-${zeroString}`;
|
|
|
|
for (const row of this.data) {
|
|
resultString += "[";
|
|
for (let i = 0; i < row.length; ++i) {
|
|
let val = row[i];
|
|
// Round like C++: round(val * 10^p) / 10^p
|
|
let roundedVal = Math.round(val * pow) / pow;
|
|
|
|
// Format to fixed precision
|
|
let formattedVal = roundedVal.toFixed(p);
|
|
|
|
// Handle the "-0.00..." case that toFixed might produce after rounding
|
|
if (formattedVal === negZeroString) {
|
|
formattedVal = zeroString;
|
|
}
|
|
|
|
resultString += formattedVal;
|
|
if (i < row.length - 1) {
|
|
resultString += ", ";
|
|
}
|
|
}
|
|
resultString += "]\n"; // Add newline after each row like C++ example
|
|
}
|
|
return resultString.trimEnd(); // Remove trailing newline
|
|
}
|
|
|
|
/**
|
|
* Helper function to get quadrant slicing parameters.
|
|
* @param {number} r Half rows
|
|
* @param {number} c Half columns
|
|
* @returns {number[][]} Array of [startRow, endRow, startCol, endCol, offsetRow, offsetCol]
|
|
*/
|
|
static params(r, c) {
|
|
// [startRow, endRow, startCol, endCol, resultOffsetRow, resultOffsetCol]
|
|
return [
|
|
[0, r, 0, c, 0, 0], // Top-left quadrant (0)
|
|
[0, r, c, 2 * c, 0, c], // Top-right quadrant (1)
|
|
[r, 2 * r, 0, c, r, 0], // Bottom-left quadrant (2)
|
|
[r, 2 * r, c, 2 * c, r, c] // Bottom-right quadrant (3)
|
|
];
|
|
}
|
|
|
|
/**
|
|
* Splits the matrix into four equally sized quadrants.
|
|
* Assumes matrix dimensions are even.
|
|
* @returns {Matrix[]} An array of four matrices [TopLeft, TopRight, BottomLeft, BottomRight].
|
|
*/
|
|
toQuarters() {
|
|
const r = this.getRows() / 2;
|
|
const c = this.getCols() / 2;
|
|
if (!Number.isInteger(r) || !Number.isInteger(c)) {
|
|
throw new Error("Matrix dimensions must be even for splitting into quarters.");
|
|
}
|
|
const p = Matrix.params(r, c);
|
|
const quarters = Array(4); // Will hold 4 Matrix objects
|
|
|
|
for (let k = 0; k < 4; ++k) {
|
|
const q_data = Array.from({ length: r }, () => Array(c));
|
|
const [startRow, endRow, startCol, endCol, offsetRow, offsetCol] = p[k];
|
|
for (let i = startRow; i < endRow; ++i) {
|
|
for (let j = startCol; j < endCol; ++j) {
|
|
// Adjust indices for the smaller quarter matrix
|
|
q_data[i - offsetRow][j - offsetCol] = this.data[i][j];
|
|
}
|
|
}
|
|
quarters[k] = new Matrix(q_data);
|
|
}
|
|
return quarters; // [TopLeft, TopRight, BottomLeft, BottomRight]
|
|
}
|
|
|
|
/**
|
|
* Creates a new matrix by combining four quadrant matrices.
|
|
* @param {Matrix[]} q An array of four matrices [TopLeft, TopRight, BottomLeft, BottomRight].
|
|
* @returns {Matrix} The combined matrix.
|
|
*/
|
|
static fromQuarters(q) {
|
|
if (q.length !== 4) throw new Error("Requires exactly four quadrant matrices.");
|
|
// Basic validation: Ensure quadrants have compatible dimensions
|
|
const r = q[0].getRows();
|
|
const c = q[0].getCols();
|
|
if (q[1].getRows() !== r || q[1].getCols() !== c ||
|
|
q[2].getRows() !== r || q[2].getCols() !== c ||
|
|
q[3].getRows() !== r || q[3].getCols() !== c) {
|
|
throw new Error("Quadrant matrices must have the same dimensions.");
|
|
}
|
|
|
|
const p = Matrix.params(r, c);
|
|
const rows = r * 2;
|
|
const cols = c * 2;
|
|
const m_data = Array.from({ length: rows }, () => Array(cols));
|
|
|
|
for (let k = 0; k < 4; ++k) {
|
|
const [startRow, endRow, startCol, endCol, offsetRow, offsetCol] = p[k];
|
|
for (let i = startRow; i < endRow; ++i) {
|
|
for (let j = startCol; j < endCol; ++j) {
|
|
// Adjust indices to read from the correct quadrant
|
|
m_data[i][j] = q[k].data[i - offsetRow][j - offsetCol];
|
|
}
|
|
}
|
|
}
|
|
return new Matrix(m_data);
|
|
}
|
|
|
|
/**
|
|
* Multiplies this matrix by another using Strassen's algorithm.
|
|
* Assumes both matrices are square and their size is a power of two.
|
|
* @param {Matrix} other The matrix to multiply by.
|
|
* @returns {Matrix} The resulting matrix product.
|
|
*/
|
|
strassen(other) {
|
|
this.validateSquarePowerOfTwo();
|
|
other.validateSquarePowerOfTwo();
|
|
if (this.getRows() !== other.getRows()) { // Columns already checked by validateSquarePowerOfTwo
|
|
throw new Error("Matrices must be square and of equal size for Strassen multiplication.");
|
|
}
|
|
|
|
// Base case: If the matrix is 1x1
|
|
if (this.getRows() === 1) {
|
|
// Use standard multiplication for the 1x1 case
|
|
return this.multiply(other);
|
|
}
|
|
|
|
// Split matrices into quarters
|
|
const qa = this.toQuarters(); // [a11, a12, a21, a22]
|
|
const qb = other.toQuarters(); // [b11, b12, b21, b22]
|
|
|
|
// Calculate the 7 products recursively (P1 to P7)
|
|
const p1 = (qa[1].subtract(qa[3])).strassen(qb[2].add(qb[3])); // p1 = (a12 - a22) * (b21 + b22)
|
|
const p2 = (qa[0].add(qa[3])).strassen(qb[0].add(qb[3])); // p2 = (a11 + a22) * (b11 + b22)
|
|
const p3 = (qa[0].subtract(qa[2])).strassen(qb[0].add(qb[1])); // p3 = (a11 - a21) * (b11 + b12)
|
|
const p4 = (qa[0].add(qa[1])).strassen(qb[3]); // p4 = (a11 + a12) * b22
|
|
const p5 = qa[0].strassen(qb[1].subtract(qb[3])); // p5 = a11 * (b12 - b22)
|
|
const p6 = qa[3].strassen(qb[2].subtract(qb[0])); // p6 = a22 * (b21 - b11)
|
|
const p7 = (qa[2].add(qa[3])).strassen(qb[0]); // p7 = (a21 + a22) * b11
|
|
|
|
// Calculate the result quarters (C11, C12, C21, C22)
|
|
const c11 = p1.add(p2).subtract(p4).add(p6);
|
|
const c12 = p4.add(p5);
|
|
const c21 = p6.add(p7);
|
|
const c22 = p2.subtract(p3).add(p5).subtract(p7);
|
|
|
|
// Combine the quarters into the result matrix
|
|
return Matrix.fromQuarters([c11, c12, c21, c22]);
|
|
}
|
|
}
|
|
|
|
// --- Main execution (equivalent to C++ main) ---
|
|
function main() {
|
|
const a = new Matrix([[1.0, 2.0], [3.0, 4.0]]);
|
|
const b = new Matrix([[5.0, 6.0], [7.0, 8.0]]);
|
|
const c = new Matrix([[1.0, 1.0, 1.0, 1.0], [2.0, 4.0, 8.0, 16.0], [3.0, 9.0, 27.0, 81.0], [4.0, 16.0, 64.0, 256.0]]);
|
|
const d = new Matrix([[4.0, -3.0, 4.0 / 3.0, -1.0 / 4.0], [-13.0 / 3.0, 19.0 / 4.0, -7.0 / 3.0, 11.0 / 24.0], [3.0 / 2.0, -2.0, 7.0 / 6.0, -1.0 / 4.0], [-1.0 / 6.0, 1.0 / 4.0, -1.0 / 6.0, 1.0 / 24.0]]);
|
|
const e = new Matrix([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0], [9.0, 10.0, 11.0, 12.0], [13.0, 14.0, 15.0, 16.0]]);
|
|
const f = new Matrix([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0]]); // Identity Matrix
|
|
|
|
console.log("Using 'normal' matrix multiplication:");
|
|
console.log(` a * b = \n${a.multiply(b).toString()}`);
|
|
console.log(`\n c * d = \n${c.multiply(d).toStringWithPrecision(6)}`); // Should be close to identity
|
|
console.log(`\n e * f = \n${e.multiply(f).toString()}`); // Should be e
|
|
|
|
console.log("\nUsing 'Strassen' matrix multiplication:");
|
|
try {
|
|
console.log(` a * b = \n${a.strassen(b).toString()}`);
|
|
console.log(`\n c * d = \n${c.strassen(d).toStringWithPrecision(6)}`); // Should be close to identity
|
|
console.log(`\n e * f = \n${e.strassen(f).toString()}`); // Should be e
|
|
} catch (error) {
|
|
console.error("Strassen multiplication failed:", error.message);
|
|
}
|
|
}
|
|
|
|
// Run the main function
|
|
main();
|