Skip to content

Commit

Permalink
Bump version of Python, torch and pytorch-lightning
Browse files Browse the repository at this point in the history
  • Loading branch information
marcovarrone committed Jan 24, 2024
1 parent 3d25c64 commit 603b9a7
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
4 changes: 2 additions & 2 deletions pycave/clustering/kmeans/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def __init__(

def update(self, data: torch.Tensor, assignments: torch.Tensor) -> None:
indices = assignments.unsqueeze(1).expand(-1, self.num_features)
self.centroids.scatter_add_(0, indices, data)
self.centroids.scatter_add_(0, indices, data.float())

counts = assignments.bincount(minlength=self.num_clusters).float()
self.cluster_counts.add_(counts)
Expand Down Expand Up @@ -149,7 +149,7 @@ def update(self, data: torch.Tensor, shortest_distances: torch.Tensor) -> None:
# Then, we sample from the data `num_choices` times and replace if needed
choices = (squared_distances + eps).multinomial(self.num_choices, replacement=True)
self.choices.masked_scatter_(
use_choice_from_data.unsqueeze(1), data[choices[use_choice_from_data]]
use_choice_from_data.unsqueeze(1), data[choices[use_choice_from_data]].float()
)

# In any case, the cumulative distances are updated
Expand Down
2 changes: 1 addition & 1 deletion pycave/utils/lightning_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def configure_optimizers(self) -> None:
def training_step(self, batch: torch.Tensor, batch_idx: int) -> None:
self.nonparametric_training_step(batch, batch_idx)

def training_epoch_end(self, outputs: List[torch.Tensor]) -> None:
def on_training_epoch_end(self, outputs: List[torch.Tensor]) -> None:
self.nonparametric_training_epoch_end()

@abstractmethod
Expand Down
8 changes: 4 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@ version = "0.0.0"
[tool.poetry.dependencies]
lightkit = "^0.5.0"
numpy = "^1.20.3"
python = ">=3.8,<3.11"
pytorch-lightning = "^1.6.0"
torch = "^1.8.0"
torchmetrics = ">=0.6,<0.12"
python = ">=3.8,<=3.11"
pytorch-lightning = "<2.2"
torch = "<2.2.0"
torchmetrics = ">=0.6,<1.4.0"

[tool.poetry.group.pre-commit.dependencies]
black = "^22.12.0"
Expand Down

0 comments on commit 603b9a7

Please sign in to comment.