From bdfb75138e2afab5f5d270eb91bcd6ae196ad9e0 Mon Sep 17 00:00:00 2001 From: tenk-9 Date: Tue, 2 Jul 2024 17:24:06 +0900 Subject: [PATCH] Reflects the changes in #59 --- pycave/clustering/kmeans/metrics.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pycave/clustering/kmeans/metrics.py b/pycave/clustering/kmeans/metrics.py index 1146f25..7c2eaa8 100644 --- a/pycave/clustering/kmeans/metrics.py +++ b/pycave/clustering/kmeans/metrics.py @@ -148,6 +148,7 @@ def update(self, data: torch.Tensor, shortest_distances: torch.Tensor) -> None: # Then, we sample from the data `num_choices` times and replace if needed choices = (squared_distances + eps).multinomial(self.num_choices, replacement=True) + self.choices = self.choices.to(data.dtype) self.choices.masked_scatter_( use_choice_from_data.unsqueeze(1), data[choices[use_choice_from_data]] )