Skip to content

Commit

Permalink
Merge pull request #600 from mdekstrand/feature/config-cleanups
Browse files Browse the repository at this point in the history
Migrate HPF to configuration object and fix BiasConfig serialization
  • Loading branch information
mdekstrand authored Jan 11, 2025
2 parents d67d1e0 + 09e05f7 commit 275b43b
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 17 deletions.
22 changes: 10 additions & 12 deletions lenskit-hpf/lenskit/hpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,20 @@

import hpfrec
import numpy as np
from typing_extensions import Any, override
from pydantic import BaseModel, JsonValue
from typing_extensions import override

from lenskit.data import Dataset, ItemList, QueryInput, RecQuery, Vocabulary
from lenskit.pipeline import Component, Trainable

_logger = logging.getLogger(__name__)


class HPFConfig(BaseModel, extra="allow"):
__pydantic_extra__: dict[str, JsonValue]
features: int = 50


class HPFScorer(Component[ItemList], Trainable):
"""
Hierarchical Poisson factorization, provided by
Expand All @@ -34,21 +40,13 @@ class HPFScorer(Component[ItemList], Trainable):
additional arguments to pass to :class:`hpfrec.HPF`.
"""

features: int
_kwargs: dict[str, Any]
config: HPFConfig

users_: Vocabulary
user_features_: np.ndarray[tuple[int, int], np.dtype[np.float64]]
items_: Vocabulary
item_features_: np.ndarray[tuple[int, int], np.dtype[np.float64]]

def __init__(self, features: int = 50, **kwargs):
self.features = features
self._kwargs = kwargs

def get_config(self):
return {"features": self.features} | self._kwargs

@property
def is_trained(self) -> bool:
return hasattr(self, "item_features_")
Expand All @@ -64,9 +62,9 @@ def train(self, data: Dataset):
}
)

hpf = hpfrec.HPF(self.features, reindex=False, **self._kwargs)
hpf = hpfrec.HPF(self.config.features, reindex=False, **self.config.__pydantic_extra__) # type: ignore

_logger.info("fitting HPF model with %d features", self.features)
_logger.info("fitting HPF model with %d features", self.config.features)
hpf.fit(log)

self.users_ = data.users
Expand Down
4 changes: 2 additions & 2 deletions lenskit-hpf/tests/test_hpf.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class TestHPF(BasicComponentTests, ScorerTests):

@mark.slow
def test_hpf_train_large(tmp_path, ml_ratings):
algo = hpf.HPFScorer(20)
algo = hpf.HPFScorer(features=20)
ratings = ml_ratings.assign(rating=ml_ratings.rating + 0.5)
ds = from_interactions_df(ratings)
algo.train(ds)
Expand Down Expand Up @@ -57,7 +57,7 @@ def test_hpf_train_large(tmp_path, ml_ratings):

@mark.slow
def test_hpf_train_binary(tmp_path, ml_ratings):
algo = hpf.HPFScorer(20)
algo = hpf.HPFScorer(features=20)
ratings = ml_ratings.drop(columns=["timestamp", "rating"])
ds = from_interactions_df(ratings)
algo.train(ds)
Expand Down
8 changes: 5 additions & 3 deletions lenskit/lenskit/basic/bias.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
import logging
from collections.abc import Container
from dataclasses import dataclass
from typing import Literal
from typing import Annotated, Literal

import numpy as np
import torch
from pydantic import BaseModel, NonNegativeFloat
from pydantic import BaseModel, NonNegativeFloat, PlainSerializer
from typing_extensions import Self, TypeAlias, overload

from lenskit.data import ID, Dataset, ItemList, QueryInput, RecQuery, Vocabulary
Expand Down Expand Up @@ -256,7 +256,9 @@ class BiasConfig(BaseModel, extra="forbid"):
Configuration for :class:`BiasScorer`.
"""

entities: set[Literal["user", "item"]] = {"user", "item"}
entities: Annotated[
set[Literal["user", "item"]], PlainSerializer(lambda s: sorted(s), return_type=list[str])
] = {"user", "item"}
"""
The entities to compute biases for, in addition to global bais. Defaults to
users and items.
Expand Down

0 comments on commit 275b43b

Please sign in to comment.