diff --git a/phiml/backend/_linalg.py b/phiml/backend/_linalg.py index 3b6c33ee..e107c9b8 100644 --- a/phiml/backend/_linalg.py +++ b/phiml/backend/_linalg.py @@ -323,7 +323,7 @@ def scipy_solve(np_y, np_x0, np_rtol, np_atol, *np_tensors): np_lin = assemble_lin(NUMPY, *np_tensors[:n_lin_tensors]) np_pre = assemble_pre(NUMPY, *np_tensors[n_lin_tensors:]) if method == 'direct': - npr = scipy_direct_linear_solve(NUMPY, np_lin, np_y) + npr = scipy_direct_linear_solve(NUMPY, np_lin, np_y, np_rtol, np_atol) else: npr = scipy_iterative_sparse_solve(NUMPY, np_lin, np_y, np_x0, np_rtol, np_atol, max_iter, np_pre, function) return npr.x, npr.residual, npr.iterations, npr.function_evaluations, npr.converged, npr.diverged @@ -347,7 +347,7 @@ def scipy_solve(np_y, np_x0, np_rtol, np_atol, *np_pre_tensors): return SolveResult(method_name, x, residual, iterations, function_evaluations, converged, diverged, [""] * batch_size) -def scipy_direct_linear_solve(b: Backend, lin, y) -> SolveResult: +def scipy_direct_linear_solve(b: Backend, lin, y, rtol, atol) -> SolveResult: batch_size = b.staticshape(y)[0] xs = [] converged = [] @@ -361,9 +361,12 @@ def scipy_direct_linear_solve(b: Backend, lin, y) -> SolveResult: for batch in range(batch_size): # use_umfpack=self.precision == 64 x = spsolve(lin[batch], y[batch]) # returns nan when diverges + residual = lin[batch] @ x - y[batch] + residual_norm = np.linalg.norm(residual) + y_norm = np.linalg.norm(y[batch]) xs.append(x) - converged.append(np.all(np.isfinite(x))) - residuals.append(lin[batch] @ x - y[batch]) + converged.append(residual_norm <= np.maximum(atol[batch], rtol[batch] * y_norm)) + residuals.append(residual) x = np.stack(xs) converged = np.stack(converged) residual = np.stack(residuals)