Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main'
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 committed Sep 19, 2023
2 parents 9cef889 + c58791a commit 5d482ef
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 4 deletions.
1 change: 1 addition & 0 deletions cca_zoo/deep/_discriminative/_dcca.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from cca_zoo.deep._base import BaseDeep
from cca_zoo.linear._mcca import MCCA


class DCCA(BaseDeep):
"""
A class used to fit a DCCA model.
Expand Down
6 changes: 5 additions & 1 deletion cca_zoo/utils/check_values.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def _check_batch_size(batch_size, latent_dimensions):
"Objective is unstable when batch size is less than the number of latent dimensions"
)


def check_tsne_support(caller_name):
try:
import openTSNE
Expand All @@ -54,6 +55,7 @@ def check_tsne_support(caller_name):
"Please install openTSNE using `pip install openTSNE`"
)


def check_umap_support(caller_name):
try:
import umap
Expand All @@ -63,6 +65,7 @@ def check_umap_support(caller_name):
"Please install umap using `pip install umap-learn`"
)


def check_seaborn_support(caller_name):
try:
import seaborn
Expand All @@ -72,11 +75,12 @@ def check_seaborn_support(caller_name):
"Please install seaborn using `pip install seaborn`"
)


def check_arviz_support(caller_name):
try:
import arviz as az
except ImportError:
raise ImportError(
f"{caller_name} requires arviz. "
"Please install arviz using `pip install arviz`"
)
)
2 changes: 1 addition & 1 deletion cca_zoo/visualisation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,5 +17,5 @@
"CovarianceHeatmapDisplay",
"TSNEScoreDisplay",
"UMAPScoreDisplay",
"WeightInferenceDisplay"
"WeightInferenceDisplay",
]
3 changes: 2 additions & 1 deletion cca_zoo/visualisation/inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from cca_zoo.utils.check_values import check_arviz_support


class WeightInferenceDisplay:
"""
Class for displaying inference-related plots.
Expand Down Expand Up @@ -76,6 +77,7 @@ def from_mcmc(cls, mcmc, true_features=None):
An InferenceDisplay instance.
"""
import arviz as az

idata = az.from_numpyro(mcmc)
return cls(idata, 2, true_features)

Expand Down Expand Up @@ -109,4 +111,3 @@ def plot(self):
plt.tight_layout()

plt.show()

1 change: 0 additions & 1 deletion docs/source/examples/plot_probabilistic.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,4 +127,3 @@ def sample(self, n):

WeightInferenceDisplay.from_estimator(pcca).plot()
plt.show()

0 comments on commit 5d482ef

Please sign in to comment.