Skip to content

Commit

Permalink
Visualize: fix flatten of observable mapping with one observable (#1515)
Browse files Browse the repository at this point in the history
* Fix flatten of axes if only one obs

* Same for other, implement tests of both

* Decrease maxiter of visualize tests
  • Loading branch information
Doresic authored Dec 2, 2024
1 parent 010e93a commit 9a20cd8
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 22 deletions.
8 changes: 3 additions & 5 deletions pypesto/visualize/observable_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def visualize_estimated_observable_mapping(
n_axes = n_relative_observables + n_semiquant_observables
n_rows = int(np.ceil(np.sqrt(n_axes)))
n_cols = int(np.ceil(n_axes / n_rows))
_, axes = plt.subplots(n_rows, n_cols, **kwargs)
_, axes = plt.subplots(n_rows, n_cols, squeeze=False, **kwargs)
axes = axes.flatten()

# Plot the estimated observable mapping for relative observables.
Expand Down Expand Up @@ -246,8 +246,7 @@ def plot_linear_observable_mappings_from_pypesto_result(
n_cols = int(np.ceil(n_relative_observables / n_rows))

# Make as many subplots as there are relative observables
_, axes = plt.subplots(n_rows, n_cols, **kwargs)

_, axes = plt.subplots(n_rows, n_cols, squeeze=False, **kwargs)
# Flatten the axes array
axes = axes.flatten()

Expand Down Expand Up @@ -590,8 +589,7 @@ def plot_splines_from_inner_result(
n_cols = int(np.ceil(n_groups / n_rows))

# Make as many subplots as there are groups
_, axes = plt.subplots(n_rows, n_cols, **kwargs)

_, axes = plt.subplots(n_rows, n_cols, squeeze=False, **kwargs)
# Flatten the axes array
axes = axes.flatten()

Expand Down
26 changes: 9 additions & 17 deletions pypesto/visualize/ordinal_categories.py
Original file line number Diff line number Diff line change
Expand Up @@ -612,26 +612,18 @@ def _get_data_for_plotting(

def _get_default_axes(n_groups, **kwargs):
"""Return a list of axes with the default layout."""
# If there is only one group, make a figure with only one plot
if n_groups == 1:
# Make figure with only one plot
fig, ax = plt.subplots(1, 1, **kwargs)
# Choose number of rows and columns to be used for the subplots
n_rows = int(np.ceil(np.sqrt(n_groups)))
n_cols = int(np.ceil(n_groups / n_rows))

axes = [ax]
# If there are multiple groups, make a figure with multiple plots
else:
# Choose number of rows and columns to be used for the subplots
n_rows = int(np.ceil(np.sqrt(n_groups)))
n_cols = int(np.ceil(n_groups / n_rows))

# Make as many subplots as there are groups
fig, axes = plt.subplots(n_rows, n_cols, **kwargs)
# Make as many subplots as there are groups
fig, axes = plt.subplots(n_rows, n_cols, squeeze=False, **kwargs)

# Increase the spacing between the subplots
fig.subplots_adjust(hspace=0.35, wspace=0.25)
# Increase the spacing between the subplots
fig.subplots_adjust(hspace=0.35, wspace=0.25)

# Flatten the axes array
axes = axes.flatten()
# Flatten the axes array
axes = axes.flatten()
return axes


Expand Down
57 changes: 57 additions & 0 deletions test/visualize/test_visualize.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from collections.abc import Sequence
from functools import wraps
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
Expand Down Expand Up @@ -1240,3 +1241,59 @@ def test_parameters_correlation_matrix(result_creation):
result = result_creation()

visualize.parameters_correlation_matrix(result)


@close_fig
def test_plot_ordinal_categories():
example_ordinal_yaml = (
Path(__file__).parent
/ ".."
/ ".."
/ "doc"
/ "example"
/ "example_ordinal"
/ "example_ordinal.yaml"
)
petab_problem = petab.Problem.from_yaml(example_ordinal_yaml)
# Set seed for reproducibility.
np.random.seed(0)
optimizer = pypesto.optimize.ScipyOptimizer(
method="L-BFGS-B", options={"maxiter": 1}
)
importer = pypesto.petab.PetabImporter(petab_problem, hierarchical=True)
problem = importer.create_problem()
result = pypesto.optimize.minimize(
problem=problem, n_starts=1, optimizer=optimizer
)
visualize.plot_categories_from_pypesto_result(result)


@close_fig
def test_visualize_estimated_observable_mapping():
example_semiquantitative_yaml = (
Path(__file__).parent
/ ".."
/ ".."
/ "doc"
/ "example"
/ "example_semiquantitative"
/ "example_semiquantitative_linear.yaml"
)
petab_problem = petab.Problem.from_yaml(example_semiquantitative_yaml)
# Set seed for reproducibility.
np.random.seed(0)
optimizer = pypesto.optimize.ScipyOptimizer(
method="L-BFGS-B",
options={
"disp": None,
"ftol": 2.220446049250313e-09,
"gtol": 1e-5,
"maxiter": 1,
},
)
importer = pypesto.petab.PetabImporter(petab_problem, hierarchical=True)
problem = importer.create_problem()
result = pypesto.optimize.minimize(
problem=problem, n_starts=1, optimizer=optimizer
)
visualize.visualize_estimated_observable_mapping(result, problem)

0 comments on commit 9a20cd8

Please sign in to comment.