478 lines
15 KiB
Zig
478 lines
15 KiB
Zig
const std = @import("std");
|
|
const fmt = std.fmt;
|
|
const ArrayList = std.ArrayList;
|
|
const Allocator = std.mem.Allocator;
|
|
|
|
const Matrix = struct {
|
|
data: ArrayList(ArrayList(f64)),
|
|
rows: usize,
|
|
cols: usize,
|
|
allocator: Allocator,
|
|
|
|
pub fn init(allocator: Allocator, data: ArrayList(ArrayList(f64))) !Matrix {
|
|
const rows = data.items.len;
|
|
const cols = if (rows > 0) data.items[0].items.len else 0;
|
|
|
|
return Matrix{
|
|
.data = data,
|
|
.rows = rows,
|
|
.cols = cols,
|
|
.allocator = allocator,
|
|
};
|
|
}
|
|
|
|
pub fn deinit(self: *Matrix) void {
|
|
for (self.data.items) |*row| {
|
|
row.deinit();
|
|
}
|
|
self.data.deinit();
|
|
}
|
|
|
|
pub fn clone(self: Matrix) !Matrix {
|
|
var new_data = ArrayList(ArrayList(f64)).init(self.allocator);
|
|
try new_data.ensureTotalCapacity(self.rows);
|
|
|
|
for (self.data.items) |row| {
|
|
var new_row = ArrayList(f64).init(self.allocator);
|
|
try new_row.ensureTotalCapacity(self.cols);
|
|
try new_row.appendSlice(row.items);
|
|
try new_data.append(new_row);
|
|
}
|
|
|
|
return Matrix{
|
|
.data = new_data,
|
|
.rows = self.rows,
|
|
.cols = self.cols,
|
|
.allocator = self.allocator,
|
|
};
|
|
}
|
|
|
|
pub fn validateDimensions(self: Matrix, other: Matrix) !void {
|
|
if (self.rows != other.rows or self.cols != other.cols) {
|
|
return error.DimensionMismatch;
|
|
}
|
|
}
|
|
|
|
pub fn validateMultiplication(self: Matrix, other: Matrix) !void {
|
|
if (self.cols != other.rows) {
|
|
return error.CannotMultiply;
|
|
}
|
|
}
|
|
|
|
pub fn validateSquarePowerOfTwo(self: Matrix) !void {
|
|
if (self.rows != self.cols) {
|
|
return error.NotSquare;
|
|
}
|
|
if (self.rows == 0 or (self.rows & (self.rows - 1)) != 0) {
|
|
return error.NotPowerOfTwo;
|
|
}
|
|
}
|
|
|
|
pub fn add(self: Matrix, other: Matrix) !Matrix {
|
|
try self.validateDimensions(other);
|
|
|
|
var result_data = ArrayList(ArrayList(f64)).init(self.allocator);
|
|
try result_data.ensureTotalCapacity(self.rows);
|
|
|
|
for (0..self.rows) |i| {
|
|
var row = ArrayList(f64).init(self.allocator);
|
|
try row.ensureTotalCapacity(self.cols);
|
|
|
|
for (0..self.cols) |j| {
|
|
try row.append(self.data.items[i].items[j] + other.data.items[i].items[j]);
|
|
}
|
|
try result_data.append(row);
|
|
}
|
|
|
|
return try Matrix.init(self.allocator, result_data);
|
|
}
|
|
|
|
pub fn sub(self: Matrix, other: Matrix) !Matrix {
|
|
try self.validateDimensions(other);
|
|
|
|
var result_data = ArrayList(ArrayList(f64)).init(self.allocator);
|
|
try result_data.ensureTotalCapacity(self.rows);
|
|
|
|
for (0..self.rows) |i| {
|
|
var row = ArrayList(f64).init(self.allocator);
|
|
try row.ensureTotalCapacity(self.cols);
|
|
|
|
for (0..self.cols) |j| {
|
|
try row.append(self.data.items[i].items[j] - other.data.items[i].items[j]);
|
|
}
|
|
try result_data.append(row);
|
|
}
|
|
|
|
return try Matrix.init(self.allocator, result_data);
|
|
}
|
|
|
|
pub fn mul(self: Matrix, other: Matrix) !Matrix {
|
|
try self.validateMultiplication(other);
|
|
|
|
var result_data = ArrayList(ArrayList(f64)).init(self.allocator);
|
|
try result_data.ensureTotalCapacity(self.rows);
|
|
|
|
for (0..self.rows) |i| {
|
|
var row = ArrayList(f64).init(self.allocator);
|
|
try row.ensureTotalCapacity(other.cols);
|
|
|
|
for (0..other.cols) |j| {
|
|
var sum: f64 = 0.0;
|
|
for (0..self.cols) |k| {
|
|
sum += self.data.items[i].items[k] * other.data.items[k].items[j];
|
|
}
|
|
try row.append(sum);
|
|
}
|
|
try result_data.append(row);
|
|
}
|
|
|
|
return try Matrix.init(self.allocator, result_data);
|
|
}
|
|
|
|
pub fn format(self: Matrix, comptime _: []const u8, _: fmt.FormatOptions, writer: anytype) !void {
|
|
for (self.data.items) |row| {
|
|
try writer.print("{any}\n", .{row.items});
|
|
}
|
|
}
|
|
|
|
pub fn toStringWithPrecision(self: Matrix, p: usize, allocator: Allocator) ![]u8 {
|
|
var output = ArrayList(u8).init(allocator);
|
|
defer output.deinit();
|
|
|
|
const pow = std.math.pow(f64, 10.0, @as(f64, @floatFromInt(p)));
|
|
|
|
for (self.data.items) |row| {
|
|
var formatted_row = ArrayList([]const u8).init(allocator);
|
|
defer {
|
|
for (formatted_row.items) |item| {
|
|
allocator.free(item);
|
|
}
|
|
formatted_row.deinit();
|
|
}
|
|
|
|
for (row.items) |val| {
|
|
const r = @round(val * pow) / pow;
|
|
const formatted = try fmt.allocPrint(allocator, "{d}", .{r});
|
|
|
|
if (std.mem.eql(u8, formatted, "-0")) {
|
|
allocator.free(formatted);
|
|
try formatted_row.append(try allocator.dupe(u8, "0"));
|
|
} else {
|
|
try formatted_row.append(formatted);
|
|
}
|
|
}
|
|
|
|
std.debug.print("{any}\n", .{formatted_row.items});
|
|
}
|
|
|
|
return output.toOwnedSlice();
|
|
}
|
|
|
|
fn params(r: usize, c: usize) [4][6]usize {
|
|
return [4][6]usize{
|
|
[_]usize{ 0, r, 0, c, 0, 0 },
|
|
[_]usize{ 0, r, c, 2 * c, 0, c },
|
|
[_]usize{ r, 2 * r, 0, c, r, 0 },
|
|
[_]usize{ r, 2 * r, c, 2 * c, r, c },
|
|
};
|
|
}
|
|
|
|
pub fn toQuarters(self: Matrix) ![4]Matrix {
|
|
const r = self.rows / 2;
|
|
const c = self.cols / 2;
|
|
const p = Matrix.params(r, c);
|
|
|
|
var quarters: [4]Matrix = undefined;
|
|
|
|
for (0..4) |k| {
|
|
var q_data = ArrayList(ArrayList(f64)).init(self.allocator);
|
|
try q_data.ensureTotalCapacity(r);
|
|
|
|
for (p[k][0]..p[k][1]) |i| {
|
|
var row = ArrayList(f64).init(self.allocator);
|
|
try row.ensureTotalCapacity(c);
|
|
|
|
for (p[k][2]..p[k][3]) |j| {
|
|
try row.append(self.data.items[i].items[j]);
|
|
}
|
|
try q_data.append(row);
|
|
}
|
|
|
|
quarters[k] = try Matrix.init(self.allocator, q_data);
|
|
}
|
|
|
|
return quarters;
|
|
}
|
|
|
|
pub fn fromQuarters(q: [4]Matrix, allocator: Allocator) !Matrix {
|
|
const r = q[0].rows;
|
|
const c = q[0].cols;
|
|
const p = Matrix.params(r, c);
|
|
const rows = r * 2;
|
|
const cols = c * 2;
|
|
|
|
var m_data = ArrayList(ArrayList(f64)).init(allocator);
|
|
try m_data.ensureTotalCapacity(rows);
|
|
|
|
for (0..rows) |_| {
|
|
var row = ArrayList(f64).init(allocator);
|
|
try row.ensureTotalCapacity(cols);
|
|
for (0..cols) |_| {
|
|
try row.append(0.0);
|
|
}
|
|
try m_data.append(row);
|
|
}
|
|
|
|
for (0..4) |k| {
|
|
for (p[k][0]..p[k][1]) |i| {
|
|
for (p[k][2]..p[k][3]) |j| {
|
|
m_data.items[i].items[j] = q[k].data.items[i - p[k][4]].items[j - p[k][5]];
|
|
}
|
|
}
|
|
}
|
|
|
|
return try Matrix.init(allocator, m_data);
|
|
}
|
|
|
|
pub fn strassen(self: Matrix, other: Matrix) !Matrix {
|
|
try self.validateSquarePowerOfTwo();
|
|
try other.validateSquarePowerOfTwo();
|
|
|
|
if (self.rows != other.rows or self.cols != other.cols) {
|
|
return error.InvalidDimensions;
|
|
}
|
|
|
|
if (self.rows == 1) {
|
|
return self.mul(other);
|
|
}
|
|
|
|
var qa = try self.toQuarters();
|
|
defer for (&qa) |*q| q.deinit();
|
|
|
|
var qb = try other.toQuarters();
|
|
defer for (&qb) |*q| q.deinit();
|
|
|
|
var t1 = try qa[1].sub(qa[3]);
|
|
defer t1.deinit();
|
|
var t2 = try qb[2].add(qb[3]);
|
|
defer t2.deinit();
|
|
var p1 = try t1.strassen(t2);
|
|
defer p1.deinit();
|
|
|
|
var t3 = try qa[0].add(qa[3]);
|
|
defer t3.deinit();
|
|
var t4 = try qb[0].add(qb[3]);
|
|
defer t4.deinit();
|
|
var p2 = try t3.strassen(t4);
|
|
defer p2.deinit();
|
|
|
|
var t5 = try qa[0].sub(qa[2]);
|
|
defer t5.deinit();
|
|
var t6 = try qb[0].add(qb[1]);
|
|
defer t6.deinit();
|
|
var p3 = try t5.strassen(t6);
|
|
defer p3.deinit();
|
|
|
|
var t7 = try qa[0].add(qa[1]);
|
|
defer t7.deinit();
|
|
var p4 = try t7.strassen(qb[3]);
|
|
defer p4.deinit();
|
|
|
|
var t8 = try qb[1].sub(qb[3]);
|
|
defer t8.deinit();
|
|
var p5 = try qa[0].strassen(t8);
|
|
defer p5.deinit();
|
|
|
|
var t9 = try qb[2].sub(qb[0]);
|
|
defer t9.deinit();
|
|
var p6 = try qa[3].strassen(t9);
|
|
defer p6.deinit();
|
|
|
|
var t10 = try qa[2].add(qa[3]);
|
|
defer t10.deinit();
|
|
var p7 = try t10.strassen(qb[0]);
|
|
defer p7.deinit();
|
|
|
|
var q: [4]Matrix = undefined;
|
|
|
|
// q[0] = p1 + p2 - p4 + p6
|
|
var ta = try p1.add(p2);
|
|
defer ta.deinit();
|
|
var tb = try ta.sub(p4);
|
|
defer tb.deinit();
|
|
q[0] = try tb.add(p6);
|
|
|
|
// q[1] = p4 + p5
|
|
q[1] = try p4.add(p5);
|
|
|
|
// q[2] = p6 + p7
|
|
q[2] = try p6.add(p7);
|
|
|
|
// q[3] = p2 - p3 + p5 - p7
|
|
var tc = try p2.sub(p3);
|
|
defer tc.deinit();
|
|
var td = try tc.add(p5);
|
|
defer td.deinit();
|
|
q[3] = try td.sub(p7);
|
|
|
|
defer for (&q) |*quarter| quarter.deinit();
|
|
|
|
return Matrix.fromQuarters(q, self.allocator);
|
|
}
|
|
};
|
|
|
|
pub fn main() !void {
|
|
var gpa = std.heap.GeneralPurposeAllocator(.{}){};
|
|
defer _ = gpa.deinit();
|
|
|
|
const allocator = gpa.allocator();
|
|
|
|
// Matrix A - [1 2; 3 4]
|
|
var a_data = ArrayList(ArrayList(f64)).init(allocator);
|
|
var a_row1 = ArrayList(f64).init(allocator);
|
|
try a_row1.appendSlice(&[_]f64{ 1.0, 2.0 });
|
|
var a_row2 = ArrayList(f64).init(allocator);
|
|
try a_row2.appendSlice(&[_]f64{ 3.0, 4.0 });
|
|
try a_data.append(a_row1);
|
|
try a_data.append(a_row2);
|
|
var a = try Matrix.init(allocator, a_data);
|
|
defer a.deinit();
|
|
|
|
// Matrix B - [5 6; 7 8]
|
|
var b_data = ArrayList(ArrayList(f64)).init(allocator);
|
|
var b_row1 = ArrayList(f64).init(allocator);
|
|
try b_row1.appendSlice(&[_]f64{ 5.0, 6.0 });
|
|
var b_row2 = ArrayList(f64).init(allocator);
|
|
try b_row2.appendSlice(&[_]f64{ 7.0, 8.0 });
|
|
try b_data.append(b_row1);
|
|
try b_data.append(b_row2);
|
|
var b = try Matrix.init(allocator, b_data);
|
|
defer b.deinit();
|
|
|
|
// Matrix C - 4x4
|
|
var c_data = ArrayList(ArrayList(f64)).init(allocator);
|
|
var c_row1 = ArrayList(f64).init(allocator);
|
|
try c_row1.appendSlice(&[_]f64{ 1.0, 1.0, 1.0, 1.0 });
|
|
var c_row2 = ArrayList(f64).init(allocator);
|
|
try c_row2.appendSlice(&[_]f64{ 2.0, 4.0, 8.0, 16.0 });
|
|
var c_row3 = ArrayList(f64).init(allocator);
|
|
try c_row3.appendSlice(&[_]f64{ 3.0, 9.0, 27.0, 81.0 });
|
|
var c_row4 = ArrayList(f64).init(allocator);
|
|
try c_row4.appendSlice(&[_]f64{ 4.0, 16.0, 64.0, 256.0 });
|
|
try c_data.append(c_row1);
|
|
try c_data.append(c_row2);
|
|
try c_data.append(c_row3);
|
|
try c_data.append(c_row4);
|
|
var c = try Matrix.init(allocator, c_data);
|
|
defer c.deinit();
|
|
|
|
// Matrix D - 4x4
|
|
var d_data = ArrayList(ArrayList(f64)).init(allocator);
|
|
var d_row1 = ArrayList(f64).init(allocator);
|
|
try d_row1.appendSlice(&[_]f64{ 4.0, -3.0, 4.0 / 3.0, -1.0 / 4.0 });
|
|
var d_row2 = ArrayList(f64).init(allocator);
|
|
try d_row2.appendSlice(&[_]f64{ -13.0 / 3.0, 19.0 / 4.0, -7.0 / 3.0, 11.0 / 24.0 });
|
|
var d_row3 = ArrayList(f64).init(allocator);
|
|
try d_row3.appendSlice(&[_]f64{ 3.0 / 2.0, -2.0, 7.0 / 6.0, -1.0 / 4.0 });
|
|
var d_row4 = ArrayList(f64).init(allocator);
|
|
try d_row4.appendSlice(&[_]f64{ -1.0 / 6.0, 1.0 / 4.0, -1.0 / 6.0, 1.0 / 24.0 });
|
|
try d_data.append(d_row1);
|
|
try d_data.append(d_row2);
|
|
try d_data.append(d_row3);
|
|
try d_data.append(d_row4);
|
|
var d = try Matrix.init(allocator, d_data);
|
|
defer d.deinit();
|
|
|
|
// Matrix E - 4x4
|
|
var e_data = ArrayList(ArrayList(f64)).init(allocator);
|
|
var e_row1 = ArrayList(f64).init(allocator);
|
|
try e_row1.appendSlice(&[_]f64{ 1.0, 2.0, 3.0, 4.0 });
|
|
var e_row2 = ArrayList(f64).init(allocator);
|
|
try e_row2.appendSlice(&[_]f64{ 5.0, 6.0, 7.0, 8.0 });
|
|
var e_row3 = ArrayList(f64).init(allocator);
|
|
try e_row3.appendSlice(&[_]f64{ 9.0, 10.0, 11.0, 12.0 });
|
|
var e_row4 = ArrayList(f64).init(allocator);
|
|
try e_row4.appendSlice(&[_]f64{ 13.0, 14.0, 15.0, 16.0 });
|
|
try e_data.append(e_row1);
|
|
try e_data.append(e_row2);
|
|
try e_data.append(e_row3);
|
|
try e_data.append(e_row4);
|
|
var e = try Matrix.init(allocator, e_data);
|
|
defer e.deinit();
|
|
|
|
// Matrix F - Identity 4x4
|
|
var f_data = ArrayList(ArrayList(f64)).init(allocator);
|
|
var f_row1 = ArrayList(f64).init(allocator);
|
|
try f_row1.appendSlice(&[_]f64{ 1.0, 0.0, 0.0, 0.0 });
|
|
var f_row2 = ArrayList(f64).init(allocator);
|
|
try f_row2.appendSlice(&[_]f64{ 0.0, 1.0, 0.0, 0.0 });
|
|
var f_row3 = ArrayList(f64).init(allocator);
|
|
try f_row3.appendSlice(&[_]f64{ 0.0, 0.0, 1.0, 0.0 });
|
|
var f_row4 = ArrayList(f64).init(allocator);
|
|
try f_row4.appendSlice(&[_]f64{ 0.0, 0.0, 0.0, 1.0 });
|
|
try f_data.append(f_row1);
|
|
try f_data.append(f_row2);
|
|
try f_data.append(f_row3);
|
|
try f_data.append(f_row4);
|
|
var f = try Matrix.init(allocator, f_data);
|
|
defer f.deinit();
|
|
|
|
const stdout = std.io.getStdOut().writer();
|
|
|
|
try stdout.print("Using 'normal' matrix multiplication:\n", .{});
|
|
|
|
var a_clone = try a.clone();
|
|
defer a_clone.deinit();
|
|
var b_clone = try b.clone();
|
|
defer b_clone.deinit();
|
|
var ab = try a_clone.mul(b_clone);
|
|
defer ab.deinit();
|
|
try stdout.print(" a * b = {}\n", .{ab});
|
|
|
|
var c_clone = try c.clone();
|
|
defer c_clone.deinit();
|
|
var d_clone = try d.clone();
|
|
defer d_clone.deinit();
|
|
var cd = try c_clone.mul(d_clone);
|
|
defer cd.deinit();
|
|
const cd_str = try cd.toStringWithPrecision(6, allocator);
|
|
defer allocator.free(cd_str);
|
|
try stdout.print(" c * d = {s}\n", .{cd_str});
|
|
|
|
var e_clone = try e.clone();
|
|
defer e_clone.deinit();
|
|
var f_clone = try f.clone();
|
|
defer f_clone.deinit();
|
|
var ef = try e_clone.mul(f_clone);
|
|
defer ef.deinit();
|
|
try stdout.print(" e * f = {}\n", .{ef});
|
|
|
|
try stdout.print("\nUsing 'Strassen' matrix multiplication:\n", .{});
|
|
|
|
var a_clone2 = try a.clone();
|
|
defer a_clone2.deinit();
|
|
var b_clone2 = try b.clone();
|
|
defer b_clone2.deinit();
|
|
var ab_s = try a_clone2.strassen(b_clone2);
|
|
defer ab_s.deinit();
|
|
try stdout.print(" a * b = {}\n", .{ab_s});
|
|
|
|
var c_clone2 = try c.clone();
|
|
defer c_clone2.deinit();
|
|
var d_clone2 = try d.clone();
|
|
defer d_clone2.deinit();
|
|
var cd_s = try c_clone2.strassen(d_clone2);
|
|
defer cd_s.deinit();
|
|
const cd_s_str = try cd_s.toStringWithPrecision(6, allocator);
|
|
defer allocator.free(cd_s_str);
|
|
try stdout.print(" c * d = {s}\n", .{cd_s_str});
|
|
|
|
var e_clone2 = try e.clone();
|
|
defer e_clone2.deinit();
|
|
var f_clone2 = try f.clone();
|
|
defer f_clone2.deinit();
|
|
var ef_s = try e_clone2.strassen(f_clone2);
|
|
defer ef_s.deinit();
|
|
try stdout.print(" e * f = {}\n", .{ef_s});
|
|
}
|