diff --git a/src/matrix.rs b/src/matrix.rs index b196e3b..8d39ff5 100644 --- a/src/matrix.rs +++ b/src/matrix.rs @@ -92,11 +92,19 @@ impl Matrix { i * self.n + j } pub fn get(&self, i: usize, j: usize) -> Result<&T, MatrixError> { + if i >= self.m || j >= self.n { + return Err(make_index_error(i, j, self)); + } self.data .get(self.index(i, j)) .ok_or(make_index_error(i, j, self)) } + pub fn set(&mut self, i: usize, j: usize, x: T) { + let index = self.index(i, j); + self.data[index] = x; + } + pub fn add(&self, other: &Matrix) -> Result, MatrixError> { // Compatibility check if self.m != other.m || self.n != other.n { @@ -135,9 +143,15 @@ impl Matrix { Ok(c) } - pub fn set(&mut self, i: usize, j: usize, x: T) { - let index = self.index(i, j); - self.data[index] = x; + pub fn transpose(&self) -> Result { + let mut c = Self::new(self.n, self.m, T::default()); + + for i in 0..self.m { + for j in 0..self.n { + c.set(j, i, *self.get(i, j)?); + } + } + Ok(c) } } @@ -227,5 +241,22 @@ mod test { let c = a.mul(&b).unwrap(); assert_eq!(c.data, vec![2, 1, 2, 4, 3, 4, 6, 5, 6]); + + let a = Matrix::from_vec(3, 4, vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]).unwrap(); + let atranspose = a.transpose().unwrap(); + + assert_eq!(*atranspose.get(0, 0).unwrap(), 1); + assert_eq!(*atranspose.get(0, 1).unwrap(), 5); + assert_eq!(*atranspose.get(0, 2).unwrap(), 9); + assert!(atranspose.get(0, 3).is_err()); + assert_eq!(*atranspose.get(1, 0).unwrap(), 2); + assert_eq!(*atranspose.get(1, 1).unwrap(), 6); + assert_eq!(*atranspose.get(1, 2).unwrap(), 10); + assert_eq!(*atranspose.get(2, 0).unwrap(), 3); + assert_eq!(*atranspose.get(2, 1).unwrap(), 7); + assert_eq!(*atranspose.get(2, 2).unwrap(), 11); + assert_eq!(*atranspose.get(3, 0).unwrap(), 4); + assert_eq!(*atranspose.get(3, 1).unwrap(), 8); + assert_eq!(*atranspose.get(3, 2).unwrap(), 12); } }