Skip to content

Commit

Permalink
fix itemlist test for user rows
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Aug 7, 2024
1 parent ee3248f commit e1f249a
Showing 1 changed file with 26 additions and 13 deletions.
39 changes: 26 additions & 13 deletions lenskit/tests/test_itemlist.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from pytest import raises

from lenskit.data import ItemList
from lenskit.data.dataset import Dataset
from lenskit.data.vocab import Vocabulary

ITEMS = ["a", "b", "c", "d", "e"]
Expand Down Expand Up @@ -204,7 +205,7 @@ def test_pandas_df_ordered():
assert np.all(df["rank"] == np.arange(1, 6))


def test_item_list_pickle_compact(ml_ds):
def test_item_list_pickle_compact(ml_ds: Dataset):
nums = [1, 0, 308, 24, 72]
il = ItemList(item_nums=nums, vocabulary=ml_ds.items)
assert len(il) == 5
Expand All @@ -221,8 +222,9 @@ def test_item_list_pickle_compact(ml_ds):
assert np.all(il2.numbers() == il.numbers())


def test_item_list_pickle_fields(ml_ds):
row = ml_ds.user_row(user_num=400)
def test_item_list_pickle_fields(ml_ds: Dataset):
row = ml_ds.user_row(user_num=400).item_list()
assert row is not None
data = pickle.dumps(row)
r2 = pickle.loads(data)

Expand All @@ -235,8 +237,9 @@ def test_item_list_pickle_fields(ml_ds):
assert np.all(r2.field("timestamp") == row.field("timestamp"))


def test_subset_mask(ml_ds):
row = ml_ds.user_row(user_num=400)
def test_subset_mask(ml_ds: Dataset):
row = ml_ds.user_row(user_num=400).item_list()
assert row is not None
ratings = row.field("rating")
assert ratings is not None

Expand All @@ -246,12 +249,17 @@ def test_subset_mask(ml_ds):
assert len(pos) == np.sum(mask)
assert np.all(pos.ids() == row.ids()[mask])
assert np.all(pos.numbers() == row.numbers()[mask])
assert np.all(pos.field("rating") == row.field("rating")[mask])
assert np.all(pos.field("rating") > 3.0)
rf = row.field("rating")
assert rf is not None
prf = pos.field("rating")
assert prf is not None
assert np.all(prf == rf[mask])
assert np.all(prf > 3.0)


def test_subset_idx(ml_ds):
row = ml_ds.user_row(user_num=400)
def test_subset_idx(ml_ds: Dataset):
row = ml_ds.user_row(user_num=400).item_list()
assert row is not None
ratings = row.field("rating")
assert ratings is not None

Expand All @@ -261,11 +269,14 @@ def test_subset_idx(ml_ds):
assert len(pos) == 3
assert np.all(pos.ids() == row.ids()[ks])
assert np.all(pos.numbers() == row.numbers()[ks])
assert np.all(pos.field("rating") == row.field("rating")[ks])
rf = row.field("rating")
assert rf is not None
assert np.all(pos.field("rating") == rf[ks])


def test_subset_slice(ml_ds):
row = ml_ds.user_row(user_num=400)
def test_subset_slice(ml_ds: Dataset):
row = ml_ds.user_row(user_num=400).item_list()
assert row is not None
ratings = row.field("rating")
assert ratings is not None

Expand All @@ -274,7 +285,9 @@ def test_subset_slice(ml_ds):
assert len(pos) == 5
assert np.all(pos.ids() == row.ids()[5:10])
assert np.all(pos.numbers() == row.numbers()[5:10])
assert np.all(pos.field("rating") == row.field("rating")[5:10])
rf = row.field("rating")
assert rf is not None
assert np.all(pos.field("rating") == rf[5:10])


def test_from_df():
Expand Down

0 comments on commit e1f249a

Please sign in to comment.