From 9d5c13b1c3752c9c19ca6bf96fedebaf9a64e61b Mon Sep 17 00:00:00 2001 From: Paul Jonas Jost <70631928+PaulJonasJost@users.noreply.github.com> Date: Tue, 14 May 2024 09:53:03 +0200 Subject: [PATCH] More detailed defaults for `problem.get_full_vector` (#1393) * Made get_full_vector more intuitive * Corrected he usage of get_full vector. Removed it for lb and ub in favor of lb_full and ub_full. * Apply suggestions from code review Co-authored-by: Maren Philipps <55318391+m-philipps@users.noreply.github.com> --------- Co-authored-by: Maren Philipps <55318391+m-philipps@users.noreply.github.com> --- pypesto/problem/base.py | 14 ++++++++++---- pypesto/result/optimize.py | 6 +++--- pypesto/visualize/parameters.py | 4 ++-- test/base/test_history.py | 4 +--- test/base/test_x_fixed.py | 3 +-- 5 files changed, 17 insertions(+), 14 deletions(-) diff --git a/pypesto/problem/base.py b/pypesto/problem/base.py index d8d29847b..ad4b87142 100644 --- a/pypesto/problem/base.py +++ b/pypesto/problem/base.py @@ -332,7 +332,10 @@ def unfix_parameters( self.normalize() def get_full_vector( - self, x: Union[np.ndarray, None], x_fixed_vals: Iterable[float] = None + self, + x: Union[np.ndarray, None], + x_fixed_vals: Iterable[float] = None, + x_is_grad: bool = False, ) -> Union[np.ndarray, None]: """ Map vector from dim to dim_full. Usually used for x, grad. @@ -342,9 +345,9 @@ def get_full_vector( x: array_like, shape=(dim,) The vector in dimension dim. x_fixed_vals: array_like, ndim=1, optional - The values to be used for the fixed indices. If None, then nans are - inserted. Usually, None will be used for grad and - problem.x_fixed_vals for x. + The values to be used for the fixed indices. If None and x_is_grad=False, problem.x_fixed_vals is used; for x_is_grad=True, nans are inserted. + x_is_grad: bool + If true, x is treated as gradients. """ if x is None: return None @@ -362,6 +365,9 @@ def get_full_vector( x_full[..., self.x_free_indices] = x if x_fixed_vals is not None: x_full[..., self.x_fixed_indices] = x_fixed_vals + return x_full + if not x_is_grad: + x_full[..., self.x_fixed_indices] = self.x_fixed_vals return x_full def get_full_matrix( diff --git a/pypesto/result/optimize.py b/pypesto/result/optimize.py index 4fb2883fa..6f1d69964 100644 --- a/pypesto/result/optimize.py +++ b/pypesto/result/optimize.py @@ -191,10 +191,10 @@ def update_to_full(self, problem: Problem) -> None: problem which contains info about how to convert to full vectors or matrices """ - self.x = problem.get_full_vector(self.x, problem.x_fixed_vals) - self.grad = problem.get_full_vector(self.grad) + self.x = problem.get_full_vector(self.x) + self.grad = problem.get_full_vector(self.grad, x_is_grad=True) self.hess = problem.get_full_matrix(self.hess) - self.x0 = problem.get_full_vector(self.x0, problem.x_fixed_vals) + self.x0 = problem.get_full_vector(self.x0) self.free_indices = np.array(problem.x_free_indices) diff --git a/pypesto/visualize/parameters.py b/pypesto/visualize/parameters.py index 81b10e1e2..c50f4fdeb 100644 --- a/pypesto/visualize/parameters.py +++ b/pypesto/visualize/parameters.py @@ -417,8 +417,8 @@ def handle_inputs( ub = result.problem.get_reduced_vector(ub, parameter_indices) x_labels = [x_labels[int(i)] for i in parameter_indices] else: - lb = result.problem.get_full_vector(lb) - ub = result.problem.get_full_vector(ub) + lb = result.problem.lb_full + ub = result.problem.ub_full if inner_xs is not None and plot_inner_parameters: lb = np.concatenate([lb, inner_lb]) diff --git a/test/base/test_history.py b/test/base/test_history.py index 57d01a3f2..9056fd7a2 100644 --- a/test/base/test_history.py +++ b/test/base/test_history.py @@ -236,9 +236,7 @@ def check_reconstruct_history( def check_history_consistency(self, start: pypesto.OptimizerResult): def xfull(x_trace): - return self.problem.get_full_vector( - x_trace, self.problem.x_fixed_vals - ) + return self.problem.get_full_vector(x_trace) if isinstance(start.history, (CsvHistory, Hdf5History)): # get index of optimal parameter diff --git a/test/base/test_x_fixed.py b/test/base/test_x_fixed.py index fc7ea2e00..320cd8d08 100644 --- a/test/base/test_x_fixed.py +++ b/test/base/test_x_fixed.py @@ -41,8 +41,7 @@ def test_optimize(): # fixed values written into parameter vector assert optimizer_result.x[1] == 1 - lb_full = problem.get_full_vector(problem.lb) - assert len(lb_full) == 5 + assert len(problem.lb_full) == 5 def create_problem():