mirror of
https://github.com/aselimov/cea-rs.git
synced 2026-04-19 00:24:20 +00:00
Add basic matrix math
This commit is contained in:
parent
93bd1f4d38
commit
df0e93735b
1 changed files with 99 additions and 3 deletions
100
src/matrix.rs
100
src/matrix.rs
|
|
@ -11,8 +11,9 @@ pub trait Numeric:
|
|||
+ Mul<Output = Self>
|
||||
+ Div<Output = Self>
|
||||
+ Sized
|
||||
+ Clone
|
||||
+ Copy
|
||||
+ Display
|
||||
+ Default
|
||||
{
|
||||
}
|
||||
|
||||
|
|
@ -25,6 +26,8 @@ impl<T> Numeric for T where
|
|||
+ Sized
|
||||
+ Clone
|
||||
+ Display
|
||||
+ Default
|
||||
+ Copy
|
||||
{
|
||||
}
|
||||
|
||||
|
|
@ -32,6 +35,9 @@ impl<T> Numeric for T where
|
|||
#[derive(Debug)]
|
||||
pub enum MatrixError {
|
||||
IndexError(usize, usize, usize, usize),
|
||||
AddError(usize, usize, usize, usize),
|
||||
MultiplicationError(usize, usize, usize, usize),
|
||||
FromVecError(usize, usize, usize),
|
||||
}
|
||||
|
||||
impl fmt::Display for MatrixError {
|
||||
|
|
@ -42,6 +48,19 @@ impl fmt::Display for MatrixError {
|
|||
"Error accessing index [{i},{j}] for matrix with dimensions [{}, {}]",
|
||||
m, n
|
||||
),
|
||||
|
||||
MatrixError::MultiplicationError(i, j, m, n) => write!(
|
||||
f,
|
||||
"Matrices with dimensions [{i},{j}] and [{m},{n}] cannot be multiplied",
|
||||
),
|
||||
MatrixError::AddError(i, j, m, n) => write!(
|
||||
f,
|
||||
"Matrices with dimensions [{i},{j}] and [{m},{n}] cannot be added",
|
||||
),
|
||||
MatrixError::FromVecError(i, j, len) => write!(
|
||||
f,
|
||||
"Matrices with dimensions [{i},{j}] cannot be created from vec with len={len}",
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -62,13 +81,58 @@ impl<T: Numeric> Matrix<T> {
|
|||
Matrix { data, m, n }
|
||||
}
|
||||
|
||||
pub fn from_vec(m: usize, n: usize, data: Vec<T>) -> Result<Self, MatrixError> {
|
||||
if m * n != data.len() {
|
||||
return Err(MatrixError::FromVecError(m, n, data.len()));
|
||||
}
|
||||
Ok(Self { m, n, data })
|
||||
}
|
||||
|
||||
fn index(&self, i: usize, j: usize) -> usize {
|
||||
i * self.n + j
|
||||
}
|
||||
pub fn get(&self, i: usize, j: usize) -> Result<&T, MatrixError> {
|
||||
self.data
|
||||
.get(self.index(i, j))
|
||||
.ok_or(make_index_error(i, j, &self))
|
||||
.ok_or(make_index_error(i, j, self))
|
||||
}
|
||||
|
||||
pub fn add(&self, other: &Matrix<T>) -> Result<Matrix<T>, MatrixError> {
|
||||
// Compatibility check
|
||||
if self.m != other.m || self.n != other.n {
|
||||
return Err(MatrixError::AddError(self.m, self.n, other.m, other.n));
|
||||
}
|
||||
|
||||
let data = self
|
||||
.data
|
||||
.iter()
|
||||
.zip(other.data.iter())
|
||||
.map(|(a, b)| *a + *b)
|
||||
.collect();
|
||||
Self::from_vec(self.m, self.n, data)
|
||||
}
|
||||
|
||||
pub fn mul(&self, other: &Matrix<T>) -> Result<Matrix<T>, MatrixError> {
|
||||
let mut c = Matrix::new(self.m, other.n, T::default());
|
||||
// Compatibility check
|
||||
if self.n != other.m {
|
||||
return Err(MatrixError::MultiplicationError(
|
||||
self.m, self.n, other.m, other.n,
|
||||
));
|
||||
}
|
||||
|
||||
for i in 0..self.m {
|
||||
for k in 0..self.n {
|
||||
for j in 0..other.n {
|
||||
c.set(
|
||||
i,
|
||||
j,
|
||||
*c.get(i, j)? + (*self.get(i, k)?) * (*other.get(k, j)?),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(c)
|
||||
}
|
||||
|
||||
pub fn set(&mut self, i: usize, j: usize, x: T) {
|
||||
|
|
@ -131,5 +195,37 @@ mod test {
|
|||
assert_delta!(m.data[7], 2.0, 1e-12);
|
||||
assert_delta!(m.data[10], 1.0, 1e-12);
|
||||
assert_delta!(m.data[11], 2.0, 1e-12);
|
||||
|
||||
// Test from_vec
|
||||
let data = vec![1, 2, 3, 4, 5, 6];
|
||||
let m = Matrix::from_vec(3, 2, data).unwrap();
|
||||
assert_eq!(*m.get(0, 0).unwrap(), 1);
|
||||
assert_eq!(*m.get(0, 1).unwrap(), 2);
|
||||
assert_eq!(*m.get(1, 0).unwrap(), 3);
|
||||
assert_eq!(*m.get(1, 1).unwrap(), 4);
|
||||
assert_eq!(*m.get(2, 0).unwrap(), 5);
|
||||
assert_eq!(*m.get(2, 1).unwrap(), 6);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_matrix_math() {
|
||||
let a = Matrix::from_vec(3, 2, vec![1, 2, 3, 4, 5, 6]).unwrap();
|
||||
let b = Matrix::from_vec(3, 2, vec![0, 1, 0, 1, 0, 1]).unwrap();
|
||||
|
||||
let c = a.add(&b).unwrap();
|
||||
assert_eq!(*c.get(0, 0).unwrap(), 1);
|
||||
assert_eq!(*c.get(0, 1).unwrap(), 3);
|
||||
assert_eq!(*c.get(1, 0).unwrap(), 3);
|
||||
assert_eq!(*c.get(1, 1).unwrap(), 5);
|
||||
assert_eq!(*c.get(2, 0).unwrap(), 5);
|
||||
assert_eq!(*c.get(2, 1).unwrap(), 7);
|
||||
|
||||
assert!(a.mul(&b).is_err());
|
||||
|
||||
let a = Matrix::from_vec(3, 2, vec![1, 2, 3, 4, 5, 6]).unwrap();
|
||||
let b = Matrix::from_vec(2, 3, vec![0, 1, 0, 1, 0, 1]).unwrap();
|
||||
let c = a.mul(&b).unwrap();
|
||||
|
||||
assert_eq!(c.data, vec![2, 1, 2, 4, 3, 4, 6, 5, 6]);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue