Skip to content

Commit

Permalink
Merge pull request #612 from mdekstrand/tweak/history-no-dataset
Browse files Browse the repository at this point in the history
Simplify training history components with new dataset
  • Loading branch information
mdekstrand authored Jan 21, 2025
2 parents 016e580 + fb728a6 commit 3d0ea7c
Showing 3 changed files with 102 additions and 82 deletions.
118 changes: 59 additions & 59 deletions lenskit/lenskit/basic/history.py
Original file line number Diff line number Diff line change
@@ -5,22 +5,31 @@
from __future__ import annotations

import logging
from dataclasses import dataclass
from typing import Literal

import numpy as np
import pandas as pd
from scipy.sparse import csr_array
from typing_extensions import override

from lenskit.data import Dataset, ItemList, QueryInput, RecQuery
from lenskit.data.matrix import CSRStructure
from lenskit.data.vocab import Vocabulary
from lenskit.data.dataset import MatrixRelationshipSet
from lenskit.diagnostics import DataError
from lenskit.pipeline import Component
from lenskit.training import Trainable, TrainingOptions

_logger = logging.getLogger(__name__)


@dataclass
class LookupConfig:
interaction_class: str | None = None
"""
The name of the interaction class to use. Leave ``None`` to use the
dataset's default interaction class.
"""


class UserTrainingHistoryLookup(Component[ItemList], Trainable):
"""
Look up a user's history from the training data.
@@ -29,16 +38,21 @@ class UserTrainingHistoryLookup(Component[ItemList], Trainable):
Caller
"""

config: None
training_data_: Dataset
config: LookupConfig

interactions: MatrixRelationshipSet

@override
def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()):
# TODO: find a better data structure for this
if hasattr(self, "training_data_") and not options.retrain:
if hasattr(self, "interactions") and not options.retrain:
return

self.training_data_ = data
self.interactions = data.interactions(self.config.interaction_class).matrix()
if self.interactions.row_type != "user": # pragma: nocover
raise DataError("interactions must have user rows")
if self.interactions.col_type != "item": # pragma: nocover
raise DataError("interactions must have item columns")

def __call__(self, query: QueryInput) -> RecQuery:
"""
@@ -50,89 +64,75 @@ def __call__(self, query: QueryInput) -> RecQuery:
return query

if query.user_items is None:
query.user_items = self.training_data_.user_row(query.user_id)
query.user_items = self.interactions.row_items(query.user_id)

return query


@dataclass
class KnownRatingConfig(LookupConfig):
score: Literal["rating", "indicator"] | None = None
"""
The field name to use to score items, or ``"indicator"`` to score with 0/1
based on presence in the training data. The default, ``None``, uses ratings
if available, and otherwise scores with ` for interacted items and leaves
un-interacted items unscored.
"""
source: Literal["training", "query"] = "training"
"""
Whether to get the known ratings from the training data or from the query.
"""


class KnownRatingScorer(Component[ItemList], Trainable):
"""
Score items by returning their values from the training data.
Stability:
Caller
Args:
score:
Whether to score items with their rating values, or a 0/1 indicator
of their presence in the training data. The default (``None``) uses
ratings if available, and otherwise scores with 1 for interacted
items and leaves non-interacted items unscored.
source:
Whether to use the training data or the user's history represented
in the query as the source of score data.
"""

config: None
score: Literal["rating", "indicator"] | None
source: Literal["training", "query"]

users_: Vocabulary
items_: Vocabulary
matrix_ = csr_array | CSRStructure

def __init__(
self,
score: Literal["rating", "indicator"] | None = None,
source: Literal["training", "query"] = "training",
):
self.score = score
self.source = source
config: KnownRatingConfig
interactions: MatrixRelationshipSet

@override
def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()):
if hasattr(self, "matrix_") and not options.retrain:
if hasattr(self, "interactions") and not options.retrain:
return

if self.source == "query":
if self.config.source == "query":
return

self.users_ = data.users
self.items_ = data.items
if self.score == "indicator":
self.matrix_ = data.interaction_matrix(format="structure")
else:
self.matrix_ = data.interaction_matrix(format="scipy", field="rating")
self.interactions = data.interactions(self.config.interaction_class).matrix()

def __call__(self, query: QueryInput, items: ItemList) -> ItemList:
query = RecQuery.create(query)

# figure out what scores we start with
scores = None
if self.source == "query" and query.user_items is not None:
if self.score != "indicator":
scores = query.user_items.field("rating", "pandas", index="ids")
if scores is None:
scores = pd.Series(1.0, index=query.user_items.ids())
ilist = None
if self.config.source == "query" and query.user_items is not None:
ilist = query.user_items

elif (
self.source == "training" and query.user_id is not None and query.user_id in self.users_
self.config.source == "training"
and query.user_id is not None
and query.user_id in self.interactions.row_vocabulary
):
urow = self.users_.number(query.user_id)
if isinstance(self.matrix_, csr_array):
assert self.score != "indicator"
# get the user's row as a sparse array
uarr = self.matrix_[[urow]]
assert isinstance(uarr, csr_array)
# create a series
scores = pd.Series(uarr.data, index=self.items_.ids(uarr.indices))
elif isinstance(self.matrix_, CSRStructure):
scores = pd.Series(1.0, index=self.items_.ids(self.matrix_.row_cs(urow)))
ilist = self.interactions.row_items(query.user_id)

if ilist is None:
scores = None
elif self.config.score == "indicator":
scores = pd.Series(1.0, index=ilist.ids())
else:
scores = ilist.field("rating", format="pandas", index="ids")
if scores is None and self.config.score is None:
scores = pd.Series(1.0, index=ilist.ids())

if scores is None:
scores = pd.Series(np.nan, index=items.ids())

scores = scores.reindex(
items.ids(), fill_value=0.0 if self.score == "indicator" else np.nan
items.ids(), fill_value=0.0 if self.config.score == "indicator" else np.nan
)
return ItemList(items, scores=scores.values) # type: ignore
45 changes: 22 additions & 23 deletions lenskit/lenskit/data/dataset.py
Original file line number Diff line number Diff line change
@@ -631,11 +631,10 @@ class RelationshipSet:
For two-entity relationships without duplicates (including relationships
formed by coalescing repeated relationships or interactions),
:class:`MatrixRelationshipSet` extends this with additional capabilities.
"""
dataset: Dataset
"""
The dataset for these relationships.
Relationship sets can be pickled or serialized, and will not save the entire
dataset with them. They are therefore safe to save as component elements
during training processes.
"""

name: str
@@ -649,6 +648,7 @@ class RelationshipSet:
The Arrow table of relationship information.
"""

_vocabularies: dict[str, Vocabulary]
_link_cols: list[str]

def __init__(
@@ -658,10 +658,11 @@ def __init__(
schema: RelationshipSchema,
table: pa.Table,
):
self.dataset = ds
self.name = name
self.schema = schema
self._table = table

self._vocabularies = {e: ds.entities(e).vocabulary for e in schema.entities}
self._link_cols = [num_col_name(e) for e in schema.entities]

@property
@@ -699,9 +700,7 @@ def arrow(self, *, attributes: str | list[str] | None = None, ids=False) -> pa.T
id_cols = {}
for e in self.schema.entity_class_names:
id_cols[id_col_name(e)] = pa.array(
self.dataset.entities(e).vocabulary.ids(
table.column(num_col_name(e)).to_numpy()
)
self._vocabularies[e].ids(table.column(num_col_name(e)).to_numpy())
)
id_tbl = pa.table(id_cols)
cols = id_tbl.column_names
@@ -760,11 +759,11 @@ class MatrixRelationshipSet(RelationshipSet):
"""

_row_ptrs: np.ndarray[int, np.dtype[np.int32]]
_row_vocab: Vocabulary
row_vocabulary: Vocabulary
row_type: str
_row_stats: pd.DataFrame | None = None

_col_vocab: Vocabulary
col_vocabulary: Vocabulary
col_type: str
_col_stats: pd.DataFrame | None = None

@@ -780,15 +779,15 @@ def __init__(
entities = list(schema.entities.keys())
row, col = entities
self.row_type = row
self._row_vocab = ds.entities(row).vocabulary
self.row_vocabulary = ds.entities(row).vocabulary
self.col_type = col
self._col_vocab = ds.entities(col).vocabulary
self.col_vocabulary = ds.entities(col).vocabulary

e_cols = [num_col_name(e) for e in entities]
table = table.sort_by([(c, "ascending") for c in e_cols])

# compute the row pointers
n_rows = len(self._row_vocab)
n_rows = len(self.row_vocabulary)
row_sizes = np.zeros(n_rows + 1, dtype=np.int32())
rsz_struct = pc.value_counts(table.column(e_cols[0]))
rsz_nums = rsz_struct.field("values")
@@ -808,8 +807,8 @@ def csr_structure(self) -> CSRStructure:
"""
Get the compressed sparse row structure of this relationship matrix.
"""
n_rows = len(self._row_vocab)
n_cols = len(self._col_vocab)
n_rows = len(self.row_vocabulary)
n_cols = len(self.col_vocabulary)

colinds = self._table.column(num_col_name(self.col_type)).to_numpy()
return CSRStructure(self._row_ptrs, colinds, (n_rows, n_cols))
@@ -858,8 +857,8 @@ def scipy(
Returns:
The sparse matrix.
"""
n_rows = len(self._row_vocab)
n_cols = len(self._col_vocab)
n_rows = len(self.row_vocabulary)
n_cols = len(self.col_vocabulary)
nnz = self._table.num_rows

colinds = self._table.column(num_col_name(self.col_type)).to_numpy()
@@ -900,8 +899,8 @@ def torch(self, attribute: str | None = None, *, layout: LAYOUT = "csr") -> torc
Returns:
The sparse matrix.
"""
n_rows = len(self._row_vocab)
n_cols = len(self._col_vocab)
n_rows = len(self.row_vocabulary)
n_cols = len(self.col_vocabulary)
nnz = self._table.num_rows

colinds = self._table.column(num_col_name(self.col_type)).to_numpy()
@@ -933,7 +932,7 @@ def row_table(self, id: ID | None = None, *, number: int | None = None) -> pa.Ta
raise ValueError("must provide one of id and number")

if number is None:
number = self._row_vocab.number(id, "none")
number = self.row_vocabulary.number(id, "none")
if number is None:
return None

@@ -956,16 +955,16 @@ def row_items(self, id: ID | None = None, *, number: int | None = None) -> ItemL
if tbl is None:
return None

return ItemList.from_arrow(tbl, vocabulary=self._col_vocab)
return ItemList.from_arrow(tbl, vocabulary=self.col_vocabulary)

def row_stats(self):
if self._row_stats is None:
self._row_stats = self._compute_stats(self.row_type, self.col_type, self._row_vocab)
self._row_stats = self._compute_stats(self.row_type, self.col_type, self.row_vocabulary)
return self._row_stats

def col_stats(self):
if self._col_stats is None:
self._col_stats = self._compute_stats(self.col_type, self.row_type, self._col_vocab)
self._col_stats = self._compute_stats(self.col_type, self.row_type, self.col_vocabulary)
return self._col_stats

def _compute_stats(
21 changes: 21 additions & 0 deletions lenskit/tests/basic/test_history.py
Original file line number Diff line number Diff line change
@@ -71,6 +71,27 @@ def test_lookup_items_only(ml_ds: Dataset):
assert np.all(query.user_items.ids() == ds_row[:-5].ids())


def test_lookup_pickle(ml_ds: Dataset):
"ensure we can correctly pickle a history component"
lookup = UserTrainingHistoryLookup()
lookup.train(ml_ds)

blob = pickle.dumps(lookup)
l2 = pickle.loads(blob)
assert isinstance(l2, UserTrainingHistoryLookup)

assert l2.interactions.count() == lookup.interactions.count()

ds_row = ml_ds.user_row(user_id=100)
l_row = lookup(100)
l2_row = l2(100)

assert l_row.user_id == 100
assert np.all(l_row.user_items.ids() == ds_row.ids())
assert l2_row.user_id == 100
assert np.all(l2_row.user_items.ids() == ds_row.ids())


def test_known_rating_defaults(ml_ds: Dataset):
algo = KnownRatingScorer()
algo.train(ml_ds)

0 comments on commit 3d0ea7c

Please sign in to comment.