diff --git a/lenskit/lenskit/data/dataset.py b/lenskit/lenskit/data/dataset.py index e5d945859..327d59644 100644 --- a/lenskit/lenskit/data/dataset.py +++ b/lenskit/lenskit/data/dataset.py @@ -510,6 +510,8 @@ def _lookup_num( ids: EntityId | ArrayLike, missing: Literal["error", "negative", "omit"] = "negative", ) -> int | np.ndarray[int, np.dtype[np.int32]] | pd.Series[int]: + if missing not in ["error", "negative", "omit"]: # pragma nocover + raise ValueError(f"invalid missing mode {missing}") if np.isscalar(ids): try: return index.get_loc(cast(EntityId, ids)) diff --git a/lenskit/tests/test_dataset.py b/lenskit/tests/test_dataset.py index 0a38af6e1..38a1915da 100644 --- a/lenskit/tests/test_dataset.py +++ b/lenskit/tests/test_dataset.py @@ -6,7 +6,7 @@ import pandas as pd from pyprojroot import here -from pytest import fixture +from pytest import fixture, raises from lenskit.data import Dataset, from_interactions_df @@ -84,6 +84,37 @@ def test_user_num_many(ml_ds: Dataset): assert np.all(ml_ds.user_num(list(users[[1, 5, 23]])) == [1, 5, 23]) +def test_user_num_missing_error(ml_ds: Dataset): + with raises(KeyError): + ml_ds.user_num(-402, missing="error") + + +def test_user_num_missing_negative(ml_ds: Dataset): + assert ml_ds.user_num(-402, missing="negative") == -1 + + +def test_user_num_missing_omit(ml_ds: Dataset): + user = ml_ds.user_vocab[5] + series = ml_ds.user_num([user, -402], missing="omit") + assert len(series) == 1 + assert series.loc[user] == 5 + + +def test_user_num_missing_vector_negative(ml_ds: Dataset): + u1 = ml_ds.user_vocab[5] + u2 = ml_ds.user_vocab[100] + res = ml_ds.user_num([u1, -402, u2], missing="negative") + assert len(res) == 3 + assert np.all(res == [5, -1, 100]) + + +def test_user_num_missing_vector_error(ml_ds: Dataset): + u1 = ml_ds.user_vocab[5] + u2 = ml_ds.user_vocab[100] + with raises(KeyError): + ml_ds.user_num([u1, -402, u2], missing="error") + + def test_item_num_single(ml_ds: Dataset): items = ml_ds.item_vocab assert ml_ds.item_num(items[0]) == 0 @@ -94,3 +125,8 @@ def test_item_num_many(ml_ds: Dataset): items = ml_ds.item_vocab assert np.all(ml_ds.item_num(items[[1, 5, 23]]) == [1, 5, 23]) assert np.all(ml_ds.item_num(list(items[[1, 5, 23]])) == [1, 5, 23]) + + +def test_item_num_missing_error(ml_ds: Dataset): + with raises(KeyError): + ml_ds.item_num(-402, missing="error")