From a408567a5dce429710784b50deec884413031a1c Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Fri, 10 Jan 2025 19:06:55 +0100 Subject: [PATCH] Fix jitted scipy-lsqr --- phiml/backend/_linalg.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/phiml/backend/_linalg.py b/phiml/backend/_linalg.py index 80d63f45..12028c24 100644 --- a/phiml/backend/_linalg.py +++ b/phiml/backend/_linalg.py @@ -338,7 +338,10 @@ def scipy_solve(np_y, np_x0, np_rtol, np_atol, *np_pre_tensors): fp = b.float_type i = INT32 bo = BOOL - x, residual, iterations, function_evaluations, converged, diverged = b.numpy_call(scipy_solve, (x0.shape, x0.shape, x0.shape[:1], x0.shape[:1], x0.shape[:1], x0.shape[:1]), (fp, fp, i, i, bo, bo), y, x0, rtol, atol, *lin_tensors, *pre_tensors) + rsd_shape = list(x0.shape) + if was_row_added: + rsd_shape[1] += 1 + x, residual, iterations, function_evaluations, converged, diverged = b.numpy_call(scipy_solve, (x0.shape, rsd_shape, x0.shape[:1], x0.shape[:1], x0.shape[:1], x0.shape[:1]), (fp, fp, i, i, bo, bo), y, x0, rtol, atol, *lin_tensors, *pre_tensors) if was_row_added: residual = residual[:, :-1] return SolveResult(method_name, x, residual, iterations, function_evaluations, converged, diverged, [""] * batch_size)