Skip to content

Commit

Permalink
use iterative trainer for ALS
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Jan 12, 2025
1 parent 66a407a commit f733d04
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 58 deletions.
83 changes: 35 additions & 48 deletions lenskit/lenskit/als/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,20 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Generator
from typing import Literal, TypeAlias

import numpy as np
import structlog
import torch
from pydantic import BaseModel
from typing_extensions import Iterator, NamedTuple, Self, override
from typing_extensions import NamedTuple, override

from lenskit import util
from lenskit.data import Dataset, ItemList, QueryInput, RecQuery, Vocabulary
from lenskit.data.types import UIPair
from lenskit.logging import item_progress
from lenskit.parallel.config import ensure_parallel_init
from lenskit.pipeline import Component
from lenskit.training import Trainable, TrainingOptions
from lenskit.training import IterativeTraining, TrainingOptions

EntityClass: TypeAlias = Literal["user", "item"]

Expand Down Expand Up @@ -126,7 +125,7 @@ def to(self, device):
return self._replace(ui_rates=self.ui_rates.to(device), iu_rates=self.iu_rates.to(device))


class ALSBase(ABC, Component[ItemList], Trainable):
class ALSBase(IterativeTraining, Component[ItemList], ABC):
"""
Base class for ALS models.
Expand All @@ -144,7 +143,9 @@ class ALSBase(ABC, Component[ItemList], Trainable):
logger: structlog.stdlib.BoundLogger

@override
def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()) -> bool:
def training_loop(
self, data: Dataset, options: TrainingOptions
) -> Generator[dict[str, float], None, None]:
"""
Run ALS to train a model.
Expand All @@ -154,49 +155,33 @@ def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()) ->
Returns:
``True`` if the model was trained.
"""
if hasattr(self, "item_features_") and not options.retrain:
return False

ensure_parallel_init()
timer = util.Stopwatch()

rng = options.random_generator()

train = self.prepare_data(data)
self.users_ = train.users
self.items_ = train.items

self.initialize_params(train, rng)

return self._training_loop_generator(train)

for algo in self.fit_iters(data, options):
pass # we just need to do the iterations

if self.user_features_ is not None:
self.logger.info(
"trained model in %s (|P|=%f, |Q|=%f)",
timer,
torch.norm(self.user_features_, "fro"),
torch.norm(self.item_features_, "fro"),
features=self.config.features,
)
else:
self.logger.info(
"trained model in %s (|Q|=%f)",
timer,
torch.norm(self.item_features_, "fro"),
features=self.config.features,
)

return True

def fit_iters(self, data: Dataset, options: TrainingOptions) -> Iterator[Self]:
def _training_loop_generator(
self, train: TrainingData
) -> Generator[dict[str, float], None, None]:
"""
Run ALS to train a model, yielding after each iteration.
Args:
ratings: the ratings data frame.
"""

log = self.logger = self.logger.bind(features=self.config.features)
rng = options.random_generator()

train = self.prepare_data(data)
self.users_ = train.users
self.items_ = train.items

self.initialize_params(train, rng)

assert self.user_features_ is not None
assert self.item_features_ is not None
Expand All @@ -207,27 +192,26 @@ def fit_iters(self, data: Dataset, options: TrainingOptions) -> Iterator[Self]:
"item", train.iu_rates, self.item_features_, self.user_features_, self.config.item_reg
)

log.info("beginning ALS model training")

with item_progress("Training ALS", self.config.epochs) as epb:
for epoch in range(self.config.epochs):
log = log.bind(epoch=epoch)
epoch = epoch + 1
for epoch in range(self.config.epochs):
log = log.bind(epoch=epoch)
epoch = epoch + 1

du = self.als_half_epoch(epoch, u_ctx)
log.debug("finished user epoch")
du = self.als_half_epoch(epoch, u_ctx)
log.debug("finished user epoch")

di = self.als_half_epoch(epoch, i_ctx)
log.debug("finished item epoch")
di = self.als_half_epoch(epoch, i_ctx)
log.debug("finished item epoch")

log.info("finished epoch (|ΔP|=%.3f, |ΔQ|=%.3f)", du, di)
epb.update()
yield self
log.debug("finished epoch (|ΔP|=%.3f, |ΔQ|=%.3f)", du, di)
yield {"deltaP": du, "deltaQ": di}

if not self.config.save_user_features:
self.user_features_ = None
self.user_ = None

log.debug("finalizing model training")
self.finalize_training()

@abstractmethod
def prepare_data(self, data: Dataset) -> TrainingData: # pragma: no cover
"""
Expand Down Expand Up @@ -270,6 +254,9 @@ def als_half_epoch(self, epoch: int, context: TrainContext) -> float: # pragma:
"""
...

def finalize_training(self):
pass

@override
def __call__(self, query: QueryInput, items: ItemList) -> ItemList:
query = RecQuery.create(query)
Expand Down
15 changes: 6 additions & 9 deletions lenskit/lenskit/als/_implicit.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from lenskit.logging.progress import item_progress_handle, pbh_update
from lenskit.math.solve import solve_cholesky
from lenskit.parallel.chunking import WorkChunks
from lenskit.training import TrainingOptions

from ._common import ALSBase, ALSConfig, TrainContext, TrainingData

Expand Down Expand Up @@ -71,14 +70,6 @@ class ImplicitMFScorer(ALSBase):

OtOr_: torch.Tensor

@override
def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()):
if super().train(data, options):
# compute OtOr and save it on the model
reg = self.config.user_reg
self.OtOr_ = _implicit_otor(self.item_features_, reg)
return True

@override
def prepare_data(self, data: Dataset) -> TrainingData:
if self.config.use_ratings:
Expand Down Expand Up @@ -109,6 +100,12 @@ def als_half_epoch(self, epoch: int, context: TrainContext) -> float:
with item_progress_handle(f"epoch {epoch} {context.label}s", total=context.nrows) as pbh:
return _train_implicit_cholesky_fanout(context, OtOr, chunks, pbh)

def finalize_training(self):
# compute OtOr and save it on the model
reg = self.config.user_reg
self.OtOr_ = _implicit_otor(self.item_features_, reg)
return True

@override
def new_user_embedding(
self, user_num: int | None, user_items: ItemList
Expand Down
4 changes: 3 additions & 1 deletion lenskit/lenskit/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def expected_training_epochs(self) -> int | None:
if cfg:
return getattr(cfg, "epochs", None)

def train(self, data: Dataset, options: TrainingOptions) -> None:
def train(self, data: Dataset, options: TrainingOptions = TrainingOptions()) -> None:
"""
Implementation of :meth:`Trainable.train` that uses the training loop.
It also uses the :attr:`trained_epochs` attribute to detect if the model
Expand All @@ -153,6 +153,8 @@ def train(self, data: Dataset, options: TrainingOptions) -> None:
self.trained_epochs += 1
pb.update()

log.info("model training finished", epochs=self.trained_epochs)

@abstractmethod
def training_loop(
self, data: Dataset, options: TrainingOptions
Expand Down

0 comments on commit f733d04

Please sign in to comment.