diff --git a/pycave/clustering/kmeans/metrics.py b/pycave/clustering/kmeans/metrics.py index 1146f25..7c2eaa8 100644 --- a/pycave/clustering/kmeans/metrics.py +++ b/pycave/clustering/kmeans/metrics.py @@ -148,6 +148,7 @@ def update(self, data: torch.Tensor, shortest_distances: torch.Tensor) -> None: # Then, we sample from the data `num_choices` times and replace if needed choices = (squared_distances + eps).multinomial(self.num_choices, replacement=True) + self.choices = self.choices.to(data.dtype) self.choices.masked_scatter_( use_choice_from_data.unsqueeze(1), data[choices[use_choice_from_data]] )