diff --git a/phiml/math/_optimize.py b/phiml/math/_optimize.py index e39182d6..bf168257 100644 --- a/phiml/math/_optimize.py +++ b/phiml/math/_optimize.py @@ -727,7 +727,7 @@ def _linear_solve_forward(y: Tensor, assert isinstance(ret, SolveResult) converged = reshaped_tensor(ret.converged, [*trj_dims, batch_dims]) diverged = reshaped_tensor(ret.diverged, [*trj_dims, batch_dims]) - x = assemble_tree(x0_nest, [reshaped_tensor(ret.x, [*trj_dims, batch_dims, pattern_dims_out])], attr_type=variable_attributes) + x = assemble_tree(x0_nest, [reshaped_tensor(ret.x, [*trj_dims, batch_dims, pattern_dims_in])], attr_type=variable_attributes) final_x = x if not trj_dims else assemble_tree(x0_nest, [reshaped_tensor(ret.x[-1, ...], [batch_dims, pattern_dims_out])], attr_type=variable_attributes) iterations = reshaped_tensor(ret.iterations, [*trj_dims, batch_dims]) function_evaluations = reshaped_tensor(ret.function_evaluations, [*trj_dims, batch_dims])