Skip to content

Commit

Permalink
Format code with black
Browse files Browse the repository at this point in the history
  • Loading branch information
jameschapman19 authored and github-actions[bot] committed Sep 19, 2023
1 parent 5d482ef commit 50f693d
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 7 deletions.
30 changes: 25 additions & 5 deletions cca_zoo/deep/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,42 @@ def on_train_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> No
train_corr,
)


class MinibatchTrainCorrelationCallback(Callback):
mcca=MCCA()
def on_train_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx, dataloader_idx):
mcca = MCCA()

def on_train_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx,
):
with torch.no_grad():
train_corr = self.mcca.loss(pl_module(batch["views"])).sum()
pl_module.log(
"train/corr",
train_corr,
)


class MinibatchValidationCorrelationCallback(Callback):
mcca=MCCA()
def on_validation_batch_end(self, trainer: Trainer, pl_module: LightningModule, outputs, batch, batch_idx, dataloader_idx):
mcca = MCCA()

def on_validation_batch_end(
self,
trainer: Trainer,
pl_module: LightningModule,
outputs,
batch,
batch_idx,
dataloader_idx,
):
with torch.no_grad():
val_corr = self.mcca.loss(pl_module(batch["views"])).sum()
pl_module.log(
"val/corr",
val_corr,
)
)
4 changes: 2 additions & 2 deletions cca_zoo/deep/objectives.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,8 @@ def correlation(self, views):
views = _demean(views)

SigmaHat12 = torch.cov(torch.hstack((views[0], views[1])).T)[
: latent_dims, latent_dims :
]
:latent_dims, latent_dims:
]
SigmaHat11 = torch.cov(views[0].T) + self.r * torch.eye(
o1, device=views[0].device
)
Expand Down

0 comments on commit 50f693d

Please sign in to comment.