Skip to content

Commit

Permalink
Check convergence with scipy-direct
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Feb 2, 2025
1 parent 40d3a2b commit b714f8b
Showing 1 changed file with 7 additions and 4 deletions.
11 changes: 7 additions & 4 deletions phiml/backend/_linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def scipy_solve(np_y, np_x0, np_rtol, np_atol, *np_tensors):
np_lin = assemble_lin(NUMPY, *np_tensors[:n_lin_tensors])
np_pre = assemble_pre(NUMPY, *np_tensors[n_lin_tensors:])
if method == 'direct':
npr = scipy_direct_linear_solve(NUMPY, np_lin, np_y)
npr = scipy_direct_linear_solve(NUMPY, np_lin, np_y, np_rtol, np_atol)
else:
npr = scipy_iterative_sparse_solve(NUMPY, np_lin, np_y, np_x0, np_rtol, np_atol, max_iter, np_pre, function)
return npr.x, npr.residual, npr.iterations, npr.function_evaluations, npr.converged, npr.diverged
Expand All @@ -347,7 +347,7 @@ def scipy_solve(np_y, np_x0, np_rtol, np_atol, *np_pre_tensors):
return SolveResult(method_name, x, residual, iterations, function_evaluations, converged, diverged, [""] * batch_size)


def scipy_direct_linear_solve(b: Backend, lin, y) -> SolveResult:
def scipy_direct_linear_solve(b: Backend, lin, y, rtol, atol) -> SolveResult:
batch_size = b.staticshape(y)[0]
xs = []
converged = []
Expand All @@ -361,9 +361,12 @@ def scipy_direct_linear_solve(b: Backend, lin, y) -> SolveResult:
for batch in range(batch_size):
# use_umfpack=self.precision == 64
x = spsolve(lin[batch], y[batch]) # returns nan when diverges
residual = lin[batch] @ x - y[batch]
residual_norm = np.linalg.norm(residual)
y_norm = np.linalg.norm(y[batch])
xs.append(x)
converged.append(np.all(np.isfinite(x)))
residuals.append(lin[batch] @ x - y[batch])
converged.append(residual_norm <= np.maximum(atol[batch], rtol[batch] * y_norm))
residuals.append(residual)
x = np.stack(xs)
converged = np.stack(converged)
residual = np.stack(residuals)
Expand Down

0 comments on commit b714f8b

Please sign in to comment.