diff --git a/lenskit/lenskit/data/dataset.py b/lenskit/lenskit/data/dataset.py index 629f26f42..e5d945859 100644 --- a/lenskit/lenskit/data/dataset.py +++ b/lenskit/lenskit/data/dataset.py @@ -15,6 +15,8 @@ import torch from numpy.typing import ArrayLike, NDArray +from lenskit.data.matrix import InteractionMatrix + from .tables import NumpyUserItemTable, TorchUserItemTable DF_FORMAT: TypeAlias = Literal["numpy", "pandas", "torch"] @@ -42,7 +44,7 @@ class Dataset: _user_counts: pd.Series[np.dtype[np.int64]] _item_counts: pd.Series[np.dtype[np.int64]] - _interactions: pd.DataFrame + _matrix: InteractionMatrix def __init__(self, users: pd.Series, items: pd.Series, interact_df: pd.DataFrame): """ @@ -54,31 +56,29 @@ def __init__(self, users: pd.Series, items: pd.Series, interact_df: pd.DataFrame """ self._user_counts = users self._item_counts = items - self._interactions = interact_df - self._init_structures() - - def _init_structures(self): - self._interactions = pd.concat( - [ - pd.DataFrame( - { - "user_num": self._user_counts.index.get_indexer( - self._interactions["user_id"].values - ), - "item_num": self._item_counts.index.get_indexer( - self._interactions["item_id"].values - ), - }, - ), - self._interactions, - ], - axis=1, - ) + self._init_structures(interact_df) + + def _init_structures(self, df: pd.DataFrame): + uno = self._user_counts.index.get_indexer(df["user_id"].values) + ino = self._item_counts.index.get_indexer(df["item_id"].values) + assert np.all(uno >= 0) + assert np.all(ino >= 0) + + df = df.assign(user_num=uno, item_num=ino) + _log.debug("sorting interaction table") - self._interactions.sort_values(["user_num", "item_num"], ignore_index=True, inplace=True) - _log.debug("rating data frame:\n%s", self._interactions) - if np.any(np.diff(self._interactions["item_num"]) == 0): + df.sort_values(["user_num", "item_num"], ignore_index=True, inplace=True) + _log.debug("rating data frame:\n%s", df) + if np.any(np.diff(df["item_num"]) == 0): # pragma nocover raise RuntimeError("repeated ratings not yet supported") + self._matrix = InteractionMatrix( + uno, + ino, + df.get("rating", None), + df.get("timestamp", None), + self._user_counts, + self.item_count, + ) @property def item_vocab(self) -> pd.Index: @@ -443,14 +443,14 @@ def normalize_interactions_df( df.columns, ["user_id", "user", "USER", "userId", "UserId"], ) - if user_col is None: + if user_col is None: # pragma nocover raise ValueError("no user column found") if item_col is None: item_col = _find_column( df.columns, ["item_id", "item", "ITEM", "itemId", "ItemId"], ) - if item_col is None: + if item_col is None: # pragma nocover raise ValueError("no item column found") if rating_col is None: rating_col = _find_column( diff --git a/lenskit/lenskit/data/matrix.py b/lenskit/lenskit/data/matrix.py index ad26bd7f9..609574811 100644 --- a/lenskit/lenskit/data/matrix.py +++ b/lenskit/lenskit/data/matrix.py @@ -18,6 +18,7 @@ import pandas as pd import scipy.sparse as sps import torch +from numpy.typing import ArrayLike from typing_extensions import Any, Generic, Literal, NamedTuple, Optional, TypeVar, overload _log = logging.getLogger(__name__) @@ -58,10 +59,14 @@ def row_cs(self, row: int) -> np.ndarray: class InteractionMatrix: """ - Internal helper class used by :class:`lenskit.data.Datset` to store - interactions in matrix format. + Internal helper class used by :class:`lenskit.data.Dataset` to store the + user-item interaction matrix. The data is stored simultaneously in CSR and + COO format. """ + n_users: int + n_items: int + user_nums: np.ndarray[int, np.dtype[np.int32]] "User (row) numbers." user_ptrs: np.ndarray[int, np.dtype[np.int32]] @@ -73,6 +78,30 @@ class InteractionMatrix: timestamps: Optional[np.ndarray[int, np.dtype[np.int64]]] "Timestamps as 64-bit Unix timestamps." + def __init__( + self, + users: ArrayLike, + items: ArrayLike, + ratings: Optional[ArrayLike], + timestamps: Optional[ArrayLike], + user_counts: pd.Series, + n_items: int, + ): + self.user_nums = np.asarray(users, np.int32) + self.item_nums = np.asarray(items, np.int32) + if ratings is not None: + self.ratings = np.asarray(ratings, np.float32) + if timestamps is not None: + self.timestamps = np.asarray(timestamps, np.int64) + + self.n_items = n_items + self.n_users = len(user_counts) + cp1 = np.zeros(self.n_users + 1, np.int32) + cp1[1:] = user_counts + self.user_ptrs = cp1.cumsum() + if self.user_ptrs[-1] != len(self.user_nums): + raise ValueError("mismatched counts and array sizes") + class RatingMatrix(NamedTuple, Generic[M]): """