Skip to content

Commit

Permalink
Optimization Parameter scatter plot (#1015)
Browse files Browse the repository at this point in the history
* Small update to parameter historgram

* Added parameter scatter plot for optimization.

* Renamed to optimization_scatter, added two tests.

* Added to an example doc. Fixed test

* Added id to optimization.

* Added correct id to optimization. Increased possible runtime

* small change

* fixed typo

* More informative doc

---------

Co-authored-by: Polina Lakrisenko <p.lakrisenko@gmail.com>
  • Loading branch information
PaulJonasJost and plakrisenko authored Feb 15, 2023
1 parent 5effaba commit b7a18e0
Show file tree
Hide file tree
Showing 5 changed files with 256 additions and 119 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
run: .github/workflows/install_deps.sh amici

- name: Run tests
timeout-minutes: 15
timeout-minutes: 25
run: tox -e base

- name: Coverage
Expand Down
254 changes: 139 additions & 115 deletions doc/example/conversion_reaction.ipynb

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion pypesto/visualize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@
)
from .optimizer_convergence import optimizer_convergence
from .optimizer_history import optimizer_history, optimizer_history_lowlevel
from .parameters import parameter_hist, parameters, parameters_lowlevel
from .parameters import (
optimization_scatter,
parameter_hist,
parameters,
parameters_lowlevel,
)
from .profile_cis import profile_cis
from .profiles import profile_lowlevel, profiles, profiles_lowlevel
from .reference_points import ReferencePoint, create_references
Expand Down
96 changes: 94 additions & 2 deletions pypesto/visualize/parameters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import logging
from typing import Callable, Iterable, List, Optional, Sequence, Tuple, Union

import matplotlib.axes
Expand All @@ -19,6 +20,8 @@
)
from .reference_points import ReferencePoint, create_references

logger = logging.getLogger(__name__)


def parameters(
results: Union[Result, Sequence[Result]],
Expand Down Expand Up @@ -207,10 +210,10 @@ def parameter_hist(
parameter_index = result.problem.x_names.index(parameter_name)
parameter_values = [x[parameter_index] for x in xs]

ax.hist(parameter_values, color=color, bins=bins)
ax.hist(parameter_values, color=color, bins=bins, label=parameter_name)
ax.set_xlabel(parameter_name)
ax.set_ylabel("counts")
ax.set_title(f"Parameter {parameter_name}")
ax.set_title(f"{parameter_name}")

return ax

Expand Down Expand Up @@ -471,3 +474,92 @@ def parameters_correlation_matrix(
if return_table:
return ax, df
return ax


def optimization_scatter(
result: Result,
parameter_indices: Union[str, Sequence[int]] = 'free_only',
start_indices: Optional[Union[int, Iterable[int]]] = None,
diag_kind: str = "kde",
suptitle: str = None,
size: Tuple[float, float] = None,
show_bounds: bool = False,
):
"""
Plot a scatter plot of all pairs of parameters for the given starts.
Parameters
----------
result:
Optimization result obtained by 'optimize.py'.
parameter_indices:
List of integers specifying the parameters to be considered.
start_indices:
List of integers specifying the multistarts to be plotted or
int specifying up to which start index should be plotted.
diag_kind:
Visualization mode for marginal densities {‘auto’, ‘hist’, ‘kde’,
None}.
suptitle:
Title of the plot.
size:
Size of the plot.
show_bounds:
Whether to show the parameter bounds.
Returns
-------
ax:
The plot axis.
"""
import seaborn as sns

start_indices = process_start_indices(
start_indices=start_indices, result=result
)
parameter_indices = process_parameter_indices(
parameter_indices=parameter_indices, result=result
)
# remove all start indices, that encounter an inf value at the start
# resulting in optimize_result[start]["x"] being None
start_indices_finite = start_indices[
[
result.optimize_result[i_start]['x'] is not None
for i_start in start_indices
]
]
# compare start_indices with start_indices_finite and log a warning
if not np.all(start_indices == start_indices_finite):
logger.warning(
'Some start indices were removed due to inf values at the start.'
)
# put all parameters into a dataframe, where columns are parameters
parameters = [
result.optimize_result[i_start]['x'][parameter_indices]
for i_start in start_indices_finite
]
x_labels = [
result.problem.x_names[parameter_index]
for parameter_index in parameter_indices
]
df = pd.DataFrame(parameters, columns=x_labels)

sns.set(style="ticks")

ax = sns.pairplot(
df,
diag_kind=diag_kind,
)

if size is not None:
ax.fig.set_size_inches(size)
if suptitle:
ax.fig.suptitle(suptitle)
if show_bounds:
# set bounds of plot to parameter bounds. Only use diagonal as
# sns.PairGrid has sharex,sharey = True by default.
for i_axis, axis in enumerate(np.diag(ax.axes)):
axis.set_xlim(result.problem.lb[i_axis], result.problem.ub[i_axis])
axis.set_ylim(result.problem.lb[i_axis], result.problem.ub[i_axis])

return ax
16 changes: 16 additions & 0 deletions test/visualize/test_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,22 @@ def test_parameters_hist():
visualize.parameter_hist(result_1, 'x1', start_indices=list(range(10)))


@close_fig
def test_optimization_scatter():
result = create_optimization_result()
visualize.optimization_scatter(result)


@close_fig
def test_optimization_scatter_with_x_None():
result = create_optimization_result()
# create an optimizerResult with x=None
optimizer_result = pypesto.OptimizerResult(x=None, fval=np.inf, id="inf")
result.optimize_result.append(optimize_result=optimizer_result)

visualize.optimization_scatter(result)


# @close_fig
def _test_ensemble_dimension_reduction():
# creates a test problem
Expand Down

0 comments on commit b7a18e0

Please sign in to comment.