const std = @import("std"); const fmt = std.fmt; const ArrayList = std.ArrayList; const Allocator = std.mem.Allocator; const Matrix = struct { data: ArrayList(ArrayList(f64)), rows: usize, cols: usize, allocator: Allocator, pub fn init(allocator: Allocator, data: ArrayList(ArrayList(f64))) !Matrix { const rows = data.items.len; const cols = if (rows > 0) data.items[0].items.len else 0; return Matrix{ .data = data, .rows = rows, .cols = cols, .allocator = allocator, }; } pub fn deinit(self: *Matrix) void { for (self.data.items) |*row| { row.deinit(); } self.data.deinit(); } pub fn clone(self: Matrix) !Matrix { var new_data = ArrayList(ArrayList(f64)).init(self.allocator); try new_data.ensureTotalCapacity(self.rows); for (self.data.items) |row| { var new_row = ArrayList(f64).init(self.allocator); try new_row.ensureTotalCapacity(self.cols); try new_row.appendSlice(row.items); try new_data.append(new_row); } return Matrix{ .data = new_data, .rows = self.rows, .cols = self.cols, .allocator = self.allocator, }; } pub fn validateDimensions(self: Matrix, other: Matrix) !void { if (self.rows != other.rows or self.cols != other.cols) { return error.DimensionMismatch; } } pub fn validateMultiplication(self: Matrix, other: Matrix) !void { if (self.cols != other.rows) { return error.CannotMultiply; } } pub fn validateSquarePowerOfTwo(self: Matrix) !void { if (self.rows != self.cols) { return error.NotSquare; } if (self.rows == 0 or (self.rows & (self.rows - 1)) != 0) { return error.NotPowerOfTwo; } } pub fn add(self: Matrix, other: Matrix) !Matrix { try self.validateDimensions(other); var result_data = ArrayList(ArrayList(f64)).init(self.allocator); try result_data.ensureTotalCapacity(self.rows); for (0..self.rows) |i| { var row = ArrayList(f64).init(self.allocator); try row.ensureTotalCapacity(self.cols); for (0..self.cols) |j| { try row.append(self.data.items[i].items[j] + other.data.items[i].items[j]); } try result_data.append(row); } return try Matrix.init(self.allocator, result_data); } pub fn sub(self: Matrix, other: Matrix) !Matrix { try self.validateDimensions(other); var result_data = ArrayList(ArrayList(f64)).init(self.allocator); try result_data.ensureTotalCapacity(self.rows); for (0..self.rows) |i| { var row = ArrayList(f64).init(self.allocator); try row.ensureTotalCapacity(self.cols); for (0..self.cols) |j| { try row.append(self.data.items[i].items[j] - other.data.items[i].items[j]); } try result_data.append(row); } return try Matrix.init(self.allocator, result_data); } pub fn mul(self: Matrix, other: Matrix) !Matrix { try self.validateMultiplication(other); var result_data = ArrayList(ArrayList(f64)).init(self.allocator); try result_data.ensureTotalCapacity(self.rows); for (0..self.rows) |i| { var row = ArrayList(f64).init(self.allocator); try row.ensureTotalCapacity(other.cols); for (0..other.cols) |j| { var sum: f64 = 0.0; for (0..self.cols) |k| { sum += self.data.items[i].items[k] * other.data.items[k].items[j]; } try row.append(sum); } try result_data.append(row); } return try Matrix.init(self.allocator, result_data); } pub fn format(self: Matrix, comptime _: []const u8, _: fmt.FormatOptions, writer: anytype) !void { for (self.data.items) |row| { try writer.print("{any}\n", .{row.items}); } } pub fn toStringWithPrecision(self: Matrix, p: usize, allocator: Allocator) ![]u8 { var output = ArrayList(u8).init(allocator); defer output.deinit(); const pow = std.math.pow(f64, 10.0, @as(f64, @floatFromInt(p))); for (self.data.items) |row| { var formatted_row = ArrayList([]const u8).init(allocator); defer { for (formatted_row.items) |item| { allocator.free(item); } formatted_row.deinit(); } for (row.items) |val| { const r = @round(val * pow) / pow; const formatted = try fmt.allocPrint(allocator, "{d}", .{r}); if (std.mem.eql(u8, formatted, "-0")) { allocator.free(formatted); try formatted_row.append(try allocator.dupe(u8, "0")); } else { try formatted_row.append(formatted); } } std.debug.print("{any}\n", .{formatted_row.items}); } return output.toOwnedSlice(); } fn params(r: usize, c: usize) [4][6]usize { return [4][6]usize{ [_]usize{ 0, r, 0, c, 0, 0 }, [_]usize{ 0, r, c, 2 * c, 0, c }, [_]usize{ r, 2 * r, 0, c, r, 0 }, [_]usize{ r, 2 * r, c, 2 * c, r, c }, }; } pub fn toQuarters(self: Matrix) ![4]Matrix { const r = self.rows / 2; const c = self.cols / 2; const p = Matrix.params(r, c); var quarters: [4]Matrix = undefined; for (0..4) |k| { var q_data = ArrayList(ArrayList(f64)).init(self.allocator); try q_data.ensureTotalCapacity(r); for (p[k][0]..p[k][1]) |i| { var row = ArrayList(f64).init(self.allocator); try row.ensureTotalCapacity(c); for (p[k][2]..p[k][3]) |j| { try row.append(self.data.items[i].items[j]); } try q_data.append(row); } quarters[k] = try Matrix.init(self.allocator, q_data); } return quarters; } pub fn fromQuarters(q: [4]Matrix, allocator: Allocator) !Matrix { const r = q[0].rows; const c = q[0].cols; const p = Matrix.params(r, c); const rows = r * 2; const cols = c * 2; var m_data = ArrayList(ArrayList(f64)).init(allocator); try m_data.ensureTotalCapacity(rows); for (0..rows) |_| { var row = ArrayList(f64).init(allocator); try row.ensureTotalCapacity(cols); for (0..cols) |_| { try row.append(0.0); } try m_data.append(row); } for (0..4) |k| { for (p[k][0]..p[k][1]) |i| { for (p[k][2]..p[k][3]) |j| { m_data.items[i].items[j] = q[k].data.items[i - p[k][4]].items[j - p[k][5]]; } } } return try Matrix.init(allocator, m_data); } pub fn strassen(self: Matrix, other: Matrix) !Matrix { try self.validateSquarePowerOfTwo(); try other.validateSquarePowerOfTwo(); if (self.rows != other.rows or self.cols != other.cols) { return error.InvalidDimensions; } if (self.rows == 1) { return self.mul(other); } var qa = try self.toQuarters(); defer for (&qa) |*q| q.deinit(); var qb = try other.toQuarters(); defer for (&qb) |*q| q.deinit(); var t1 = try qa[1].sub(qa[3]); defer t1.deinit(); var t2 = try qb[2].add(qb[3]); defer t2.deinit(); var p1 = try t1.strassen(t2); defer p1.deinit(); var t3 = try qa[0].add(qa[3]); defer t3.deinit(); var t4 = try qb[0].add(qb[3]); defer t4.deinit(); var p2 = try t3.strassen(t4); defer p2.deinit(); var t5 = try qa[0].sub(qa[2]); defer t5.deinit(); var t6 = try qb[0].add(qb[1]); defer t6.deinit(); var p3 = try t5.strassen(t6); defer p3.deinit(); var t7 = try qa[0].add(qa[1]); defer t7.deinit(); var p4 = try t7.strassen(qb[3]); defer p4.deinit(); var t8 = try qb[1].sub(qb[3]); defer t8.deinit(); var p5 = try qa[0].strassen(t8); defer p5.deinit(); var t9 = try qb[2].sub(qb[0]); defer t9.deinit(); var p6 = try qa[3].strassen(t9); defer p6.deinit(); var t10 = try qa[2].add(qa[3]); defer t10.deinit(); var p7 = try t10.strassen(qb[0]); defer p7.deinit(); var q: [4]Matrix = undefined; // q[0] = p1 + p2 - p4 + p6 var ta = try p1.add(p2); defer ta.deinit(); var tb = try ta.sub(p4); defer tb.deinit(); q[0] = try tb.add(p6); // q[1] = p4 + p5 q[1] = try p4.add(p5); // q[2] = p6 + p7 q[2] = try p6.add(p7); // q[3] = p2 - p3 + p5 - p7 var tc = try p2.sub(p3); defer tc.deinit(); var td = try tc.add(p5); defer td.deinit(); q[3] = try td.sub(p7); defer for (&q) |*quarter| quarter.deinit(); return Matrix.fromQuarters(q, self.allocator); } }; pub fn main() !void { var gpa = std.heap.GeneralPurposeAllocator(.{}){}; defer _ = gpa.deinit(); const allocator = gpa.allocator(); // Matrix A - [1 2; 3 4] var a_data = ArrayList(ArrayList(f64)).init(allocator); var a_row1 = ArrayList(f64).init(allocator); try a_row1.appendSlice(&[_]f64{ 1.0, 2.0 }); var a_row2 = ArrayList(f64).init(allocator); try a_row2.appendSlice(&[_]f64{ 3.0, 4.0 }); try a_data.append(a_row1); try a_data.append(a_row2); var a = try Matrix.init(allocator, a_data); defer a.deinit(); // Matrix B - [5 6; 7 8] var b_data = ArrayList(ArrayList(f64)).init(allocator); var b_row1 = ArrayList(f64).init(allocator); try b_row1.appendSlice(&[_]f64{ 5.0, 6.0 }); var b_row2 = ArrayList(f64).init(allocator); try b_row2.appendSlice(&[_]f64{ 7.0, 8.0 }); try b_data.append(b_row1); try b_data.append(b_row2); var b = try Matrix.init(allocator, b_data); defer b.deinit(); // Matrix C - 4x4 var c_data = ArrayList(ArrayList(f64)).init(allocator); var c_row1 = ArrayList(f64).init(allocator); try c_row1.appendSlice(&[_]f64{ 1.0, 1.0, 1.0, 1.0 }); var c_row2 = ArrayList(f64).init(allocator); try c_row2.appendSlice(&[_]f64{ 2.0, 4.0, 8.0, 16.0 }); var c_row3 = ArrayList(f64).init(allocator); try c_row3.appendSlice(&[_]f64{ 3.0, 9.0, 27.0, 81.0 }); var c_row4 = ArrayList(f64).init(allocator); try c_row4.appendSlice(&[_]f64{ 4.0, 16.0, 64.0, 256.0 }); try c_data.append(c_row1); try c_data.append(c_row2); try c_data.append(c_row3); try c_data.append(c_row4); var c = try Matrix.init(allocator, c_data); defer c.deinit(); // Matrix D - 4x4 var d_data = ArrayList(ArrayList(f64)).init(allocator); var d_row1 = ArrayList(f64).init(allocator); try d_row1.appendSlice(&[_]f64{ 4.0, -3.0, 4.0 / 3.0, -1.0 / 4.0 }); var d_row2 = ArrayList(f64).init(allocator); try d_row2.appendSlice(&[_]f64{ -13.0 / 3.0, 19.0 / 4.0, -7.0 / 3.0, 11.0 / 24.0 }); var d_row3 = ArrayList(f64).init(allocator); try d_row3.appendSlice(&[_]f64{ 3.0 / 2.0, -2.0, 7.0 / 6.0, -1.0 / 4.0 }); var d_row4 = ArrayList(f64).init(allocator); try d_row4.appendSlice(&[_]f64{ -1.0 / 6.0, 1.0 / 4.0, -1.0 / 6.0, 1.0 / 24.0 }); try d_data.append(d_row1); try d_data.append(d_row2); try d_data.append(d_row3); try d_data.append(d_row4); var d = try Matrix.init(allocator, d_data); defer d.deinit(); // Matrix E - 4x4 var e_data = ArrayList(ArrayList(f64)).init(allocator); var e_row1 = ArrayList(f64).init(allocator); try e_row1.appendSlice(&[_]f64{ 1.0, 2.0, 3.0, 4.0 }); var e_row2 = ArrayList(f64).init(allocator); try e_row2.appendSlice(&[_]f64{ 5.0, 6.0, 7.0, 8.0 }); var e_row3 = ArrayList(f64).init(allocator); try e_row3.appendSlice(&[_]f64{ 9.0, 10.0, 11.0, 12.0 }); var e_row4 = ArrayList(f64).init(allocator); try e_row4.appendSlice(&[_]f64{ 13.0, 14.0, 15.0, 16.0 }); try e_data.append(e_row1); try e_data.append(e_row2); try e_data.append(e_row3); try e_data.append(e_row4); var e = try Matrix.init(allocator, e_data); defer e.deinit(); // Matrix F - Identity 4x4 var f_data = ArrayList(ArrayList(f64)).init(allocator); var f_row1 = ArrayList(f64).init(allocator); try f_row1.appendSlice(&[_]f64{ 1.0, 0.0, 0.0, 0.0 }); var f_row2 = ArrayList(f64).init(allocator); try f_row2.appendSlice(&[_]f64{ 0.0, 1.0, 0.0, 0.0 }); var f_row3 = ArrayList(f64).init(allocator); try f_row3.appendSlice(&[_]f64{ 0.0, 0.0, 1.0, 0.0 }); var f_row4 = ArrayList(f64).init(allocator); try f_row4.appendSlice(&[_]f64{ 0.0, 0.0, 0.0, 1.0 }); try f_data.append(f_row1); try f_data.append(f_row2); try f_data.append(f_row3); try f_data.append(f_row4); var f = try Matrix.init(allocator, f_data); defer f.deinit(); const stdout = std.io.getStdOut().writer(); try stdout.print("Using 'normal' matrix multiplication:\n", .{}); var a_clone = try a.clone(); defer a_clone.deinit(); var b_clone = try b.clone(); defer b_clone.deinit(); var ab = try a_clone.mul(b_clone); defer ab.deinit(); try stdout.print(" a * b = {}\n", .{ab}); var c_clone = try c.clone(); defer c_clone.deinit(); var d_clone = try d.clone(); defer d_clone.deinit(); var cd = try c_clone.mul(d_clone); defer cd.deinit(); const cd_str = try cd.toStringWithPrecision(6, allocator); defer allocator.free(cd_str); try stdout.print(" c * d = {s}\n", .{cd_str}); var e_clone = try e.clone(); defer e_clone.deinit(); var f_clone = try f.clone(); defer f_clone.deinit(); var ef = try e_clone.mul(f_clone); defer ef.deinit(); try stdout.print(" e * f = {}\n", .{ef}); try stdout.print("\nUsing 'Strassen' matrix multiplication:\n", .{}); var a_clone2 = try a.clone(); defer a_clone2.deinit(); var b_clone2 = try b.clone(); defer b_clone2.deinit(); var ab_s = try a_clone2.strassen(b_clone2); defer ab_s.deinit(); try stdout.print(" a * b = {}\n", .{ab_s}); var c_clone2 = try c.clone(); defer c_clone2.deinit(); var d_clone2 = try d.clone(); defer d_clone2.deinit(); var cd_s = try c_clone2.strassen(d_clone2); defer cd_s.deinit(); const cd_s_str = try cd_s.toStringWithPrecision(6, allocator); defer allocator.free(cd_s_str); try stdout.print(" c * d = {s}\n", .{cd_s_str}); var e_clone2 = try e.clone(); defer e_clone2.deinit(); var f_clone2 = try f.clone(); defer f_clone2.deinit(); var ef_s = try e_clone2.strassen(f_clone2); defer ef_s.deinit(); try stdout.print(" e * f = {}\n", .{ef_s}); }