Skip to content

Commit

Permalink
add interaction log tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Jul 11, 2024
1 parent 4657a29 commit 03ae65a
Showing 1 changed file with 75 additions and 0 deletions.
75 changes: 75 additions & 0 deletions lenskit/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pytest import fixture, raises

from lenskit.data import Dataset, from_interactions_df
from lenskit.data.tables import NumpyUserItemTable, TorchUserItemTable


@fixture(scope="module")
Expand Down Expand Up @@ -130,3 +131,77 @@ def test_item_num_many(ml_ds: Dataset):
def test_item_num_missing_error(ml_ds: Dataset):
with raises(KeyError):
ml_ds.item_num(-402, missing="error")


def test_pandas_log_defaults(ml_ratings: pd.DataFrame, ml_ds: Dataset):
int_df = ml_ds.interaction_log(format="pandas")
assert isinstance(int_df, pd.DataFrame)
# we should have exactly the 4 expected columns
assert len(int_df.columns) == 4
assert "user_num" in int_df.columns
assert "item_num" in int_df.columns
assert "rating" in int_df.columns
assert "timestamp" in int_df.columns

# the interact
int_df = int_df.sort_values(["user_num", "item_num"])
uids = ml_ds.user_id(int_df["user_num"])
iids = ml_ds.user_id(int_df["item_num"])

ml_df = ml_ratings.sort_values(["userId", "movieId"])
assert np.all(uids == ml_df["userId"])
assert np.all(iids == ml_df["movieId"])
assert np.all(int_df["rating"] == ml_df["rating"])
assert np.all(int_df["timestamp"] == ml_df["timestamp"])

# and the total length
assert len(int_df) == len(ml_ratings)


def test_pandas_log_ids(ml_ratings: pd.DataFrame, ml_ds: Dataset):
int_df = ml_ds.interaction_log(format="pandas", original_ids=True)
assert isinstance(int_df, pd.DataFrame)
# we should have exactly the 4 expected columns
assert len(int_df.columns) == 4
assert "user_id" in int_df.columns
assert "item_id" in int_df.columns
assert "rating" in int_df.columns
assert "timestamp" in int_df.columns

# the interact
int_df = int_df.sort_values(["user_id", "item_id"])

ml_df = ml_ratings.sort_values(["userId", "movieId"])
assert np.all(int_df["user_id"] == ml_df["userId"])
assert np.all(int_df["item_id"] == ml_df["movieId"])
assert np.all(int_df["rating"] == ml_df["rating"])
assert np.all(int_df["timestamp"] == ml_df["timestamp"])

# and the total length
assert len(int_df) == len(ml_ratings)


def test_numpy_log_defaults(ml_ratings: pd.DataFrame, ml_ds: Dataset):
log = ml_ds.interaction_log(format="numpy")
assert isinstance(log, NumpyUserItemTable)
assert log.ratings is not None
assert log.timestamps is not None

# and the total length
assert len(log.user_nums) == len(ml_ratings)
assert len(log.item_nums) == len(ml_ratings)
assert len(log.ratings) == len(ml_ratings)
assert len(log.timestamps) == len(ml_ratings)


def test_torch_log_defaults(ml_ratings: pd.DataFrame, ml_ds: Dataset):
log = ml_ds.interaction_log(format="torch")
assert isinstance(log, TorchUserItemTable)
assert log.ratings is not None
assert log.timestamps is not None

# and the total length
assert len(log.user_nums) == len(ml_ratings)
assert len(log.item_nums) == len(ml_ratings)
assert len(log.ratings) == len(ml_ratings)
assert len(log.timestamps) == len(ml_ratings)

0 comments on commit 03ae65a

Please sign in to comment.