Skip to content

Commit

Permalink
smoke-tested dataset import
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Jun 17, 2024
1 parent 32ecbdc commit fcbea77
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 28 deletions.
52 changes: 26 additions & 26 deletions lenskit/lenskit/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import torch
from numpy.typing import ArrayLike, NDArray

from lenskit.data.matrix import InteractionMatrix

from .tables import NumpyUserItemTable, TorchUserItemTable

DF_FORMAT: TypeAlias = Literal["numpy", "pandas", "torch"]
Expand Down Expand Up @@ -42,7 +44,7 @@ class Dataset:

_user_counts: pd.Series[np.dtype[np.int64]]
_item_counts: pd.Series[np.dtype[np.int64]]
_interactions: pd.DataFrame
_matrix: InteractionMatrix

def __init__(self, users: pd.Series, items: pd.Series, interact_df: pd.DataFrame):
"""
Expand All @@ -54,31 +56,29 @@ def __init__(self, users: pd.Series, items: pd.Series, interact_df: pd.DataFrame
"""
self._user_counts = users
self._item_counts = items
self._interactions = interact_df
self._init_structures()

def _init_structures(self):
self._interactions = pd.concat(
[
pd.DataFrame(
{
"user_num": self._user_counts.index.get_indexer(
self._interactions["user_id"].values
),
"item_num": self._item_counts.index.get_indexer(
self._interactions["item_id"].values
),
},
),
self._interactions,
],
axis=1,
)
self._init_structures(interact_df)

def _init_structures(self, df: pd.DataFrame):
uno = self._user_counts.index.get_indexer(df["user_id"].values)
ino = self._item_counts.index.get_indexer(df["item_id"].values)
assert np.all(uno >= 0)
assert np.all(ino >= 0)

df = df.assign(user_num=uno, item_num=ino)

_log.debug("sorting interaction table")
self._interactions.sort_values(["user_num", "item_num"], ignore_index=True, inplace=True)
_log.debug("rating data frame:\n%s", self._interactions)
if np.any(np.diff(self._interactions["item_num"]) == 0):
df.sort_values(["user_num", "item_num"], ignore_index=True, inplace=True)
_log.debug("rating data frame:\n%s", df)
if np.any(np.diff(df["item_num"]) == 0): # pragma nocover
raise RuntimeError("repeated ratings not yet supported")
self._matrix = InteractionMatrix(
uno,
ino,
df.get("rating", None),
df.get("timestamp", None),
self._user_counts,
self.item_count,
)

@property
def item_vocab(self) -> pd.Index:
Expand Down Expand Up @@ -443,14 +443,14 @@ def normalize_interactions_df(
df.columns,
["user_id", "user", "USER", "userId", "UserId"],
)
if user_col is None:
if user_col is None: # pragma nocover
raise ValueError("no user column found")
if item_col is None:
item_col = _find_column(
df.columns,
["item_id", "item", "ITEM", "itemId", "ItemId"],
)
if item_col is None:
if item_col is None: # pragma nocover
raise ValueError("no item column found")
if rating_col is None:
rating_col = _find_column(
Expand Down
33 changes: 31 additions & 2 deletions lenskit/lenskit/data/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import pandas as pd
import scipy.sparse as sps
import torch
from numpy.typing import ArrayLike
from typing_extensions import Any, Generic, Literal, NamedTuple, Optional, TypeVar, overload

_log = logging.getLogger(__name__)
Expand Down Expand Up @@ -58,10 +59,14 @@ def row_cs(self, row: int) -> np.ndarray:

class InteractionMatrix:
"""
Internal helper class used by :class:`lenskit.data.Datset` to store
interactions in matrix format.
Internal helper class used by :class:`lenskit.data.Dataset` to store the
user-item interaction matrix. The data is stored simultaneously in CSR and
COO format.
"""

n_users: int
n_items: int

user_nums: np.ndarray[int, np.dtype[np.int32]]
"User (row) numbers."
user_ptrs: np.ndarray[int, np.dtype[np.int32]]
Expand All @@ -73,6 +78,30 @@ class InteractionMatrix:
timestamps: Optional[np.ndarray[int, np.dtype[np.int64]]]
"Timestamps as 64-bit Unix timestamps."

def __init__(
self,
users: ArrayLike,
items: ArrayLike,
ratings: Optional[ArrayLike],
timestamps: Optional[ArrayLike],
user_counts: pd.Series,
n_items: int,
):
self.user_nums = np.asarray(users, np.int32)
self.item_nums = np.asarray(items, np.int32)
if ratings is not None:
self.ratings = np.asarray(ratings, np.float32)
if timestamps is not None:
self.timestamps = np.asarray(timestamps, np.int64)

self.n_items = n_items
self.n_users = len(user_counts)
cp1 = np.zeros(self.n_users + 1, np.int32)
cp1[1:] = user_counts
self.user_ptrs = cp1.cumsum()
if self.user_ptrs[-1] != len(self.user_nums):
raise ValueError("mismatched counts and array sizes")

Check warning on line 103 in lenskit/lenskit/data/matrix.py

View check run for this annotation

Codecov / codecov/patch

lenskit/lenskit/data/matrix.py#L103

Added line #L103 was not covered by tests


class RatingMatrix(NamedTuple, Generic[M]):
"""
Expand Down

0 comments on commit fcbea77

Please sign in to comment.