Skip to content

Commit

Permalink
exercise more lookup code paths
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Jun 18, 2024
1 parent fcbea77 commit 653b992
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
2 changes: 2 additions & 0 deletions lenskit/lenskit/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,8 @@ def _lookup_num(
ids: EntityId | ArrayLike,
missing: Literal["error", "negative", "omit"] = "negative",
) -> int | np.ndarray[int, np.dtype[np.int32]] | pd.Series[int]:
if missing not in ["error", "negative", "omit"]: # pragma nocover
raise ValueError(f"invalid missing mode {missing}")
if np.isscalar(ids):
try:
return index.get_loc(cast(EntityId, ids))
Expand Down
38 changes: 37 additions & 1 deletion lenskit/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pandas as pd
from pyprojroot import here

from pytest import fixture
from pytest import fixture, raises

from lenskit.data import Dataset, from_interactions_df

Expand Down Expand Up @@ -84,6 +84,37 @@ def test_user_num_many(ml_ds: Dataset):
assert np.all(ml_ds.user_num(list(users[[1, 5, 23]])) == [1, 5, 23])


def test_user_num_missing_error(ml_ds: Dataset):
with raises(KeyError):
ml_ds.user_num(-402, missing="error")


def test_user_num_missing_negative(ml_ds: Dataset):
assert ml_ds.user_num(-402, missing="negative") == -1


def test_user_num_missing_omit(ml_ds: Dataset):
user = ml_ds.user_vocab[5]
series = ml_ds.user_num([user, -402], missing="omit")
assert len(series) == 1
assert series.loc[user] == 5


def test_user_num_missing_vector_negative(ml_ds: Dataset):
u1 = ml_ds.user_vocab[5]
u2 = ml_ds.user_vocab[100]
res = ml_ds.user_num([u1, -402, u2], missing="negative")
assert len(res) == 3
assert np.all(res == [5, -1, 100])


def test_user_num_missing_vector_error(ml_ds: Dataset):
u1 = ml_ds.user_vocab[5]
u2 = ml_ds.user_vocab[100]
with raises(KeyError):
ml_ds.user_num([u1, -402, u2], missing="error")


def test_item_num_single(ml_ds: Dataset):
items = ml_ds.item_vocab
assert ml_ds.item_num(items[0]) == 0
Expand All @@ -94,3 +125,8 @@ def test_item_num_many(ml_ds: Dataset):
items = ml_ds.item_vocab
assert np.all(ml_ds.item_num(items[[1, 5, 23]]) == [1, 5, 23])
assert np.all(ml_ds.item_num(list(items[[1, 5, 23]])) == [1, 5, 23])


def test_item_num_missing_error(ml_ds: Dataset):
with raises(KeyError):
ml_ds.item_num(-402, missing="error")

0 comments on commit 653b992

Please sign in to comment.