From a7c4aab9c5c2f9c3ca239b0bb46c644fb4fd081b Mon Sep 17 00:00:00 2001 From: Michael Ekstrand <mdekstrand@drexel.edu> Date: Tue, 21 Jan 2025 02:14:01 -0500 Subject: [PATCH] fix HPF --- lenskit-hpf/lenskit/hpf.py | 7 +++++-- lenskit/lenskit/data/dataset.py | 4 ++++ 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/lenskit-hpf/lenskit/hpf.py b/lenskit-hpf/lenskit/hpf.py index 1f1218869..9d4819d09 100644 --- a/lenskit-hpf/lenskit/hpf.py +++ b/lenskit-hpf/lenskit/hpf.py @@ -53,14 +53,17 @@ def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()): if hasattr(self, "item_features_") and not options.retrain: return - log = data.interaction_matrix(format="pandas", field="rating") + log = data.interactions().pandas() + if "rating" not in log.columns: + log["rating"] = 1.0 + log = log.rename( columns={ "user_num": "UserId", "item_num": "ItemId", "rating": "Count", } - ) + )[["UserId", "ItemId", "Count"]] hpf = hpfrec.HPF(self.config.features, reindex=False, **self.config.__pydantic_extra__) # type: ignore diff --git a/lenskit/lenskit/data/dataset.py b/lenskit/lenskit/data/dataset.py index 20f914f1f..bb4231432 100644 --- a/lenskit/lenskit/data/dataset.py +++ b/lenskit/lenskit/data/dataset.py @@ -670,6 +670,10 @@ def is_interaction(self) -> bool: """ return self.schema.interaction + @property + def attribute_names(self) -> list[str]: + return [c for c in self._table.column_names if c not in self._link_cols] + def count(self): if "count" in self._table.column_names: raise NotImplementedError()