From df0e93735bce8eb94ed6ab4ca055b46c6ecd3568 Mon Sep 17 00:00:00 2001 From: Alex Selimov Date: Sat, 4 Apr 2026 21:50:06 -0400 Subject: [PATCH] Add basic matrix math --- src/matrix.rs | 102 ++++++++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 99 insertions(+), 3 deletions(-) diff --git a/src/matrix.rs b/src/matrix.rs index 0fbe4c9..b196e3b 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -11,8 +11,9 @@ pub trait Numeric: + Mul + Div + Sized - + Clone + + Copy + Display + + Default { } @@ -25,6 +26,8 @@ impl Numeric for T where + Sized + Clone + Display + + Default + + Copy { } @@ -32,6 +35,9 @@ impl Numeric for T where #[derive(Debug)] pub enum MatrixError { IndexError(usize, usize, usize, usize), + AddError(usize, usize, usize, usize), + MultiplicationError(usize, usize, usize, usize), + FromVecError(usize, usize, usize), } impl fmt::Display for MatrixError { @@ -42,6 +48,19 @@ impl fmt::Display for MatrixError { "Error accessing index [{i},{j}] for matrix with dimensions [{}, {}]", m, n ), + + MatrixError::MultiplicationError(i, j, m, n) => write!( + f, + "Matrices with dimensions [{i},{j}] and [{m},{n}] cannot be multiplied", + ), + MatrixError::AddError(i, j, m, n) => write!( + f, + "Matrices with dimensions [{i},{j}] and [{m},{n}] cannot be added", + ), + MatrixError::FromVecError(i, j, len) => write!( + f, + "Matrices with dimensions [{i},{j}] cannot be created from vec with len={len}", + ), } } } @@ -62,13 +81,58 @@ impl Matrix { Matrix { data, m, n } } + pub fn from_vec(m: usize, n: usize, data: Vec) -> Result { + if m * n != data.len() { + return Err(MatrixError::FromVecError(m, n, data.len())); + } + Ok(Self { m, n, data }) + } + fn index(&self, i: usize, j: usize) -> usize { i * self.n + j } pub fn get(&self, i: usize, j: usize) -> Result<&T, MatrixError> { self.data .get(self.index(i, j)) - .ok_or(make_index_error(i, j, &self)) + .ok_or(make_index_error(i, j, self)) + } + + pub fn add(&self, other: &Matrix) -> Result, MatrixError> { + // Compatibility check + if self.m != other.m || self.n != other.n { + return Err(MatrixError::AddError(self.m, self.n, other.m, other.n)); + } + + let data = self + .data + .iter() + .zip(other.data.iter()) + .map(|(a, b)| *a + *b) + .collect(); + Self::from_vec(self.m, self.n, data) + } + + pub fn mul(&self, other: &Matrix) -> Result, MatrixError> { + let mut c = Matrix::new(self.m, other.n, T::default()); + // Compatibility check + if self.n != other.m { + return Err(MatrixError::MultiplicationError( + self.m, self.n, other.m, other.n, + )); + } + + for i in 0..self.m { + for k in 0..self.n { + for j in 0..other.n { + c.set( + i, + j, + *c.get(i, j)? + (*self.get(i, k)?) * (*other.get(k, j)?), + ); + } + } + } + Ok(c) } pub fn set(&mut self, i: usize, j: usize, x: T) { @@ -94,7 +158,7 @@ impl Display for Matrix { }) .collect::>() .join(" "); - write!(f, "\n{msg}") + write!(f, "\n {msg}") } } @@ -131,5 +195,37 @@ mod test { assert_delta!(m.data[7], 2.0, 1e-12); assert_delta!(m.data[10], 1.0, 1e-12); assert_delta!(m.data[11], 2.0, 1e-12); + + // Test from_vec + let data = vec![1, 2, 3, 4, 5, 6]; + let m = Matrix::from_vec(3, 2, data).unwrap(); + assert_eq!(*m.get(0, 0).unwrap(), 1); + assert_eq!(*m.get(0, 1).unwrap(), 2); + assert_eq!(*m.get(1, 0).unwrap(), 3); + assert_eq!(*m.get(1, 1).unwrap(), 4); + assert_eq!(*m.get(2, 0).unwrap(), 5); + assert_eq!(*m.get(2, 1).unwrap(), 6); + } + + #[test] + fn test_matrix_math() { + let a = Matrix::from_vec(3, 2, vec![1, 2, 3, 4, 5, 6]).unwrap(); + let b = Matrix::from_vec(3, 2, vec![0, 1, 0, 1, 0, 1]).unwrap(); + + let c = a.add(&b).unwrap(); + assert_eq!(*c.get(0, 0).unwrap(), 1); + assert_eq!(*c.get(0, 1).unwrap(), 3); + assert_eq!(*c.get(1, 0).unwrap(), 3); + assert_eq!(*c.get(1, 1).unwrap(), 5); + assert_eq!(*c.get(2, 0).unwrap(), 5); + assert_eq!(*c.get(2, 1).unwrap(), 7); + + assert!(a.mul(&b).is_err()); + + let a = Matrix::from_vec(3, 2, vec![1, 2, 3, 4, 5, 6]).unwrap(); + let b = Matrix::from_vec(2, 3, vec![0, 1, 0, 1, 0, 1]).unwrap(); + let c = a.mul(&b).unwrap(); + + assert_eq!(c.data, vec![2, 1, 2, 4, 3, 4, 6, 5, 6]); } }