mirror of
https://github.com/aselimov/cea-rs.git
synced 2026-04-21 01:14:20 +00:00
Add basic matrix math
This commit is contained in:
parent
93bd1f4d38
commit
df0e93735b
1 changed files with 99 additions and 3 deletions
102
src/matrix.rs
102
src/matrix.rs
|
|
@ -11,8 +11,9 @@ pub trait Numeric:
|
||||||
+ Mul<Output = Self>
|
+ Mul<Output = Self>
|
||||||
+ Div<Output = Self>
|
+ Div<Output = Self>
|
||||||
+ Sized
|
+ Sized
|
||||||
+ Clone
|
+ Copy
|
||||||
+ Display
|
+ Display
|
||||||
|
+ Default
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -25,6 +26,8 @@ impl<T> Numeric for T where
|
||||||
+ Sized
|
+ Sized
|
||||||
+ Clone
|
+ Clone
|
||||||
+ Display
|
+ Display
|
||||||
|
+ Default
|
||||||
|
+ Copy
|
||||||
{
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -32,6 +35,9 @@ impl<T> Numeric for T where
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub enum MatrixError {
|
pub enum MatrixError {
|
||||||
IndexError(usize, usize, usize, usize),
|
IndexError(usize, usize, usize, usize),
|
||||||
|
AddError(usize, usize, usize, usize),
|
||||||
|
MultiplicationError(usize, usize, usize, usize),
|
||||||
|
FromVecError(usize, usize, usize),
|
||||||
}
|
}
|
||||||
|
|
||||||
impl fmt::Display for MatrixError {
|
impl fmt::Display for MatrixError {
|
||||||
|
|
@ -42,6 +48,19 @@ impl fmt::Display for MatrixError {
|
||||||
"Error accessing index [{i},{j}] for matrix with dimensions [{}, {}]",
|
"Error accessing index [{i},{j}] for matrix with dimensions [{}, {}]",
|
||||||
m, n
|
m, n
|
||||||
),
|
),
|
||||||
|
|
||||||
|
MatrixError::MultiplicationError(i, j, m, n) => write!(
|
||||||
|
f,
|
||||||
|
"Matrices with dimensions [{i},{j}] and [{m},{n}] cannot be multiplied",
|
||||||
|
),
|
||||||
|
MatrixError::AddError(i, j, m, n) => write!(
|
||||||
|
f,
|
||||||
|
"Matrices with dimensions [{i},{j}] and [{m},{n}] cannot be added",
|
||||||
|
),
|
||||||
|
MatrixError::FromVecError(i, j, len) => write!(
|
||||||
|
f,
|
||||||
|
"Matrices with dimensions [{i},{j}] cannot be created from vec with len={len}",
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -62,13 +81,58 @@ impl<T: Numeric> Matrix<T> {
|
||||||
Matrix { data, m, n }
|
Matrix { data, m, n }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn from_vec(m: usize, n: usize, data: Vec<T>) -> Result<Self, MatrixError> {
|
||||||
|
if m * n != data.len() {
|
||||||
|
return Err(MatrixError::FromVecError(m, n, data.len()));
|
||||||
|
}
|
||||||
|
Ok(Self { m, n, data })
|
||||||
|
}
|
||||||
|
|
||||||
fn index(&self, i: usize, j: usize) -> usize {
|
fn index(&self, i: usize, j: usize) -> usize {
|
||||||
i * self.n + j
|
i * self.n + j
|
||||||
}
|
}
|
||||||
pub fn get(&self, i: usize, j: usize) -> Result<&T, MatrixError> {
|
pub fn get(&self, i: usize, j: usize) -> Result<&T, MatrixError> {
|
||||||
self.data
|
self.data
|
||||||
.get(self.index(i, j))
|
.get(self.index(i, j))
|
||||||
.ok_or(make_index_error(i, j, &self))
|
.ok_or(make_index_error(i, j, self))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn add(&self, other: &Matrix<T>) -> Result<Matrix<T>, MatrixError> {
|
||||||
|
// Compatibility check
|
||||||
|
if self.m != other.m || self.n != other.n {
|
||||||
|
return Err(MatrixError::AddError(self.m, self.n, other.m, other.n));
|
||||||
|
}
|
||||||
|
|
||||||
|
let data = self
|
||||||
|
.data
|
||||||
|
.iter()
|
||||||
|
.zip(other.data.iter())
|
||||||
|
.map(|(a, b)| *a + *b)
|
||||||
|
.collect();
|
||||||
|
Self::from_vec(self.m, self.n, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mul(&self, other: &Matrix<T>) -> Result<Matrix<T>, MatrixError> {
|
||||||
|
let mut c = Matrix::new(self.m, other.n, T::default());
|
||||||
|
// Compatibility check
|
||||||
|
if self.n != other.m {
|
||||||
|
return Err(MatrixError::MultiplicationError(
|
||||||
|
self.m, self.n, other.m, other.n,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
for i in 0..self.m {
|
||||||
|
for k in 0..self.n {
|
||||||
|
for j in 0..other.n {
|
||||||
|
c.set(
|
||||||
|
i,
|
||||||
|
j,
|
||||||
|
*c.get(i, j)? + (*self.get(i, k)?) * (*other.get(k, j)?),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn set(&mut self, i: usize, j: usize, x: T) {
|
pub fn set(&mut self, i: usize, j: usize, x: T) {
|
||||||
|
|
@ -94,7 +158,7 @@ impl<T: Numeric> Display for Matrix<T> {
|
||||||
})
|
})
|
||||||
.collect::<Vec<String>>()
|
.collect::<Vec<String>>()
|
||||||
.join(" ");
|
.join(" ");
|
||||||
write!(f, "\n{msg}")
|
write!(f, "\n {msg}")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -131,5 +195,37 @@ mod test {
|
||||||
assert_delta!(m.data[7], 2.0, 1e-12);
|
assert_delta!(m.data[7], 2.0, 1e-12);
|
||||||
assert_delta!(m.data[10], 1.0, 1e-12);
|
assert_delta!(m.data[10], 1.0, 1e-12);
|
||||||
assert_delta!(m.data[11], 2.0, 1e-12);
|
assert_delta!(m.data[11], 2.0, 1e-12);
|
||||||
|
|
||||||
|
// Test from_vec
|
||||||
|
let data = vec![1, 2, 3, 4, 5, 6];
|
||||||
|
let m = Matrix::from_vec(3, 2, data).unwrap();
|
||||||
|
assert_eq!(*m.get(0, 0).unwrap(), 1);
|
||||||
|
assert_eq!(*m.get(0, 1).unwrap(), 2);
|
||||||
|
assert_eq!(*m.get(1, 0).unwrap(), 3);
|
||||||
|
assert_eq!(*m.get(1, 1).unwrap(), 4);
|
||||||
|
assert_eq!(*m.get(2, 0).unwrap(), 5);
|
||||||
|
assert_eq!(*m.get(2, 1).unwrap(), 6);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_matrix_math() {
|
||||||
|
let a = Matrix::from_vec(3, 2, vec![1, 2, 3, 4, 5, 6]).unwrap();
|
||||||
|
let b = Matrix::from_vec(3, 2, vec![0, 1, 0, 1, 0, 1]).unwrap();
|
||||||
|
|
||||||
|
let c = a.add(&b).unwrap();
|
||||||
|
assert_eq!(*c.get(0, 0).unwrap(), 1);
|
||||||
|
assert_eq!(*c.get(0, 1).unwrap(), 3);
|
||||||
|
assert_eq!(*c.get(1, 0).unwrap(), 3);
|
||||||
|
assert_eq!(*c.get(1, 1).unwrap(), 5);
|
||||||
|
assert_eq!(*c.get(2, 0).unwrap(), 5);
|
||||||
|
assert_eq!(*c.get(2, 1).unwrap(), 7);
|
||||||
|
|
||||||
|
assert!(a.mul(&b).is_err());
|
||||||
|
|
||||||
|
let a = Matrix::from_vec(3, 2, vec![1, 2, 3, 4, 5, 6]).unwrap();
|
||||||
|
let b = Matrix::from_vec(2, 3, vec![0, 1, 0, 1, 0, 1]).unwrap();
|
||||||
|
let c = a.mul(&b).unwrap();
|
||||||
|
|
||||||
|
assert_eq!(c.data, vec![2, 1, 2, 4, 3, 4, 6, 5, 6]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue