RosettaCodeData/Task/Strassens-algorithm/Rust/strassens-algorithm.rs

270 lines
7.9 KiB
Rust

use std::fmt;
use std::ops::{Add, Mul, Sub};
#[derive(Debug, Clone)]
struct Matrix {
data: Vec<Vec<f64>>,
rows: usize,
cols: usize,
}
impl Matrix {
fn new(data: Vec<Vec<f64>>) -> Self {
let rows = data.len();
let cols = if rows > 0 { data[0].len() } else { 0 };
Matrix { data, rows, cols }
}
fn rows(&self) -> usize {
self.rows
}
fn cols(&self) -> usize {
self.cols
}
fn validate_dimensions(&self, other: &Matrix) {
if self.rows() != other.rows() || self.cols() != other.cols() {
panic!("Matrices must have the same dimensions.");
}
}
fn validate_multiplication(&self, other: &Matrix) {
if self.cols() != other.rows() {
panic!("Cannot multiply these matrices.");
}
}
fn validate_square_power_of_two(&self) {
if self.rows() != self.cols() {
panic!("Matrix must be square.");
}
if self.rows() == 0 || (self.rows() & (self.rows() - 1)) != 0 {
panic!("Size of matrix must be a power of two.");
}
}
}
impl Add for Matrix {
type Output = Self;
fn add(self, other: Self) -> Self {
self.validate_dimensions(&other);
let mut result_data = Vec::with_capacity(self.rows());
for i in 0..self.rows() {
let mut row = Vec::with_capacity(self.cols());
for j in 0..self.cols() {
row.push(self.data[i][j] + other.data[i][j]);
}
result_data.push(row);
}
Matrix::new(result_data)
}
}
impl Sub for Matrix {
type Output = Self;
fn sub(self, other: Self) -> Self {
self.validate_dimensions(&other);
let mut result_data = Vec::with_capacity(self.rows());
for i in 0..self.rows() {
let mut row = Vec::with_capacity(self.cols());
for j in 0..self.cols() {
row.push(self.data[i][j] - other.data[i][j]);
}
result_data.push(row);
}
Matrix::new(result_data)
}
}
impl Mul for Matrix {
type Output = Self;
fn mul(self, other: Self) -> Self {
self.validate_multiplication(&other);
let mut result_data = Vec::with_capacity(self.rows());
for i in 0..self.rows() {
let mut row = Vec::with_capacity(other.cols());
for j in 0..other.cols() {
let mut sum = 0.0;
for k in 0..other.rows() {
sum += self.data[i][k] * other.data[k][j];
}
row.push(sum);
}
result_data.push(row);
}
Matrix::new(result_data)
}
}
impl fmt::Display for Matrix {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let mut s = String::new();
for row in &self.data {
s.push_str(&format!("{:?}\n", row));
}
write!(f, "{}", s)
}
}
impl Matrix {
fn to_string_with_precision(&self, p: usize) -> String {
let mut s = String::new();
let pow = 10.0_f64.powi(p as i32);
for row in &self.data {
let mut t = Vec::new();
for &val in row {
let r = (val * pow).round() / pow;
let formatted = format!("{}", r);
if formatted == "-0" {
t.push("0".to_string());
} else {
t.push(formatted);
}
}
s.push_str(&format!("{:?}\n", t));
}
s
}
fn params(r: usize, c: usize) -> [[usize; 6]; 4] {
[
[0, r, 0, c, 0, 0],
[0, r, c, 2 * c, 0, c],
[r, 2 * r, 0, c, r, 0],
[r, 2 * r, c, 2 * c, r, c],
]
}
fn to_quarters(&self) -> [Matrix; 4] {
let r = self.rows() / 2;
let c = self.cols() / 2;
let p = Matrix::params(r, c);
let mut quarters: [Matrix; 4] = [
Matrix::new(vec![vec![0.0; c]; r]),
Matrix::new(vec![vec![0.0; c]; r]),
Matrix::new(vec![vec![0.0; c]; r]),
Matrix::new(vec![vec![0.0; c]; r]),
];
for k in 0..4 {
let mut q_data = Vec::with_capacity(r);
for i in p[k][0]..p[k][1] {
let mut row = Vec::with_capacity(c);
for j in p[k][2]..p[k][3] {
row.push(self.data[i][j]);
}
q_data.push(row);
}
quarters[k] = Matrix::new(q_data);
}
quarters
}
fn from_quarters(q: [Matrix; 4]) -> Matrix {
let r = q[0].rows();
let c = q[0].cols();
let p = Matrix::params(r, c);
let rows = r * 2;
let cols = c * 2;
let mut m_data = vec![vec![0.0; cols]; rows];
for k in 0..4 {
for i in p[k][0]..p[k][1] {
for j in p[k][2]..p[k][3] {
m_data[i][j] = q[k].data[i - p[k][4]][j - p[k][5]];
}
}
}
Matrix::new(m_data)
}
fn strassen(&self, other: Matrix) -> Matrix {
self.validate_square_power_of_two();
other.validate_square_power_of_two();
if self.rows() != other.rows() || self.cols() != other.cols() {
panic!("Matrices must be square and of equal size for Strassen multiplication.");
}
if self.rows() == 1 {
return self.clone() * other;
}
let qa = self.to_quarters();
let qb = other.to_quarters();
let p1 = (qa[1].clone() - qa[3].clone()).strassen(qb[2].clone() + qb[3].clone());
let p2 = (qa[0].clone() + qa[3].clone()).strassen(qb[0].clone() + qb[3].clone());
let p3 = (qa[0].clone() - qa[2].clone()).strassen(qb[0].clone() + qb[1].clone());
let p4 = (qa[0].clone() + qa[1].clone()).strassen(qb[3].clone());
let p5 = qa[0].clone().strassen(qb[1].clone() - qb[3].clone());
let p6 = qa[3].clone().strassen(qb[2].clone() - qb[0].clone());
let p7 = (qa[2].clone() + qa[3].clone()).strassen(qb[0].clone());
let mut q: [Matrix; 4] = [
Matrix::new(vec![vec![0.0; qa[0].cols()]; qa[0].rows()]),
Matrix::new(vec![vec![0.0; qa[0].cols()]; qa[0].rows()]),
Matrix::new(vec![vec![0.0; qa[0].cols()]; qa[0].rows()]),
Matrix::new(vec![vec![0.0; qa[0].cols()]; qa[0].rows()]),
];
q[0] = p1.clone() + p2.clone() - p4.clone() + p6.clone();
q[1] = p4 + p5.clone();
q[2] = p6 + p7.clone();
q[3] = p2 - p3.clone() + p5 - p7;
Matrix::from_quarters(q)
}
}
fn main() {
let a = Matrix::new(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
let b = Matrix::new(vec![vec![5.0, 6.0], vec![7.0, 8.0]]);
let c = Matrix::new(vec![
vec![1.0, 1.0, 1.0, 1.0],
vec![2.0, 4.0, 8.0, 16.0],
vec![3.0, 9.0, 27.0, 81.0],
vec![4.0, 16.0, 64.0, 256.0],
]);
let d = Matrix::new(vec![
vec![4.0, -3.0, 4.0 / 3.0, -1.0 / 4.0],
vec![-13.0 / 3.0, 19.0 / 4.0, -7.0 / 3.0, 11.0 / 24.0],
vec![3.0 / 2.0, -2.0, 7.0 / 6.0, -1.0 / 4.0],
vec![-1.0 / 6.0, 1.0 / 4.0, -1.0 / 6.0, 1.0 / 24.0],
]);
let e = Matrix::new(vec![
vec![1.0, 2.0, 3.0, 4.0],
vec![5.0, 6.0, 7.0, 8.0],
vec![9.0, 10.0, 11.0, 12.0],
vec![13.0, 14.0, 15.0, 16.0],
]);
let f = Matrix::new(vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
vec![0.0, 0.0, 1.0, 0.0],
vec![0.0, 0.0, 0.0, 1.0],
]);
println!("Using 'normal' matrix multiplication:");
println!(" a * b = {}", a.clone() * b.clone());
println!(" c * d = {}", (c.clone() * d.clone()).to_string_with_precision(6));
println!(" e * f = {}", e.clone() * f.clone());
println!("\nUsing 'Strassen' matrix multiplication:");
println!(" a * b = {}", a.strassen(b));
println!(" c * d = {}", c.strassen(d).to_string_with_precision(6));
println!(" e * f = {}", e.strassen(f));
}