diff --git a/pypesto/visualize/observable_mapping.py b/pypesto/visualize/observable_mapping.py index ec378fd1c..250a82dd5 100644 --- a/pypesto/visualize/observable_mapping.py +++ b/pypesto/visualize/observable_mapping.py @@ -129,7 +129,7 @@ def visualize_estimated_observable_mapping( n_axes = n_relative_observables + n_semiquant_observables n_rows = int(np.ceil(np.sqrt(n_axes))) n_cols = int(np.ceil(n_axes / n_rows)) - _, axes = plt.subplots(n_rows, n_cols, **kwargs) + _, axes = plt.subplots(n_rows, n_cols, squeeze=False, **kwargs) axes = axes.flatten() # Plot the estimated observable mapping for relative observables. @@ -246,8 +246,7 @@ def plot_linear_observable_mappings_from_pypesto_result( n_cols = int(np.ceil(n_relative_observables / n_rows)) # Make as many subplots as there are relative observables - _, axes = plt.subplots(n_rows, n_cols, **kwargs) - + _, axes = plt.subplots(n_rows, n_cols, squeeze=False, **kwargs) # Flatten the axes array axes = axes.flatten() @@ -590,8 +589,7 @@ def plot_splines_from_inner_result( n_cols = int(np.ceil(n_groups / n_rows)) # Make as many subplots as there are groups - _, axes = plt.subplots(n_rows, n_cols, **kwargs) - + _, axes = plt.subplots(n_rows, n_cols, squeeze=False, **kwargs) # Flatten the axes array axes = axes.flatten() diff --git a/pypesto/visualize/ordinal_categories.py b/pypesto/visualize/ordinal_categories.py index b93e97f22..c8e81def5 100644 --- a/pypesto/visualize/ordinal_categories.py +++ b/pypesto/visualize/ordinal_categories.py @@ -612,26 +612,18 @@ def _get_data_for_plotting( def _get_default_axes(n_groups, **kwargs): """Return a list of axes with the default layout.""" - # If there is only one group, make a figure with only one plot - if n_groups == 1: - # Make figure with only one plot - fig, ax = plt.subplots(1, 1, **kwargs) + # Choose number of rows and columns to be used for the subplots + n_rows = int(np.ceil(np.sqrt(n_groups))) + n_cols = int(np.ceil(n_groups / n_rows)) - axes = [ax] - # If there are multiple groups, make a figure with multiple plots - else: - # Choose number of rows and columns to be used for the subplots - n_rows = int(np.ceil(np.sqrt(n_groups))) - n_cols = int(np.ceil(n_groups / n_rows)) - - # Make as many subplots as there are groups - fig, axes = plt.subplots(n_rows, n_cols, **kwargs) + # Make as many subplots as there are groups + fig, axes = plt.subplots(n_rows, n_cols, squeeze=False, **kwargs) - # Increase the spacing between the subplots - fig.subplots_adjust(hspace=0.35, wspace=0.25) + # Increase the spacing between the subplots + fig.subplots_adjust(hspace=0.35, wspace=0.25) - # Flatten the axes array - axes = axes.flatten() + # Flatten the axes array + axes = axes.flatten() return axes diff --git a/test/visualize/test_visualize.py b/test/visualize/test_visualize.py index 149c65719..ecdc38eab 100644 --- a/test/visualize/test_visualize.py +++ b/test/visualize/test_visualize.py @@ -3,6 +3,7 @@ import os from collections.abc import Sequence from functools import wraps +from pathlib import Path import matplotlib.pyplot as plt import numpy as np @@ -1240,3 +1241,59 @@ def test_parameters_correlation_matrix(result_creation): result = result_creation() visualize.parameters_correlation_matrix(result) + + +@close_fig +def test_plot_ordinal_categories(): + example_ordinal_yaml = ( + Path(__file__).parent + / ".." + / ".." + / "doc" + / "example" + / "example_ordinal" + / "example_ordinal.yaml" + ) + petab_problem = petab.Problem.from_yaml(example_ordinal_yaml) + # Set seed for reproducibility. + np.random.seed(0) + optimizer = pypesto.optimize.ScipyOptimizer( + method="L-BFGS-B", options={"maxiter": 1} + ) + importer = pypesto.petab.PetabImporter(petab_problem, hierarchical=True) + problem = importer.create_problem() + result = pypesto.optimize.minimize( + problem=problem, n_starts=1, optimizer=optimizer + ) + visualize.plot_categories_from_pypesto_result(result) + + +@close_fig +def test_visualize_estimated_observable_mapping(): + example_semiquantitative_yaml = ( + Path(__file__).parent + / ".." + / ".." + / "doc" + / "example" + / "example_semiquantitative" + / "example_semiquantitative_linear.yaml" + ) + petab_problem = petab.Problem.from_yaml(example_semiquantitative_yaml) + # Set seed for reproducibility. + np.random.seed(0) + optimizer = pypesto.optimize.ScipyOptimizer( + method="L-BFGS-B", + options={ + "disp": None, + "ftol": 2.220446049250313e-09, + "gtol": 1e-5, + "maxiter": 1, + }, + ) + importer = pypesto.petab.PetabImporter(petab_problem, hierarchical=True) + problem = importer.create_problem() + result = pypesto.optimize.minimize( + problem=problem, n_starts=1, optimizer=optimizer + ) + visualize.visualize_estimated_observable_mapping(result, problem)