diff --git a/lenskit/tests/test_matrix_rows.py b/lenskit/tests/test_matrix_rows.py index 14ec04d10..b6438fef5 100644 --- a/lenskit/tests/test_matrix_rows.py +++ b/lenskit/tests/test_matrix_rows.py @@ -30,15 +30,20 @@ def test_sparse_mean_center(tensor: torch.Tensor): coo = tensor.to_sparse_coo() rows = coo.indices()[0, :].numpy() counts = np.zeros(nr, dtype=np.int32) - sums = np.zeros(nr, dtype=np.float64) + if tensor.dtype == torch.float64: + sums = np.zeros(nr, dtype=np.float64) + else: + sums = np.zeros(nr, dtype=np.float64) + np.add.at(counts, rows, 1) np.add.at(sums, rows, coo.values().numpy()) tgt_means = sums / counts tgt_means = np.nan_to_num(tgt_means, nan=0) + nt, means = normalize_sparse_rows(tensor, "center") assert means.shape == torch.Size([nr]) - assert means.numpy() == approx(tgt_means, nan_ok=True, rel=5.0e-5) + assert means.numpy() == approx(tgt_means, nan_ok=True, rel=1.0e-5) for i in range(nr): tr = tensor[i].values().numpy()