Skip to content

Commit

Permalink
update to parameter correlation (#1009)
Browse files Browse the repository at this point in the history
* added a save and show parameter. changed default colors to blue-white-red, added the option to return the list of parameters for other inspections.

* removing suggestions of autosave again.
  • Loading branch information
PaulJonasJost authored Feb 3, 2023
1 parent 3dfd88d commit a1ee29d
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 3 deletions.
2 changes: 1 addition & 1 deletion pypesto/visualize/model_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def _time_trajectory_model_without_states(
amici.plotting.plotObservableTrajectories(
rdata=rdata,
observable_indices=observable_indices,
ax=axes[i_cond],
ax=axes[i_cond] if len(rdatas) > 1 else axes,
model=model,
)
return axes
16 changes: 14 additions & 2 deletions pypesto/visualize/parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from matplotlib.colors import Colormap
from matplotlib.ticker import MaxNLocator

from pypesto.util import delete_nan_inf
Expand Down Expand Up @@ -409,6 +410,8 @@ def parameters_correlation_matrix(
start_indices: Optional[Union[int, Iterable[int]]] = None,
method: Union[str, Callable] = 'pearson',
cluster: bool = True,
cmap: Union[Colormap, str] = 'bwr',
return_table: bool = False,
) -> matplotlib.axes.Axes:
"""
Plot correlation of optimized parameters.
Expand All @@ -427,6 +430,11 @@ def parameters_correlation_matrix(
spearman` or a callable function.
cluster:
Whether to cluster the correlation matrix.
cmap:
Colormap to use for the heatmap. Defaults to 'bwr'.
return_table:
Whether to return the parameter table additionally for further
inspection.
Returns
-------
Expand Down Expand Up @@ -454,8 +462,12 @@ def parameters_correlation_matrix(
corr_matrix = df.corr(method=method)
if cluster:
ax = sns.clustermap(
data=corr_matrix, yticklabels=True, vmin=-1, vmax=1
data=corr_matrix, yticklabels=True, vmin=-1, vmax=1, cmap=cmap
)
else:
ax = sns.heatmap(data=corr_matrix, yticklabels=True, vmin=-1, vmax=1)
ax = sns.heatmap(
data=corr_matrix, yticklabels=True, vmin=-1, vmax=1, cmap=cmap
)
if return_table:
return ax, df
return ax

0 comments on commit a1ee29d

Please sign in to comment.