Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add negative sampling to interaction matrix #631

Merged
merged 6 commits into from
Feb 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions lenskit/lenskit/data/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@
raise ValueError(f"class name “{name}” already in use")

check_name(name)
if len(entities) != 2:
raise NotImplementedError("more than 2 entities not yet supported")

Check warning on line 151 in lenskit/lenskit/data/builder.py

View check run for this annotation

Codecov / codecov/patch

lenskit/lenskit/data/builder.py#L151

Added line #L151 was not covered by tests

self._log.debug("adding relationship class", class_name=name)
e_dict: dict[str, str | None]
Expand Down
110 changes: 105 additions & 5 deletions lenskit/lenskit/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from __future__ import annotations

import functools
import warnings
from abc import abstractmethod
from collections.abc import Callable, Mapping
from os import PathLike
Expand All @@ -25,8 +26,9 @@
from numpy.typing import NDArray
from typing_extensions import Any, Literal, TypeAlias, TypeVar, overload, override

from lenskit.diagnostics import DataError
from lenskit.diagnostics import DataError, DataWarning
from lenskit.logging import get_logger
from lenskit.random import random_generator

from .attributes import AttributeSet, attr_set
from .container import DataContainer
Expand Down Expand Up @@ -788,6 +790,8 @@
col_type: str
_col_stats: pd.DataFrame | None = None

rc_index: pd.Index

def __init__(
self,
ds: Dataset,
Expand All @@ -799,6 +803,9 @@
# order the table to compute the sparse matrix
entities = list(schema.entities.keys())
row, col = entities
row_col_name = num_col_name(row)
col_col_name = num_col_name(col)

self.row_type = row
self.row_vocabulary = ds.entities(row).vocabulary
self.col_type = col
Expand All @@ -817,6 +824,22 @@
self._row_ptrs = np.cumsum(row_sizes, dtype=np.int32)
self._table = table

# make the index
self.rc_index = pd.Index(
self._rc_combined_nums(
self._table.column(row_col_name).to_numpy(),
self._table.column(col_col_name).to_numpy(),
)
)

@property
def n_rows(self):
return len(self.row_vocabulary)

@property
def n_cols(self) -> int:
return len(self.col_vocabulary)

@override
def matrix(
self, *, combine: MAT_AGG | dict[str, MAT_AGG] | None = None
Expand Down Expand Up @@ -878,8 +901,8 @@
Returns:
The sparse matrix.
"""
n_rows = len(self.row_vocabulary)
n_cols = len(self.col_vocabulary)
n_rows = self.n_rows
n_cols = self.n_cols
nnz = self._table.num_rows

colinds = self._table.column(num_col_name(self.col_type)).to_numpy()
Expand Down Expand Up @@ -920,8 +943,8 @@
Returns:
The sparse matrix.
"""
n_rows = len(self.row_vocabulary)
n_cols = len(self.col_vocabulary)
n_rows = self.n_rows
n_cols = self.n_cols
nnz = self._table.num_rows

colinds = self._table.column(num_col_name(self.col_type)).to_numpy()
Expand Down Expand Up @@ -949,6 +972,83 @@
indices=indices, values=values, size=(n_rows, n_cols)
).coalesce()

def sample_negatives(
self,
rows: np.ndarray[int, np.dtype[np.int32]],
*,
weighting: Literal["uniform", "popularity"] = "uniform",
verify: bool = True,
max_attempts: int = 10,
rng: np.random.Generator | None = None,
) -> NDArray[np.int32]:
"""
Sample negative columns (columns with no observation recorded) for an
array of rows. On a normal interaction matrix, this samples negative
items for users.

Args:
rows:
The row numbers. Duplicates are allowed, and negative columns
are sampled independently for each row. Must be a 1D array or
tensor.
weighting:
The weighting for sampled negatives; ``uniform`` samples them
uniformly at random, while ``popularity`` samples them
proportional to their popularity (number of occurrences).
verify:
Whether to verify that the negative items are actually negative.
Unverified sampling is much faster but can return false
negatives.
max_attempts:
When verification is on, the maximum attempts before giving up
and returning a possible false negative.
rng:
A random number generator to use.
"""
rng = random_generator(rng)

_log.debug("samping negatives", nrows=len(rows))
match weighting:
case "uniform":
columns = rng.choice(self.n_cols, size=len(rows), replace=True)
case "popularity":
ccol = self._table.column(num_col_name(self.col_type)).to_numpy()
trows = rng.choice(self._table.num_rows, size=len(rows), replace=True)
columns = ccol[trows]
columns = np.require(columns, "i4")

if verify:
non_neg = self._check_negatives(rows, columns)
_log.debug("checking negatives", nrows=len(rows), npos=np.sum(non_neg).item())
if np.any(non_neg):
if max_attempts > 0:
columns[non_neg] = self.sample_negatives(
rows[non_neg],
verify=True,
rng=rng,
max_attempts=max_attempts - 1,
weighting=weighting,
)
else:
warnings.warn(

Check warning on line 1033 in lenskit/lenskit/data/dataset.py

View check run for this annotation

Codecov / codecov/patch

lenskit/lenskit/data/dataset.py#L1033

Added line #L1033 was not covered by tests
"failed to find verified negatives for {} users".format(np.sum(non_neg)),
DataWarning,
)

return columns

def _check_negatives(
self, rows: NDArray[np.int32], columns: NDArray[np.int32]
) -> NDArray[np.bool]:
nums = self._rc_combined_nums(rows, columns)
locs = self.rc_index.get_indexer_for(nums)
return locs >= 0

def _rc_combined_nums(self, rows: NDArray[np.int32], columns: NDArray[np.int32]):
rnums = rows.astype(np.uint64)
cnums = columns.astype(np.uint64)
return (rnums << 32) + cnums

def row_table(self, id: ID | None = None, *, number: int | None = None) -> pa.Table | None:
"""
Get a single row of this interaction matrix as a table.
Expand Down
6 changes: 5 additions & 1 deletion lenskit/lenskit/random.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,21 @@
"""

# pyright: strict
from __future__ import annotations

import os
from abc import abstractmethod
from hashlib import md5
from pathlib import Path
from typing import TYPE_CHECKING
from uuid import UUID

import numpy as np
from numpy.random import Generator, SeedSequence, default_rng
from typing_extensions import Any, Literal, Protocol, Sequence, TypeAlias, override

from lenskit.data import RecQuery
if TYPE_CHECKING: # avoid circular import
from lenskit.data import RecQuery

Check warning on line 26 in lenskit/lenskit/random.py

View check run for this annotation

Codecov / codecov/patch

lenskit/lenskit/random.py#L26

Added line #L26 was not covered by tests

SeedLike: TypeAlias = int | Sequence[int] | np.random.SeedSequence
"""
Expand Down
76 changes: 76 additions & 0 deletions lenskit/tests/data/test_negative_sampling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import numpy as np

from pytest import mark

from lenskit.data import Dataset
from lenskit.logging import get_logger

_log = get_logger(__name__)


@mark.parametrize("weighting", ["uniform", "popularity"])
def test_negative(rng: np.random.Generator, ml_ds: Dataset, weighting):
log = _log.bind()
matrix = ml_ds.interactions().matrix()

users = rng.choice(ml_ds.user_count, 100, replace=True)
users = np.require(users, "i4")

negatives = matrix.sample_negatives(users, rng=rng, weighting=weighting)

log.info("checking basic item results")
assert np.all(negatives >= 0)
assert np.all(negatives < ml_ds.item_count)
log.info("checking negative items")
for u, i in zip(users, negatives):
ulog = log.bind(
user_num=u.item(),
user_id=int(ml_ds.users.id(u)), # type: ignore
item_num=i.item(),
item_id=int(ml_ds.items.id(i)), # type: ignore
)
row = ml_ds.user_row(user_num=u)
ulog = ulog.bind(u_nitems=len(row))
ulog.debug("checking if item is negative")
assert (u, i) not in matrix.rc_index
print(ml_ds.users.id(u), row.ids())
assert i not in row.numbers()


@mark.parametrize("weighting", ["uniform", "popularity"])
def test_negative_unverified(rng: np.random.Generator, ml_ds: Dataset, weighting):
matrix = ml_ds.interactions().matrix()

users = rng.choice(ml_ds.user_count, 500, replace=True)
users = np.require(users, "i4")

negatives = matrix.sample_negatives(users, verify=False, rng=rng, weighting=weighting)

assert np.all(negatives >= 0)
assert np.all(negatives < ml_ds.item_count)


@mark.benchmark()
def test_negative_unverified_bench(rng: np.random.Generator, ml_ds: Dataset, benchmark):
matrix = ml_ds.interactions().matrix()

users = rng.choice(ml_ds.user_count, 500, replace=True)
users = np.require(users, "i4")

def sample():
_items = matrix.sample_negatives(users, verify=False, rng=rng)

benchmark(sample)


@mark.benchmark()
def test_negative_verified_bench(rng: np.random.Generator, ml_ds: Dataset, benchmark):
matrix = ml_ds.interactions().matrix()

users = rng.choice(ml_ds.user_count, 500, replace=True)
users = np.require(users, "i4")

def sample():
_items = matrix.sample_negatives(users, rng=rng)

benchmark(sample)