From 8805b76535757c64d99c6c218ac3175a87b143c6 Mon Sep 17 00:00:00 2001 From: Alex Selimov Date: Tue, 14 Apr 2026 22:33:24 -0400 Subject: [PATCH] Simple gauss-jordan solver --- src/mixtures/gas_mixture.rs | 1 - src/solvers/equations.rs | 209 ++++++++++++++++++++++++++++++++++++ src/solvers/mod.rs | 3 + 3 files changed, 212 insertions(+), 1 deletion(-) create mode 100644 src/solvers/equations.rs create mode 100644 src/solvers/mod.rs diff --git a/src/mixtures/gas_mixture.rs b/src/mixtures/gas_mixture.rs index 651bef5..f0df99f 100644 --- a/src/mixtures/gas_mixture.rs +++ b/src/mixtures/gas_mixture.rs @@ -15,7 +15,6 @@ pub struct GasMixture { impl GasMixture { // Calculate the normalized chemical potential (μ/RT) for each component in the mixture. - // // Equations 2.11 from reference paper pub fn gas_chem_potentials_over_rt(&self, temp: f64, pressure: f64) -> Vec { self.ns diff --git a/src/solvers/equations.rs b/src/solvers/equations.rs new file mode 100644 index 0000000..0829845 --- /dev/null +++ b/src/solvers/equations.rs @@ -0,0 +1,209 @@ +use std::fmt; + +use crate::{error::CEAError, matrix::Matrix}; + +// Basic Error handling for solver +#[derive(Debug)] +pub enum SolverError { + NotSquareMatrix(usize, usize), + NotColumnVector(usize, usize), + DimensionMismatch(usize, usize), + SingularMatrix, +} + +impl fmt::Display for SolverError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + SolverError::NotSquareMatrix(m, n) => { + write!(f, "Matrix with dimensions [{m}, {n}] is not square") + } + SolverError::NotColumnVector(m, n) => write!( + f, + "Matrix with dimensions [{m}, {n}] is not a column vector" + ), + SolverError::DimensionMismatch(a_m, b_n) => write!( + f, + "Square matrix dimensions [{a_m}, {a_m}]\ + do not match column vector dimensions [{b_n}, 1]" + ), + SolverError::SingularMatrix => write!(f, "Matrix is singular, solution does not exist"), + } + } +} +impl std::error::Error for SolverError {} + +// Solve Ax=b for x using Gauss-Jordan elimination +// A must be a square n x n matrix and b must be a column vector, i.e. n x 1 +// Algorithm taken from Numerical Methods for Engineers by Ayyub and McCuen +pub fn gauss_jordan_elimination(a: &Matrix, b: &Matrix) -> Result, CEAError> { + validate_ax_eq_b(a, b)?; + + let mut a = a.clone(); + let mut b = b.clone(); + + for i in 0..a.shape().0 { + // First pivot rows + let mut max_and_index = (*a.get(i, i)?, i); + for j in i + 1..a.shape().0 { + max_and_index = if a.get(j, i)?.abs() > max_and_index.0.abs() { + (*a.get(j, i)?, j) + } else { + max_and_index + } + } + + if max_and_index.0.abs() < 1e-12 { + return Err(SolverError::SingularMatrix.into()); + } + a.swap_rows(i, max_and_index.1); + b.swap_rows(i, max_and_index.1); + + // Normalization + for j in i + 1..a.shape().0 { + a.set(i, j, a.get(i, j)? / dbg!(a.get(i, i)?)); + } + b.set(i, 0, b.get(i, 0)? / a.get(i, i)?); + a.set(i, i, 1.0); + + //forward pass (seems wrong) + for k in i + 1..a.shape().0 { + let aki = *a.get(k, i)?; + for j in i..a.shape().0 { + a.set(k, j, a.get(k, j)? - aki * a.get(i, j)?); + } + b.set(k, 0, b.get(k, 0)? - aki * b.get(i, 0)?); + } + } + //backward pass + let mut xs = b.clone(); + for i in (0..=a.shape().0 - 1).rev() { + for j in i + 1..a.shape().0 { + xs.set(i, 0, xs.get(i, 0)? - a.get(i, j)? * xs.get(j, 0)?); + } + } + + Ok(xs) +} + +#[cfg(test)] +mod tests { + use crate::{assert_delta, matrix::Matrix, solvers::equations::gauss_jordan_elimination}; + + fn col_vec(data: Vec) -> Matrix { + let n = data.len(); + Matrix::from_vec(n, 1, data).unwrap() + } + + // 2x2: 2x + y = 5, x + 3y = 10 => x=1, y=3 + #[test] + fn test_2x2_simple() { + let a = Matrix::from_vec(2, 2, vec![2.0, 1.0, 1.0, 3.0]).unwrap(); + let b = col_vec(vec![5.0, 10.0]); + let x = gauss_jordan_elimination(&a, &b).unwrap(); + assert_delta!(x.get(0, 0).unwrap(), 1.0, 1e-10); + assert_delta!(x.get(1, 0).unwrap(), 3.0, 1e-10); + } + + // 3x3 identity: Ix = b => x = b + #[test] + fn test_3x3_identity() { + let a = Matrix::from_vec(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]).unwrap(); + let b = col_vec(vec![4.0, 7.0, 2.0]); + let x = gauss_jordan_elimination(&a, &b).unwrap(); + assert_delta!(x.get(0, 0).unwrap(), 4.0, 1e-10); + assert_delta!(x.get(1, 0).unwrap(), 7.0, 1e-10); + assert_delta!(x.get(2, 0).unwrap(), 2.0, 1e-10); + } + + // 3x3 general system with singularity + // A = [[1,1,1],[0,2,1],[2,0,1]], b = [6, 7, 5] + #[test] + fn test_3x3_singular() { + let a = Matrix::from_vec(3, 3, vec![1.0, 1.0, 1.0, 0.0, 2.0, 1.0, 2.0, 0.0, 1.0]).unwrap(); + let b = col_vec(vec![6.0, 7.0, 5.0]); + assert!(gauss_jordan_elimination(&a, &b).is_err()); + } + + // 3x3 example. Example 5-6 in the reference book + #[test] + fn test_3x3_example_5_6() { + let a = Matrix::from_vec(3, 3, vec![1.0, 3.0, 2.0, 2.0, 4.0, 3.0, 3.0, 4.0, 7.0]).unwrap(); + let b = col_vec(vec![15.0, 22.0, 39.0]); + let x = gauss_jordan_elimination(&a, &b).unwrap(); + + assert_delta!(x.get(0, 0).unwrap(), 1.0, 1e-12); + assert_delta!(x.get(1, 0).unwrap(), 2.0, 1e-12); + assert_delta!(x.get(2, 0).unwrap(), 4.0, 1e-12); + } + + #[test] + fn test_6x6_complex() { + let a = Matrix::from_vec( + 6, + 6, + vec![ + 4.0, 3.0, 2.0, 1.0, 1.0, 2.0, 1.0, 5.0, 2.0, 3.0, 1.0, 1.0, 2.0, 1.0, 6.0, 2.0, + 3.0, 1.0, 1.0, 2.0, 1.0, 7.0, 2.0, 3.0, 3.0, 1.0, 2.0, 1.0, 5.0, 2.0, 2.0, 3.0, + 1.0, 2.0, 1.0, 6.0, + ], + ) + .unwrap(); + let b = col_vec(vec![37.0, 40.0, 51.0, 64.0, 52.0, 60.0]); + let x = gauss_jordan_elimination(&a, &b).unwrap(); + assert_delta!(x.get(0, 0).unwrap(), 1.0, 1e-12); + assert_delta!(x.get(1, 0).unwrap(), 2.0, 1e-12); + assert_delta!(x.get(2, 0).unwrap(), 3.0, 1e-12); + assert_delta!(x.get(3, 0).unwrap(), 4.0, 1e-12); + assert_delta!(x.get(4, 0).unwrap(), 5.0, 1e-12); + assert_delta!(x.get(5, 0).unwrap(), 6.0, 1e-12); + } + + // Singular matrix should return an error + #[test] + fn test_singular_matrix() { + let a = Matrix::from_vec(2, 2, vec![1.0, 2.0, 2.0, 4.0]).unwrap(); + let b = col_vec(vec![1.0, 2.0]); + assert!(gauss_jordan_elimination(&a, &b).is_err()); + } + + // Non-square A should return an error + #[test] + fn test_non_square_matrix() { + let a = Matrix::from_vec(2, 3, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap(); + let b = col_vec(vec![1.0, 2.0]); + assert!(gauss_jordan_elimination(&a, &b).is_err()); + } + + // b with more than one column should return an error + #[test] + fn test_b_not_column_vector() { + let a = Matrix::from_vec(2, 2, vec![1.0, 0.0, 0.0, 1.0]).unwrap(); + let b = Matrix::from_vec(2, 2, vec![1.0, 0.0, 0.0, 1.0]).unwrap(); + assert!(gauss_jordan_elimination(&a, &b).is_err()); + } + + // Dimension mismatch between A and b should return an error + #[test] + fn test_dimension_mismatch() { + let a = Matrix::from_vec(3, 3, vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0]).unwrap(); + let b = col_vec(vec![1.0, 2.0]); + assert!(gauss_jordan_elimination(&a, &b).is_err()); + } +} + +fn validate_ax_eq_b(a: &Matrix, b: &Matrix) -> Result<(), CEAError> { + let (a_m, a_n) = a.shape(); + if a_m != a_n { + return Err(SolverError::NotSquareMatrix(a_m, a_n).into()); + } + + let (b_m, b_n) = b.shape(); + if b_n != 1 { + return Err(SolverError::NotColumnVector(b_m, b_n).into()); + } + + if b_m != a_n { + return Err(SolverError::DimensionMismatch(a_n, b_m).into()); + } + Ok(()) +} diff --git a/src/solvers/mod.rs b/src/solvers/mod.rs new file mode 100644 index 0000000..1812bb7 --- /dev/null +++ b/src/solvers/mod.rs @@ -0,0 +1,3 @@ +pub mod equations; + +pub use equations::SolverError;