Add basic matrix math

This commit is contained in:
Alex Selimov 2026-04-04 21:50:06 -04:00
parent 93bd1f4d38
commit df0e93735b

View file

@ -11,8 +11,9 @@ pub trait Numeric:
+ Mul<Output = Self>
+ Div<Output = Self>
+ Sized
+ Clone
+ Copy
+ Display
+ Default
{
}
@ -25,6 +26,8 @@ impl<T> Numeric for T where
+ Sized
+ Clone
+ Display
+ Default
+ Copy
{
}
@ -32,6 +35,9 @@ impl<T> Numeric for T where
#[derive(Debug)]
pub enum MatrixError {
IndexError(usize, usize, usize, usize),
AddError(usize, usize, usize, usize),
MultiplicationError(usize, usize, usize, usize),
FromVecError(usize, usize, usize),
}
impl fmt::Display for MatrixError {
@ -42,6 +48,19 @@ impl fmt::Display for MatrixError {
"Error accessing index [{i},{j}] for matrix with dimensions [{}, {}]",
m, n
),
MatrixError::MultiplicationError(i, j, m, n) => write!(
f,
"Matrices with dimensions [{i},{j}] and [{m},{n}] cannot be multiplied",
),
MatrixError::AddError(i, j, m, n) => write!(
f,
"Matrices with dimensions [{i},{j}] and [{m},{n}] cannot be added",
),
MatrixError::FromVecError(i, j, len) => write!(
f,
"Matrices with dimensions [{i},{j}] cannot be created from vec with len={len}",
),
}
}
}
@ -62,13 +81,58 @@ impl<T: Numeric> Matrix<T> {
Matrix { data, m, n }
}
pub fn from_vec(m: usize, n: usize, data: Vec<T>) -> Result<Self, MatrixError> {
if m * n != data.len() {
return Err(MatrixError::FromVecError(m, n, data.len()));
}
Ok(Self { m, n, data })
}
fn index(&self, i: usize, j: usize) -> usize {
i * self.n + j
}
pub fn get(&self, i: usize, j: usize) -> Result<&T, MatrixError> {
self.data
.get(self.index(i, j))
.ok_or(make_index_error(i, j, &self))
.ok_or(make_index_error(i, j, self))
}
pub fn add(&self, other: &Matrix<T>) -> Result<Matrix<T>, MatrixError> {
// Compatibility check
if self.m != other.m || self.n != other.n {
return Err(MatrixError::AddError(self.m, self.n, other.m, other.n));
}
let data = self
.data
.iter()
.zip(other.data.iter())
.map(|(a, b)| *a + *b)
.collect();
Self::from_vec(self.m, self.n, data)
}
pub fn mul(&self, other: &Matrix<T>) -> Result<Matrix<T>, MatrixError> {
let mut c = Matrix::new(self.m, other.n, T::default());
// Compatibility check
if self.n != other.m {
return Err(MatrixError::MultiplicationError(
self.m, self.n, other.m, other.n,
));
}
for i in 0..self.m {
for k in 0..self.n {
for j in 0..other.n {
c.set(
i,
j,
*c.get(i, j)? + (*self.get(i, k)?) * (*other.get(k, j)?),
);
}
}
}
Ok(c)
}
pub fn set(&mut self, i: usize, j: usize, x: T) {
@ -131,5 +195,37 @@ mod test {
assert_delta!(m.data[7], 2.0, 1e-12);
assert_delta!(m.data[10], 1.0, 1e-12);
assert_delta!(m.data[11], 2.0, 1e-12);
// Test from_vec
let data = vec![1, 2, 3, 4, 5, 6];
let m = Matrix::from_vec(3, 2, data).unwrap();
assert_eq!(*m.get(0, 0).unwrap(), 1);
assert_eq!(*m.get(0, 1).unwrap(), 2);
assert_eq!(*m.get(1, 0).unwrap(), 3);
assert_eq!(*m.get(1, 1).unwrap(), 4);
assert_eq!(*m.get(2, 0).unwrap(), 5);
assert_eq!(*m.get(2, 1).unwrap(), 6);
}
#[test]
fn test_matrix_math() {
let a = Matrix::from_vec(3, 2, vec![1, 2, 3, 4, 5, 6]).unwrap();
let b = Matrix::from_vec(3, 2, vec![0, 1, 0, 1, 0, 1]).unwrap();
let c = a.add(&b).unwrap();
assert_eq!(*c.get(0, 0).unwrap(), 1);
assert_eq!(*c.get(0, 1).unwrap(), 3);
assert_eq!(*c.get(1, 0).unwrap(), 3);
assert_eq!(*c.get(1, 1).unwrap(), 5);
assert_eq!(*c.get(2, 0).unwrap(), 5);
assert_eq!(*c.get(2, 1).unwrap(), 7);
assert!(a.mul(&b).is_err());
let a = Matrix::from_vec(3, 2, vec![1, 2, 3, 4, 5, 6]).unwrap();
let b = Matrix::from_vec(2, 3, vec![0, 1, 0, 1, 0, 1]).unwrap();
let c = a.mul(&b).unwrap();
assert_eq!(c.data, vec![2, 1, 2, 4, 3, 4, 6, 5, 6]);
}
}