Source code for fame.FVM.solver

import numpy as np
import jax
import jax.numpy as jnp
import scipy.sparse as sp
import matplotlib.pyplot as plt

from petsc4py import PETSc
from jax.experimental.sparse import BCOO

[docs] class Solver: def __init__(self, A, b, backend="scipy"): """ Initialize the solver with the matrix A, vector b, and backend. Parameters: A: scipy.sparse matrix (A in Ax = b) b: numpy array (b in Ax = b) backend: str, one of ["scipy", "jax", "petsc"] """ if not sp.isspmatrix(A): raise TypeError("A must be a scipy sparse matrix.") if not isinstance(b, np.ndarray): raise TypeError("b must be a numpy array.") self.A = A self.b = b self.solution = None self.backend = backend.lower() if self.backend not in ["scipy", "jax", "petsc"]: raise ValueError("Unsupported backend. Choose from 'scipy', 'jax', or 'petsc'.")
[docs] def solve(self, method="bicgstab", preconditioner="none"): """ Solve the system Ax = b using the selected backend and method. Parameters: method: str, optional (default="bicgstab") The solver method to use (e.g., "bicgstab", "cg", "gmres"). preconditioner: str, optional (default="none") Preconditioner type (e.g., "jacobi") or "none" for no preconditioning. For PETSc, passed directly to pc.setType(). Returns: solution: numpy array The solution vector. """ if self.backend == "scipy": self.solution = self._solve_scipy(method, preconditioner) elif self.backend == "jax": self.solution = self._solve_jax(method, preconditioner) elif self.backend == "petsc": self.solution = self._solve_petsc(method, preconditioner) return self.solution
def _solve_scipy(self, method, preconditioner): """ Solve using Scipy's iterative solvers with optional Jacobi preconditioning. """ solverMethods = { "bicgstab": sp.linalg.bicgstab, "cg": sp.linalg.cg, "gmres": sp.linalg.gmres } if method not in solverMethods: raise ValueError(f"Unsupported method '{method}' for scipy backend.") if preconditioner == "jacobi": # Construct the Jacobi preconditioner jacobi_diag = self.A.diagonal() if np.any(jacobi_diag == 0): raise ValueError("Jacobi preconditioner cannot be constructed: zero diagonal entries.") preconditioner_fn = sp.linalg.LinearOperator( dtype=self.A.dtype, shape=self.A.shape, matvec=lambda x: x / jacobi_diag, ) elif preconditioner == "none": preconditioner_fn = None else: raise ValueError(f"Unsupported preconditioner '{preconditioner}' for scipy backend.") solution, info = solverMethods[method](self.A, self.b, rtol=1e-10, atol=1e-10, maxiter=None, M=preconditioner_fn) err = np.linalg.norm(self.A @ solution - self.b) print(f"Scipy {method} solver residual: {err}") return solution, err, info def _solve_jax(self, method, preconditioner): """ Solve using JAX's iterative solvers with optional Jacobi preconditioning. """ solver_methods = { "bicgstab": jax.scipy.sparse.linalg.bicgstab, "cg": jax.scipy.sparse.linalg.cg, "gmres": jax.scipy.sparse.linalg.gmres } if method not in solver_methods: raise ValueError(f"Unsupported method '{method}' for JAX backend. Supported methods: {list(solver_methods.keys())}.") # Convert scipy sparse matrix to JAX sparse matrix A_jax = BCOO.from_scipy_sparse(self.A).sort_indices() # Prepare the right-hand side vector b_jax = jnp.array(self.b) # Create preconditioner (Jacobi diagonal scaling) if preconditioner == "jacobi": jacobi_diag = jnp.array(self.A.diagonal()) if jnp.any(jacobi_diag == 0): raise ValueError("Jacobi preconditioner cannot be constructed: zero diagonal entries.") preconditioner_fn = lambda x: x / jacobi_diag elif preconditioner == "none": preconditioner_fn = None else: raise ValueError(f"Unsupported preconditioner '{preconditioner}' for JAX backend.") # Solve using the selected JAX method x0_jax = jnp.zeros_like(b_jax) solution, info = solver_methods[method](A_jax, b_jax, tol=1e-10, atol=1e-10, maxiter=None, M=preconditioner_fn, x0=x0_jax, ) # Flatten the solution if necessary and convert to NumPy if isinstance(solution, (tuple, list)): solution = solution[0] # Use the first element of the tuple if applicable solution = np.array(solution) # Verify convergence residual = jnp.linalg.norm(A_jax @ solution - b_jax) print(f"JAX {method} solver residual: {residual}") if info is not None and info != 0: raise RuntimeError(f"JAX solver failed to converge: info={info}") return solution, residual, info def _solve_petsc(self, method, preconditioner): """ Solve using PETSc solver. Available Preconditioners: - jacobi: Diagonal scaling preconditioner. - ilu: Incomplete LU factorization. - sor: Successive over-relaxation. - none: No preconditioning. - asm: Additive Schwarz method. - bjacobi: Block Jacobi preconditioner. """ petscMethods = { "bicgstab": "bcgs", # Correct PETSc name for BiCGSTAB "cg": "cg", # PETSc name for Conjugate Gradient "gmres": "gmres" # PETSc name for GMRES } if method not in petscMethods: raise ValueError(f"Unsupported method '{method}' for petsc backend.") if not isinstance(self.A, sp.csr_matrix): self.A = self.A.tocsr() mat = PETSc.Mat().createAIJ(size=self.A.shape, csr=(self.A.indptr, self.A.indices, self.A.data)) vec_b = PETSc.Vec().createWithArray(self.b) vec_x = PETSc.Vec().createWithArray(np.zeros_like(self.b)) ksp = PETSc.KSP().create() ksp.setOperators(mat) ksp.setType(petscMethods[method]) # Use corrected PETSc solver type pc = ksp.getPC() pc.setType(preconditioner) ksp.solve(vec_b, vec_x) solution = vec_x.getArray() err = np.linalg.norm(self.A @ solution - self.b) iteration_number = ksp.getIterationNumber() print(f"PETSc {method} solver residual: {err}, Iterations: {iteration_number}") return solution, err, iteration_number # Utility method to visualize the matrix
[docs] def plotSparseMatrix(self, matrix, filename="matrix.jpeg"): """ Saves a visualization of the sparse matrix as a .jpeg image. Parameters: matrix (sp.spmatrix): Sparse matrix to visualize. filename (str): Output file name for the image. """ if not sp.isspmatrix(matrix): raise TypeError("Matrix A must be a scipy sparse matrix.") plt.figure(figsize=(8, 8)) try: plt.spy(matrix, markersize=1) plt.title("Sparse Matrix Visualization") plt.xlabel("Columns") plt.ylabel("Rows") plt.savefig(filename, format="jpeg") plt.close() print(f"Sparse matrix plot saved to {filename}") except Exception as e: plt.close() raise RuntimeError(f"Failed to generate the sparse matrix plot: {e}")