From f733d0485fe92f5d0fd6473a45d96315ac148aca Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Sat, 11 Jan 2025 20:05:17 -0500 Subject: [PATCH] use iterative trainer for ALS --- lenskit/lenskit/als/_common.py | 83 ++++++++++++++------------------ lenskit/lenskit/als/_implicit.py | 15 +++--- lenskit/lenskit/training.py | 4 +- 3 files changed, 44 insertions(+), 58 deletions(-) diff --git a/lenskit/lenskit/als/_common.py b/lenskit/lenskit/als/_common.py index 954adf181..4123cdc0d 100644 --- a/lenskit/lenskit/als/_common.py +++ b/lenskit/lenskit/als/_common.py @@ -7,21 +7,20 @@ from __future__ import annotations from abc import ABC, abstractmethod +from collections.abc import Generator from typing import Literal, TypeAlias import numpy as np import structlog import torch from pydantic import BaseModel -from typing_extensions import Iterator, NamedTuple, Self, override +from typing_extensions import NamedTuple, override -from lenskit import util from lenskit.data import Dataset, ItemList, QueryInput, RecQuery, Vocabulary from lenskit.data.types import UIPair -from lenskit.logging import item_progress from lenskit.parallel.config import ensure_parallel_init from lenskit.pipeline import Component -from lenskit.training import Trainable, TrainingOptions +from lenskit.training import IterativeTraining, TrainingOptions EntityClass: TypeAlias = Literal["user", "item"] @@ -126,7 +125,7 @@ def to(self, device): return self._replace(ui_rates=self.ui_rates.to(device), iu_rates=self.iu_rates.to(device)) -class ALSBase(ABC, Component[ItemList], Trainable): +class ALSBase(IterativeTraining, Component[ItemList], ABC): """ Base class for ALS models. @@ -144,7 +143,9 @@ class ALSBase(ABC, Component[ItemList], Trainable): logger: structlog.stdlib.BoundLogger @override - def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()) -> bool: + def training_loop( + self, data: Dataset, options: TrainingOptions + ) -> Generator[dict[str, float], None, None]: """ Run ALS to train a model. @@ -154,49 +155,33 @@ def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()) -> Returns: ``True`` if the model was trained. """ - if hasattr(self, "item_features_") and not options.retrain: - return False - ensure_parallel_init() - timer = util.Stopwatch() + + rng = options.random_generator() + + train = self.prepare_data(data) + self.users_ = train.users + self.items_ = train.items + + self.initialize_params(train, rng) + + return self._training_loop_generator(train) for algo in self.fit_iters(data, options): pass # we just need to do the iterations - if self.user_features_ is not None: - self.logger.info( - "trained model in %s (|P|=%f, |Q|=%f)", - timer, - torch.norm(self.user_features_, "fro"), - torch.norm(self.item_features_, "fro"), - features=self.config.features, - ) - else: - self.logger.info( - "trained model in %s (|Q|=%f)", - timer, - torch.norm(self.item_features_, "fro"), - features=self.config.features, - ) - return True - def fit_iters(self, data: Dataset, options: TrainingOptions) -> Iterator[Self]: + def _training_loop_generator( + self, train: TrainingData + ) -> Generator[dict[str, float], None, None]: """ Run ALS to train a model, yielding after each iteration. Args: ratings: the ratings data frame. """ - log = self.logger = self.logger.bind(features=self.config.features) - rng = options.random_generator() - - train = self.prepare_data(data) - self.users_ = train.users - self.items_ = train.items - - self.initialize_params(train, rng) assert self.user_features_ is not None assert self.item_features_ is not None @@ -207,27 +192,26 @@ def fit_iters(self, data: Dataset, options: TrainingOptions) -> Iterator[Self]: "item", train.iu_rates, self.item_features_, self.user_features_, self.config.item_reg ) - log.info("beginning ALS model training") - - with item_progress("Training ALS", self.config.epochs) as epb: - for epoch in range(self.config.epochs): - log = log.bind(epoch=epoch) - epoch = epoch + 1 + for epoch in range(self.config.epochs): + log = log.bind(epoch=epoch) + epoch = epoch + 1 - du = self.als_half_epoch(epoch, u_ctx) - log.debug("finished user epoch") + du = self.als_half_epoch(epoch, u_ctx) + log.debug("finished user epoch") - di = self.als_half_epoch(epoch, i_ctx) - log.debug("finished item epoch") + di = self.als_half_epoch(epoch, i_ctx) + log.debug("finished item epoch") - log.info("finished epoch (|ΔP|=%.3f, |ΔQ|=%.3f)", du, di) - epb.update() - yield self + log.debug("finished epoch (|ΔP|=%.3f, |ΔQ|=%.3f)", du, di) + yield {"deltaP": du, "deltaQ": di} if not self.config.save_user_features: self.user_features_ = None self.user_ = None + log.debug("finalizing model training") + self.finalize_training() + @abstractmethod def prepare_data(self, data: Dataset) -> TrainingData: # pragma: no cover """ @@ -270,6 +254,9 @@ def als_half_epoch(self, epoch: int, context: TrainContext) -> float: # pragma: """ ... + def finalize_training(self): + pass + @override def __call__(self, query: QueryInput, items: ItemList) -> ItemList: query = RecQuery.create(query) diff --git a/lenskit/lenskit/als/_implicit.py b/lenskit/lenskit/als/_implicit.py index c183618f4..78d02f69f 100644 --- a/lenskit/lenskit/als/_implicit.py +++ b/lenskit/lenskit/als/_implicit.py @@ -17,7 +17,6 @@ from lenskit.logging.progress import item_progress_handle, pbh_update from lenskit.math.solve import solve_cholesky from lenskit.parallel.chunking import WorkChunks -from lenskit.training import TrainingOptions from ._common import ALSBase, ALSConfig, TrainContext, TrainingData @@ -71,14 +70,6 @@ class ImplicitMFScorer(ALSBase): OtOr_: torch.Tensor - @override - def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()): - if super().train(data, options): - # compute OtOr and save it on the model - reg = self.config.user_reg - self.OtOr_ = _implicit_otor(self.item_features_, reg) - return True - @override def prepare_data(self, data: Dataset) -> TrainingData: if self.config.use_ratings: @@ -109,6 +100,12 @@ def als_half_epoch(self, epoch: int, context: TrainContext) -> float: with item_progress_handle(f"epoch {epoch} {context.label}s", total=context.nrows) as pbh: return _train_implicit_cholesky_fanout(context, OtOr, chunks, pbh) + def finalize_training(self): + # compute OtOr and save it on the model + reg = self.config.user_reg + self.OtOr_ = _implicit_otor(self.item_features_, reg) + return True + @override def new_user_embedding( self, user_num: int | None, user_items: ItemList diff --git a/lenskit/lenskit/training.py b/lenskit/lenskit/training.py index b73cd9a61..ab5d2d2c4 100644 --- a/lenskit/lenskit/training.py +++ b/lenskit/lenskit/training.py @@ -126,7 +126,7 @@ def expected_training_epochs(self) -> int | None: if cfg: return getattr(cfg, "epochs", None) - def train(self, data: Dataset, options: TrainingOptions) -> None: + def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()) -> None: """ Implementation of :meth:`Trainable.train` that uses the training loop. It also uses the :attr:`trained_epochs` attribute to detect if the model @@ -153,6 +153,8 @@ def train(self, data: Dataset, options: TrainingOptions) -> None: self.trained_epochs += 1 pb.update() + log.info("model training finished", epochs=self.trained_epochs) + @abstractmethod def training_loop( self, data: Dataset, options: TrainingOptions