Skip to content

Commit

Permalink
remove sparse_row_stats function
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Jul 25, 2024
1 parent 070b35e commit 6502152
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 57 deletions.
41 changes: 7 additions & 34 deletions lenskit/lenskit/data/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,37 +113,6 @@ def shape(self) -> tuple[int, int]:
return (self.n_users, self.n_items)


class DimStats(NamedTuple):
"""
The statistics for a matrix along a dimension (e.g. rows or columns).
"""

"The size along this dimension."
n: int
"The other dimension of the matrix."
n_other: int
"The number of stored entries for each element."
counts: t.Tensor
"The sum of entries for each element."
sums: t.Tensor
"The mean of stored entries for each element."
means: t.Tensor


def sparse_row_stats(matrix: t.Tensor) -> DimStats:
if not matrix.is_sparse_csr:
raise TypeError("only sparse CSR matrice supported")

n, n_other = matrix.shape
counts = matrix.crow_indices().diff()
assert counts.shape == (n,), f"count shape {counts.shape} != {n}"
sums = matrix.sum(dim=1, keepdim=True).to_dense().reshape(n)
assert sums.shape == (n,), f"sum shape {sums.shape} != {n}"
means = sums / counts

return DimStats(n, n_other, counts, sums, means)


@overload
def normalize_sparse_rows(
matrix: t.Tensor, method: Literal["center"], inplace: bool = False
Expand All @@ -168,13 +137,17 @@ def normalize_sparse_rows(


def _nsr_mean_center(matrix: t.Tensor) -> tuple[t.Tensor, t.Tensor]:
stats = sparse_row_stats(matrix)
nr, _nc = matrix.shape
sums = matrix.sum(dim=1, keepdim=True).to_dense().reshape(nr)
counts = torch.diff(matrix.crow_indices())
assert sums.shape == counts.shape
means = torch.nan_to_num(sums / counts, 0)
return t.sparse_csr_tensor(
crow_indices=matrix.crow_indices(),
col_indices=matrix.col_indices(),
values=matrix.values() - t.repeat_interleave(stats.means, stats.counts),
values=matrix.values() - t.repeat_interleave(means, counts),
size=matrix.shape,
), stats.means
), means


def _nsr_unit(matrix: t.Tensor) -> tuple[t.Tensor, t.Tensor]:
Expand Down
35 changes: 12 additions & 23 deletions lenskit/tests/test_matrix_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,40 +16,29 @@
from hypothesis import HealthCheck, given, settings
from pytest import approx

from lenskit.data.matrix import normalize_sparse_rows, sparse_row_stats
from lenskit.data.matrix import normalize_sparse_rows
from lenskit.util.test import sparse_tensors

_log = logging.getLogger(__name__)


@settings(suppress_health_check=[HealthCheck.too_slow])
@given(sparse_tensors())
def test_sparse_stats(tensor):
nr, nc = tensor.shape
_log.debug("tensor: %d x %d", nr, nc)

stats = sparse_row_stats(tensor)
assert stats.means.shape == (nr,)
assert stats.counts.shape == (nr,)

assert np.sum(stats.counts.numpy()) == tensor.values().shape[0]

sums = tensor.sum(dim=1, keepdim=True)
sums = sums.to_dense().reshape(-1)
tots = stats.means * stats.counts
mask = stats.counts.numpy() > 0
assert tots.numpy()[mask] == approx(sums.numpy()[mask])


@settings(deadline=1000, suppress_health_check=[HealthCheck.too_slow])
@given(sparse_tensors())
def test_sparse_mean_center(tensor):
def test_sparse_mean_center(tensor: torch.Tensor):
nr, nc = tensor.shape

stats = sparse_row_stats(tensor)
coo = tensor.to_sparse_coo()
rows = coo.indices()[0, :].numpy()
counts = np.zeros(nr, dtype=np.int32)
sums = np.zeros(nr, dtype=np.float32)
np.add.at(counts, rows, 1)
np.add.at(sums, rows, coo.values().numpy())
tgt_means = sums / counts
tgt_means = np.nan_to_num(tgt_means, nan=0)
nt, means = normalize_sparse_rows(tensor, "center")
assert means.shape == torch.Size([nr])

assert means.numpy() == approx(stats.means.numpy(), nan_ok=True)
assert means.numpy() == approx(tgt_means, nan_ok=True)

for i in range(nr):
tr = tensor[i].values().numpy()
Expand Down

0 comments on commit 6502152

Please sign in to comment.