diff --git a/cca_zoo/deep/callbacks.py b/cca_zoo/deep/callbacks.py index 9310da89..69d9771e 100644 --- a/cca_zoo/deep/callbacks.py +++ b/cca_zoo/deep/callbacks.py @@ -23,9 +23,19 @@ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> No train_corr, ) + class MinibatchTrainCorrelationCallback(Callback): - mcca=MCCA() - def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx, dataloader_idx): + mcca = MCCA() + + def on_train_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs, + batch, + batch_idx, + dataloader_idx, + ): with torch.no_grad(): train_corr = self.mcca.loss(pl_module(batch["views"])).sum() pl_module.log( @@ -33,12 +43,22 @@ def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outpu train_corr, ) + class MinibatchValidationCorrelationCallback(Callback): - mcca=MCCA() - def on_validation_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx, dataloader_idx): + mcca = MCCA() + + def on_validation_batch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + outputs, + batch, + batch_idx, + dataloader_idx, + ): with torch.no_grad(): val_corr = self.mcca.loss(pl_module(batch["views"])).sum() pl_module.log( "val/corr", val_corr, - ) \ No newline at end of file + ) diff --git a/cca_zoo/deep/objectives.py b/cca_zoo/deep/objectives.py index 3b4a6ea1..3ac5fccf 100644 --- a/cca_zoo/deep/objectives.py +++ b/cca_zoo/deep/objectives.py @@ -140,8 +140,8 @@ def correlation(self, views): views = _demean(views) SigmaHat12 = torch.cov(torch.hstack((views[0], views[1])).T)[ - : latent_dims, latent_dims : - ] + :latent_dims, latent_dims: + ] SigmaHat11 = torch.cov(views[0].T) + self.r * torch.eye( o1, device=views[0].device )