Skip to content

Commit

Permalink
create iterative training class
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Jan 12, 2025
1 parent 1954efb commit 66a407a
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 9 deletions.
65 changes: 61 additions & 4 deletions docs/guide/impl-tips.rst
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
Algorithm Implementation Tips
=============================
Model Implementation Tips
=========================

Implementing algorithms is fun, but there are a few things that are good to keep in mind.

Expand All @@ -9,5 +9,62 @@ In general, development follows the following:
2. Clear
3. Fast

In that order. Further, we always want LensKit to be *usable* in an easy fashion. Code
implementing algorithms, however, may be quite complex in order to achieve good performance.
In that order. Further, we always want LensKit to be *usable* in an easy
fashion. Code implementing commonly-used models, however, may be quite complex
in order to achieve good performance.

.. _iterative-training:

Iterative Training
~~~~~~~~~~~~~~~~~~

The :class:`lenskit.training.IterativeTraining` class provides a standardized
interface and training loop support for training models with iterative methods
that pass through the training data in multiple *epochs*. Models that use this
support extend :class:`~lenskit.training.IterativeTraining` in addition to
:class:`~lenskit.pipeline.Component`, and implement the
:meth:`~lenskit.training.IterativeTraining.training_loop` method instead of
:meth:`~lenskit.training.Trainable.train`. Iteratively-trainable components
should also have an ``epochs`` setting on their configuration class that
specifies the number of training epochs to run.

The :meth:`~lenskit.training.IterativeTraining.training_loop` method does 3 things:

1. Set up initial data structures, preparation, etc. needed for model training.
2. Train the model, yielding after each training epoch. It can optionally
yield a set of metrics, such as training loss or update magnitudes.
3. Perform any final steps and training data cleanup.

The model should be usable after each epoch, to support things like measuring
performance on validation data.

The training loop itself is represented as a Python iterator, so that a ``for``
loop will loop through the training epochs. While the interface definition
specifies the ``Iterator`` type in order to minimize restrictions on component
implementers, we recommend that it actually be a ``Generator``, which allows the
caller to request early termination (through the
:meth:`~collections.abc.Generator.close` method). We also recommend that the
``training_loop()`` method only return the generator after initial data preparation
is complete, so that setup time is not included in the time taken for the first
loop iteration. The easiest way to do implement this is by delegating to an
inner loop function, written as a Python generator:

.. code:: python
def training_loop(self, data: Dataset, options: TrainingOptions):
# do initial data setup/prep for training
context = ...
# pass off to inner generator
return self._training_loop_impl(context)
def _training_loop_impl(self, context):
for i in range(self.config.epochs):
# do the model training
# compute the metrics
try:
yield {'loss': loss}
except GeneratorExit:
# client code has requested early termination
break
# any final cleanup steps
4 changes: 3 additions & 1 deletion lenskit/lenskit/logging/progress/_dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def set_progress_impl(name: str | None, *options: Any):
raise ValueError(f"unknown progress backend {name}")


def item_progress(label: str, total: int, fields: dict[str, str | None] | None = None) -> Progress:
def item_progress(
label: str, total: int | None = None, fields: dict[str, str | None] | None = None
) -> Progress:
"""
Create a progress bar for distinct, counted items.
Expand Down
85 changes: 81 additions & 4 deletions lenskit/lenskit/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,17 +10,19 @@
# pyright: strict
from __future__ import annotations

from abc import ABC, abstractmethod
from collections.abc import Iterator
from dataclasses import dataclass
from typing import (
Protocol,
runtime_checkable,
)
from typing import Protocol, runtime_checkable

import numpy as np

from lenskit.data.dataset import Dataset
from lenskit.logging import get_logger, item_progress
from lenskit.random import RNGInput, random_generator

_log = get_logger(__name__)


@dataclass(frozen=True)
class TrainingOptions:
Expand Down Expand Up @@ -94,3 +96,78 @@ def train(self, data: Dataset, options: TrainingOptions) -> None:
The training options.
"""
raise NotImplementedError()


class IterativeTraining(ABC, Trainable):
"""
Base class for components that support iterative training. This both
automates the :meth:`Trainable.train` method for iterative training in terms
of initialization, epoch, and finalization methods, and exposes those
methods to client code that may wish to directly control the iterative
training process.
Stability:
Full
"""

trained_epochs: int = 0
"""
The number of epochs for which this model has been trained.
"""

@property
def expected_training_epochs(self) -> int | None:
"""
Get the number of training epochs expected to run. The default
implementation looks for an ``epochs`` attribute on the configuration
object (``self.config``).
"""
cfg = getattr(self, "config", None)
if cfg:
return getattr(cfg, "epochs", None)

def train(self, data: Dataset, options: TrainingOptions) -> None:
"""
Implementation of :meth:`Trainable.train` that uses the training loop.
It also uses the :attr:`trained_epochs` attribute to detect if the model
has already been trained for the purposes of honoring
:attr:`TrainingOptions.retrain`, and updates that attribute as model
training progresses.
"""
if self.trained_epochs > 0 and not options.retrain:
return

self.trained_epochs = 0
log = _log.bind(
model=f"{self.__class__.__module__.__qualname__}.{self.__class__.__qualname__}"
)
log.info("training model")
n = self.expected_training_epochs
log.debug("creating training loop")
loop = self.training_loop(data, options)
log.debug("beginning training iterations")
with item_progress("Training iterations", total=n) as pb:
for i, metrics in enumerate(loop, 1):
metrics = metrics or {}
log.info("finished epoch", epoch=i, **metrics)
self.trained_epochs += 1
pb.update()

@abstractmethod
def training_loop(
self, data: Dataset, options: TrainingOptions
) -> Iterator[dict[str, float] | None]:
"""
Training loop implementation, to be supplied by the derived class. This
method should return a iterator that, when iterated, will perform each
training epoch; when training is complete, it should finalize the model
and signal iteration completion.
Each epoch can yield metrics, such as training or validation loss, to be
logged with structured logging and can be used by calling code to do
other analysis.
See :ref:`iterative-training` for more details on writing iterative
training loops.
"""
raise NotImplementedError()

0 comments on commit 66a407a

Please sign in to comment.