From 59d616da00a3b99427bacef9a07786bc3c98bc96 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Tue, 14 Jan 2025 16:17:45 -0500 Subject: [PATCH 01/37] Bring back default docs --- docs/guide/pipeline.rst | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/docs/guide/pipeline.rst b/docs/guide/pipeline.rst index a9106a21e..b0ec4ccfc 100644 --- a/docs/guide/pipeline.rst +++ b/docs/guide/pipeline.rst @@ -134,16 +134,16 @@ These input connections are specified via keyword arguments to the component's input name(s) and the node or data to which each input should be wired. -.. - You can also use :meth:`Pipeline.add_default` to specify default connections. For example, - you can specify a default for ``user``:: - pipe.add_default('user', user_history) +You can also use :meth:`Pipeline.set_default` to specify default connections. +For example, you can specify a default for inputs named ``user``:: + + pipe.set_default('user', user_history) - With this default in place, if a component has an input named ``user`` and that - input is not explicitly connected to a node, then the ``user_history`` node will - be used to supply its value. Judicious use of defaults can reduce the amount of - code overhead needed to wire common pipelines. +With this default in place, if a component has an input named ``user`` and that +input is not explicitly connected to a node, then the ``user_history`` node will +be used to supply its value. Judicious use of defaults can reduce the amount of +code overhead needed to wire common pipelines. .. note:: From a088799ef2e62fb0fccf162f8fc20eff43df165c Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Tue, 14 Jan 2025 16:41:23 -0500 Subject: [PATCH 02/37] copy the pipeline to be a builder class --- lenskit/lenskit/pipeline/__init__.py | 2 + lenskit/lenskit/pipeline/builder.py | 684 +++++++++++++++++++++++++++ 2 files changed, 686 insertions(+) create mode 100644 lenskit/lenskit/pipeline/builder.py diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index 6902a95a7..8985de205 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -13,6 +13,7 @@ from lenskit.diagnostics import PipelineError, PipelineWarning from ._impl import CloneMethod, Pipeline +from .builder import PipelineBuilder from .common import RecPipelineBuilder, topn_pipeline from .components import ( Component, @@ -25,6 +26,7 @@ __all__ = [ "Pipeline", + "PipelineBuilder", "CloneMethod", "PipelineError", "PipelineWarning", diff --git a/lenskit/lenskit/pipeline/builder.py b/lenskit/lenskit/pipeline/builder.py new file mode 100644 index 000000000..183bb667d --- /dev/null +++ b/lenskit/lenskit/pipeline/builder.py @@ -0,0 +1,684 @@ +""" +LensKit pipeline builder. +""" + +# pyright: strict +from __future__ import annotations + +import typing +import warnings +from dataclasses import replace +from types import FunctionType, UnionType +from uuid import NAMESPACE_URL, uuid4, uuid5 + +from numpy.random import BitGenerator, Generator, SeedSequence +from typing_extensions import Any, Literal, Self, TypeAlias, TypeVar, cast, overload + +from lenskit.data import Dataset +from lenskit.diagnostics import PipelineError, PipelineWarning +from lenskit.logging import get_logger +from lenskit.training import Trainable, TrainingOptions + +from . import config +from ._impl import Pipeline +from .components import ( # type: ignore # noqa: F401 + Component, + PipelineFunction, + fallback_on_none, + instantiate_component, +) +from .config import PipelineConfig +from .nodes import ND, ComponentNode, InputNode, LiteralNode, Node +from .types import parse_type_string + +_log = get_logger(__name__) + +# common type var for quick use +T = TypeVar("T") +T1 = TypeVar("T1") +T2 = TypeVar("T2") +T3 = TypeVar("T3") +T4 = TypeVar("T4") +T5 = TypeVar("T5") + +CloneMethod: TypeAlias = Literal["config", "pipeline-config"] +NAMESPACE_LITERAL_DATA = uuid5(NAMESPACE_URL, "https://ns.lenskit.org/literal-data/") + + +class PipelineBuilder: + """ + Builder for LensKit recommendation pipelines. :ref:`Pipelines ` + are the core abstraction for using LensKit models and other components to + produce recommendations in a useful way. They allow you to wire together + components in (mostly) abitrary graphs, train them on data, and serialize + the resulting pipelines to disk for use elsewhere. + + The builder configures and builds pipelines that can then be run. If you + have a scoring model and just want to generate recommenations with a default + setup and minimal configuration, see :func:`~lenskit.pipeline.topn_pipeline` + or :class:`~lenskit.pipeline.RecPipelineBuilder`. + + Args: + name: + A name for the pipeline. + version: + A numeric version for the pipeline. + + Stability: + Caller + """ + + name: str | None = None + """ + The pipeline name. + """ + version: str | None = None + """ + The pipeline version string. + """ + + _nodes: dict[str, Node[Any]] + _aliases: dict[str, Node[Any]] + _defaults: dict[str, Node[Any]] + _components: dict[str, PipelineFunction[Any] | Component[Any]] + _hash: str | None = None + _last: Node[Any] | None = None + _anon_nodes: set[str] + "Track generated node names." + + def __init__(self, name: str | None = None, version: str | None = None): + self.name = name + self.version = version + self._nodes = {} + self._aliases = {} + self._defaults = {} + self._components = {} + self._anon_nodes = set() + self._clear_caches() + + def meta(self, *, include_hash: bool = True) -> config.PipelineMeta: + """ + Get the metadata (name, version, hash, etc.) for this pipeline without + returning the whole config. + + Args: + include_hash: + Whether to include a configuration hash in the metadata. + """ + meta = config.PipelineMeta(name=self.name, version=self.version) + if include_hash: + meta.hash = self.config_hash() + return meta + + @property + def nodes(self) -> list[Node[object]]: + """ + Get the nodes in the pipeline graph. + """ + return list(self._nodes.values()) + + @overload + def node(self, node: str, *, missing: Literal["error"] = "error") -> Node[object]: ... + @overload + def node(self, node: str, *, missing: Literal["none"] | None) -> Node[object] | None: ... + @overload + def node(self, node: Node[T]) -> Node[T]: ... + def node( + self, node: str | Node[Any], *, missing: Literal["error", "none"] | None = "error" + ) -> Node[object] | None: + """ + Get the pipeline node with the specified name. If passed a node, it + returns the node or fails if the node is not a member of the pipeline. + + Args: + node: + The name of the pipeline node to look up, or a node to check for + membership. + + Returns: + The pipeline node, if it exists. + + Raises: + KeyError: + The specified node does not exist. + """ + if isinstance(node, Node): + self._check_member_node(node) + return node + elif node in self._aliases: + return self._aliases[node] + elif node in self._nodes: + return self._nodes[node] + elif missing == "none" or missing is None: + return None + else: + raise KeyError(f"node {node}") + + def create_input(self, name: str, *types: type[T] | None) -> Node[T]: + """ + Create an input node for the pipeline. Pipelines expect their inputs to + be provided when they are run. + + Args: + name: + The name of the input. The name must be unique in the pipeline + (among both components and inputs). + types: + The allowable types of the input; input data can be of any + specified type. If ``None`` is among the allowed types, the + input can be omitted. + + Returns: + A pipeline node representing this input. + + Raises: + ValueError: + a node with the specified ``name`` already exists. + """ + self._check_available_name(name) + + rts: set[type[T | None]] = set() + for t in types: + if t is None: + rts.add(type(None)) + elif isinstance(t, UnionType): + rts |= set(typing.get_args(t)) + else: + rts.add(t) + + node = InputNode[Any](name, types=rts) + self._nodes[name] = node + self._clear_caches() + return node + + def literal(self, value: T, *, name: str | None = None) -> LiteralNode[T]: + """ + Create a literal node (a node with a fixed value). + + .. note:: + Literal nodes cannot be serialized witih :meth:`get_config` or + :meth:`save_config`. + """ + if name is None: + name = str(uuid4()) + self._anon_nodes.add(name) + node = LiteralNode(name, value, types=set([type(value)])) + self._nodes[name] = node + self._clear_caches() + return node + + def set_default(self, name: str, node: Node[Any] | object) -> None: + """ + Set the default wiring for a component input. Components that declare + an input parameter with the specified ``name`` but no configured input + will be wired to this node. + + This is intended to be used for things like wiring up `user` parameters + to semi-automatically receive the target user's identity and history. + + Args: + name: + The name of the parameter to set a default for. + node: + The node or literal value to wire to this parameter. + """ + if not isinstance(node, Node): + node = self.literal(node) + self._defaults[name] = node + self._clear_caches() + + def get_default(self, name: str) -> Node[Any] | None: + """ + Get the default wiring for an input name. + """ + return self._defaults.get(name, None) + + def alias(self, alias: str, node: Node[Any] | str) -> None: + """ + Create an alias for a node. After aliasing, the node can be retrieved + from :meth:`node` using either its original name or its alias. + + Args: + alias: + The alias to add to the node. + node: + The node (or node name) to alias. + + Raises: + ValueError: + if the alias is already used as an alias or node name. + """ + node = self.node(node) + self._check_available_name(alias) + self._aliases[alias] = node + self._clear_caches() + + def add_component( + self, name: str, obj: Component[ND] | PipelineFunction[ND], **inputs: Node[Any] | object + ) -> Node[ND]: + """ + Add a component and connect it into the graph. + + Args: + name: + The name of the component in the pipeline. The name must be + unique in the pipeline (among both components and inputs). + obj: + The component itself. + inputs: + The component's input wiring. See :ref:`pipeline-connections` + for details. + + Returns: + The node representing this component in the pipeline. + """ + self._check_available_name(name) + + node = ComponentNode(name, obj) + self._nodes[name] = node + self._components[name] = obj + + self.connect(node, **inputs) + + self._clear_caches() + self._last = node + return node + + def replace_component( + self, + name: str | Node[ND], + obj: Component[ND] | PipelineFunction[ND], + **inputs: Node[Any] | object, + ) -> Node[ND]: + """ + Replace a component in the graph. The new component must have a type + that is compatible with the old component. The old component's input + connections will be replaced (as the new component may have different + inputs), but any connections that use the old component to supply an + input will use the new component instead. + """ + if isinstance(name, Node): + name = name.name + + node = ComponentNode(name, obj) + self._nodes[name] = node + self._components[name] = obj + + self.connect(node, **inputs) + + self._clear_caches() + return node + + def connect(self, obj: str | Node[Any], **inputs: Node[Any] | str | object): + """ + Provide additional input connections for a component that has already + been added. See :ref:`pipeline-connections` for details. + + Args: + obj: + The name or node of the component to wire. + inputs: + The component's input wiring. For each keyword argument in the + component's function signature, that argument can be provided + here with an input that the pipeline will provide to that + argument of the component when the pipeline is run. + """ + if isinstance(obj, Node): + node = obj + else: + node = self.node(obj) + if not isinstance(node, ComponentNode): + raise TypeError(f"only component nodes can be wired, not {node}") + + for k, n in inputs.items(): + if isinstance(n, Node): + n = cast(Node[Any], n) + self._check_member_node(n) + node.connections[k] = n.name + else: + lit = self.literal(n) + node.connections[k] = lit.name + + self._clear_caches() + + def component_configs(self) -> dict[str, dict[str, Any]]: + """ + Get the configurations for the components. This is the configurations + only, it does not include pipeline inputs or wiring. + """ + return { + name: comp.dump_config() + for (name, comp) in self._components.items() + if isinstance(comp, Component) + } + + def clone(self, how: CloneMethod = "config") -> Pipeline: + """ + Clone the pipeline, optionally including trained parameters. + + The ``how`` parameter controls how the pipeline is cloned, and what is + available in the clone pipeline. It can be one of the following values: + + ``"config"`` + Create fresh component instances using the configurations of the + components in this pipeline. When applied to a trained pipeline, + the clone does **not** have the original's learned parameters. This + is the default clone method. + ``"pipeline-config"`` + Round-trip the entire pipeline through :meth:`get_config` and + :meth:`from_config`. + + Args: + how: + The mechanism to use for cloning the pipeline. + + Returns: + A new pipeline with the same components and wiring, but fresh + instances created by round-tripping the configuration. + """ + if how == "pipeline-config": + cfg = self.get_config() + return self.from_config(cfg) + elif how != "config": # pragma: nocover + raise NotImplementedError("only 'config' cloning is currently supported") + + clone = PipelineBuilder() + + for node in self.nodes: + match node: + case InputNode(name, types=types): + if types is None: + types = set[type]() + clone.create_input(name, *types) + case LiteralNode(name, value): + clone._nodes[name] = LiteralNode(name, value) + case ComponentNode(name, comp, _inputs, wiring): + if isinstance(comp, FunctionType): + comp = comp + elif isinstance(comp, Component): + comp = comp.__class__(comp.config) # type: ignore + else: + comp = comp.__class__() # type: ignore + cn = clone.add_component(node.name, comp) # type: ignore + for wn, wt in wiring.items(): + clone.connect(cn, **{wn: clone.node(wt)}) + case _: # pragma: nocover + raise RuntimeError(f"invalid node {node}") + + for n, t in self._aliases.items(): + clone.alias(n, t.name) + + for n, t in self._defaults.items(): + clone.set_default(n, clone.node(t.name)) + + return clone + + def get_config(self, *, include_hash: bool = True) -> PipelineConfig: + """ + Get this pipeline's configuration for serialization. The configuration + consists of all inputs and components along with their configurations + and input connections. It can be serialized to disk (in JSON, YAML, or + a similar format) to save a pipeline. + + The configuration does **not** include any trained parameter values, + although the configuration may include things such as paths to + checkpoints to load such parameters, depending on the design of the + components in the pipeline. + + .. note:: + Literal nodes (from :meth:`literal`, or literal values wired to + inputs) cannot be serialized, and this method will fail if they + are present in the pipeline. + """ + meta = self.meta(include_hash=False) + cfg = PipelineConfig(meta=meta) + + # We map anonymous nodes to hash-based names for stability. If we ever + # allow anonymous components, this will need to be adjusted to maintain + # component ordering, but it works for now since only literals can be + # anonymous. First handle the anonymous nodes, so we have that mapping: + remapped: dict[str, str] = {} + for an in self._anon_nodes: + node = self._nodes.get(an, None) + match node: + case None: + # skip nodes that no longer exist + continue + case LiteralNode(name, value): + lit = config.PipelineLiteral.represent(value) + sname = str(uuid5(NAMESPACE_LITERAL_DATA, lit.model_dump_json())) + _log.debug("renamed anonymous node %s to %s", name, sname) + remapped[name] = sname + cfg.literals[sname] = lit + case _: + # the pipeline only generates anonymous literal nodes right now + raise RuntimeError(f"unexpected anonymous node {node}") + + # Now we go over all named nodes and add them to the config: + for node in self.nodes: + if node.name in remapped: + continue + + match node: + case InputNode(): + cfg.inputs.append(config.PipelineInput.from_node(node)) + case LiteralNode(name, value): + cfg.literals[name] = config.PipelineLiteral.represent(value) + case ComponentNode(name): + cfg.components[name] = config.PipelineComponent.from_node(node, remapped) + case _: # pragma: nocover + raise RuntimeError(f"invalid node {node}") + + cfg.aliases = {a: t.name for (a, t) in self._aliases.items()} + cfg.defaults = {n: t.name for (n, t) in self._defaults.items()} + + if include_hash: + cfg.meta.hash = config.hash_config(cfg) + + return cfg + + def config_hash(self) -> str: + """ + Get a hash of the pipeline's configuration to uniquely identify it for + logging, version control, or other purposes. + + The hash format and algorithm are not guaranteed, but is stable within a + LensKit version. For the same version of LensKit and component code, + the same configuration will produce the same hash, so long as there are + no literal nodes. Literal nodes will *usually* hash consistently, but + since literals other than basic JSON values are hashed by pickling, hash + stability depends on the stability of the pickle bytestream. + + In LensKit 2025.1, the configuration hash is computed by computing the + JSON serialization of the pipeline configuration *without* a hash and + returning the hex-encoded SHA256 hash of that configuration. + """ + if self._hash is None: + # get the config *without* a hash + cfg = self.get_config(include_hash=False) + self._hash = config.hash_config(cfg) + return self._hash + + @classmethod + def from_config(cls, config: object) -> Self: + """ + Reconstruct a pipeline from a serialized configuration. + + Args: + config: + The configuration object, as loaded from JSON, TOML, YAML, or + similar. Will be validated into a :class:`PipelineConfig`. + Returns: + The configured (but not trained) pipeline. + Raises: + PipelineError: + If there is a configuration error reconstructing the pipeline. + Warns: + PipelineWarning: + If the configuration is funny but usable; for example, the + configuration includes a hash but the constructed pipeline does + not have a matching hash. + """ + cfg = PipelineConfig.model_validate(config) + pipe = cls() + for inpt in cfg.inputs: + types: list[type[Any] | None] = [] + if inpt.types is not None: + types += [parse_type_string(t) for t in inpt.types] + pipe.create_input(inpt.name, *types) + + # we now add the components and other nodes in multiple passes to ensure + # that nodes are available before they are wired (since `connect` can + # introduce out-of-order dependencies). + + # pass 1: add literals + for name, data in cfg.literals.items(): + pipe.literal(data.decode(), name=name) + + # pass 2: add components + to_wire: list[config.PipelineComponent] = [] + for name, comp in cfg.components.items(): + if comp.code.startswith("@"): + # ignore special nodes in first pass + continue + + obj = instantiate_component(comp.code, comp.config) + pipe.add_component(name, obj) + to_wire.append(comp) + + # pass 3: wiring + for name, comp in cfg.components.items(): + if isinstance(comp.inputs, dict): + inputs = {n: pipe.node(t) for (n, t) in comp.inputs.items()} + pipe.connect(name, **inputs) + elif not comp.code.startswith("@"): + raise PipelineError(f"component {name} inputs must be dict, not list") + + # pass 4: aliases + for n, t in cfg.aliases.items(): + pipe.alias(n, t) + + # pass 5: defaults + for n, t in cfg.defaults.items(): + pipe.set_default(n, pipe.node(t)) + + if cfg.meta.hash is not None: + h2 = pipe.config_hash() + if h2 != cfg.meta.hash: + _log.warning("loaded pipeline does not match hash") + warnings.warn("loaded pipeline config does not match hash", PipelineWarning) + + return pipe + + def train(self, data: Dataset, options: TrainingOptions | None = None) -> None: + """ + Trains the pipeline's trainable components (those implementing the + :class:`TrainableComponent` interface) on some training data. + + .. admonition:: Random Number Generation + :class: note + + If :attr:`TrainingOptions.rng` is set and is not a generator or bit + generator (i.e. it is a seed), then this method wraps the seed in a + :class:`~numpy.random.SeedSequence` and calls + :class:`~numpy.random.SeedSequence.spawn()` to generate a distinct + seed for each component in the pipeline. + + Args: + data: + The dataset to train on. + options: + The training options. If ``None``, default options are used. + """ + log = _log.bind(pipeline=self.name) + if options is None: + options = TrainingOptions() + + if isinstance(options.rng, SeedSequence): + seed = options.rng + elif options.rng is None or isinstance(options.rng, (Generator, BitGenerator)): + seed = None + else: + seed = SeedSequence(options.rng) + + log.info("training pipeline components") + for name, comp in self._components.items(): + clog = log.bind(name=name, component=comp) + if isinstance(comp, Trainable): + # spawn new seed if needed + c_opts = options if seed is None else replace(options, rng=seed.spawn(1)[0]) + clog.info("training component") + comp.train(data, c_opts) + else: + clog.debug("training not required") + + def use_first_of(self, name: str, primary: Node[T | None], fallback: Node[T]) -> Node[T]: + """ + Ergonomic method to create a new node that returns the result of its + ``input`` if it is provided and not ``None``, and otherwise returns the + result of ``fallback``. This method is used for things like filling in + optional pipeline inputs. For example, if you want the pipeline to take + candidate items through an ``items`` input, but look them up from the + user's history and the training data if ``items`` is not supplied, you + would do: + + .. code:: python + + pipe = Pipeline() + # allow candidate items to be optionally specified + items = pipe.create_input('items', list[EntityId], None) + # find candidates from the training data (optional) + lookup_candidates = pipe.add_component( + 'select-candidates', UnratedTrainingItemsCandidateSelector(), + user=history, + ) + # if the client provided items as a pipeline input, use those; otherwise + # use the candidate selector we just configured. + candidates = pipe.use_first_of('candidates', items, lookup_candidates) + + .. note:: + + This method does not distinguish between an input being unspecified + and explicitly specified as ``None``. + + .. note:: + + This method does *not* implement item-level fallbacks, only + fallbacks at the level of entire results. For item-level score + fallbacks, see :class:`~lenskit.basic.FallbackScorer`. + + .. note:: + If one of the fallback elements is a component ``A`` that depends on + another component or input ``B``, and ``B`` is missing or returns + ``None`` such that ``A`` would usually fail, then ``A`` will be + skipped and the fallback will move on to the next node. This works + with arbitrarily-deep transitive chains. + + Args: + name: + The name of the node. + primary: + The node to use as the primary input, if it is available. + fallback: + The node to use if the primary input does not provide a value. + """ + return self.add_component(name, fallback_on_none, primary=primary, fallback=fallback) + + def build(self) -> Pipeline: + """ + Build the pipeline. + """ + return self # type: ignore + + def _check_available_name(self, name: str) -> None: + if name in self._nodes or name in self._aliases: + raise ValueError(f"pipeline already has node {name}") + + def _check_member_node(self, node: Node[Any]) -> None: + nw = self._nodes.get(node.name) + if nw is not node: + raise PipelineError(f"node {node} not in pipeline") + + def _clear_caches(self): + if "_hash" in self.__dict__: + del self._hash From 34a1549ea8aa8d533b2c581e273c69910d6a7c55 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Tue, 14 Jan 2025 16:41:32 -0500 Subject: [PATCH 03/37] change tests to use builder --- .../tests/pipeline/test_component_config.py | 13 +-- lenskit/tests/pipeline/test_fallback.py | 20 +++-- lenskit/tests/pipeline/test_lazy.py | 20 +++-- lenskit/tests/pipeline/test_pipeline.py | 84 ++++++++++++------- lenskit/tests/pipeline/test_pipeline_clone.py | 22 +++-- lenskit/tests/pipeline/test_pipeline_state.py | 1 + lenskit/tests/pipeline/test_save_load.py | 66 +++++++++------ 7 files changed, 140 insertions(+), 86 deletions(-) diff --git a/lenskit/tests/pipeline/test_component_config.py b/lenskit/tests/pipeline/test_component_config.py index 3a13dae21..a105414ed 100644 --- a/lenskit/tests/pipeline/test_component_config.py +++ b/lenskit/tests/pipeline/test_component_config.py @@ -14,7 +14,7 @@ from pytest import mark -from lenskit.pipeline import Pipeline +from lenskit.pipeline import PipelineBuilder from lenskit.pipeline.components import Component @@ -88,10 +88,11 @@ def test_auto_config_roundtrip(prefixer: type[Component]): def test_pipeline_config(prefixer: type[Component]): comp = prefixer(prefix="scroll named ") - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str) pipe.add_component("prefix", comp, msg=msg) + pipe = pipe.build() assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH" config = pipe.component_configs() @@ -105,15 +106,15 @@ def test_pipeline_config(prefixer: type[Component]): def test_pipeline_config_roundtrip(prefixer: type[Component]): comp = prefixer(prefix="scroll named ") - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str) pipe.add_component("prefix", comp, msg=msg) - assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH" + assert pipe.build().run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH" config = pipe.get_config() print(config.model_dump_json(indent=2)) - p2 = Pipeline.from_config(config) + p2 = PipelineBuilder.from_config(config) assert p2.node("prefix", missing="none") is not None - assert p2.run(msg="READ ME") == "scroll named READ ME" + assert p2.build().run(msg="READ ME") == "scroll named READ ME" diff --git a/lenskit/tests/pipeline/test_fallback.py b/lenskit/tests/pipeline/test_fallback.py index de5822ddb..9aca88262 100644 --- a/lenskit/tests/pipeline/test_fallback.py +++ b/lenskit/tests/pipeline/test_fallback.py @@ -6,11 +6,11 @@ from pytest import fail, raises -from lenskit.pipeline import Pipeline +from lenskit.pipeline import PipelineBuilder def test_fallback_input(): - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -28,12 +28,13 @@ def add(x: int, y: int) -> int: fb = pipe.use_first_of("fill-operand", b, nn) na = pipe.add_component("add", add, x=nd, y=fb) + pipe = pipe.build() # 3 * 2 + -3 = 3 assert pipe.run(na, a=3) == 3 def test_fallback_only_run_if_needed(): - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -51,11 +52,12 @@ def add(x: int, y: int) -> int: fb = pipe.use_first_of("fill-operand", b, nn) na = pipe.add_component("add", add, x=nd, y=fb) + pipe = pipe.build() assert pipe.run(na, a=3, b=8) == 14 def test_fallback_fail_with_missing_options(): - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -73,13 +75,14 @@ def add(x: int, y: int) -> int: fb = pipe.use_first_of("fill-operand", b, nn) na = pipe.add_component("add", add, x=nd, y=fb) + pipe = pipe.build() with raises(TypeError, match="no data available"): pipe.run(na, a=3) def test_fallback_transitive(): "test that a fallback works if a dependency's dependency fails" - pipe = Pipeline() + pipe = PipelineBuilder() ia = pipe.create_input("a", int) ib = pipe.create_input("b", int) @@ -92,13 +95,14 @@ def double(x: int) -> int: # use the first that succeeds c = pipe.use_first_of("result", c1, c2) + pipe = pipe.build() # omitting the first input should result in the second component assert pipe.run(c, b=17) == 34 def test_fallback_transitive_deeper(): "deeper transitive fallback test" - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -112,12 +116,13 @@ def double(x: int) -> int: nn = pipe.add_component("negate", negative, x=nd) nr = pipe.use_first_of("fill-operand", nn, b) + pipe = pipe.build() assert pipe.run(nr, b=8) == 8 def test_fallback_transitive_nodefail(): "deeper transitive fallback test" - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -135,5 +140,6 @@ def double(x: int) -> int: nn = pipe.add_component("negate", negative, x=nd) nr = pipe.use_first_of("fill-operand", nn, b) + pipe = pipe.build() assert pipe.run(nr, a=2, b=8) == -4 assert pipe.run(nr, a=-7, b=8) == 8 diff --git a/lenskit/tests/pipeline/test_lazy.py b/lenskit/tests/pipeline/test_lazy.py index 59fde5923..7f962acf7 100644 --- a/lenskit/tests/pipeline/test_lazy.py +++ b/lenskit/tests/pipeline/test_lazy.py @@ -7,7 +7,7 @@ # pyright: strict from pytest import fail, raises -from lenskit.pipeline import Lazy, Pipeline +from lenskit.pipeline import Lazy, PipelineBuilder def fallback(first: int | None, second: Lazy[int]) -> int: @@ -18,7 +18,7 @@ def fallback(first: int | None, second: Lazy[int]) -> int: def test_lazy_input(): - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -36,12 +36,13 @@ def add(x: int, y: int) -> int: fb = pipe.add_component("fill-operand", fallback, first=b, second=nn) na = pipe.add_component("add", add, x=nd, y=fb) + pipe = pipe.build() # 3 * 2 + -3 = 3 assert pipe.run(na, a=3) == 3 def test_lazy_only_run_if_needed(): - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -59,11 +60,12 @@ def add(x: int, y: int) -> int: fb = pipe.add_component("fill-operand", fallback, first=b, second=nn) na = pipe.add_component("add", add, x=nd, y=fb) + pipe = pipe.build() assert pipe.run(na, a=3, b=8) == 14 def test_lazy_fail_with_missing_options(): - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -81,13 +83,14 @@ def add(x: int, y: int) -> int: fb = pipe.add_component("fill-operand", fallback, first=b, second=nn) na = pipe.add_component("add", add, x=nd, y=fb) + pipe = pipe.build() with raises(TypeError): pipe.run(na, a=3) def test_lazy_transitive(): "test that a fallback works if a dependency's dependency fails" - pipe = Pipeline() + pipe = PipelineBuilder() ia = pipe.create_input("a", int) ib = pipe.create_input("b", int) @@ -100,13 +103,14 @@ def double(x: int) -> int: # use the first that succeeds c = pipe.add_component("fill-operand", fallback, first=c1, second=c2) + pipe = pipe.build() # omitting the first input should result in the second component assert pipe.run(c, b=17) == 34 def test_lazy_transitive_deeper(): "deeper transitive fallback test" - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -120,12 +124,13 @@ def double(x: int) -> int: nn = pipe.add_component("negate", negative, x=nd) nr = pipe.add_component("fill-operand", fallback, first=nn, second=b) + pipe = pipe.build() assert pipe.run(nr, b=8) == 8 def test_lazy_transitive_nodefail(): "deeper transitive fallback test" - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -143,5 +148,6 @@ def double(x: int) -> int: nn = pipe.add_component("negate", negative, x=nd) nr = pipe.add_component("fill-operand", fallback, first=nn, second=b) + pipe = pipe.build() assert pipe.run(nr, a=2, b=8) == -4 assert pipe.run(nr, a=-7, b=8) == 8 diff --git a/lenskit/tests/pipeline/test_pipeline.py b/lenskit/tests/pipeline/test_pipeline.py index c2948221c..f42b92def 100644 --- a/lenskit/tests/pipeline/test_pipeline.py +++ b/lenskit/tests/pipeline/test_pipeline.py @@ -12,19 +12,19 @@ from pytest import raises, warns -from lenskit.pipeline import Pipeline, PipelineError +from lenskit.pipeline import PipelineBuilder, PipelineError from lenskit.pipeline.nodes import InputNode, Node from lenskit.pipeline.types import TypecheckWarning def test_init_empty(): - pipe = Pipeline() + pipe = PipelineBuilder() assert len(pipe.nodes) == 0 def test_create_input(): "create an input node" - pipe = Pipeline() + pipe = PipelineBuilder() src = pipe.create_input("user", int, str) assert_type(src, Node[int | str]) assert isinstance(src, InputNode) @@ -37,7 +37,7 @@ def test_create_input(): def test_lookup_optional(): "lookup a node without failing" - pipe = Pipeline() + pipe = PipelineBuilder() pipe.create_input("user", int, str) assert pipe.node("item", missing="none") is None @@ -45,7 +45,7 @@ def test_lookup_optional(): def test_lookup_missing(): "lookup a node without failing" - pipe = Pipeline() + pipe = PipelineBuilder() pipe.create_input("user", int, str) with raises(KeyError): @@ -54,7 +54,7 @@ def test_lookup_missing(): def test_dup_input_fails(): "create an input node" - pipe = Pipeline() + pipe = PipelineBuilder() pipe.create_input("user", int, str) with raises(ValueError, match="has node"): @@ -63,7 +63,7 @@ def test_dup_input_fails(): def test_dup_component_fails(): "create an input node" - pipe = Pipeline() + pipe = PipelineBuilder() pipe.create_input("user", int, str) with raises(ValueError, match="has node"): @@ -72,7 +72,7 @@ def test_dup_component_fails(): def test_dup_alias_fails(): "create an input node" - pipe = Pipeline() + pipe = PipelineBuilder() n = pipe.create_input("user", int, str) with raises(ValueError, match="has node"): @@ -81,7 +81,7 @@ def test_dup_alias_fails(): def test_alias(): "alias a node" - pipe = Pipeline() + pipe = PipelineBuilder() user = pipe.create_input("user", int, str) pipe.alias("person", user) @@ -94,7 +94,7 @@ def test_alias(): def test_component_type(): - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str) def incr(msg: str) -> str: @@ -106,7 +106,7 @@ def incr(msg: str) -> str: def test_single_input(): - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str) def incr(msg: str) -> str: @@ -114,6 +114,8 @@ def incr(msg: str) -> str: node = pipe.add_component("return", incr, msg=msg) + pipe = pipe.build() + ret = pipe.run(node, msg="hello") assert ret == "hello" @@ -122,7 +124,7 @@ def incr(msg: str) -> str: def test_single_input_required(): - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str) def incr(msg: str) -> str: @@ -130,12 +132,13 @@ def incr(msg: str) -> str: node = pipe.add_component("return", incr, msg=msg) + pipe = pipe.build() with raises(PipelineError, match="not specified"): pipe.run(node) def test_single_optional_input(): - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str, None) def fill(msg: str | None) -> str: @@ -143,12 +146,13 @@ def fill(msg: str | None) -> str: node = pipe.add_component("return", fill, msg=msg) + pipe = pipe.build() assert pipe.run(node) == "undefined" assert pipe.run(node, msg="hello") == "hello" def test_single_input_typecheck(): - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str) def incr(msg: str) -> str: @@ -156,23 +160,25 @@ def incr(msg: str) -> str: node = pipe.add_component("return", incr, msg=msg) + pipe = pipe.build() with raises(TypeError): pipe.run(node, msg=47) def test_component_type_mismatch(): - pipe = Pipeline() + pipe = PipelineBuilder() def incr(msg: str) -> str: return msg node = pipe.add_component("return", incr, msg=47) + pipe = pipe.build() with raises(TypeError): pipe.run(node) def test_component_unwired_input(): - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str) def ident(msg: str, m2: str | None) -> str: @@ -182,11 +188,12 @@ def ident(msg: str, m2: str | None) -> str: return msg node = pipe.add_component("return", ident, msg=msg) + pipe = pipe.build() assert pipe.run(node, msg="hello") == "hello" def test_chain(): - pipe = Pipeline() + pipe = PipelineBuilder() x = pipe.create_input("x", int) def incr(x: int) -> int: @@ -198,6 +205,7 @@ def triple(x: int) -> int: ni = pipe.add_component("incr", incr, x=x) nt = pipe.add_component("triple", triple, x=ni) + pipe = pipe.build() # run default pipe ret = pipe.run(x=1) assert ret == 6 @@ -210,7 +218,7 @@ def triple(x: int) -> int: def test_simple_graph(): - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -223,13 +231,14 @@ def add(x: int, y: int) -> int: nd = pipe.add_component("double", double, x=a) na = pipe.add_component("add", add, x=nd, y=b) + pipe = pipe.build() assert pipe.run(a=1, b=7) == 9 assert pipe.run(na, a=3, b=7) == 13 assert pipe.run(nd, a=3, b=7) == 6 def test_cycle(): - pipe = Pipeline() + pipe = PipelineBuilder() b = pipe.create_input("b", int) def double(x: int) -> int: @@ -243,11 +252,11 @@ def add(x: int, y: int) -> int: pipe.connect(nd, x=na) with raises(PipelineError, match="cycle"): - pipe.run(a=1, b=7) + pipe.build() def test_replace_component(): - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -265,6 +274,8 @@ def add(x: int, y: int) -> int: nt = pipe.replace_component("double", triple, x=a) + pipe = pipe.build() + # run through the end assert pipe.run(a=1, b=7) == 10 assert pipe.run(na, a=3, b=7) == 16 @@ -277,7 +288,7 @@ def add(x: int, y: int) -> int: def test_default_wiring(): - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -292,12 +303,13 @@ def add(x: int, y: int) -> int: nd = pipe.add_component("double", double, x=a) na = pipe.add_component("add", add, x=nd) + pipe = pipe.build() assert pipe.run(a=1, b=7) == 9 assert pipe.run(na, a=3, b=7) == 13 def test_run_by_name(): - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -310,11 +322,12 @@ def add(x: int, y: int) -> int: nd = pipe.add_component("double", double, x=a) pipe.add_component("add", add, x=nd, y=b) + pipe = pipe.build() assert pipe.run("double", a=1, b=7) == 2 def test_invalid_type(): - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -327,12 +340,13 @@ def add(x: int, y: int) -> int: nd = pipe.add_component("double", double, x=a) pipe.add_component("add", add, x=nd, y=b) + pipe = pipe.build() with raises(TypeError): pipe.run(a=1, b="seven") def test_run_by_alias(): - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -347,11 +361,12 @@ def add(x: int, y: int) -> int: pipe.alias("result", na) + pipe = pipe.build() assert pipe.run("result", a=1, b=7) == 9 def test_run_all(): - pipe = Pipeline("test", "7.2") + pipe = PipelineBuilder("test", "7.2") a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -366,6 +381,7 @@ def add(x: int, y: int) -> int: pipe.alias("result", na) + pipe = pipe.build() state = pipe.run_all(a=1, b=7) assert state["double"] == 2 assert state["add"] == 9 @@ -378,7 +394,7 @@ def add(x: int, y: int) -> int: def test_run_all_limit(): - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -393,6 +409,7 @@ def add(x: int, y: int) -> int: pipe.alias("result", na) + pipe = pipe.build() state = pipe.run_all("double", a=1, b=7) assert state["double"] == 2 assert "add" not in state @@ -400,7 +417,7 @@ def add(x: int, y: int) -> int: def test_connect_literal(): - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) def double(x: int) -> int: @@ -412,11 +429,12 @@ def add(x: int, y: int) -> int: nd = pipe.add_component("double", double, x=a) na = pipe.add_component("add", add, x=nd, y=2) + pipe = pipe.build() assert pipe.run(na, a=3) == 8 def test_connect_literal_explicit(): - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) def double(x: int) -> int: @@ -428,11 +446,12 @@ def add(x: int, y: int) -> int: nd = pipe.add_component("double", double, x=a) na = pipe.add_component("add", add, x=nd, y=pipe.literal(2)) + pipe = pipe.build() assert pipe.run(na, a=3) == 8 def test_fail_missing_input(): - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -445,6 +464,8 @@ def add(x: int, y: int) -> int: nd = pipe.add_component("double", double, x=a) na = pipe.add_component("add", add, x=nd, y=b) + pipe = pipe.build() + with raises(PipelineError, match=r"input.*not specified"): pipe.run(na, a=3) @@ -457,7 +478,7 @@ def test_pipeline_component_default(): Test that the last *component* is last. It also exercises the warning logic for missing component types. """ - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) def add(x, y): # type: ignore @@ -466,6 +487,7 @@ def add(x, y): # type: ignore with warns(TypecheckWarning): pipe.add_component("add", add, x=np.arange(10), y=a) # type: ignore + pipe = pipe.build() # the component runs assert np.all(pipe.run("add", a=5) == np.arange(5, 15)) diff --git a/lenskit/tests/pipeline/test_pipeline_clone.py b/lenskit/tests/pipeline/test_pipeline_clone.py index 08c307bad..01378f348 100644 --- a/lenskit/tests/pipeline/test_pipeline_clone.py +++ b/lenskit/tests/pipeline/test_pipeline_clone.py @@ -4,10 +4,10 @@ # Licensed under the MIT license, see LICENSE.md for details. # SPDX-License-Identifier: MIT -import json +# pyright: strict from dataclasses import dataclass -from lenskit.pipeline import Pipeline +from lenskit.pipeline import PipelineBuilder from lenskit.pipeline.components import Component from lenskit.pipeline.nodes import ComponentNode @@ -38,10 +38,11 @@ def exclaim(msg: str) -> str: def test_pipeline_clone(): comp = Prefixer(PrefixConfig("scroll named ")) - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str) pipe.add_component("prefix", comp, msg=msg) + pipe = pipe.build() assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH" p2 = pipe.clone() @@ -57,11 +58,12 @@ def test_pipeline_clone(): def test_pipeline_clone_with_function(): comp = Prefixer(prefix="scroll named ") - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str) pfx = pipe.add_component("prefix", comp, msg=msg) pipe.add_component("exclaim", exclaim, msg=pfx) + pipe = pipe.build() assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH!" p2 = pipe.clone() @@ -72,11 +74,12 @@ def test_pipeline_clone_with_function(): def test_pipeline_clone_with_nonconfig_class(): comp = Prefixer(prefix="scroll named ") - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str) pfx = pipe.add_component("prefix", comp, msg=msg) pipe.add_component("question", Question(), msg=pfx) + pipe = pipe.build() assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH?" p2 = pipe.clone() @@ -85,11 +88,12 @@ def test_pipeline_clone_with_nonconfig_class(): def test_clone_defaults(): - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str) pipe.set_default("msg", msg) pipe.add_component("return", exclaim) + pipe = pipe.build() assert pipe.run(msg="hello") == "hello!" p2 = pipe.clone() @@ -98,11 +102,12 @@ def test_clone_defaults(): def test_clone_alias(): - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str) excl = pipe.add_component("exclaim", exclaim, msg=msg) pipe.alias("return", excl) + pipe = pipe.build() assert pipe.run("return", msg="hello") == "hello!" p2 = pipe.clone() @@ -111,12 +116,13 @@ def test_clone_alias(): def test_clone_hash(): - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str) pipe.set_default("msg", msg) excl = pipe.add_component("exclaim", exclaim) pipe.alias("return", excl) + pipe = pipe.build() assert pipe.run("return", msg="hello") == "hello!" p2 = pipe.clone() diff --git a/lenskit/tests/pipeline/test_pipeline_state.py b/lenskit/tests/pipeline/test_pipeline_state.py index 4ed5a14e2..423b06d10 100644 --- a/lenskit/tests/pipeline/test_pipeline_state.py +++ b/lenskit/tests/pipeline/test_pipeline_state.py @@ -4,6 +4,7 @@ # Licensed under the MIT license, see LICENSE.md for details. # SPDX-License-Identifier: MIT +# pyright: strict from pytest import raises from lenskit.pipeline import PipelineState diff --git a/lenskit/tests/pipeline/test_save_load.py b/lenskit/tests/pipeline/test_save_load.py index 5dc3195e4..99704256e 100644 --- a/lenskit/tests/pipeline/test_save_load.py +++ b/lenskit/tests/pipeline/test_save_load.py @@ -14,7 +14,7 @@ from pytest import fail, warns -from lenskit.pipeline import Pipeline, PipelineWarning +from lenskit.pipeline import PipelineBuilder, PipelineWarning from lenskit.pipeline.components import Component from lenskit.pipeline.config import PipelineConfig from lenskit.pipeline.nodes import ComponentNode, InputNode @@ -60,7 +60,7 @@ def msg_prefix(prefix: str, msg: str) -> str: def test_serialize_input(): "serialize with one input node" - pipe = Pipeline("test") + pipe = PipelineBuilder("test") pipe.create_input("user", int, str) cfg = pipe.get_config() @@ -73,13 +73,13 @@ def test_serialize_input(): def test_round_trip_input(): "serialize with one input node" - pipe = Pipeline() + pipe = PipelineBuilder() pipe.create_input("user", int, str) cfg = pipe.get_config() print(cfg) - p2 = Pipeline.from_config(cfg) + p2 = PipelineBuilder.from_config(cfg) i2 = p2.node("user") assert isinstance(i2, InputNode) assert i2.name == "user" @@ -88,13 +88,13 @@ def test_round_trip_input(): def test_round_trip_optional_input(): "serialize with one input node" - pipe = Pipeline() + pipe = PipelineBuilder() pipe.create_input("user", int, str, None) cfg = pipe.get_config() assert cfg.inputs[0].types == {"int", "str", "None"} - p2 = Pipeline.from_config(cfg) + p2 = PipelineBuilder.from_config(cfg) i2 = p2.node("user") assert isinstance(i2, InputNode) assert i2.name == "user" @@ -102,7 +102,7 @@ def test_round_trip_optional_input(): def test_config_single_node(): - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str) pipe.add_component("return", msg_ident, msg=msg) @@ -119,25 +119,26 @@ def test_config_single_node(): def test_round_trip_single_node(): - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str) pipe.add_component("return", msg_ident, msg=msg) cfg = pipe.get_config() - p2 = Pipeline.from_config(cfg) + p2 = PipelineBuilder.from_config(cfg) assert len(p2.nodes) == 2 r2 = p2.node("return") assert isinstance(r2, ComponentNode) assert r2.component is msg_ident assert r2.connections == {"msg": "msg"} + p2 = p2.build() assert p2.run("return", msg="foo") == "foo" def test_configurable_component(): - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str) pfx = Prefixer(prefix="scroll named ") @@ -146,7 +147,7 @@ def test_configurable_component(): cfg = pipe.get_config() assert cfg.components["prefix"].config == {"prefix": "scroll named "} - p2 = Pipeline.from_config(cfg) + p2 = PipelineBuilder.from_config(cfg) assert len(p2.nodes) == 2 r2 = p2.node("prefix") assert isinstance(r2, ComponentNode) @@ -154,6 +155,7 @@ def test_configurable_component(): assert r2.component is not pfx assert r2.connections == {"msg": "msg"} + p2 = p2.build() assert p2.run("prefix", msg="HACKEM MUCHE") == "scroll named HACKEM MUCHE" print("hash:", pipe.config_hash()) @@ -162,22 +164,25 @@ def test_configurable_component(): def test_save_defaults(): - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str) pipe.set_default("msg", msg) pipe.add_component("return", msg_ident) + cfg = pipe.get_config() + + pipe = pipe.build() assert pipe.run(msg="hello") == "hello" - cfg = pipe.get_config() - p2 = Pipeline.from_config(cfg) + p2 = PipelineBuilder.from_config(cfg) + p2 = p2.build() assert p2.run(msg="hello") == "hello" def test_hashes_different(): - p1 = Pipeline() - p2 = Pipeline() + p1 = PipelineBuilder() + p2 = PipelineBuilder() a1 = p1.create_input("a", int) a2 = p2.create_input("a", int) @@ -194,10 +199,11 @@ def test_hashes_different(): _log.info("p1 stage 2 hash: %s", p1.config_hash()) _log.info("p2 stage 2 hash: %s", p2.config_hash()) assert p1.config_hash() != p2.config_hash() + assert p1.build().config_hash() != p2.build().config_hash() def test_save_with_fallback(): - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -211,14 +217,15 @@ def test_save_with_fallback(): print(json) c2 = PipelineConfig.model_validate_json(json) - p2 = Pipeline.from_config(c2) + p2 = PipelineBuilder.from_config(c2) + p2 = p2.build() # 3 * 2 + -3 = 3 assert p2.run("fill-operand", "add", a=3) == (-3, 3) def test_hash_validate(): - pipe = Pipeline() + pipe = PipelineBuilder() msg = pipe.create_input("msg", str) pfx = Prefixer(prefix="scroll named ") @@ -231,24 +238,25 @@ def test_hash_validate(): print("modified config:", cfg.model_dump_json(indent=2)) with warns(PipelineWarning): - Pipeline.from_config(cfg) + PipelineBuilder.from_config(cfg) def test_alias_input(): "just an input node and an alias" - pipe = Pipeline() + pipe = PipelineBuilder() user = pipe.create_input("user", int, str) pipe.alias("person", user) cfg = pipe.get_config() - p2 = Pipeline.from_config(cfg) + p2 = PipelineBuilder.from_config(cfg) + p2 = p2.build() assert p2.run("person", user=32) == 32 def test_alias_node(): - pipe = Pipeline() + pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) @@ -256,6 +264,7 @@ def test_alias_node(): na = pipe.add_component("add", add, x=nd, y=b) pipe.alias("result", na) + pipe = pipe.build() assert pipe.run("result", a=5, b=7) == 17 p2 = pipe.clone("pipeline-config") @@ -263,11 +272,12 @@ def test_alias_node(): def test_literal(): - pipe = Pipeline("literal-prefix") + pipe = PipelineBuilder("literal-prefix") msg = pipe.create_input("msg", str) pipe.add_component("prefix", msg_prefix, prefix=pipe.literal("hello, "), msg=msg) + pipe = pipe.build() assert pipe.run(msg="HACKEM MUCHE") == "hello, HACKEM MUCHE" print(pipe.get_config().model_dump_json(indent=2)) @@ -276,11 +286,12 @@ def test_literal(): def test_literal_array(): - pipe = Pipeline("literal-add-array") + pipe = PipelineBuilder("literal-add-array") a = pipe.create_input("a", int) pipe.add_component("add", add, x=np.arange(10), y=a) + pipe = pipe.build() res = pipe.run(a=5) assert np.all(res == np.arange(5, 15)) @@ -291,12 +302,13 @@ def test_literal_array(): def test_stable_with_literals(): "test that two identical pipelines have the same hash, even with literals" - p1 = Pipeline("literal-add-array") + p1 = PipelineBuilder("literal-add-array") a = p1.create_input("a", int) p1.add_component("add", add, x=np.arange(10), y=a) - p2 = Pipeline("literal-add-array") + p2 = PipelineBuilder("literal-add-array") a = p2.create_input("a", int) p2.add_component("add", add, x=np.arange(10), y=a) assert p1.config_hash() == p2.config_hash() + assert p1.build().config_hash() == p2.build().config_hash() From 74b84f707c80754976ba831b00e7bf4d8550d4fd Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 16:21:22 -0500 Subject: [PATCH 04/37] update test_train --- lenskit/tests/pipeline/test_train.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lenskit/tests/pipeline/test_train.py b/lenskit/tests/pipeline/test_train.py index 07cb304a3..04fa7bdf4 100644 --- a/lenskit/tests/pipeline/test_train.py +++ b/lenskit/tests/pipeline/test_train.py @@ -8,17 +8,18 @@ from lenskit.data.dataset import Dataset from lenskit.data.vocab import Vocabulary -from lenskit.pipeline import Pipeline +from lenskit.pipeline import PipelineBuilder from lenskit.training import Trainable, TrainingOptions def test_train(ml_ds: Dataset): - pipe = Pipeline() + pipe = PipelineBuilder() item = pipe.create_input("item", int) tc: Trainable = TestComponent() pipe.add_component("test", tc, item=item) + pipe = pipe.build() pipe.train(ml_ds) # return true for an item that exists From b9de3389a464e92ba4b0f8e334fc3aa0062da42d Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 16:21:54 -0500 Subject: [PATCH 05/37] remove builder methods from pipeline --- lenskit/lenskit/pipeline/_impl.py | 257 ++---------------------------- 1 file changed, 10 insertions(+), 247 deletions(-) diff --git a/lenskit/lenskit/pipeline/_impl.py b/lenskit/lenskit/pipeline/_impl.py index f2ec38a4f..0bfffe2ae 100644 --- a/lenskit/lenskit/pipeline/_impl.py +++ b/lenskit/lenskit/pipeline/_impl.py @@ -1,14 +1,13 @@ # pyright: strict from __future__ import annotations -import typing import warnings from dataclasses import replace -from types import FunctionType, UnionType -from uuid import NAMESPACE_URL, uuid4, uuid5 +from types import FunctionType +from uuid import NAMESPACE_URL, uuid5 from numpy.random import BitGenerator, Generator, SeedSequence -from typing_extensions import Any, Literal, Self, TypeAlias, TypeVar, cast, overload +from typing_extensions import Any, Literal, Self, TypeAlias, TypeVar, overload from lenskit.data import Dataset from lenskit.diagnostics import PipelineError, PipelineWarning @@ -19,11 +18,10 @@ from .components import ( # type: ignore # noqa: F401 Component, PipelineFunction, - fallback_on_none, instantiate_component, ) from .config import PipelineConfig -from .nodes import ND, ComponentNode, InputNode, LiteralNode, Node +from .nodes import ComponentNode, InputNode, LiteralNode, Node from .state import PipelineState from .types import parse_type_string @@ -48,8 +46,12 @@ class Pipeline: way. It allows you to wire together components in (mostly) abitrary graphs, train them on data, and serialize pipelines to disk for use elsewhere. - If you have a scoring model and just want to generate recommenations with a - default setup and minimal configuration, see :func:`topn_pipeline`. + Pipelines cannot be directly instantiated; they must be built with a + :class:`~lenskit.pipeline.PipelineBuilder` class, or loaded from a + configuration with :meth:`from_config`. If you have a scoring model and just + want to generate recommenations with a default setup and minimal + configuration, see :func:`~lenskit.pipeline.topn_pipeline` or + :class:`~lenskit.pipeline.RecPipelineBuilder`. Pipelines are also :class:`~lenskit.training.Trainable`, and train all trainable components. @@ -144,193 +146,6 @@ def node( else: raise KeyError(f"node {node}") - def create_input(self, name: str, *types: type[T] | None) -> Node[T]: - """ - Create an input node for the pipeline. Pipelines expect their inputs to - be provided when they are run. - - Args: - name: - The name of the input. The name must be unique in the pipeline - (among both components and inputs). - types: - The allowable types of the input; input data can be of any - specified type. If ``None`` is among the allowed types, the - input can be omitted. - - Returns: - A pipeline node representing this input. - - Raises: - ValueError: - a node with the specified ``name`` already exists. - """ - self._check_available_name(name) - - rts: set[type[T | None]] = set() - for t in types: - if t is None: - rts.add(type(None)) - elif isinstance(t, UnionType): - rts |= set(typing.get_args(t)) - else: - rts.add(t) - - node = InputNode[Any](name, types=rts) - self._nodes[name] = node - self._clear_caches() - return node - - def literal(self, value: T, *, name: str | None = None) -> LiteralNode[T]: - """ - Create a literal node (a node with a fixed value). - - .. note:: - Literal nodes cannot be serialized witih :meth:`get_config` or - :meth:`save_config`. - """ - if name is None: - name = str(uuid4()) - self._anon_nodes.add(name) - node = LiteralNode(name, value, types=set([type(value)])) - self._nodes[name] = node - self._clear_caches() - return node - - def set_default(self, name: str, node: Node[Any] | object) -> None: - """ - Set the default wiring for a component input. Components that declare - an input parameter with the specified ``name`` but no configured input - will be wired to this node. - - This is intended to be used for things like wiring up `user` parameters - to semi-automatically receive the target user's identity and history. - - Args: - name: - The name of the parameter to set a default for. - node: - The node or literal value to wire to this parameter. - """ - if not isinstance(node, Node): - node = self.literal(node) - self._defaults[name] = node - self._clear_caches() - - def get_default(self, name: str) -> Node[Any] | None: - """ - Get the default wiring for an input name. - """ - return self._defaults.get(name, None) - - def alias(self, alias: str, node: Node[Any] | str) -> None: - """ - Create an alias for a node. After aliasing, the node can be retrieved - from :meth:`node` using either its original name or its alias. - - Args: - alias: - The alias to add to the node. - node: - The node (or node name) to alias. - - Raises: - ValueError: - if the alias is already used as an alias or node name. - """ - node = self.node(node) - self._check_available_name(alias) - self._aliases[alias] = node - self._clear_caches() - - def add_component( - self, name: str, obj: Component[ND] | PipelineFunction[ND], **inputs: Node[Any] | object - ) -> Node[ND]: - """ - Add a component and connect it into the graph. - - Args: - name: - The name of the component in the pipeline. The name must be - unique in the pipeline (among both components and inputs). - obj: - The component itself. - inputs: - The component's input wiring. See :ref:`pipeline-connections` - for details. - - Returns: - The node representing this component in the pipeline. - """ - self._check_available_name(name) - - node = ComponentNode(name, obj) - self._nodes[name] = node - self._components[name] = obj - - self.connect(node, **inputs) - - self._clear_caches() - self._last = node - return node - - def replace_component( - self, - name: str | Node[ND], - obj: Component[ND] | PipelineFunction[ND], - **inputs: Node[Any] | object, - ) -> Node[ND]: - """ - Replace a component in the graph. The new component must have a type - that is compatible with the old component. The old component's input - connections will be replaced (as the new component may have different - inputs), but any connections that use the old component to supply an - input will use the new component instead. - """ - if isinstance(name, Node): - name = name.name - - node = ComponentNode(name, obj) - self._nodes[name] = node - self._components[name] = obj - - self.connect(node, **inputs) - - self._clear_caches() - return node - - def connect(self, obj: str | Node[Any], **inputs: Node[Any] | str | object): - """ - Provide additional input connections for a component that has already - been added. See :ref:`pipeline-connections` for details. - - Args: - obj: - The name or node of the component to wire. - inputs: - The component's input wiring. For each keyword argument in the - component's function signature, that argument can be provided - here with an input that the pipeline will provide to that - argument of the component when the pipeline is run. - """ - if isinstance(obj, Node): - node = obj - else: - node = self.node(obj) - if not isinstance(node, ComponentNode): - raise TypeError(f"only component nodes can be wired, not {node}") - - for k, n in inputs.items(): - if isinstance(n, Node): - n = cast(Node[Any], n) - self._check_member_node(n) - node.connections[k] = n.name - else: - lit = self.literal(n) - node.connections[k] = lit.name - - self._clear_caches() - def component_configs(self) -> dict[str, dict[str, Any]]: """ Get the configurations for the components. This is the configurations @@ -717,58 +532,6 @@ def run_all(self, *nodes: str | Node[Any], **kwargs: object) -> PipelineState: meta=self.meta(), ) - def use_first_of(self, name: str, primary: Node[T | None], fallback: Node[T]) -> Node[T]: - """ - Ergonomic method to create a new node that returns the result of its - ``input`` if it is provided and not ``None``, and otherwise returns the - result of ``fallback``. This method is used for things like filling in - optional pipeline inputs. For example, if you want the pipeline to take - candidate items through an ``items`` input, but look them up from the - user's history and the training data if ``items`` is not supplied, you - would do: - - .. code:: python - - pipe = Pipeline() - # allow candidate items to be optionally specified - items = pipe.create_input('items', list[EntityId], None) - # find candidates from the training data (optional) - lookup_candidates = pipe.add_component( - 'select-candidates', UnratedTrainingItemsCandidateSelector(), - user=history, - ) - # if the client provided items as a pipeline input, use those; otherwise - # use the candidate selector we just configured. - candidates = pipe.use_first_of('candidates', items, lookup_candidates) - - .. note:: - - This method does not distinguish between an input being unspecified - and explicitly specified as ``None``. - - .. note:: - - This method does *not* implement item-level fallbacks, only - fallbacks at the level of entire results. For item-level score - fallbacks, see :class:`~lenskit.basic.FallbackScorer`. - - .. note:: - If one of the fallback elements is a component ``A`` that depends on - another component or input ``B``, and ``B`` is missing or returns - ``None`` such that ``A`` would usually fail, then ``A`` will be - skipped and the fallback will move on to the next node. This works - with arbitrarily-deep transitive chains. - - Args: - name: - The name of the node. - primary: - The node to use as the primary input, if it is available. - fallback: - The node to use if the primary input does not provide a value. - """ - return self.add_component(name, fallback_on_none, primary=primary, fallback=fallback) - def _check_available_name(self, name: str) -> None: if name in self._nodes or name in self._aliases: raise ValueError(f"pipeline already has node {name}") From 34c92d135952ddf939445bbf8e12bbe5b20ce763 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 17:12:36 -0500 Subject: [PATCH 06/37] first pass at working separate builder --- lenskit/lenskit/pipeline/_impl.py | 50 ++--- lenskit/lenskit/pipeline/builder.py | 212 +++++++++--------- lenskit/lenskit/pipeline/components.py | 16 ++ lenskit/lenskit/pipeline/nodes.py | 43 +++- .../tests/pipeline/test_component_config.py | 9 +- 5 files changed, 176 insertions(+), 154 deletions(-) diff --git a/lenskit/lenskit/pipeline/_impl.py b/lenskit/lenskit/pipeline/_impl.py index 0bfffe2ae..0e8545f90 100644 --- a/lenskit/lenskit/pipeline/_impl.py +++ b/lenskit/lenskit/pipeline/_impl.py @@ -35,8 +35,8 @@ T4 = TypeVar("T4") T5 = TypeVar("T5") -CloneMethod: TypeAlias = Literal["config", "pipeline-config"] NAMESPACE_LITERAL_DATA = uuid5(NAMESPACE_URL, "https://ns.lenskit.org/literal-data/") +CloneMethod: TypeAlias = Literal["config", "pipeline-config"] class Pipeline: @@ -46,7 +46,7 @@ class Pipeline: way. It allows you to wire together components in (mostly) abitrary graphs, train them on data, and serialize pipelines to disk for use elsewhere. - Pipelines cannot be directly instantiated; they must be built with a + Pipelines should not be directly instantiated; they must be built with a :class:`~lenskit.pipeline.PipelineBuilder` class, or loaded from a configuration with :meth:`from_config`. If you have a scoring model and just want to generate recommenations with a default setup and minimal @@ -56,19 +56,11 @@ class Pipeline: Pipelines are also :class:`~lenskit.training.Trainable`, and train all trainable components. - Args: - name: - A name for the pipeline. - version: - A numeric version for the pipeline. - Stability: Caller """ - name: str | None = None - version: str | None = None - + _config: config.PipelineConfig _nodes: dict[str, Node[Any]] _aliases: dict[str, Node[Any]] _defaults: dict[str, Node[Any]] @@ -78,29 +70,27 @@ class Pipeline: _anon_nodes: set[str] "Track generated node names." - def __init__(self, name: str | None = None, version: str | None = None): - self.name = name - self.version = version - self._nodes = {} - self._aliases = {} - self._defaults = {} - self._components = {} - self._anon_nodes = set() - self._clear_caches() - - def meta(self, *, include_hash: bool = True) -> config.PipelineMeta: + def __init__(self, config: config.PipelineConfig, nodes: dict[str, Node[Any]]): + self._config = config + self._nodes = dict(nodes) + self._aliases = {a: self.node(t) for (a, t) in config.aliases.items()} + self._defaults = {n: self.node(t) for (n, t) in config.defaults.items()} + + @property + def name(self) -> str | None: + return self._config.meta.name + + @property + def version(self) -> str | None: + return self._config.meta.version + + @property + def meta(self) -> config.PipelineMeta: """ Get the metadata (name, version, hash, etc.) for this pipeline without returning the whole config. - - Args: - include_hash: - Whether to include a configuration hash in the metadata. """ - meta = config.PipelineMeta(name=self.name, version=self.version) - if include_hash: - meta.hash = self.config_hash() - return meta + return self._config.meta @property def nodes(self) -> list[Node[object]]: diff --git a/lenskit/lenskit/pipeline/builder.py b/lenskit/lenskit/pipeline/builder.py index 183bb667d..8a68fb121 100644 --- a/lenskit/lenskit/pipeline/builder.py +++ b/lenskit/lenskit/pipeline/builder.py @@ -7,32 +7,38 @@ import typing import warnings -from dataclasses import replace -from types import FunctionType, UnionType +from types import UnionType from uuid import NAMESPACE_URL, uuid4, uuid5 -from numpy.random import BitGenerator, Generator, SeedSequence -from typing_extensions import Any, Literal, Self, TypeAlias, TypeVar, cast, overload +from typing_extensions import Any, Literal, Self, TypeVar, cast, overload -from lenskit.data import Dataset from lenskit.diagnostics import PipelineError, PipelineWarning from lenskit.logging import get_logger -from lenskit.training import Trainable, TrainingOptions from . import config from ._impl import Pipeline from .components import ( # type: ignore # noqa: F401 Component, + ComponentConstructor, PipelineFunction, fallback_on_none, instantiate_component, ) from .config import PipelineConfig -from .nodes import ND, ComponentNode, InputNode, LiteralNode, Node +from .nodes import ( + ND, + ComponentConstructorNode, + ComponentInstanceNode, + ComponentNode, + InputNode, + LiteralNode, + Node, +) from .types import parse_type_string _log = get_logger(__name__) +CFG = TypeVar("CFG") # common type var for quick use T = TypeVar("T") T1 = TypeVar("T1") @@ -41,7 +47,6 @@ T4 = TypeVar("T4") T5 = TypeVar("T5") -CloneMethod: TypeAlias = Literal["config", "pipeline-config"] NAMESPACE_LITERAL_DATA = uuid5(NAMESPACE_URL, "https://ns.lenskit.org/literal-data/") @@ -81,7 +86,6 @@ class PipelineBuilder: _aliases: dict[str, Node[Any]] _defaults: dict[str, Node[Any]] _components: dict[str, PipelineFunction[Any] | Component[Any]] - _hash: str | None = None _last: Node[Any] | None = None _anon_nodes: set[str] "Track generated node names." @@ -94,7 +98,6 @@ def __init__(self, name: str | None = None, version: str | None = None): self._defaults = {} self._components = {} self._anon_nodes = set() - self._clear_caches() def meta(self, *, include_hash: bool = True) -> config.PipelineMeta: """ @@ -188,7 +191,6 @@ def create_input(self, name: str, *types: type[T] | None) -> Node[T]: node = InputNode[Any](name, types=rts) self._nodes[name] = node - self._clear_caches() return node def literal(self, value: T, *, name: str | None = None) -> LiteralNode[T]: @@ -204,7 +206,6 @@ def literal(self, value: T, *, name: str | None = None) -> LiteralNode[T]: self._anon_nodes.add(name) node = LiteralNode(name, value, types=set([type(value)])) self._nodes[name] = node - self._clear_caches() return node def set_default(self, name: str, node: Node[Any] | object) -> None: @@ -225,7 +226,6 @@ def set_default(self, name: str, node: Node[Any] | object) -> None: if not isinstance(node, Node): node = self.literal(node) self._defaults[name] = node - self._clear_caches() def get_default(self, name: str) -> Node[Any] | None: """ @@ -251,10 +251,31 @@ def alias(self, alias: str, node: Node[Any] | str) -> None: node = self.node(node) self._check_available_name(alias) self._aliases[alias] = node - self._clear_caches() + @overload def add_component( - self, name: str, obj: Component[ND] | PipelineFunction[ND], **inputs: Node[Any] | object + self, + name: str, + cls: ComponentConstructor[CFG, ND], + config: CFG = None, + /, + **inputs: Node[Any], + ) -> Node[ND]: ... + @overload + def add_component( + self, + name: str, + instance: Component[ND] | PipelineFunction[ND], + /, + **inputs: Node[Any] | object, + ) -> Node[ND]: ... + def add_component( + self, + name: str, + comp: ComponentConstructor[CFG, ND] | Component[ND] | PipelineFunction[ND], + config: CFG | None = None, + /, + **inputs: Node[Any] | object, ) -> Node[ND]: """ Add a component and connect it into the graph. @@ -263,8 +284,12 @@ def add_component( name: The name of the component in the pipeline. The name must be unique in the pipeline (among both components and inputs). - obj: - The component itself. + cls: + A component class. + config: + The configuration object for the component class. + instance: + A raw function or pre-instantiated component. inputs: The component's input wiring. See :ref:`pipeline-connections` for details. @@ -274,20 +299,41 @@ def add_component( """ self._check_available_name(name) - node = ComponentNode(name, obj) + if isinstance(comp, ComponentConstructor): + node = ComponentConstructorNode(name, comp, config) + else: + node = ComponentInstanceNode(name, comp) + self._nodes[name] = node - self._components[name] = obj self.connect(node, **inputs) - self._clear_caches() self._last = node return node + @overload + def replace_component( + self, + name: str | Node[ND], + cls: ComponentConstructor[CFG, ND], + config: CFG = None, + /, + **inputs: Node[Any], + ) -> Node[ND]: ... + @overload def replace_component( self, name: str | Node[ND], - obj: Component[ND] | PipelineFunction[ND], + instance: Component[ND] | PipelineFunction[ND], + /, + **inputs: Node[Any] | object, + ) -> Node[ND]: ... + def replace_component( + self, + name: str | Node[ND], + comp: ComponentConstructor[CFG, ND] | Component[ND] | PipelineFunction[ND], + config: CFG | None = None, + /, **inputs: Node[Any] | object, ) -> Node[ND]: """ @@ -300,13 +346,15 @@ def replace_component( if isinstance(name, Node): name = name.name - node = ComponentNode(name, obj) + if isinstance(comp, ComponentConstructor): + node = ComponentConstructorNode(name, comp, config) + else: + node = ComponentInstanceNode(name, comp) + self._nodes[name] = node - self._components[name] = obj self.connect(node, **inputs) - self._clear_caches() return node def connect(self, obj: str | Node[Any], **inputs: Node[Any] | str | object): @@ -339,8 +387,6 @@ def connect(self, obj: str | Node[Any], **inputs: Node[Any] | str | object): lit = self.literal(n) node.connections[k] = lit.name - self._clear_caches() - def component_configs(self) -> dict[str, dict[str, Any]]: """ Get the configurations for the components. This is the configurations @@ -352,35 +398,12 @@ def component_configs(self) -> dict[str, dict[str, Any]]: if isinstance(comp, Component) } - def clone(self, how: CloneMethod = "config") -> Pipeline: + def clone(self) -> PipelineBuilder: """ - Clone the pipeline, optionally including trained parameters. - - The ``how`` parameter controls how the pipeline is cloned, and what is - available in the clone pipeline. It can be one of the following values: - - ``"config"`` - Create fresh component instances using the configurations of the - components in this pipeline. When applied to a trained pipeline, - the clone does **not** have the original's learned parameters. This - is the default clone method. - ``"pipeline-config"`` - Round-trip the entire pipeline through :meth:`get_config` and - :meth:`from_config`. - - Args: - how: - The mechanism to use for cloning the pipeline. - - Returns: - A new pipeline with the same components and wiring, but fresh - instances created by round-tripping the configuration. + Clone the pipeline builder. The resulting builder starts as a copy of + this builder, and any subsequent modifications only the copy to which + they are applied. """ - if how == "pipeline-config": - cfg = self.get_config() - return self.from_config(cfg) - elif how != "config": # pragma: nocover - raise NotImplementedError("only 'config' cloning is currently supported") clone = PipelineBuilder() @@ -392,22 +415,25 @@ def clone(self, how: CloneMethod = "config") -> Pipeline: clone.create_input(name, *types) case LiteralNode(name, value): clone._nodes[name] = LiteralNode(name, value) - case ComponentNode(name, comp, _inputs, wiring): - if isinstance(comp, FunctionType): - comp = comp - elif isinstance(comp, Component): - comp = comp.__class__(comp.config) # type: ignore - else: - comp = comp.__class__() # type: ignore - cn = clone.add_component(node.name, comp) # type: ignore - for wn, wt in wiring.items(): - clone.connect(cn, **{wn: clone.node(wt)}) + case ComponentConstructorNode(name, comp, config): + cn = clone.add_component(name, comp, config) + case ComponentInstanceNode(name, comp): + cn = clone.add_component(name, comp) case _: # pragma: nocover raise RuntimeError(f"invalid node {node}") for n, t in self._aliases.items(): clone.alias(n, t.name) + for node in self.nodes: + match node: + case ComponentNode(name, connections=wiring): + cn = clone.node(name) + for wn, wt in wiring.items(): + clone.connect(cn, **{wn: clone.node(wt)}) + case _: + pass + for n, t in self._defaults.items(): clone.set_default(n, clone.node(t.name)) @@ -502,7 +528,7 @@ def config_hash(self) -> str: @classmethod def from_config(cls, config: object) -> Self: """ - Reconstruct a pipeline from a serialized configuration. + Reconstruct a pipeline builder from a serialized configuration. Args: config: @@ -570,48 +596,6 @@ def from_config(cls, config: object) -> Self: return pipe - def train(self, data: Dataset, options: TrainingOptions | None = None) -> None: - """ - Trains the pipeline's trainable components (those implementing the - :class:`TrainableComponent` interface) on some training data. - - .. admonition:: Random Number Generation - :class: note - - If :attr:`TrainingOptions.rng` is set and is not a generator or bit - generator (i.e. it is a seed), then this method wraps the seed in a - :class:`~numpy.random.SeedSequence` and calls - :class:`~numpy.random.SeedSequence.spawn()` to generate a distinct - seed for each component in the pipeline. - - Args: - data: - The dataset to train on. - options: - The training options. If ``None``, default options are used. - """ - log = _log.bind(pipeline=self.name) - if options is None: - options = TrainingOptions() - - if isinstance(options.rng, SeedSequence): - seed = options.rng - elif options.rng is None or isinstance(options.rng, (Generator, BitGenerator)): - seed = None - else: - seed = SeedSequence(options.rng) - - log.info("training pipeline components") - for name, comp in self._components.items(): - clog = log.bind(name=name, component=comp) - if isinstance(comp, Trainable): - # spawn new seed if needed - c_opts = options if seed is None else replace(options, rng=seed.spawn(1)[0]) - clog.info("training component") - comp.train(data, c_opts) - else: - clog.debug("training not required") - def use_first_of(self, name: str, primary: Node[T | None], fallback: Node[T]) -> Node[T]: """ Ergonomic method to create a new node that returns the result of its @@ -668,7 +652,17 @@ def build(self) -> Pipeline: """ Build the pipeline. """ - return self # type: ignore + config = self.get_config() + return Pipeline(config, self._nodes) + + def _instantiate(self, node: Node[ND]) -> Node[ND]: + match node: + case ComponentConstructorNode(name, constructor, config, connections=cxns): + _log.debug("instantiating component", component=constructor) + instance = constructor(config) + return ComponentInstanceNode(name, instance, cxns) + case _: + return node def _check_available_name(self, name: str) -> None: if name in self._nodes or name in self._aliases: @@ -678,7 +672,3 @@ def _check_member_node(self, node: Node[Any]) -> None: nw = self._nodes.get(node.name) if nw is not node: raise PipelineError(f"node {node} not in pipeline") - - def _clear_caches(self): - if "_hash" in self.__dict__: - del self._hash diff --git a/lenskit/lenskit/pipeline/components.py b/lenskit/lenskit/pipeline/components.py index de7954703..ab53d67bd 100644 --- a/lenskit/lenskit/pipeline/components.py +++ b/lenskit/lenskit/pipeline/components.py @@ -35,6 +35,7 @@ P = ParamSpec("P") T = TypeVar("T") +CFG = TypeVar("CFG", contravariant=True) CArgs = ParamSpec("CArgs", default=...) """ Argument type for a component. It is difficult to actually specify this, but @@ -48,6 +49,21 @@ Return type for a component. """ PipelineFunction: TypeAlias = Callable[..., COut] +""" +Pure-function interface for pipeline functions. +""" + + +class ComponentConstructor(ABC, Generic[CFG, COut]): + """ + Protocol for component constructors. + """ + + def __call__(self, config: CFG | None = None) -> Component[COut]: ... + + def __isinstance__(self, obj: Any) -> bool: + # FIXME: implement a more rigorous check for this + return isinstance(obj, type) and issubclass(obj, Component) @runtime_checkable diff --git a/lenskit/lenskit/pipeline/nodes.py b/lenskit/lenskit/pipeline/nodes.py index 249665505..91e7e5971 100644 --- a/lenskit/lenskit/pipeline/nodes.py +++ b/lenskit/lenskit/pipeline/nodes.py @@ -8,15 +8,17 @@ import warnings from inspect import Signature, signature +from typing import Any, Callable from typing_extensions import Generic, TypeVar -from .components import PipelineFunction +from .components import Component, ComponentConstructor, PipelineFunction from .types import TypecheckWarning # Nodes are (conceptually) immutable data containers, so Node[U] can be assigned # to Node[T] if U ≼ T. ND = TypeVar("ND", covariant=True) +CFG = TypeVar("CFG", contravariant=True, bound=object) class Node(Generic[ND]): @@ -70,16 +72,14 @@ def __init__(self, name: str, value: ND, *, types: set[type] | None = None): class ComponentNode(Node[ND], Generic[ND]): """ - A node storing a component. + A node storing a component. This is an abstract node class; see subclasses + :class:`ComponentConstructorNode` and `ComponentInstanceNode`. Stability: Internal """ - __match_args__ = ("name", "component", "inputs", "connections") - - component: PipelineFunction[ND] - "The component associated with this node" + __match_args__ = ("name", "inputs", "connections") inputs: dict[str, type | None] "The component's inputs." @@ -87,11 +87,11 @@ class ComponentNode(Node[ND], Generic[ND]): connections: dict[str, str] "The component's input connections." - def __init__(self, name: str, component: PipelineFunction[ND]): + def __init__(self, name: str): super().__init__(name) - self.component = component self.connections = {} + def _setup_signature(self, component: Callable[..., ND]): sig = signature(component, eval_str=True) if sig.return_annotation == Signature.empty: warnings.warn( @@ -111,3 +111,30 @@ def __init__(self, name: str, component: PipelineFunction[ND]): self.inputs[param.name] = None else: self.inputs[param.name] = param.annotation + + +class ComponentConstructorNode(ComponentNode[ND], Generic[ND]): + __match_args__ = ("name", "constructor", "config", "inputs", "connections") + constructor: ComponentConstructor[Any, ND] + config: object | None + + def __init__(self, name: str, constructor: ComponentConstructor[CFG, ND], config: CFG): + self.constructor = constructor + self.config = config + + +class ComponentInstanceNode(ComponentNode[ND], Generic[ND]): + __match_args__ = ("name", "component", "inputs", "connections") + + component: Component[ND] | PipelineFunction[ND] + + def __init__( + self, + name: str, + component: Component[ND] | PipelineFunction[ND], + connections: dict[str, str] | None = None, + ): + super().__init__(name) + self.component = component + self._setup_signature(component) + self.connections = connections or {} diff --git a/lenskit/tests/pipeline/test_component_config.py b/lenskit/tests/pipeline/test_component_config.py index a105414ed..230847ef5 100644 --- a/lenskit/tests/pipeline/test_component_config.py +++ b/lenskit/tests/pipeline/test_component_config.py @@ -8,6 +8,7 @@ import json from dataclasses import dataclass +from typing import Any from pydantic import BaseModel from pydantic.dataclasses import dataclass as pydantic_dataclass @@ -15,7 +16,7 @@ from pytest import mark from lenskit.pipeline import PipelineBuilder -from lenskit.pipeline.components import Component +from lenskit.pipeline.components import Component, ComponentConstructor @dataclass @@ -85,12 +86,10 @@ def test_auto_config_roundtrip(prefixer: type[Component]): @mark.parametrize("prefixer", [PrefixerDC, PrefixerM, PrefixerPYDC]) -def test_pipeline_config(prefixer: type[Component]): - comp = prefixer(prefix="scroll named ") - +def test_pipeline_config(prefixer: ComponentConstructor[Any, str]): pipe = PipelineBuilder() msg = pipe.create_input("msg", str) - pipe.add_component("prefix", comp, msg=msg) + pipe.add_component("prefix", prefixer, {"prefix": "scroll named "}, msg=msg) pipe = pipe.build() assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH" From e0da460fb11f1842ccca778fd77b86403db04f13 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 18:28:17 -0500 Subject: [PATCH 07/37] A lot more pieces, including fixing the run api --- docs/guide/pipeline.rst | 3 - lenskit/lenskit/pipeline/_impl.py | 315 ++++++------------ lenskit/lenskit/pipeline/builder.py | 95 +++--- lenskit/lenskit/pipeline/components.py | 14 +- lenskit/lenskit/pipeline/config.py | 28 +- lenskit/lenskit/pipeline/nodes.py | 17 +- lenskit/lenskit/pipeline/runner.py | 6 +- lenskit/lenskit/testing/_components.py | 2 +- .../tests/pipeline/test_component_config.py | 13 +- lenskit/tests/pipeline/test_pipeline.py | 56 +++- lenskit/tests/pipeline/test_pipeline_clone.py | 10 +- lenskit/tests/pipeline/test_save_load.py | 38 +-- lenskit/tests/pipeline/test_train.py | 1 + 13 files changed, 270 insertions(+), 328 deletions(-) diff --git a/docs/guide/pipeline.rst b/docs/guide/pipeline.rst index b0ec4ccfc..7e9337b8e 100644 --- a/docs/guide/pipeline.rst +++ b/docs/guide/pipeline.rst @@ -169,9 +169,6 @@ The :meth:`~Pipeline.run` method takes two types of inputs: obtained (e.g. initial item scores and final rankings, which may have altered scores). - If no components are specified, it is the same as specifying the last - component that was added to the pipeline. - * Keyword arguments specifying the values for the pipeline's inputs, as defined by calls to :meth:`Pipeline.create_input`. diff --git a/lenskit/lenskit/pipeline/_impl.py b/lenskit/lenskit/pipeline/_impl.py index 0e8545f90..e8849fbf8 100644 --- a/lenskit/lenskit/pipeline/_impl.py +++ b/lenskit/lenskit/pipeline/_impl.py @@ -1,29 +1,22 @@ # pyright: strict from __future__ import annotations -import warnings from dataclasses import replace -from types import FunctionType from uuid import NAMESPACE_URL, uuid5 from numpy.random import BitGenerator, Generator, SeedSequence -from typing_extensions import Any, Literal, Self, TypeAlias, TypeVar, overload +from typing_extensions import Any, Literal, TypeAlias, TypeVar, overload from lenskit.data import Dataset -from lenskit.diagnostics import PipelineError, PipelineWarning +from lenskit.diagnostics import PipelineError from lenskit.logging import get_logger from lenskit.training import Trainable, TrainingOptions from . import config -from .components import ( # type: ignore # noqa: F401 - Component, - PipelineFunction, - instantiate_component, -) +from .components import Component from .config import PipelineConfig -from .nodes import ComponentNode, InputNode, LiteralNode, Node +from .nodes import ComponentInstanceNode, ComponentNode, InputNode, LiteralNode, Node from .state import PipelineState -from .types import parse_type_string _log = get_logger(__name__) @@ -63,28 +56,42 @@ class Pipeline: _config: config.PipelineConfig _nodes: dict[str, Node[Any]] _aliases: dict[str, Node[Any]] - _defaults: dict[str, Node[Any]] - _components: dict[str, PipelineFunction[Any] | Component[Any]] + _default: Node[Any] | None = None _hash: str | None = None - _last: Node[Any] | None = None - _anon_nodes: set[str] - "Track generated node names." def __init__(self, config: config.PipelineConfig, nodes: dict[str, Node[Any]]): self._config = config self._nodes = dict(nodes) self._aliases = {a: self.node(t) for (a, t) in config.aliases.items()} - self._defaults = {n: self.node(t) for (n, t) in config.defaults.items()} + if config.default: + self._default = self.node(config.default) + + @property + def config(self) -> PipelineConfig: + """ + Get the pipline configuration. + + .. important:: + + Do not modify the configuration returned, or it will become + out-of-sync with the pipeline and likely not behave correctly. + """ + return self._config @property def name(self) -> str | None: + """ + Get the pipeline name (if configured). + """ return self._config.meta.name @property def version(self) -> str | None: + """ + Get the pipeline version (if configured). + """ return self._config.meta.version - @property def meta(self) -> config.PipelineMeta: """ Get the metadata (name, version, hash, etc.) for this pipeline without @@ -92,7 +99,6 @@ def meta(self) -> config.PipelineMeta: """ return self._config.meta - @property def nodes(self) -> list[Node[object]]: """ Get the nodes in the pipeline graph. @@ -136,17 +142,6 @@ def node( else: raise KeyError(f"node {node}") - def component_configs(self) -> dict[str, dict[str, Any]]: - """ - Get the configurations for the components. This is the configurations - only, it does not include pipeline inputs or wiring. - """ - return { - name: comp.dump_config() - for (name, comp) in self._components.items() - if isinstance(comp, Component) - } - def clone(self, how: CloneMethod = "config") -> Pipeline: """ Clone the pipeline, optionally including trained parameters. @@ -171,107 +166,46 @@ def clone(self, how: CloneMethod = "config") -> Pipeline: A new pipeline with the same components and wiring, but fresh instances created by round-tripping the configuration. """ + from .builder import PipelineBuilder + if how == "pipeline-config": - cfg = self.get_config() - return self.from_config(cfg) + return self.from_config(self._config) elif how != "config": # pragma: nocover raise NotImplementedError("only 'config' cloning is currently supported") - clone = Pipeline() + clone = PipelineBuilder() - for node in self.nodes: + for node in self.nodes(): match node: case InputNode(name, types=types): if types is None: types = set[type]() clone.create_input(name, *types) case LiteralNode(name, value): - clone._nodes[name] = LiteralNode(name, value) - case ComponentNode(name, comp, _inputs, wiring): - if isinstance(comp, FunctionType): - comp = comp - elif isinstance(comp, Component): - comp = comp.__class__(comp.config) # type: ignore - else: - comp = comp.__class__() # type: ignore - cn = clone.add_component(node.name, comp) # type: ignore - for wn, wt in wiring.items(): - clone.connect(cn, **{wn: clone.node(wt)}) + clone.literal(value, name=name) + case ComponentInstanceNode(name, comp): + config = None + if isinstance(comp, Component): + config = comp.config + comp = comp.__class__ # type: ignore + clone.add_component(name, comp, config) # type: ignore case _: # pragma: nocover raise RuntimeError(f"invalid node {node}") for n, t in self._aliases.items(): clone.alias(n, t.name) - for n, t in self._defaults.items(): - clone.set_default(n, clone.node(t.name)) - - return clone - - def get_config(self, *, include_hash: bool = True) -> PipelineConfig: - """ - Get this pipeline's configuration for serialization. The configuration - consists of all inputs and components along with their configurations - and input connections. It can be serialized to disk (in JSON, YAML, or - a similar format) to save a pipeline. - - The configuration does **not** include any trained parameter values, - although the configuration may include things such as paths to - checkpoints to load such parameters, depending on the design of the - components in the pipeline. - - .. note:: - Literal nodes (from :meth:`literal`, or literal values wired to - inputs) cannot be serialized, and this method will fail if they - are present in the pipeline. - """ - meta = self.meta(include_hash=False) - cfg = PipelineConfig(meta=meta) - - # We map anonymous nodes to hash-based names for stability. If we ever - # allow anonymous components, this will need to be adjusted to maintain - # component ordering, but it works for now since only literals can be - # anonymous. First handle the anonymous nodes, so we have that mapping: - remapped: dict[str, str] = {} - for an in self._anon_nodes: - node = self._nodes.get(an, None) + for node in self.nodes(): match node: - case None: - # skip nodes that no longer exist - continue - case LiteralNode(name, value): - lit = config.PipelineLiteral.represent(value) - sname = str(uuid5(NAMESPACE_LITERAL_DATA, lit.model_dump_json())) - _log.debug("renamed anonymous node %s to %s", name, sname) - remapped[name] = sname - cfg.literals[sname] = lit + case ComponentNode(name, connections=cxns): + cn = clone.node(name) + clone.connect(cn, **{wt: clone.node(wn) for (wt, wn) in cxns.items()}) case _: - # the pipeline only generates anonymous literal nodes right now - raise RuntimeError(f"unexpected anonymous node {node}") + pass - # Now we go over all named nodes and add them to the config: - for node in self.nodes: - if node.name in remapped: - continue - - match node: - case InputNode(): - cfg.inputs.append(config.PipelineInput.from_node(node)) - case LiteralNode(name, value): - cfg.literals[name] = config.PipelineLiteral.represent(value) - case ComponentNode(name): - cfg.components[name] = config.PipelineComponent.from_node(node, remapped) - case _: # pragma: nocover - raise RuntimeError(f"invalid node {node}") - - cfg.aliases = {a: t.name for (a, t) in self._aliases.items()} - cfg.defaults = {n: t.name for (n, t) in self._defaults.items()} - - if include_hash: - cfg.meta.hash = config.hash_config(cfg) - - return cfg + return clone.build() + @property def config_hash(self) -> str: """ Get a hash of the pipeline's configuration to uniquely identify it for @@ -288,14 +222,11 @@ def config_hash(self) -> str: JSON serialization of the pipeline configuration *without* a hash and returning the hex-encoded SHA256 hash of that configuration. """ - if self._hash is None: - # get the config *without* a hash - cfg = self.get_config(include_hash=False) - self._hash = config.hash_config(cfg) - return self._hash - - @classmethod - def from_config(cls, config: object) -> Self: + assert self._config.meta.hash, "pipeline configuration has no hash" + return self._config.meta.hash + + @staticmethod + def from_config(config: object) -> Pipeline: """ Reconstruct a pipeline from a serialized configuration. @@ -314,56 +245,11 @@ def from_config(cls, config: object) -> Self: configuration includes a hash but the constructed pipeline does not have a matching hash. """ - cfg = PipelineConfig.model_validate(config) - pipe = cls() - for inpt in cfg.inputs: - types: list[type[Any] | None] = [] - if inpt.types is not None: - types += [parse_type_string(t) for t in inpt.types] - pipe.create_input(inpt.name, *types) - - # we now add the components and other nodes in multiple passes to ensure - # that nodes are available before they are wired (since `connect` can - # introduce out-of-order dependencies). - - # pass 1: add literals - for name, data in cfg.literals.items(): - pipe.literal(data.decode(), name=name) - - # pass 2: add components - to_wire: list[config.PipelineComponent] = [] - for name, comp in cfg.components.items(): - if comp.code.startswith("@"): - # ignore special nodes in first pass - continue - - obj = instantiate_component(comp.code, comp.config) - pipe.add_component(name, obj) - to_wire.append(comp) - - # pass 3: wiring - for name, comp in cfg.components.items(): - if isinstance(comp.inputs, dict): - inputs = {n: pipe.node(t) for (n, t) in comp.inputs.items()} - pipe.connect(name, **inputs) - elif not comp.code.startswith("@"): - raise PipelineError(f"component {name} inputs must be dict, not list") - - # pass 4: aliases - for n, t in cfg.aliases.items(): - pipe.alias(n, t) - - # pass 5: defaults - for n, t in cfg.defaults.items(): - pipe.set_default(n, pipe.node(t)) - - if cfg.meta.hash is not None: - h2 = pipe.config_hash() - if h2 != cfg.meta.hash: - _log.warning("loaded pipeline does not match hash") - warnings.warn("loaded pipeline config does not match hash", PipelineWarning) - - return pipe + from .builder import PipelineBuilder + + config = PipelineConfig.model_validate(config) + builder = PipelineBuilder.from_config(config) + return builder.build() def train(self, data: Dataset, options: TrainingOptions | None = None) -> None: """ @@ -397,46 +283,51 @@ def train(self, data: Dataset, options: TrainingOptions | None = None) -> None: seed = SeedSequence(options.rng) log.info("training pipeline components") - for name, comp in self._components.items(): - clog = log.bind(name=name, component=comp) - if isinstance(comp, Trainable): - # spawn new seed if needed - c_opts = options if seed is None else replace(options, rng=seed.spawn(1)[0]) - clog.info("training component") - comp.train(data, c_opts) - else: - clog.debug("training not required") + for node in self.nodes(): + match node: + case ComponentInstanceNode(name, comp): + clog = log.bind(name=name, component=comp) + if isinstance(comp, Trainable): + # spawn new seed if needed + c_opts = options if seed is None else replace(options, rng=seed.spawn(1)[0]) + clog.info("training component") + comp.train(data, c_opts) + else: + clog.debug("training not required") + case _: + pass @overload def run(self, /, **kwargs: object) -> object: ... @overload def run(self, node: str, /, **kwargs: object) -> object: ... @overload - def run(self, n1: str, n2: str, /, *nrest: str, **kwargs: object) -> tuple[object]: ... + def run(self, nodes: tuple[str, ...], /, **kwargs: object) -> tuple[object, ...]: ... @overload def run(self, node: Node[T], /, **kwargs: object) -> T: ... @overload - def run(self, n1: Node[T1], n2: Node[T2], /, **kwargs: object) -> tuple[T1, T2]: ... + def run(self, nodes: tuple[Node[T1], Node[T2]], /, **kwargs: object) -> tuple[T1, T2]: ... @overload def run( - self, n1: Node[T1], n2: Node[T2], n3: Node[T3], /, **kwargs: object + self, nodes: tuple[Node[T1], Node[T2], Node[T3]], /, **kwargs: object ) -> tuple[T1, T2, T3]: ... @overload def run( - self, n1: Node[T1], n2: Node[T2], n3: Node[T3], n4: Node[T4], /, **kwargs: object + self, nodes: tuple[Node[T1], Node[T2], Node[T3], Node[T4]], /, **kwargs: object ) -> tuple[T1, T2, T3, T4]: ... @overload def run( self, - n1: Node[T1], - n2: Node[T2], - n3: Node[T3], - n4: Node[T4], - n5: Node[T5], + nodes: tuple[Node[T1], Node[T2], Node[T3], Node[T4], Node[T5]], /, **kwargs: object, ) -> tuple[T1, T2, T3, T4, T5]: ... - def run(self, *nodes: str | Node[Any], **kwargs: object) -> object: + def run( + self, + nodes: str | Node[Any] | tuple[str, ...] | tuple[Node[Any], ...] | None = None, + /, + **kwargs: object, + ) -> object: """ Run the pipeline and obtain the return value(s) of one or more of its components. See :ref:`pipeline-execution` for details of the pipeline @@ -449,9 +340,10 @@ def run(self, *nodes: str | Node[Any], **kwargs: object) -> object: The pipeline's inputs, as defined with :meth:`create_input`. Returns: - The pipeline result. If zero or one nodes are specified, the result - is returned as-is. If multiple nodes are specified, their results - are returned in a tuple. + The pipeline result. If no nodes are supplied, this is the result + of the default node. If a single node is supplied, it is the result + of that node. If a tuple of nodes is supplied, it is a tuple of + their results. Raises: PipelineError: @@ -463,14 +355,20 @@ def run(self, *nodes: str | Node[Any], **kwargs: object) -> object: other: exceptions thrown by components are passed through. """ - if not nodes: - if self._last is None: # pragma: nocover - raise PipelineError("pipeline has no components") - nodes = (self._last,) - state = self.run_all(*nodes, **kwargs) - results = [state[self.node(n).name] for n in nodes] - - if len(results) > 1: + if nodes is None: + if self._default: + node_list = [self._default] + else: + raise RuntimeError("no node specified and pipeline has no default") + elif isinstance(nodes, str) or isinstance(nodes, Node): + node_list = [nodes] + else: + node_list = nodes + + state = self.run_all(*node_list, **kwargs) + results = [state[self.node(n).name] for n in node_list] + + if node_list is nodes: return tuple(results) else: return results[0] @@ -485,10 +383,10 @@ def run_all(self, *nodes: str | Node[Any], **kwargs: object) -> PipelineState: 1. It returns the data from all nodes as a mapping (dictionary-like object), not just the specified nodes as a tuple. - 2. If no nodes are specified, it runs *all* nodes instead of only the - last node. This has the consequence of running nodes that are not - required to fulfill the last node (such scenarios typically result - from using :meth:`use_first_of`). + 2. If no nodes are specified, it runs *all* nodes. This has the + consequence of running nodes that are not required to fulfill the + last node (such scenarios typically result from using + :meth:`use_first_of`). Args: nodes: @@ -499,8 +397,7 @@ def run_all(self, *nodes: str | Node[Any], **kwargs: object) -> PipelineState: Returns: The full pipeline state, with :attr:`~PipelineState.default` set to - the last node specified (either the last node in `nodes`, or the - last node added to the pipeline). + the last node specified. """ from .runner import PipelineRunner @@ -508,7 +405,7 @@ def run_all(self, *nodes: str | Node[Any], **kwargs: object) -> PipelineState: node_list = [self.node(n) for n in nodes] _log.debug("running pipeline", name=self.name, nodes=[n.name for n in node_list]) if not node_list: - node_list = self.nodes + node_list = self.nodes() last = None for node in node_list: @@ -522,15 +419,7 @@ def run_all(self, *nodes: str | Node[Any], **kwargs: object) -> PipelineState: meta=self.meta(), ) - def _check_available_name(self, name: str) -> None: - if name in self._nodes or name in self._aliases: - raise ValueError(f"pipeline already has node {name}") - def _check_member_node(self, node: Node[Any]) -> None: nw = self._nodes.get(node.name) if nw is not node: raise PipelineError(f"node {node} not in pipeline") - - def _clear_caches(self): - if "_hash" in self.__dict__: - del self._hash diff --git a/lenskit/lenskit/pipeline/builder.py b/lenskit/lenskit/pipeline/builder.py index 8a68fb121..d2368603f 100644 --- a/lenskit/lenskit/pipeline/builder.py +++ b/lenskit/lenskit/pipeline/builder.py @@ -84,9 +84,9 @@ class PipelineBuilder: _nodes: dict[str, Node[Any]] _aliases: dict[str, Node[Any]] - _defaults: dict[str, Node[Any]] + _default_connections: dict[str, Node[Any]] _components: dict[str, PipelineFunction[Any] | Component[Any]] - _last: Node[Any] | None = None + _default: str | None = None _anon_nodes: set[str] "Track generated node names." @@ -95,7 +95,7 @@ def __init__(self, name: str | None = None, version: str | None = None): self.version = version self._nodes = {} self._aliases = {} - self._defaults = {} + self._default_connections = {} self._components = {} self._anon_nodes = set() @@ -113,7 +113,6 @@ def meta(self, *, include_hash: bool = True) -> config.PipelineMeta: meta.hash = self.config_hash() return meta - @property def nodes(self) -> list[Node[object]]: """ Get the nodes in the pipeline graph. @@ -208,7 +207,7 @@ def literal(self, value: T, *, name: str | None = None) -> LiteralNode[T]: self._nodes[name] = node return node - def set_default(self, name: str, node: Node[Any] | object) -> None: + def default_connection(self, name: str, node: Node[Any] | object) -> None: """ Set the default wiring for a component input. Components that declare an input parameter with the specified ``name`` but no configured input @@ -217,6 +216,12 @@ def set_default(self, name: str, node: Node[Any] | object) -> None: This is intended to be used for things like wiring up `user` parameters to semi-automatically receive the target user's identity and history. + .. important:: + + Defaults are a feature of the builder only, and are resolved in + :meth:`build`. They are not included in serialized configuration or + resulting pipeline. + Args: name: The name of the parameter to set a default for. @@ -225,13 +230,15 @@ def set_default(self, name: str, node: Node[Any] | object) -> None: """ if not isinstance(node, Node): node = self.literal(node) - self._defaults[name] = node + self._default_connections[name] = node - def get_default(self, name: str) -> Node[Any] | None: + def default_component(self, node: str | Node[Any]) -> None: """ - Get the default wiring for an input name. + Set the default node for the pipeline. If :meth:`Pipeline.run` is + called without a node, then it will run this node (and all of its + dependencies). """ - return self._defaults.get(name, None) + self._default = node.name if isinstance(node, Node) else node def alias(self, alias: str, node: Node[Any] | str) -> None: """ @@ -299,16 +306,11 @@ def add_component( """ self._check_available_name(name) - if isinstance(comp, ComponentConstructor): - node = ComponentConstructorNode(name, comp, config) - else: - node = ComponentInstanceNode(name, comp) - + node = ComponentNode[ND].create(name, comp, config) self._nodes[name] = node self.connect(node, **inputs) - self._last = node return node @overload @@ -346,11 +348,7 @@ def replace_component( if isinstance(name, Node): name = name.name - if isinstance(comp, ComponentConstructor): - node = ComponentConstructorNode(name, comp, config) - else: - node = ComponentInstanceNode(name, comp) - + node = ComponentNode[ND].create(name, comp, config) self._nodes[name] = node self.connect(node, **inputs) @@ -387,17 +385,6 @@ def connect(self, obj: str | Node[Any], **inputs: Node[Any] | str | object): lit = self.literal(n) node.connections[k] = lit.name - def component_configs(self) -> dict[str, dict[str, Any]]: - """ - Get the configurations for the components. This is the configurations - only, it does not include pipeline inputs or wiring. - """ - return { - name: comp.dump_config() - for (name, comp) in self._components.items() - if isinstance(comp, Component) - } - def clone(self) -> PipelineBuilder: """ Clone the pipeline builder. The resulting builder starts as a copy of @@ -407,7 +394,7 @@ def clone(self) -> PipelineBuilder: clone = PipelineBuilder() - for node in self.nodes: + for node in self.nodes(): match node: case InputNode(name, types=types): if types is None: @@ -425,21 +412,20 @@ def clone(self) -> PipelineBuilder: for n, t in self._aliases.items(): clone.alias(n, t.name) - for node in self.nodes: + for node in self.nodes(): match node: case ComponentNode(name, connections=wiring): cn = clone.node(name) - for wn, wt in wiring.items(): - clone.connect(cn, **{wn: clone.node(wt)}) + clone.connect(cn, **{wn: clone.node(wt) for (wn, wt) in wiring.items()}) case _: pass - for n, t in self._defaults.items(): - clone.set_default(n, clone.node(t.name)) + for n, t in self._default_connections.items(): + clone.default_connection(n, clone.node(t.name)) return clone - def get_config(self, *, include_hash: bool = True) -> PipelineConfig: + def build_config(self, *, include_hash: bool = True) -> PipelineConfig: """ Get this pipeline's configuration for serialization. The configuration consists of all inputs and components along with their configurations @@ -481,7 +467,7 @@ def get_config(self, *, include_hash: bool = True) -> PipelineConfig: raise RuntimeError(f"unexpected anonymous node {node}") # Now we go over all named nodes and add them to the config: - for node in self.nodes: + for node in self.nodes(): if node.name in remapped: continue @@ -496,7 +482,9 @@ def get_config(self, *, include_hash: bool = True) -> PipelineConfig: raise RuntimeError(f"invalid node {node}") cfg.aliases = {a: t.name for (a, t) in self._aliases.items()} - cfg.defaults = {n: t.name for (n, t) in self._defaults.items()} + + if self._default: + cfg.default = self._default if include_hash: cfg.meta.hash = config.hash_config(cfg) @@ -508,22 +496,21 @@ def config_hash(self) -> str: Get a hash of the pipeline's configuration to uniquely identify it for logging, version control, or other purposes. - The hash format and algorithm are not guaranteed, but is stable within a - LensKit version. For the same version of LensKit and component code, - the same configuration will produce the same hash, so long as there are - no literal nodes. Literal nodes will *usually* hash consistently, but - since literals other than basic JSON values are hashed by pickling, hash - stability depends on the stability of the pickle bytestream. + The hash format and algorithm are not guaranteed, but hashes are stable + within a LensKit version. For the same version of LensKit and component + code, the same configuration will produce the same hash, so long as + there are no literal nodes. Literal nodes will *usually* hash + consistently, but since literals other than basic JSON values are hashed + by pickling, hash stability depends on the stability of the pickle + bytestream. In LensKit 2025.1, the configuration hash is computed by computing the JSON serialization of the pipeline configuration *without* a hash and returning the hex-encoded SHA256 hash of that configuration. """ - if self._hash is None: - # get the config *without* a hash - cfg = self.get_config(include_hash=False) - self._hash = config.hash_config(cfg) - return self._hash + + cfg = self.build_config(include_hash=False) + return config.hash_config(cfg) @classmethod def from_config(cls, config: object) -> Self: @@ -584,10 +571,6 @@ def from_config(cls, config: object) -> Self: for n, t in cfg.aliases.items(): pipe.alias(n, t) - # pass 5: defaults - for n, t in cfg.defaults.items(): - pipe.set_default(n, pipe.node(t)) - if cfg.meta.hash is not None: h2 = pipe.config_hash() if h2 != cfg.meta.hash: @@ -652,7 +635,7 @@ def build(self) -> Pipeline: """ Build the pipeline. """ - config = self.get_config() + config = self.build_config() return Pipeline(config, self._nodes) def _instantiate(self, node: Node[ND]) -> Node[ND]: diff --git a/lenskit/lenskit/pipeline/components.py b/lenskit/lenskit/pipeline/components.py index ab53d67bd..5ea4c46a5 100644 --- a/lenskit/lenskit/pipeline/components.py +++ b/lenskit/lenskit/pipeline/components.py @@ -35,7 +35,7 @@ P = ParamSpec("P") T = TypeVar("T") -CFG = TypeVar("CFG", contravariant=True) +CFG = TypeVar("CFG") CArgs = ParamSpec("CArgs", default=...) """ Argument type for a component. It is difficult to actually specify this, but @@ -61,6 +61,8 @@ class ComponentConstructor(ABC, Generic[CFG, COut]): def __call__(self, config: CFG | None = None) -> Component[COut]: ... + def config_class(self) -> type[CFG] | None: ... + def __isinstance__(self, obj: Any) -> bool: # FIXME: implement a more rigorous check for this return isinstance(obj, type) and issubclass(obj, Component) @@ -160,7 +162,7 @@ class MyComponent(Component): def __init_subclass__(cls, **kwargs: Any): super().__init_subclass__(**kwargs) if not isabstract(cls): - ct = cls._config_class(return_any=True) + ct = cls.config_class(return_any=True) if ct == Any: warnings.warn( "component class {} does not define a config attribute".format( @@ -175,14 +177,14 @@ def __init__(self, config: object | None = None, **kwargs: Any): elif kwargs: raise RuntimeError("cannot supply both a configuration object and kwargs") - cfg_cls = self._config_class() + cfg_cls = self.config_class() if cfg_cls and not isinstance(config, cfg_cls): raise TypeError(f"invalid configuration type {type(config)}") self.config = config @classmethod - def _config_class(cls, return_any: bool = False) -> type | None: + def config_class(cls, return_any: bool = False) -> type | None: hints = get_type_hints(cls) ct = hints.get("config", None) if ct == NoneType: @@ -202,7 +204,7 @@ def dump_config(self) -> dict[str, JsonValue]: """ Dump the configuration to JSON-serializable format. """ - cfg_cls = self._config_class() + cfg_cls = self.config_class() if cfg_cls: return TypeAdapter(cfg_cls).dump_python(self.config, mode="json") # type: ignore else: @@ -215,7 +217,7 @@ def validate_config(cls, data: Mapping[str, JsonValue] | None = None) -> object """ if data is None: data = {} - cfg_cls = cls._config_class() + cfg_cls = cls.config_class() if cfg_cls: return TypeAdapter(cfg_cls).validate_python(data) # type: ignore elif data: # pragma: nocover diff --git a/lenskit/lenskit/pipeline/config.py b/lenskit/lenskit/pipeline/config.py index 7b255da2e..187c9d8a3 100644 --- a/lenskit/lenskit/pipeline/config.py +++ b/lenskit/lenskit/pipeline/config.py @@ -18,11 +18,11 @@ from types import FunctionType from typing import Literal, Mapping -from pydantic import BaseModel, Field, JsonValue, ValidationError +from pydantic import BaseModel, Field, JsonValue, TypeAdapter, ValidationError from typing_extensions import Any, Optional, Self from .components import Component -from .nodes import ComponentNode, InputNode +from .nodes import ComponentConstructorNode, ComponentInstanceNode, ComponentNode, InputNode from .types import type_string @@ -40,12 +40,12 @@ class PipelineConfig(BaseModel): "Pipeline metadata." inputs: list[PipelineInput] = Field(default_factory=list) "Pipeline inputs." - defaults: dict[str, str] = Field(default_factory=dict) - "Default pipeline wirings." components: OrderedDict[str, PipelineComponent] = Field(default_factory=OrderedDict) "Pipeline components, with their configurations and wiring." aliases: dict[str, str] = Field(default_factory=dict) "Pipeline node aliases." + default: str | None = None + "The default node for running this pipeline." literals: dict[str, PipelineLiteral] = Field(default_factory=dict) "Literals" @@ -120,16 +120,22 @@ def from_node(cls, node: ComponentNode[Any], mapping: dict[str, str] | None = No if mapping is None: mapping = {} - comp = node.component - if isinstance(comp, FunctionType): - ctype = comp - else: - ctype = comp.__class__ + match node: + case ComponentInstanceNode(_name, comp): + config = None + if isinstance(comp, FunctionType): + ctype = comp + else: + ctype = comp.__class__ + if isinstance(comp, Component): + config = comp.dump_config() + case ComponentConstructorNode(_name, ctype, config): + config = TypeAdapter[Any](ctype.config_class()).dump_python(config, mode="json") + case _: + raise TypeError("unexpected node type") code = f"{ctype.__module__}:{ctype.__qualname__}" - config = comp.dump_config() if isinstance(comp, Component) else None - return cls( code=code, config=config, diff --git a/lenskit/lenskit/pipeline/nodes.py b/lenskit/lenskit/pipeline/nodes.py index 91e7e5971..251158c40 100644 --- a/lenskit/lenskit/pipeline/nodes.py +++ b/lenskit/lenskit/pipeline/nodes.py @@ -5,10 +5,11 @@ # SPDX-License-Identifier: MIT # pyright: strict +from __future__ import annotations import warnings from inspect import Signature, signature -from typing import Any, Callable +from typing import Any, Callable, cast from typing_extensions import Generic, TypeVar @@ -91,6 +92,18 @@ def __init__(self, name: str): super().__init__(name) self.connections = {} + @staticmethod + def create( + name: str, + comp: ComponentConstructor[CFG, ND] | Component[ND] | PipelineFunction[ND], + config: CFG | None = None, + ) -> ComponentNode[ND]: + if isinstance(comp, ComponentConstructor): + comp = cast(ComponentConstructor[CFG, ND], comp) + return ComponentConstructorNode(name, comp, config) + else: + return ComponentInstanceNode(name, comp) + def _setup_signature(self, component: Callable[..., ND]): sig = signature(component, eval_str=True) if sig.return_annotation == Signature.empty: @@ -118,7 +131,7 @@ class ComponentConstructorNode(ComponentNode[ND], Generic[ND]): constructor: ComponentConstructor[Any, ND] config: object | None - def __init__(self, name: str, constructor: ComponentConstructor[CFG, ND], config: CFG): + def __init__(self, name: str, constructor: ComponentConstructor[CFG, ND], config: CFG | None): self.constructor = constructor self.config = config diff --git a/lenskit/lenskit/pipeline/runner.py b/lenskit/lenskit/pipeline/runner.py index b371ffbb3..e12fac779 100644 --- a/lenskit/lenskit/pipeline/runner.py +++ b/lenskit/lenskit/pipeline/runner.py @@ -19,7 +19,7 @@ from ._impl import Pipeline from .components import PipelineFunction -from .nodes import ComponentNode, InputNode, LiteralNode, Node +from .nodes import ComponentInstanceNode, InputNode, LiteralNode, Node from .types import Lazy, is_compatible_data _log = get_logger(__name__) @@ -48,7 +48,7 @@ def __init__(self, pipe: Pipeline, inputs: dict[str, Any]): self.log = _log.bind(pipeline=pipe.name) self.pipe = pipe self.inputs = inputs - self.status = {n.name: "pending" for n in pipe.nodes} + self.status = {n.name: "pending" for n in pipe.nodes()} self.state = {} def run(self, node: Node[Any], *, required: bool = True) -> Any: @@ -88,7 +88,7 @@ def _run_node(self, node: Node[Any], required: bool) -> None: self.state[name] = value case InputNode(name, types=types): self._inject_input(name, types, required) - case ComponentNode(name, comp, inputs, wiring): + case ComponentInstanceNode(name, comp, inputs, wiring): self._run_component(name, comp, inputs, wiring, required) case _: # pragma: nocover raise PipelineError(f"invalid node {node}") diff --git a/lenskit/lenskit/testing/_components.py b/lenskit/lenskit/testing/_components.py index 546c2535f..43107ca4c 100644 --- a/lenskit/lenskit/testing/_components.py +++ b/lenskit/lenskit/testing/_components.py @@ -25,7 +25,7 @@ def test_instantiate_default(self): inst = self.component() assert inst is not None - if self.component._config_class() is not None: + if self.component.config_class() is not None: assert inst.config is not None else: assert inst.config is None diff --git a/lenskit/tests/pipeline/test_component_config.py b/lenskit/tests/pipeline/test_component_config.py index 230847ef5..3c85af4fa 100644 --- a/lenskit/tests/pipeline/test_component_config.py +++ b/lenskit/tests/pipeline/test_component_config.py @@ -61,7 +61,7 @@ def __call__(self, msg: str) -> str: @mark.parametrize("prefixer", [PrefixerDC, PrefixerM, PrefixerPYDC, PrefixerM2]) def test_config_setup(prefixer: type[Component]): - ccls = prefixer._config_class() # type: ignore + ccls = prefixer.config_class() # type: ignore assert ccls is not None comp = prefixer() @@ -92,13 +92,14 @@ def test_pipeline_config(prefixer: ComponentConstructor[Any, str]): pipe.add_component("prefix", prefixer, {"prefix": "scroll named "}, msg=msg) pipe = pipe.build() - assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH" + assert pipe.run("prefix", msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH" - config = pipe.component_configs() + config = pipe.config.components print(json.dumps(config, indent=2)) assert "prefix" in config - assert config["prefix"]["prefix"] == "scroll named " + assert config["prefix"].config + assert config["prefix"].config["prefix"] == "scroll named " @mark.parametrize("prefixer", [PrefixerDC, PrefixerM, PrefixerPYDC]) @@ -109,9 +110,9 @@ def test_pipeline_config_roundtrip(prefixer: type[Component]): msg = pipe.create_input("msg", str) pipe.add_component("prefix", comp, msg=msg) - assert pipe.build().run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH" + assert pipe.build().run("prefix", msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH" - config = pipe.get_config() + config = pipe.build_config() print(config.model_dump_json(indent=2)) p2 = PipelineBuilder.from_config(config) diff --git a/lenskit/tests/pipeline/test_pipeline.py b/lenskit/tests/pipeline/test_pipeline.py index f42b92def..bac071212 100644 --- a/lenskit/tests/pipeline/test_pipeline.py +++ b/lenskit/tests/pipeline/test_pipeline.py @@ -19,7 +19,7 @@ def test_init_empty(): pipe = PipelineBuilder() - assert len(pipe.nodes) == 0 + assert len(pipe.nodes()) == 0 def test_create_input(): @@ -31,7 +31,7 @@ def test_create_input(): assert src.name == "user" assert src.types == set([int, str]) - assert len(pipe.nodes) == 1 + assert len(pipe.nodes()) == 1 assert pipe.node("user") is src @@ -271,6 +271,7 @@ def add(x: int, y: int) -> int: nd = pipe.add_component("double", double, x=a) na = pipe.add_component("add", add, x=nd, y=b) + pipe.default_component(na) nt = pipe.replace_component("double", triple, x=a) @@ -298,7 +299,7 @@ def double(x: int) -> int: def add(x: int, y: int) -> int: return x + y - pipe.set_default("y", b) + pipe.default_connection("y", b) nd = pipe.add_component("double", double, x=a) na = pipe.add_component("add", add, x=nd) @@ -326,6 +327,48 @@ def add(x: int, y: int) -> int: assert pipe.run("double", a=1, b=7) == 2 +def test_run_tuple_name(): + pipe = PipelineBuilder() + a = pipe.create_input("a", int) + b = pipe.create_input("b", int) + + def double(x: int) -> int: + return x * 2 + + def add(x: int, y: int) -> int: + return x + y + + nd = pipe.add_component("double", double, x=a) + pipe.add_component("add", add, x=nd, y=b) + + pipe = pipe.build() + res = pipe.run(("double",), a=1, b=7) + assert isinstance(res, tuple) + assert res[0] == 2 + + +def test_run_tuple_pair(): + pipe = PipelineBuilder() + a = pipe.create_input("a", int) + b = pipe.create_input("b", int) + + def double(x: int) -> int: + return x * 2 + + def add(x: int, y: int) -> int: + return x + y + + nd = pipe.add_component("double", double, x=a) + pipe.add_component("add", add, x=nd, y=b) + + pipe = pipe.build() + res = pipe.run(("double", "add"), a=1, b=7) + assert isinstance(res, tuple) + d, a = res + assert d == 2 + assert a == 9 + + def test_invalid_type(): pipe = PipelineBuilder() a = pipe.create_input("a", int) @@ -475,8 +518,7 @@ def add(x: int, y: int) -> int: def test_pipeline_component_default(): """ - Test that the last *component* is last. It also exercises the warning logic - for missing component types. + Test that the default component is run correctly. """ pipe = PipelineBuilder() a = pipe.create_input("a", int) @@ -486,6 +528,10 @@ def add(x, y): # type: ignore with warns(TypecheckWarning): pipe.add_component("add", add, x=np.arange(10), y=a) # type: ignore + pipe.default_component("add") + + cfg = pipe.build_config() + assert cfg.default == "add" pipe = pipe.build() # the component runs diff --git a/lenskit/tests/pipeline/test_pipeline_clone.py b/lenskit/tests/pipeline/test_pipeline_clone.py index 01378f348..8e810cdba 100644 --- a/lenskit/tests/pipeline/test_pipeline_clone.py +++ b/lenskit/tests/pipeline/test_pipeline_clone.py @@ -41,6 +41,7 @@ def test_pipeline_clone(): pipe = PipelineBuilder() msg = pipe.create_input("msg", str) pipe.add_component("prefix", comp, msg=msg) + pipe.default_component("prefix") pipe = pipe.build() assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH" @@ -62,6 +63,7 @@ def test_pipeline_clone_with_function(): msg = pipe.create_input("msg", str) pfx = pipe.add_component("prefix", comp, msg=msg) pipe.add_component("exclaim", exclaim, msg=pfx) + pipe.default_component("prefix") pipe = pipe.build() assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH!" @@ -78,6 +80,7 @@ def test_pipeline_clone_with_nonconfig_class(): msg = pipe.create_input("msg", str) pfx = pipe.add_component("prefix", comp, msg=msg) pipe.add_component("question", Question(), msg=pfx) + pipe.default_component("prefix") pipe = pipe.build() assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH?" @@ -90,8 +93,9 @@ def test_pipeline_clone_with_nonconfig_class(): def test_clone_defaults(): pipe = PipelineBuilder() msg = pipe.create_input("msg", str) - pipe.set_default("msg", msg) + pipe.default_connection("msg", msg) pipe.add_component("return", exclaim) + pipe.default_component("prefix") pipe = pipe.build() assert pipe.run(msg="hello") == "hello!" @@ -118,7 +122,7 @@ def test_clone_alias(): def test_clone_hash(): pipe = PipelineBuilder() msg = pipe.create_input("msg", str) - pipe.set_default("msg", msg) + pipe.default_connection("msg", msg) excl = pipe.add_component("exclaim", exclaim) pipe.alias("return", excl) @@ -128,4 +132,4 @@ def test_clone_hash(): p2 = pipe.clone() assert p2.run("return", msg="hello") == "hello!" - assert p2.config_hash() == pipe.config_hash() + assert p2.config_hash == pipe.config_hash diff --git a/lenskit/tests/pipeline/test_save_load.py b/lenskit/tests/pipeline/test_save_load.py index 99704256e..9355e2225 100644 --- a/lenskit/tests/pipeline/test_save_load.py +++ b/lenskit/tests/pipeline/test_save_load.py @@ -63,7 +63,7 @@ def test_serialize_input(): pipe = PipelineBuilder("test") pipe.create_input("user", int, str) - cfg = pipe.get_config() + cfg = pipe.build_config() print(cfg) assert cfg.meta.name == "test" assert len(cfg.inputs) == 1 @@ -76,7 +76,7 @@ def test_round_trip_input(): pipe = PipelineBuilder() pipe.create_input("user", int, str) - cfg = pipe.get_config() + cfg = pipe.build_config() print(cfg) p2 = PipelineBuilder.from_config(cfg) @@ -91,7 +91,7 @@ def test_round_trip_optional_input(): pipe = PipelineBuilder() pipe.create_input("user", int, str, None) - cfg = pipe.get_config() + cfg = pipe.build_config() assert cfg.inputs[0].types == {"int", "str", "None"} p2 = PipelineBuilder.from_config(cfg) @@ -107,7 +107,7 @@ def test_config_single_node(): pipe.add_component("return", msg_ident, msg=msg) - cfg = pipe.get_config() + cfg = pipe.build_config() assert len(cfg.inputs) == 1 assert len(cfg.components) == 1 @@ -124,10 +124,10 @@ def test_round_trip_single_node(): pipe.add_component("return", msg_ident, msg=msg) - cfg = pipe.get_config() + cfg = pipe.build_config() p2 = PipelineBuilder.from_config(cfg) - assert len(p2.nodes) == 2 + assert len(p2.nodes()) == 2 r2 = p2.node("return") assert isinstance(r2, ComponentNode) assert r2.component is msg_ident @@ -144,11 +144,11 @@ def test_configurable_component(): pfx = Prefixer(prefix="scroll named ") pipe.add_component("prefix", pfx, msg=msg) - cfg = pipe.get_config() + cfg = pipe.build_config() assert cfg.components["prefix"].config == {"prefix": "scroll named "} p2 = PipelineBuilder.from_config(cfg) - assert len(p2.nodes) == 2 + assert len(p2.nodes()) == 2 r2 = p2.node("prefix") assert isinstance(r2, ComponentNode) assert isinstance(r2.component, Prefixer) @@ -160,16 +160,16 @@ def test_configurable_component(): print("hash:", pipe.config_hash()) assert pipe.config_hash() is not None - assert p2.config_hash() == pipe.config_hash() + assert p2.config_hash == pipe.config_hash def test_save_defaults(): pipe = PipelineBuilder() msg = pipe.create_input("msg", str) - pipe.set_default("msg", msg) + pipe.default_connection("msg", msg) pipe.add_component("return", msg_ident) - cfg = pipe.get_config() + cfg = pipe.build_config() pipe = pipe.build() assert pipe.run(msg="hello") == "hello" @@ -199,7 +199,7 @@ def test_hashes_different(): _log.info("p1 stage 2 hash: %s", p1.config_hash()) _log.info("p2 stage 2 hash: %s", p2.config_hash()) assert p1.config_hash() != p2.config_hash() - assert p1.build().config_hash() != p2.build().config_hash() + assert p1.build().config_hash != p2.build().config_hash def test_save_with_fallback(): @@ -212,7 +212,7 @@ def test_save_with_fallback(): fb = pipe.use_first_of("fill-operand", b, nn) pipe.add_component("add", add, x=nd, y=fb) - cfg = pipe.get_config() + cfg = pipe.build_config() json = cfg.model_dump_json(exclude_none=True) print(json) c2 = PipelineConfig.model_validate_json(json) @@ -221,7 +221,7 @@ def test_save_with_fallback(): p2 = p2.build() # 3 * 2 + -3 = 3 - assert p2.run("fill-operand", "add", a=3) == (-3, 3) + assert p2.run(("fill-operand", "add"), a=3) == (-3, 3) def test_hash_validate(): @@ -231,7 +231,7 @@ def test_hash_validate(): pfx = Prefixer(prefix="scroll named ") pipe.add_component("prefix", pfx, msg=msg) - cfg = pipe.get_config() + cfg = pipe.build_config() print("initial config:", cfg.model_dump_json(indent=2)) assert cfg.meta.hash is not None cfg.components["prefix"].config["prefix"] = "scroll called " # type: ignore @@ -248,7 +248,7 @@ def test_alias_input(): pipe.alias("person", user) - cfg = pipe.get_config() + cfg = pipe.build_config() p2 = PipelineBuilder.from_config(cfg) p2 = p2.build() @@ -280,7 +280,7 @@ def test_literal(): pipe = pipe.build() assert pipe.run(msg="HACKEM MUCHE") == "hello, HACKEM MUCHE" - print(pipe.get_config().model_dump_json(indent=2)) + print(pipe.config.model_dump_json(indent=2)) p2 = pipe.clone("pipeline-config") assert p2.run(msg="FOOBIE BLETCH") == "hello, FOOBIE BLETCH" @@ -295,7 +295,7 @@ def test_literal_array(): res = pipe.run(a=5) assert np.all(res == np.arange(5, 15)) - print(pipe.get_config().model_dump_json(indent=2)) + print(pipe.config.model_dump_json(indent=2)) p2 = pipe.clone("pipeline-config") assert np.all(p2.run(a=5) == np.arange(5, 15)) @@ -311,4 +311,4 @@ def test_stable_with_literals(): p2.add_component("add", add, x=np.arange(10), y=a) assert p1.config_hash() == p2.config_hash() - assert p1.build().config_hash() == p2.build().config_hash() + assert p1.build().config_hash == p2.build().config_hash diff --git a/lenskit/tests/pipeline/test_train.py b/lenskit/tests/pipeline/test_train.py index 04fa7bdf4..5ee4e0fc8 100644 --- a/lenskit/tests/pipeline/test_train.py +++ b/lenskit/tests/pipeline/test_train.py @@ -18,6 +18,7 @@ def test_train(ml_ds: Dataset): tc: Trainable = TestComponent() pipe.add_component("test", tc, item=item) + pipe.default_component("test") pipe = pipe.build() pipe.train(ml_ds) From 8b1e70adfa957430bfba548e79524acfaca200f4 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 18:32:44 -0500 Subject: [PATCH 08/37] deal with several default component problems --- lenskit/lenskit/pipeline/_impl.py | 3 +++ lenskit/tests/pipeline/test_component_config.py | 1 + lenskit/tests/pipeline/test_pipeline.py | 4 +++- lenskit/tests/pipeline/test_save_load.py | 6 ++++-- 4 files changed, 11 insertions(+), 3 deletions(-) diff --git a/lenskit/lenskit/pipeline/_impl.py b/lenskit/lenskit/pipeline/_impl.py index e8849fbf8..608f97ae8 100644 --- a/lenskit/lenskit/pipeline/_impl.py +++ b/lenskit/lenskit/pipeline/_impl.py @@ -203,6 +203,9 @@ def clone(self, how: CloneMethod = "config") -> Pipeline: case _: pass + if self._default: + clone.default_component(self._default.name) + return clone.build() @property diff --git a/lenskit/tests/pipeline/test_component_config.py b/lenskit/tests/pipeline/test_component_config.py index 3c85af4fa..a275ca456 100644 --- a/lenskit/tests/pipeline/test_component_config.py +++ b/lenskit/tests/pipeline/test_component_config.py @@ -109,6 +109,7 @@ def test_pipeline_config_roundtrip(prefixer: type[Component]): pipe = PipelineBuilder() msg = pipe.create_input("msg", str) pipe.add_component("prefix", comp, msg=msg) + pipe.default_component("prefix") assert pipe.build().run("prefix", msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH" diff --git a/lenskit/tests/pipeline/test_pipeline.py b/lenskit/tests/pipeline/test_pipeline.py index bac071212..166bdfb5d 100644 --- a/lenskit/tests/pipeline/test_pipeline.py +++ b/lenskit/tests/pipeline/test_pipeline.py @@ -221,6 +221,7 @@ def test_simple_graph(): pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) + pipe.default_component("b") def double(x: int) -> int: return x * 2 @@ -303,6 +304,7 @@ def add(x: int, y: int) -> int: nd = pipe.add_component("double", double, x=a) na = pipe.add_component("add", add, x=nd) + pipe.default_component(na) pipe = pipe.build() assert pipe.run(a=1, b=7) == 9 @@ -385,7 +387,7 @@ def add(x: int, y: int) -> int: pipe = pipe.build() with raises(TypeError): - pipe.run(a=1, b="seven") + pipe.run("add", a=1, b="seven") def test_run_by_alias(): diff --git a/lenskit/tests/pipeline/test_save_load.py b/lenskit/tests/pipeline/test_save_load.py index 9355e2225..ae7a5c61b 100644 --- a/lenskit/tests/pipeline/test_save_load.py +++ b/lenskit/tests/pipeline/test_save_load.py @@ -158,8 +158,8 @@ def test_configurable_component(): p2 = p2.build() assert p2.run("prefix", msg="HACKEM MUCHE") == "scroll named HACKEM MUCHE" - print("hash:", pipe.config_hash()) - assert pipe.config_hash() is not None + print("hash:", pipe.config_hash) + assert pipe.config_hash is not None assert p2.config_hash == pipe.config_hash @@ -276,6 +276,7 @@ def test_literal(): msg = pipe.create_input("msg", str) pipe.add_component("prefix", msg_prefix, prefix=pipe.literal("hello, "), msg=msg) + pipe.default_component("prefix") pipe = pipe.build() assert pipe.run(msg="HACKEM MUCHE") == "hello, HACKEM MUCHE" @@ -290,6 +291,7 @@ def test_literal_array(): a = pipe.create_input("a", int) pipe.add_component("add", add, x=np.arange(10), y=a) + pipe.default_component("prefix") pipe = pipe.build() res = pipe.run(a=5) From 3164bcc3e89a7e00cd9353c94f919b428d17d95f Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 18:33:48 -0500 Subject: [PATCH 09/37] fix config load default in builder --- lenskit/lenskit/pipeline/builder.py | 20 +++++++++++--------- 1 file changed, 11 insertions(+), 9 deletions(-) diff --git a/lenskit/lenskit/pipeline/builder.py b/lenskit/lenskit/pipeline/builder.py index d2368603f..182e634a5 100644 --- a/lenskit/lenskit/pipeline/builder.py +++ b/lenskit/lenskit/pipeline/builder.py @@ -533,12 +533,12 @@ def from_config(cls, config: object) -> Self: not have a matching hash. """ cfg = PipelineConfig.model_validate(config) - pipe = cls() + builder = cls() for inpt in cfg.inputs: types: list[type[Any] | None] = [] if inpt.types is not None: types += [parse_type_string(t) for t in inpt.types] - pipe.create_input(inpt.name, *types) + builder.create_input(inpt.name, *types) # we now add the components and other nodes in multiple passes to ensure # that nodes are available before they are wired (since `connect` can @@ -546,7 +546,7 @@ def from_config(cls, config: object) -> Self: # pass 1: add literals for name, data in cfg.literals.items(): - pipe.literal(data.decode(), name=name) + builder.literal(data.decode(), name=name) # pass 2: add components to_wire: list[config.PipelineComponent] = [] @@ -556,28 +556,30 @@ def from_config(cls, config: object) -> Self: continue obj = instantiate_component(comp.code, comp.config) - pipe.add_component(name, obj) + builder.add_component(name, obj) to_wire.append(comp) # pass 3: wiring for name, comp in cfg.components.items(): if isinstance(comp.inputs, dict): - inputs = {n: pipe.node(t) for (n, t) in comp.inputs.items()} - pipe.connect(name, **inputs) + inputs = {n: builder.node(t) for (n, t) in comp.inputs.items()} + builder.connect(name, **inputs) elif not comp.code.startswith("@"): raise PipelineError(f"component {name} inputs must be dict, not list") # pass 4: aliases for n, t in cfg.aliases.items(): - pipe.alias(n, t) + builder.alias(n, t) + + builder._default = cfg.default if cfg.meta.hash is not None: - h2 = pipe.config_hash() + h2 = builder.config_hash() if h2 != cfg.meta.hash: _log.warning("loaded pipeline does not match hash") warnings.warn("loaded pipeline config does not match hash", PipelineWarning) - return pipe + return builder def use_first_of(self, name: str, primary: Node[T | None], fallback: Node[T]) -> Node[T]: """ From 62b5e4aab4e162faf624d83ac066dc9e95056bd5 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 18:34:37 -0500 Subject: [PATCH 10/37] stray defaults in save/load --- lenskit/tests/pipeline/test_pipeline.py | 1 + lenskit/tests/pipeline/test_save_load.py | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/lenskit/tests/pipeline/test_pipeline.py b/lenskit/tests/pipeline/test_pipeline.py index 166bdfb5d..e6b19d7f8 100644 --- a/lenskit/tests/pipeline/test_pipeline.py +++ b/lenskit/tests/pipeline/test_pipeline.py @@ -204,6 +204,7 @@ def triple(x: int) -> int: ni = pipe.add_component("incr", incr, x=x) nt = pipe.add_component("triple", triple, x=ni) + pipe.default_component(nt) pipe = pipe.build() # run default pipe diff --git a/lenskit/tests/pipeline/test_save_load.py b/lenskit/tests/pipeline/test_save_load.py index ae7a5c61b..d6ba7eda4 100644 --- a/lenskit/tests/pipeline/test_save_load.py +++ b/lenskit/tests/pipeline/test_save_load.py @@ -163,11 +163,12 @@ def test_configurable_component(): assert p2.config_hash == pipe.config_hash -def test_save_defaults(): +def test_save_with_defaults(): pipe = PipelineBuilder() msg = pipe.create_input("msg", str) pipe.default_connection("msg", msg) pipe.add_component("return", msg_ident) + pipe.default_component("return") cfg = pipe.build_config() From 8e5f40ca24e3f1df20d824b4dbee10936a9748ba Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 18:38:01 -0500 Subject: [PATCH 11/37] fix several type errors and missing defaults --- lenskit/lenskit/pipeline/_impl.py | 10 ++++++++-- lenskit/lenskit/pipeline/builder.py | 4 ++-- lenskit/tests/pipeline/test_pipeline.py | 2 +- lenskit/tests/pipeline/test_save_load.py | 6 +++--- 4 files changed, 14 insertions(+), 8 deletions(-) diff --git a/lenskit/lenskit/pipeline/_impl.py b/lenskit/lenskit/pipeline/_impl.py index 608f97ae8..431b5e5cb 100644 --- a/lenskit/lenskit/pipeline/_impl.py +++ b/lenskit/lenskit/pipeline/_impl.py @@ -1,6 +1,7 @@ # pyright: strict from __future__ import annotations +from collections.abc import Iterable from dataclasses import replace from uuid import NAMESPACE_URL, uuid5 @@ -59,9 +60,14 @@ class Pipeline: _default: Node[Any] | None = None _hash: str | None = None - def __init__(self, config: config.PipelineConfig, nodes: dict[str, Node[Any]]): + def __init__(self, config: config.PipelineConfig, nodes: Iterable[Node[Any]]): + self._nodes = {} + for node in nodes: + if isinstance(node, ComponentInstanceNode): + raise RuntimeError("pipeline is not fully instantiated") + self._nodes[node.name] = node + self._config = config - self._nodes = dict(nodes) self._aliases = {a: self.node(t) for (a, t) in config.aliases.items()} if config.default: self._default = self.node(config.default) diff --git a/lenskit/lenskit/pipeline/builder.py b/lenskit/lenskit/pipeline/builder.py index 182e634a5..b2f321aa6 100644 --- a/lenskit/lenskit/pipeline/builder.py +++ b/lenskit/lenskit/pipeline/builder.py @@ -156,7 +156,7 @@ def node( else: raise KeyError(f"node {node}") - def create_input(self, name: str, *types: type[T] | None) -> Node[T]: + def create_input(self, name: str, *types: type[T] | UnionType | None) -> Node[T]: """ Create an input node for the pipeline. Pipelines expect their inputs to be provided when they are run. @@ -638,7 +638,7 @@ def build(self) -> Pipeline: Build the pipeline. """ config = self.build_config() - return Pipeline(config, self._nodes) + return Pipeline(config, self._nodes.values()) def _instantiate(self, node: Node[ND]) -> Node[ND]: match node: diff --git a/lenskit/tests/pipeline/test_pipeline.py b/lenskit/tests/pipeline/test_pipeline.py index e6b19d7f8..9204a9867 100644 --- a/lenskit/tests/pipeline/test_pipeline.py +++ b/lenskit/tests/pipeline/test_pipeline.py @@ -436,7 +436,7 @@ def add(x: int, y: int) -> int: assert state.meta is not None assert state.meta.name == "test" assert state.meta.version == "7.2" - assert state.meta.hash == pipe.config_hash() + assert state.meta.hash == pipe.config_hash def test_run_all_limit(): diff --git a/lenskit/tests/pipeline/test_save_load.py b/lenskit/tests/pipeline/test_save_load.py index d6ba7eda4..2a2aadcae 100644 --- a/lenskit/tests/pipeline/test_save_load.py +++ b/lenskit/tests/pipeline/test_save_load.py @@ -17,7 +17,7 @@ from lenskit.pipeline import PipelineBuilder, PipelineWarning from lenskit.pipeline.components import Component from lenskit.pipeline.config import PipelineConfig -from lenskit.pipeline.nodes import ComponentNode, InputNode +from lenskit.pipeline.nodes import ComponentInstanceNode, ComponentNode, InputNode _log = logging.getLogger(__name__) @@ -129,7 +129,7 @@ def test_round_trip_single_node(): p2 = PipelineBuilder.from_config(cfg) assert len(p2.nodes()) == 2 r2 = p2.node("return") - assert isinstance(r2, ComponentNode) + assert isinstance(r2, ComponentInstanceNode) assert r2.component is msg_ident assert r2.connections == {"msg": "msg"} @@ -150,7 +150,7 @@ def test_configurable_component(): p2 = PipelineBuilder.from_config(cfg) assert len(p2.nodes()) == 2 r2 = p2.node("prefix") - assert isinstance(r2, ComponentNode) + assert isinstance(r2, ComponentInstanceNode) assert isinstance(r2.component, Prefixer) assert r2.component is not pfx assert r2.connections == {"msg": "msg"} From c51eb7ba5311296cf0b0b221f1e2dca2d67f1ab7 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 18:40:49 -0500 Subject: [PATCH 12/37] fix incorrect name --- lenskit/tests/pipeline/test_save_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lenskit/tests/pipeline/test_save_load.py b/lenskit/tests/pipeline/test_save_load.py index 2a2aadcae..eaf495b48 100644 --- a/lenskit/tests/pipeline/test_save_load.py +++ b/lenskit/tests/pipeline/test_save_load.py @@ -292,7 +292,7 @@ def test_literal_array(): a = pipe.create_input("a", int) pipe.add_component("add", add, x=np.arange(10), y=a) - pipe.default_component("prefix") + pipe.default_component("add") pipe = pipe.build() res = pipe.run(a=5) From 72be51a885b549f7ac2632923b710b88bc27c11d Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 18:42:03 -0500 Subject: [PATCH 13/37] fix pipeline simple graph test --- lenskit/tests/pipeline/test_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lenskit/tests/pipeline/test_pipeline.py b/lenskit/tests/pipeline/test_pipeline.py index 9204a9867..e81d9d072 100644 --- a/lenskit/tests/pipeline/test_pipeline.py +++ b/lenskit/tests/pipeline/test_pipeline.py @@ -222,7 +222,6 @@ def test_simple_graph(): pipe = PipelineBuilder() a = pipe.create_input("a", int) b = pipe.create_input("b", int) - pipe.default_component("b") def double(x: int) -> int: return x * 2 @@ -232,6 +231,7 @@ def add(x: int, y: int) -> int: nd = pipe.add_component("double", double, x=a) na = pipe.add_component("add", add, x=nd, y=b) + pipe.default_component("add") pipe = pipe.build() assert pipe.run(a=1, b=7) == 9 From 316c3246e28ac6a20d1a166f96a38b4df10e058b Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 19:06:57 -0500 Subject: [PATCH 14/37] implement utility functions for component input/output --- lenskit/lenskit/pipeline/_impl.py | 17 +++- lenskit/lenskit/pipeline/components.py | 44 +++++++++- lenskit/lenskit/pipeline/nodes.py | 50 +++++------- lenskit/lenskit/pipeline/runner.py | 14 ++-- lenskit/tests/pipeline/test_component_util.py | 80 +++++++++++++++++++ 5 files changed, 159 insertions(+), 46 deletions(-) create mode 100644 lenskit/tests/pipeline/test_component_util.py diff --git a/lenskit/lenskit/pipeline/_impl.py b/lenskit/lenskit/pipeline/_impl.py index 431b5e5cb..f86ab119a 100644 --- a/lenskit/lenskit/pipeline/_impl.py +++ b/lenskit/lenskit/pipeline/_impl.py @@ -16,7 +16,14 @@ from . import config from .components import Component from .config import PipelineConfig -from .nodes import ComponentInstanceNode, ComponentNode, InputNode, LiteralNode, Node +from .nodes import ( + ComponentConstructorNode, + ComponentInstanceNode, + ComponentNode, + InputNode, + LiteralNode, + Node, +) from .state import PipelineState _log = get_logger(__name__) @@ -63,12 +70,14 @@ class Pipeline: def __init__(self, config: config.PipelineConfig, nodes: Iterable[Node[Any]]): self._nodes = {} for node in nodes: - if isinstance(node, ComponentInstanceNode): + if isinstance(node, ComponentConstructorNode): raise RuntimeError("pipeline is not fully instantiated") self._nodes[node.name] = node self._config = config - self._aliases = {a: self.node(t) for (a, t) in config.aliases.items()} + self._aliases = {} + for a, t in config.aliases.items(): + self._aliases[a] = self.node(t) if config.default: self._default = self.node(config.default) @@ -146,7 +155,7 @@ def node( elif missing == "none" or missing is None: return None else: - raise KeyError(f"node {node}") + raise KeyError(node) def clone(self, how: CloneMethod = "config") -> Pipeline: """ diff --git a/lenskit/lenskit/pipeline/components.py b/lenskit/lenskit/pipeline/components.py index 5ea4c46a5..948f25536 100644 --- a/lenskit/lenskit/pipeline/components.py +++ b/lenskit/lenskit/pipeline/components.py @@ -13,7 +13,7 @@ import warnings from abc import ABC, abstractmethod from importlib import import_module -from inspect import isabstract +from inspect import isabstract, signature from types import FunctionType, NoneType from pydantic import JsonValue, TypeAdapter @@ -31,7 +31,7 @@ runtime_checkable, ) -from .types import Lazy +from .types import Lazy, TypecheckWarning P = ParamSpec("P") T = TypeVar("T") @@ -267,6 +267,46 @@ def instantiate_component( return comp() # type: ignore +def component_inputs( + component: Component[COut] | ComponentConstructor[Any, COut] | PipelineFunction[COut], +) -> dict[str, type | None]: + if isinstance(component, (Component, type)): + function = component.__call__ + else: + function = component + + types = get_type_hints(function) + sig = signature(function) + + inputs: dict[str, type | None] = {} + for param in sig.parameters.values(): + if param.name == "self": + continue + + if pt := types.get(param.name, None): + inputs[param.name] = pt + else: + warnings.warn( + f"parameter {param.name} of component {component} has no type annotation", + TypecheckWarning, + 2, + ) + inputs[param.name] = None + + return inputs + + +def component_return_type( + component: Component[COut] | ComponentConstructor[Any, COut] | PipelineFunction[COut], +) -> type | None: + if isinstance(component, (Component, type)): + types = get_type_hints(component.__call__) + else: + types = get_type_hints(component) + print(types) + return types.get("return", None) + + def fallback_on_none(primary: T | None, fallback: Lazy[T]) -> T: """ Fallback to a second component if the primary input is `None`. diff --git a/lenskit/lenskit/pipeline/nodes.py b/lenskit/lenskit/pipeline/nodes.py index 251158c40..3ac42b11e 100644 --- a/lenskit/lenskit/pipeline/nodes.py +++ b/lenskit/lenskit/pipeline/nodes.py @@ -7,14 +7,12 @@ # pyright: strict from __future__ import annotations -import warnings -from inspect import Signature, signature -from typing import Any, Callable, cast +from abc import abstractmethod +from typing import Any, cast from typing_extensions import Generic, TypeVar -from .components import Component, ComponentConstructor, PipelineFunction -from .types import TypecheckWarning +from .components import Component, ComponentConstructor, PipelineFunction, component_inputs # Nodes are (conceptually) immutable data containers, so Node[U] can be assigned # to Node[T] if U ≼ T. @@ -80,10 +78,7 @@ class ComponentNode(Node[ND], Generic[ND]): Internal """ - __match_args__ = ("name", "inputs", "connections") - - inputs: dict[str, type | None] - "The component's inputs." + __match_args__ = ("name", "connections") connections: dict[str, str] "The component's input connections." @@ -104,30 +99,14 @@ def create( else: return ComponentInstanceNode(name, comp) - def _setup_signature(self, component: Callable[..., ND]): - sig = signature(component, eval_str=True) - if sig.return_annotation == Signature.empty: - warnings.warn( - f"component {component} has no return type annotation", TypecheckWarning, 2 - ) - else: - self.types = set([sig.return_annotation]) - - self.inputs = {} - for param in sig.parameters.values(): - if param.annotation == Signature.empty: - warnings.warn( - f"parameter {param.name} of component {component} has no type annotation", - TypecheckWarning, - 2, - ) - self.inputs[param.name] = None - else: - self.inputs[param.name] = param.annotation + @property + @abstractmethod + def inputs(self) -> dict[str, type | None]: # pragma: nocover + raise NotImplementedError() class ComponentConstructorNode(ComponentNode[ND], Generic[ND]): - __match_args__ = ("name", "constructor", "config", "inputs", "connections") + __match_args__ = ("name", "constructor", "config", "connections") constructor: ComponentConstructor[Any, ND] config: object | None @@ -135,9 +114,13 @@ def __init__(self, name: str, constructor: ComponentConstructor[CFG, ND], config self.constructor = constructor self.config = config + @property + def inputs(self): + return component_inputs(self.constructor) + class ComponentInstanceNode(ComponentNode[ND], Generic[ND]): - __match_args__ = ("name", "component", "inputs", "connections") + __match_args__ = ("name", "component", "connections") component: Component[ND] | PipelineFunction[ND] @@ -149,5 +132,8 @@ def __init__( ): super().__init__(name) self.component = component - self._setup_signature(component) self.connections = connections or {} + + @property + def inputs(self): + return component_inputs(self.component) diff --git a/lenskit/lenskit/pipeline/runner.py b/lenskit/lenskit/pipeline/runner.py index e12fac779..03f8432b2 100644 --- a/lenskit/lenskit/pipeline/runner.py +++ b/lenskit/lenskit/pipeline/runner.py @@ -18,7 +18,7 @@ from lenskit.logging import get_logger, trace from ._impl import Pipeline -from .components import PipelineFunction +from .components import PipelineFunction, component_inputs from .nodes import ComponentInstanceNode, InputNode, LiteralNode, Node from .types import Lazy, is_compatible_data @@ -88,8 +88,8 @@ def _run_node(self, node: Node[Any], required: bool) -> None: self.state[name] = value case InputNode(name, types=types): self._inject_input(name, types, required) - case ComponentInstanceNode(name, comp, inputs, wiring): - self._run_component(name, comp, inputs, wiring, required) + case ComponentInstanceNode(name, comp, wiring): + self._run_component(name, comp, wiring, required) case _: # pragma: nocover raise PipelineError(f"invalid node {node}") @@ -108,20 +108,18 @@ def _run_component( self, name: str, comp: PipelineFunction[Any], - inputs: dict[str, type | None], wiring: dict[str, str], required: bool, ) -> None: in_data = {} log = self.log.bind(node=name) trace(log, "processing inputs") + inputs = component_inputs(comp) for iname, itype in inputs.items(): # look up the input wiring for this parameter input - src = wiring.get(iname, None) - if src is not None: + snode = None + if src := wiring.get(iname, None): snode = self.pipe.node(src) - else: - snode = self.pipe.get_default(iname) # check if this is a lazy node lazy = False diff --git a/lenskit/tests/pipeline/test_component_util.py b/lenskit/tests/pipeline/test_component_util.py new file mode 100644 index 000000000..cf6ad3c37 --- /dev/null +++ b/lenskit/tests/pipeline/test_component_util.py @@ -0,0 +1,80 @@ +# pyright: strict +from dataclasses import dataclass + +from lenskit.pipeline.components import Component, component_inputs, component_return_type + + +@dataclass +class TestConfig: + suffix: str = "" + + +class TestComp(Component): + config: TestConfig + + def __call__(self, msg: str) -> str: + return msg + self.config.suffix + + +def test_empty_input(): + def func() -> int: + return 9 + + inputs = component_inputs(func) + assert not inputs + + +def test_single_function_input(): + def func(x: int) -> int: + return 9 + x + + inputs = component_inputs(func) + assert len(inputs) == 1 + assert inputs["x"] is int + + +def test_component_class_input(): + inputs = component_inputs(TestComp) + assert len(inputs) == 1 + assert inputs["msg"] is str + + +def test_component_object_input(): + inputs = component_inputs(TestComp()) + assert len(inputs) == 1 + assert inputs["msg"] is str + + +def test_component_unknown_input(): + def func(x) -> int: # type: ignore + return x + 5 # type: ignore + + inputs = component_inputs(func) # type: ignore + assert len(inputs) == 1 + assert inputs["x"] is None + + +def test_function_return(): + def func(x: int) -> int: + return x + 5 + + rt = component_return_type(func) + assert rt is int + + +def test_class_return(): + rt = component_return_type(TestComp) + assert rt is str + + +def test_instance_return(): + rt = component_return_type(TestComp()) + assert rt is str + + +def test_unknown_return(): + def func(): + pass + + rt = component_return_type(func) + assert rt is None From 3303001e1afd52d646170d72c4b1a9ba15b2630a Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 19:17:34 -0500 Subject: [PATCH 15/37] name fix --- lenskit/tests/pipeline/test_pipeline_clone.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lenskit/tests/pipeline/test_pipeline_clone.py b/lenskit/tests/pipeline/test_pipeline_clone.py index 8e810cdba..c6023efdd 100644 --- a/lenskit/tests/pipeline/test_pipeline_clone.py +++ b/lenskit/tests/pipeline/test_pipeline_clone.py @@ -95,7 +95,7 @@ def test_clone_defaults(): msg = pipe.create_input("msg", str) pipe.default_connection("msg", msg) pipe.add_component("return", exclaim) - pipe.default_component("prefix") + pipe.default_component("return") pipe = pipe.build() assert pipe.run(msg="hello") == "hello!" From cb0ec1a1a61a0d378153b2d1ed9e8e3cd122757d Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 19:17:39 -0500 Subject: [PATCH 16/37] ugly input resolution --- lenskit/lenskit/pipeline/builder.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/lenskit/lenskit/pipeline/builder.py b/lenskit/lenskit/pipeline/builder.py index b2f321aa6..fefd95798 100644 --- a/lenskit/lenskit/pipeline/builder.py +++ b/lenskit/lenskit/pipeline/builder.py @@ -84,7 +84,7 @@ class PipelineBuilder: _nodes: dict[str, Node[Any]] _aliases: dict[str, Node[Any]] - _default_connections: dict[str, Node[Any]] + _default_connections: dict[str, str] _components: dict[str, PipelineFunction[Any] | Component[Any]] _default: str | None = None _anon_nodes: set[str] @@ -230,7 +230,7 @@ def default_connection(self, name: str, node: Node[Any] | object) -> None: """ if not isinstance(node, Node): node = self.literal(node) - self._default_connections[name] = node + self._default_connections[name] = node.name def default_component(self, node: str | Node[Any]) -> None: """ @@ -421,7 +421,7 @@ def clone(self) -> PipelineBuilder: pass for n, t in self._default_connections.items(): - clone.default_connection(n, clone.node(t.name)) + clone.default_connection(n, clone.node(t)) return clone @@ -445,6 +445,13 @@ def build_config(self, *, include_hash: bool = True) -> PipelineConfig: meta = self.meta(include_hash=False) cfg = PipelineConfig(meta=meta) + # FIXME: don't mutate + for node in self._nodes.values(): + if isinstance(node, ComponentNode): + for iname in node.inputs.keys(): + if iname not in node.connections and iname in self._default_connections: + node.connections[iname] = self._default_connections[iname] + # We map anonymous nodes to hash-based names for stability. If we ever # allow anonymous components, this will need to be adjusted to maintain # component ordering, but it works for now since only literals can be From 4b4eb920eb32a0384dc604d0693a49f87b4c3adc Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 19:17:49 -0500 Subject: [PATCH 17/37] more input tracing --- lenskit/lenskit/pipeline/runner.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/lenskit/lenskit/pipeline/runner.py b/lenskit/lenskit/pipeline/runner.py index 03f8432b2..f87c33757 100644 --- a/lenskit/lenskit/pipeline/runner.py +++ b/lenskit/lenskit/pipeline/runner.py @@ -116,9 +116,12 @@ def _run_component( trace(log, "processing inputs") inputs = component_inputs(comp) for iname, itype in inputs.items(): + ilog = log.bind(input_name=iname, input_type=itype) + trace(ilog, "resolving input") # look up the input wiring for this parameter input snode = None if src := wiring.get(iname, None): + trace(ilog, "resolving from wiring") snode = self.pipe.node(src) # check if this is a lazy node From 6cd8764106b0e2d12b5e4c785d622e46df81165d Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 19:18:01 -0500 Subject: [PATCH 18/37] fix config_hash invocation --- lenskit/tests/pipeline/test_save_load.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lenskit/tests/pipeline/test_save_load.py b/lenskit/tests/pipeline/test_save_load.py index eaf495b48..a1f1a2d21 100644 --- a/lenskit/tests/pipeline/test_save_load.py +++ b/lenskit/tests/pipeline/test_save_load.py @@ -160,7 +160,7 @@ def test_configurable_component(): print("hash:", pipe.config_hash) assert pipe.config_hash is not None - assert p2.config_hash == pipe.config_hash + assert p2.config_hash == pipe.config_hash() def test_save_with_defaults(): From 9f21a6818627fdb8851c717127809f6100412bbb Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 19:24:18 -0500 Subject: [PATCH 19/37] handle callable objects (not recommended) --- lenskit/lenskit/pipeline/components.py | 19 ++++++++----- lenskit/tests/pipeline/test_component_util.py | 27 +++++++++++++++++++ 2 files changed, 39 insertions(+), 7 deletions(-) diff --git a/lenskit/lenskit/pipeline/components.py b/lenskit/lenskit/pipeline/components.py index 948f25536..e25f8ae6b 100644 --- a/lenskit/lenskit/pipeline/components.py +++ b/lenskit/lenskit/pipeline/components.py @@ -270,10 +270,12 @@ def instantiate_component( def component_inputs( component: Component[COut] | ComponentConstructor[Any, COut] | PipelineFunction[COut], ) -> dict[str, type | None]: - if isinstance(component, (Component, type)): - function = component.__call__ - else: + if isinstance(component, FunctionType): function = component + elif hasattr(component, "__call__"): + function = getattr(component, "__call__") + else: + raise TypeError("invalid component " + repr(component)) types = get_type_hints(function) sig = signature(function) @@ -299,11 +301,14 @@ def component_inputs( def component_return_type( component: Component[COut] | ComponentConstructor[Any, COut] | PipelineFunction[COut], ) -> type | None: - if isinstance(component, (Component, type)): - types = get_type_hints(component.__call__) + if isinstance(component, FunctionType): + function = component + elif hasattr(component, "__call__"): + function = getattr(component, "__call__") else: - types = get_type_hints(component) - print(types) + raise TypeError("invalid component " + repr(component)) + + types = get_type_hints(function) return types.get("return", None) diff --git a/lenskit/tests/pipeline/test_component_util.py b/lenskit/tests/pipeline/test_component_util.py index cf6ad3c37..65be731e0 100644 --- a/lenskit/tests/pipeline/test_component_util.py +++ b/lenskit/tests/pipeline/test_component_util.py @@ -16,6 +16,11 @@ def __call__(self, msg: str) -> str: return msg + self.config.suffix +class CallObj: + def __call__(self, q: str) -> bytes: + return q.encode() + + def test_empty_input(): def func() -> int: return 9 @@ -54,6 +59,18 @@ def func(x) -> int: # type: ignore assert inputs["x"] is None +def test_callable_object_input(): + inputs = component_inputs(CallObj()) + assert len(inputs) == 1 + assert inputs["q"] is str + + +def test_callable_class_input(): + inputs = component_inputs(CallObj) + assert len(inputs) == 1 + assert inputs["q"] is str + + def test_function_return(): def func(x: int) -> int: return x + 5 @@ -78,3 +95,13 @@ def func(): rt = component_return_type(func) assert rt is None + + +def test_callable_object_return(): + rt = component_return_type(CallObj()) + assert rt is bytes + + +def test_callable_class_return(): + rt = component_return_type(CallObj) + assert rt is bytes From 0e0314ec082a44817353520e5b07573bccfb7dda Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 19:44:43 -0500 Subject: [PATCH 20/37] fix a lot of type and build errors --- lenskit/lenskit/pipeline/builder.py | 2 +- lenskit/lenskit/pipeline/components.py | 7 +++--- lenskit/lenskit/pipeline/nodes.py | 25 ++++++++++++++----- .../tests/pipeline/test_component_config.py | 10 +++++--- lenskit/tests/pipeline/test_pipeline_clone.py | 9 +++---- 5 files changed, 34 insertions(+), 19 deletions(-) diff --git a/lenskit/lenskit/pipeline/builder.py b/lenskit/lenskit/pipeline/builder.py index fefd95798..36a139f3e 100644 --- a/lenskit/lenskit/pipeline/builder.py +++ b/lenskit/lenskit/pipeline/builder.py @@ -645,7 +645,7 @@ def build(self) -> Pipeline: Build the pipeline. """ config = self.build_config() - return Pipeline(config, self._nodes.values()) + return Pipeline(config, [self._instantiate(n) for n in self._nodes.values()]) def _instantiate(self, node: Node[ND]) -> Node[ND]: match node: diff --git a/lenskit/lenskit/pipeline/components.py b/lenskit/lenskit/pipeline/components.py index e25f8ae6b..842ebd15a 100644 --- a/lenskit/lenskit/pipeline/components.py +++ b/lenskit/lenskit/pipeline/components.py @@ -54,7 +54,8 @@ """ -class ComponentConstructor(ABC, Generic[CFG, COut]): +@runtime_checkable +class ComponentConstructor(Protocol, Generic[CFG, COut]): """ Protocol for component constructors. """ @@ -63,9 +64,7 @@ def __call__(self, config: CFG | None = None) -> Component[COut]: ... def config_class(self) -> type[CFG] | None: ... - def __isinstance__(self, obj: Any) -> bool: - # FIXME: implement a more rigorous check for this - return isinstance(obj, type) and issubclass(obj, Component) + def validate_config(self, data: Any = None) -> CFG | None: ... @runtime_checkable diff --git a/lenskit/lenskit/pipeline/nodes.py b/lenskit/lenskit/pipeline/nodes.py index 3ac42b11e..5f9eda59e 100644 --- a/lenskit/lenskit/pipeline/nodes.py +++ b/lenskit/lenskit/pipeline/nodes.py @@ -8,11 +8,19 @@ from __future__ import annotations from abc import abstractmethod +from collections.abc import Mapping from typing import Any, cast +from pydantic import JsonValue from typing_extensions import Generic, TypeVar -from .components import Component, ComponentConstructor, PipelineFunction, component_inputs +from .components import ( + Component, + ComponentConstructor, + PipelineFunction, + component_inputs, + component_return_type, +) # Nodes are (conceptually) immutable data containers, so Node[U] can be assigned # to Node[T] if U ≼ T. @@ -91,13 +99,13 @@ def __init__(self, name: str): def create( name: str, comp: ComponentConstructor[CFG, ND] | Component[ND] | PipelineFunction[ND], - config: CFG | None = None, + config: CFG | Mapping[str, JsonValue] | None = None, ) -> ComponentNode[ND]: - if isinstance(comp, ComponentConstructor): - comp = cast(ComponentConstructor[CFG, ND], comp) - return ComponentConstructorNode(name, comp, config) + if isinstance(comp, Component) or not isinstance(comp, ComponentConstructor): + return ComponentInstanceNode(name, comp) # type: ignore else: - return ComponentInstanceNode(name, comp) + comp = cast(ComponentConstructor[CFG, ND], comp) + return ComponentConstructorNode(name, comp, comp.validate_config(config)) @property @abstractmethod @@ -111,8 +119,11 @@ class ComponentConstructorNode(ComponentNode[ND], Generic[ND]): config: object | None def __init__(self, name: str, constructor: ComponentConstructor[CFG, ND], config: CFG | None): + super().__init__(name) self.constructor = constructor self.config = config + if rt := component_return_type(constructor): + self.types = {rt} @property def inputs(self): @@ -133,6 +144,8 @@ def __init__( super().__init__(name) self.component = component self.connections = connections or {} + if rt := component_return_type(component): + self.types = {rt} @property def inputs(self): diff --git a/lenskit/tests/pipeline/test_component_config.py b/lenskit/tests/pipeline/test_component_config.py index a275ca456..a2b6824a3 100644 --- a/lenskit/tests/pipeline/test_component_config.py +++ b/lenskit/tests/pipeline/test_component_config.py @@ -10,13 +10,14 @@ from dataclasses import dataclass from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, TypeAdapter from pydantic.dataclasses import dataclass as pydantic_dataclass from pytest import mark from lenskit.pipeline import PipelineBuilder from lenskit.pipeline.components import Component, ComponentConstructor +from lenskit.pipeline.nodes import ComponentConstructorNode @dataclass @@ -89,13 +90,16 @@ def test_auto_config_roundtrip(prefixer: type[Component]): def test_pipeline_config(prefixer: ComponentConstructor[Any, str]): pipe = PipelineBuilder() msg = pipe.create_input("msg", str) - pipe.add_component("prefix", prefixer, {"prefix": "scroll named "}, msg=msg) + pn = pipe.add_component("prefix", prefixer, {"prefix": "scroll named "}, msg=msg) + assert isinstance(pn, ComponentConstructorNode) + assert pn.constructor == prefixer + assert getattr(pn.config, "prefix") == "scroll named " pipe = pipe.build() assert pipe.run("prefix", msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH" config = pipe.config.components - print(json.dumps(config, indent=2)) + print(TypeAdapter(dict).dump_json(config, indent=2)) assert "prefix" in config assert config["prefix"].config diff --git a/lenskit/tests/pipeline/test_pipeline_clone.py b/lenskit/tests/pipeline/test_pipeline_clone.py index c6023efdd..66a932fe9 100644 --- a/lenskit/tests/pipeline/test_pipeline_clone.py +++ b/lenskit/tests/pipeline/test_pipeline_clone.py @@ -9,7 +9,7 @@ from lenskit.pipeline import PipelineBuilder from lenskit.pipeline.components import Component -from lenskit.pipeline.nodes import ComponentNode +from lenskit.pipeline.nodes import ComponentInstanceNode, ComponentNode @dataclass @@ -36,19 +36,18 @@ def exclaim(msg: str) -> str: def test_pipeline_clone(): - comp = Prefixer(PrefixConfig("scroll named ")) - pipe = PipelineBuilder() msg = pipe.create_input("msg", str) - pipe.add_component("prefix", comp, msg=msg) + pipe.add_component("prefix", Prefixer, PrefixConfig(prefix="scroll named "), msg=msg) pipe.default_component("prefix") pipe = pipe.build() assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH" + comp = pipe.node("prefix").component # type: ignore p2 = pipe.clone() n2 = p2.node("prefix") - assert isinstance(n2, ComponentNode) + assert isinstance(n2, ComponentInstanceNode) assert isinstance(n2.component, Prefixer) assert n2.component is not comp assert n2.component.config.prefix == comp.config.prefix From 3c958e2d78967f309870bdc31edba5873827422e Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 19:47:39 -0500 Subject: [PATCH 21/37] fix bad clone tests --- lenskit/lenskit/pipeline/nodes.py | 10 +++++++--- lenskit/tests/pipeline/test_pipeline_clone.py | 8 ++++---- 2 files changed, 11 insertions(+), 7 deletions(-) diff --git a/lenskit/lenskit/pipeline/nodes.py b/lenskit/lenskit/pipeline/nodes.py index 5f9eda59e..23c4ce21c 100644 --- a/lenskit/lenskit/pipeline/nodes.py +++ b/lenskit/lenskit/pipeline/nodes.py @@ -101,11 +101,15 @@ def create( comp: ComponentConstructor[CFG, ND] | Component[ND] | PipelineFunction[ND], config: CFG | Mapping[str, JsonValue] | None = None, ) -> ComponentNode[ND]: - if isinstance(comp, Component) or not isinstance(comp, ComponentConstructor): - return ComponentInstanceNode(name, comp) # type: ignore - else: + if isinstance(comp, Component): + return ComponentInstanceNode(name, cast(Component[ND], comp)) + elif isinstance(comp, ComponentConstructor): comp = cast(ComponentConstructor[CFG, ND], comp) return ComponentConstructorNode(name, comp, comp.validate_config(config)) + elif isinstance(comp, type): + return ComponentConstructorNode(name, comp, None) # type: ignore + else: + return ComponentInstanceNode(name, comp) @property @abstractmethod diff --git a/lenskit/tests/pipeline/test_pipeline_clone.py b/lenskit/tests/pipeline/test_pipeline_clone.py index 66a932fe9..c66d0ec3e 100644 --- a/lenskit/tests/pipeline/test_pipeline_clone.py +++ b/lenskit/tests/pipeline/test_pipeline_clone.py @@ -9,7 +9,7 @@ from lenskit.pipeline import PipelineBuilder from lenskit.pipeline.components import Component -from lenskit.pipeline.nodes import ComponentInstanceNode, ComponentNode +from lenskit.pipeline.nodes import ComponentInstanceNode @dataclass @@ -50,7 +50,7 @@ def test_pipeline_clone(): assert isinstance(n2, ComponentInstanceNode) assert isinstance(n2.component, Prefixer) assert n2.component is not comp - assert n2.component.config.prefix == comp.config.prefix + assert n2.component.config.prefix == comp.config.prefix # type: ignore assert p2.run(msg="HACKEM MUCHE") == "scroll named HACKEM MUCHE" @@ -62,7 +62,7 @@ def test_pipeline_clone_with_function(): msg = pipe.create_input("msg", str) pfx = pipe.add_component("prefix", comp, msg=msg) pipe.add_component("exclaim", exclaim, msg=pfx) - pipe.default_component("prefix") + pipe.default_component("exclaim") pipe = pipe.build() assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH!" @@ -79,7 +79,7 @@ def test_pipeline_clone_with_nonconfig_class(): msg = pipe.create_input("msg", str) pfx = pipe.add_component("prefix", comp, msg=msg) pipe.add_component("question", Question(), msg=pfx) - pipe.default_component("prefix") + pipe.default_component("question") pipe = pipe.build() assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH?" From 7be94f52d3080f60ecde817dad4417f03f50f9ad Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 19:50:11 -0500 Subject: [PATCH 22/37] re-introduce typecheck warning --- lenskit/lenskit/pipeline/builder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lenskit/lenskit/pipeline/builder.py b/lenskit/lenskit/pipeline/builder.py index 36a139f3e..f601d9880 100644 --- a/lenskit/lenskit/pipeline/builder.py +++ b/lenskit/lenskit/pipeline/builder.py @@ -34,7 +34,7 @@ LiteralNode, Node, ) -from .types import parse_type_string +from .types import TypecheckWarning, parse_type_string _log = get_logger(__name__) @@ -307,6 +307,8 @@ def add_component( self._check_available_name(name) node = ComponentNode[ND].create(name, comp, config) + if node.types is None: + warnings.warn(f"cannot determine return type of component {comp}", TypecheckWarning) self._nodes[name] = node self.connect(node, **inputs) From fd63b8d17b4907df700af155729a81b9a3d8018c Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 19:52:12 -0500 Subject: [PATCH 23/37] silence a few warnings --- lenskit/lenskit/pipeline/components.py | 13 ++++++++----- lenskit/lenskit/pipeline/runner.py | 2 +- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/lenskit/lenskit/pipeline/components.py b/lenskit/lenskit/pipeline/components.py index 842ebd15a..f6469e73f 100644 --- a/lenskit/lenskit/pipeline/components.py +++ b/lenskit/lenskit/pipeline/components.py @@ -268,6 +268,8 @@ def instantiate_component( def component_inputs( component: Component[COut] | ComponentConstructor[Any, COut] | PipelineFunction[COut], + *, + warn_on_missing: bool = True, ) -> dict[str, type | None]: if isinstance(component, FunctionType): function = component @@ -287,11 +289,12 @@ def component_inputs( if pt := types.get(param.name, None): inputs[param.name] = pt else: - warnings.warn( - f"parameter {param.name} of component {component} has no type annotation", - TypecheckWarning, - 2, - ) + if warn_on_missing: + warnings.warn( + f"parameter {param.name} of component {component} has no type annotation", + TypecheckWarning, + 2, + ) inputs[param.name] = None return inputs diff --git a/lenskit/lenskit/pipeline/runner.py b/lenskit/lenskit/pipeline/runner.py index f87c33757..bc9381645 100644 --- a/lenskit/lenskit/pipeline/runner.py +++ b/lenskit/lenskit/pipeline/runner.py @@ -114,7 +114,7 @@ def _run_component( in_data = {} log = self.log.bind(node=name) trace(log, "processing inputs") - inputs = component_inputs(comp) + inputs = component_inputs(comp, warn_on_missing=False) for iname, itype in inputs.items(): ilog = log.bind(input_name=iname, input_type=itype) trace(ilog, "resolving input") From 938caf7e82ca8b78d071a2d49227dc63a333af86 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 20:59:14 -0500 Subject: [PATCH 24/37] xfail the cycle detector --- lenskit/tests/pipeline/test_pipeline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lenskit/tests/pipeline/test_pipeline.py b/lenskit/tests/pipeline/test_pipeline.py index e81d9d072..f17ece2ef 100644 --- a/lenskit/tests/pipeline/test_pipeline.py +++ b/lenskit/tests/pipeline/test_pipeline.py @@ -10,7 +10,7 @@ import numpy as np from typing_extensions import assert_type -from pytest import raises, warns +from pytest import mark, raises, warns from lenskit.pipeline import PipelineBuilder, PipelineError from lenskit.pipeline.nodes import InputNode, Node @@ -239,6 +239,7 @@ def add(x: int, y: int) -> int: assert pipe.run(nd, a=3, b=7) == 6 +@mark.xfail(reason="cycle detection not yet implemented") def test_cycle(): pipe = PipelineBuilder() b = pipe.create_input("b", int) From 3a3dec2d06bd4353642088e4a0761823f907b9b4 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 21:00:16 -0500 Subject: [PATCH 25/37] rename thing to stop confusing pytest --- lenskit/tests/pipeline/test_component_util.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/lenskit/tests/pipeline/test_component_util.py b/lenskit/tests/pipeline/test_component_util.py index 65be731e0..64a510997 100644 --- a/lenskit/tests/pipeline/test_component_util.py +++ b/lenskit/tests/pipeline/test_component_util.py @@ -5,12 +5,12 @@ @dataclass -class TestConfig: +class XConfig: suffix: str = "" -class TestComp(Component): - config: TestConfig +class XComp(Component): + config: XConfig def __call__(self, msg: str) -> str: return msg + self.config.suffix @@ -39,13 +39,13 @@ def func(x: int) -> int: def test_component_class_input(): - inputs = component_inputs(TestComp) + inputs = component_inputs(XComp) assert len(inputs) == 1 assert inputs["msg"] is str def test_component_object_input(): - inputs = component_inputs(TestComp()) + inputs = component_inputs(XComp()) assert len(inputs) == 1 assert inputs["msg"] is str @@ -80,12 +80,12 @@ def func(x: int) -> int: def test_class_return(): - rt = component_return_type(TestComp) + rt = component_return_type(XComp) assert rt is str def test_instance_return(): - rt = component_return_type(TestComp()) + rt = component_return_type(XComp()) assert rt is str From fdf9574a99c94994de179b9afc59db9283b5b3a2 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 21:32:17 -0500 Subject: [PATCH 26/37] update common for the new builder --- lenskit/lenskit/pipeline/common.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/lenskit/lenskit/pipeline/common.py b/lenskit/lenskit/pipeline/common.py index fdbffda9f..faf502bb4 100644 --- a/lenskit/lenskit/pipeline/common.py +++ b/lenskit/lenskit/pipeline/common.py @@ -8,6 +8,7 @@ from lenskit.data import ID, ItemList, RecQuery from ._impl import Pipeline +from .builder import PipelineBuilder from .components import Component @@ -89,7 +90,7 @@ def build(self, name: str | None = None) -> Pipeline: from lenskit.basic.composite import FallbackScorer from lenskit.basic.history import UserTrainingHistoryLookup - pipe = Pipeline(name=name) + pipe = PipelineBuilder(name=name) query = pipe.create_input("query", RecQuery, ID, ItemList) @@ -121,8 +122,9 @@ def build(self, name: str | None = None) -> Pipeline: rank = pipe.add_component("ranker", self._ranker, items=n_score, n=n_n) pipe.alias("recommender", rank) + pipe.default_component("recommender") - return pipe + return pipe.build() def topn_pipeline( @@ -196,7 +198,7 @@ def predict_pipeline( from lenskit.basic.composite import FallbackScorer from lenskit.basic.history import UserTrainingHistoryLookup - pipe = Pipeline(name=name) + pipe = PipelineBuilder(name=name) query = pipe.create_input("query", RecQuery, ID, ItemList) items = pipe.create_input("items", ItemList) @@ -214,4 +216,6 @@ def predict_pipeline( backup = pipe.add_component("fallback-predictor", fallback, query=lookup, items=items) pipe.add_component("rating-predictor", FallbackScorer(), primary=score, fallback=backup) - return pipe + pipe.default_component("rating-predictor") + + return pipe.build() From d88ecb78a88a6591d83f3d45fd0735fb8f440a70 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 21:38:04 -0500 Subject: [PATCH 27/37] document the builder --- docs/api/pipeline.rst | 1 + docs/guide/pipeline.rst | 51 +++++++++++++++++++++++++++++------------ 2 files changed, 37 insertions(+), 15 deletions(-) diff --git a/docs/api/pipeline.rst b/docs/api/pipeline.rst index c18e5530b..25a7202ca 100644 --- a/docs/api/pipeline.rst +++ b/docs/api/pipeline.rst @@ -15,6 +15,7 @@ Pipeline Classes :caption: Data Sets ~lenskit.pipeline.Pipeline + ~lenskit.pipeline.PipelineBuilder ~lenskit.pipeline.PipelineState ~lenskit.pipeline.Node ~lenskit.pipeline.Lazy diff --git a/docs/guide/pipeline.rst b/docs/guide/pipeline.rst index 7e9337b8e..0555299fd 100644 --- a/docs/guide/pipeline.rst +++ b/docs/guide/pipeline.rst @@ -65,7 +65,10 @@ that class, do: pipe = builder.build('ALS') For maximum flexibility, you can directly construct and wire the pipeline -yourself; this is described in :ref:`standard-pipelines`. +yourself; this is described in :ref:`standard-pipelines`. Pipelines are built +with a :class:`PipelineBuilder`, which sets up the nodes and connections, checks +for things like cycles, and instantiates the components to make a pipeline that +can be trained and used. After any of these methods, you can run the pipeline to produce recommendations with: @@ -107,8 +110,8 @@ These are arranged in a directed acyclic graph, consisting of: :ref:`pipeline-connections` for details. Each node has a name that can be used to look up the node with -:meth:`Pipeline.node` and appears in serialization and logging situations. Names -must be unique within a pipeline. +:meth:`Pipeline.node` (or :meth:`PipelineBuilder.node`) and appears in +serialization and logging situations. Names must be unique within a pipeline. .. _pipeline-connections: @@ -124,21 +127,21 @@ the following types: * A :class:`Node`, in which case the input will be provided from the corresponding pipeline input or component return value. Nodes are returned by - :meth:`~Pipeline.create_input` or :meth:`~Pipeline.add_component`, and can be - looked up after creation with :meth:`~Pipeline.node`. + :meth:`~PipelineBuilder.create_input` or :meth:`~PipelineBuilder.add_component`, and can be + looked up after creation with :meth:`~PipelineBuilder.node`. * A Python object, in which case that value will be provided directly to the component input argument. These input connections are specified via keyword arguments to the -:meth:`Pipeline.add_component` or :meth:`Pipeline.connect` methods — specify the -component's input name(s) and the node or data to which each input should be -wired. +:meth:`PipelineBuilder.add_component` or :meth:`PipelineBuilder.connect` methods +— specify the component's input name(s) and the node or data to which each input +should be wired. -You can also use :meth:`Pipeline.set_default` to specify default connections. -For example, you can specify a default for inputs named ``user``:: +You can also use :meth:`PipelineBuilder.default_conection` to specify default +connections. For example, you can specify a default for inputs named ``user``:: - pipe.set_default('user', user_history) + pipe.default_connection('user', user_history) With this default in place, if a component has an input named ``user`` and that input is not explicitly connected to a node, then the ``user_history`` node will @@ -148,9 +151,25 @@ code overhead needed to wire common pipelines. .. note:: You cannot directly wire an input another component using only that - component's name; if you only have a name, pass it to :meth:`Pipeline.node` - to obtain the node. This is because it would be impossible to distinguish - between a string component name and a string data value. + component's name; if you only have a name, pass it to + :meth:`PipelineBuilder.node` to obtain the node. This is because it would + be impossible to distinguish between a string component name and a string + data value. + +.. _pipeline-building: + +Building the Pipeline +--------------------- + +Once you have set up the pipeline with the various methods to :class:`PipelineBuilder`, +you can do a couple of things: + +- Call :class:`PipelineBuilder.build` to build a usable :class:`Pipeline`. + The pipeline can then be trained, run, etc. + +- Call :class:`PipelineBuilder.build_config` to build a + :class:`PipelineConfig` that can be serialized and reloaded from JSON, YAML, + or similar formats. .. _pipeline-execution: @@ -298,7 +317,7 @@ The convenience methods are equivalent to the following pipeline code: .. code:: python - pipe = Pipeline() + pipe = PipelineBuilder() # define an input parameter for the user ID (the 'query') query = pipe.create_input('query', ID) # allow candidate items to be optionally specified @@ -319,6 +338,8 @@ The convenience methods are equivalent to the following pipeline code: # rank the items by score recommend = pipe.add_component('ranker', TopNRanker(50), items=score) pipe.alias('recommender', recommend) + pipe.default_component('recommender') + pipe = pipe.build() If we want to also emit rating predictions, with fallback to a baseline model to From 3c147b5cc1b76515eac8fa1a6e6dd1accafd0e12 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 21:49:11 -0500 Subject: [PATCH 28/37] fix remaining bits for the pipeline builder --- lenskit/lenskit/batch/_runner.py | 2 +- lenskit/tests/basic/test_bias.py | 10 +++++----- lenskit/tests/basic/test_composite.py | 8 ++++---- 3 files changed, 10 insertions(+), 10 deletions(-) diff --git a/lenskit/lenskit/batch/_runner.py b/lenskit/lenskit/batch/_runner.py index b356e7900..bad81e2b7 100644 --- a/lenskit/lenskit/batch/_runner.py +++ b/lenskit/lenskit/batch/_runner.py @@ -147,7 +147,7 @@ def run( n_users = len(test_data) log = _log.bind( - name=pipeline.name, hash=pipeline.config_hash(), n_queries=n_users, n_jobs=self.n_jobs + name=pipeline.name, hash=pipeline.config_hash, n_queries=n_users, n_jobs=self.n_jobs ) log.info("beginning batch run") diff --git a/lenskit/tests/basic/test_bias.py b/lenskit/tests/basic/test_bias.py index d04facd96..bbfed8f7a 100644 --- a/lenskit/tests/basic/test_bias.py +++ b/lenskit/tests/basic/test_bias.py @@ -18,8 +18,7 @@ from lenskit.data import Dataset, from_interactions_df from lenskit.data.items import ItemList from lenskit.operations import predict, recommend -from lenskit.pipeline import Pipeline -from lenskit.pipeline.common import topn_pipeline +from lenskit.pipeline import Pipeline, PipelineBuilder, topn_pipeline from lenskit.testing import BasicComponentTests, ScorerTests _log = logging.getLogger(__name__) @@ -303,7 +302,7 @@ def test_bias_save(): def test_bias_pipeline(ml_ds: Dataset): - pipe = Pipeline() + pipe = PipelineBuilder() user = pipe.create_input("user", int) items = pipe.create_input("items") @@ -311,6 +310,7 @@ def test_bias_pipeline(ml_ds: Dataset): bias.train(ml_ds) out = pipe.add_component("bias", bias, query=user, items=items) + pipe = pipe.build() res = pipe.run(out, user=2, items=ItemList(item_ids=[10, 11, -1])) assert len(res) == 3 @@ -323,7 +323,7 @@ def test_bias_pipeline(ml_ds: Dataset): def test_bias_topn(ml_ds: Dataset): pipe = topn_pipeline(BiasScorer(), predicts_ratings=True, n=10) - print(pipe.get_config()) + print(pipe.config) pipe.train(ml_ds) res = predict(pipe, 2, ItemList(item_ids=[10, 11, -1])) @@ -338,7 +338,7 @@ def test_bias_topn(ml_ds: Dataset): def test_bias_topn_run_length(ml_ds: Dataset): pipe = topn_pipeline(BiasScorer(), predicts_ratings=True, n=100) - print(pipe.get_config()) + print(pipe.config) pipe.train(ml_ds) res = predict(pipe, 2, items=ItemList(item_ids=[10, 11, -1])) diff --git a/lenskit/tests/basic/test_composite.py b/lenskit/tests/basic/test_composite.py index c19fbb8d6..d7f5ab178 100644 --- a/lenskit/tests/basic/test_composite.py +++ b/lenskit/tests/basic/test_composite.py @@ -20,8 +20,7 @@ from lenskit.data.items import ItemList from lenskit.data.types import ID from lenskit.operations import predict, score -from lenskit.pipeline import Pipeline -from lenskit.pipeline.common import RecPipelineBuilder +from lenskit.pipeline import Pipeline, PipelineBuilder, RecPipelineBuilder from lenskit.testing import BasicComponentTests _log = logging.getLogger(__name__) @@ -32,7 +31,7 @@ class TestFallbackScorer(BasicComponentTests): def test_fallback_fill_missing(ml_ds: Dataset): - pipe = Pipeline() + pipe = PipelineBuilder() user = pipe.create_input("user", int) items = pipe.create_input("items") @@ -44,6 +43,7 @@ def test_fallback_fill_missing(ml_ds: Dataset): fallback = FallbackScorer() score = pipe.add_component("mix", fallback, scores=s1, backup=s2) + pipe = pipe.build() pipe.train(ml_ds) # the first 2 of these are rated, the 3rd does not exist, and the other 2 are not rated @@ -66,7 +66,7 @@ def test_fallback_double_bias(rng: np.random.Generator, ml_ds: Dataset): builder.predicts_ratings(fallback=BiasScorer(damping=0)) pipe = builder.build("double-bias") - _log.info("pipeline configuration: %s", pipe.get_config().model_dump_json(indent=2)) + _log.info("pipeline configuration: %s", pipe.config.model_dump_json(indent=2)) pipe.train(ml_ds) From 20a9efaf14f81ddbec3d46b3a8bc24d622fb2cb0 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 21:56:03 -0500 Subject: [PATCH 29/37] doc note --- docs/guide/pipeline.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/docs/guide/pipeline.rst b/docs/guide/pipeline.rst index 0555299fd..c35dd52a7 100644 --- a/docs/guide/pipeline.rst +++ b/docs/guide/pipeline.rst @@ -171,6 +171,9 @@ you can do a couple of things: :class:`PipelineConfig` that can be serialized and reloaded from JSON, YAML, or similar formats. +Building a pipeline resolves default connections, instantiates components from their +configurations, and checks for cycles. + .. _pipeline-execution: Execution From 709ea7d430bd5ee27acda71d7f1f6f4192d0b8d3 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Wed, 15 Jan 2025 22:00:35 -0500 Subject: [PATCH 30/37] document 2 ways to invoke pipelines --- docs/guide/pipeline.rst | 32 +++++++++++++++++++++++++++++--- 1 file changed, 29 insertions(+), 3 deletions(-) diff --git a/docs/guide/pipeline.rst b/docs/guide/pipeline.rst index c35dd52a7..14d8bd90a 100644 --- a/docs/guide/pipeline.rst +++ b/docs/guide/pipeline.rst @@ -326,11 +326,11 @@ The convenience methods are equivalent to the following pipeline code: # allow candidate items to be optionally specified items = pipe.create_input('items', ItemList, None) # look up a user's history in the training data - history = pipe.add_component('history-lookup', LookupTrainingHistory(), query=query) + history = pipe.add_component('history-lookup', LookupTrainingHistory, query=query) # find candidates from the training data default_candidates = pipe.add_component( 'candidate-selector', - UnratedTrainingItemsCandidateSelector(), + UnratedTrainingItemsCandidateSelector, query=history, ) # if the client provided items as a pipeline input, use those; otherwise @@ -339,7 +339,7 @@ The convenience methods are equivalent to the following pipeline code: # score the candidate items using the specified scorer score = pipe.add_component('scorer', scorer, query=query, items=candidates) # rank the items by score - recommend = pipe.add_component('ranker', TopNRanker(50), items=score) + recommend = pipe.add_component('ranker', TopNRanker, {'n': 50}, items=score) pipe.alias('recommender', recommend) pipe.default_component('recommender') pipe = pipe.build() @@ -464,6 +464,32 @@ Finally, you can directly pass configuration parameters to the component constru See :ref:`conventions` for more conventions for component design. +Adding Components to the Pipeline +--------------------------------- + +You can add components to the pipeline in two ways: + +* Instantiate the component with its configuration options and pass it to + :meth:`PipelineBuilder.add_component`:: + + builder.add_component('component-name', MyComponent(option='value')) + + When you convert the pipeline into + a configuration or clone it, the component will be re-instantiated from its + configuration. + +* Pass the component class and configuration separately to + :meth:`PipelineBuilder.add_component`:: + + builder.add_component('component-name', MyComponent, MyConfig(option='value')) + + Alternatively:: + + builder.add_component('component-name', MyComponent, {'option': 'value'})) + +When you use the second approach, :meth:`PipelineBuilder.build` instantiates the +component from the provided configuration. + POPROX and Other Integrators ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ From e3c645fe9eff35d031ba74f02978cd66374dfbc8 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 16 Jan 2025 12:07:00 -0500 Subject: [PATCH 31/37] separate edges from nodes --- lenskit/lenskit/pipeline/_impl.py | 81 ++++++------------------ lenskit/lenskit/pipeline/builder.py | 69 ++++++++------------ lenskit/lenskit/pipeline/config.py | 13 +--- lenskit/lenskit/pipeline/nodes.py | 12 +--- lenskit/lenskit/pipeline/runner.py | 6 +- lenskit/tests/pipeline/test_save_load.py | 12 ++-- 6 files changed, 57 insertions(+), 136 deletions(-) diff --git a/lenskit/lenskit/pipeline/_impl.py b/lenskit/lenskit/pipeline/_impl.py index f86ab119a..84ca3b8a4 100644 --- a/lenskit/lenskit/pipeline/_impl.py +++ b/lenskit/lenskit/pipeline/_impl.py @@ -3,6 +3,7 @@ from collections.abc import Iterable from dataclasses import replace +from typing import Mapping from uuid import NAMESPACE_URL, uuid5 from numpy.random import BitGenerator, Generator, SeedSequence @@ -14,14 +15,10 @@ from lenskit.training import Trainable, TrainingOptions from . import config -from .components import Component from .config import PipelineConfig from .nodes import ( ComponentConstructorNode, ComponentInstanceNode, - ComponentNode, - InputNode, - LiteralNode, Node, ) from .state import PipelineState @@ -63,16 +60,22 @@ class Pipeline: _config: config.PipelineConfig _nodes: dict[str, Node[Any]] + _edges: dict[str, dict[str, str]] _aliases: dict[str, Node[Any]] _default: Node[Any] | None = None _hash: str | None = None - def __init__(self, config: config.PipelineConfig, nodes: Iterable[Node[Any]]): + def __init__( + self, + config: config.PipelineConfig, + nodes: Iterable[Node[Any]], + ): self._nodes = {} for node in nodes: if isinstance(node, ComponentConstructorNode): raise RuntimeError("pipeline is not fully instantiated") self._nodes[node.name] = node + self._edges = {name: cc.inputs for (name, cc) in config.components.items()} self._config = config self._aliases = {} @@ -157,71 +160,23 @@ def node( else: raise KeyError(node) - def clone(self, how: CloneMethod = "config") -> Pipeline: + def node_input_connections(self, node: str | Node[Any]) -> Mapping[str, Node[Any]]: """ - Clone the pipeline, optionally including trained parameters. - - The ``how`` parameter controls how the pipeline is cloned, and what is - available in the clone pipeline. It can be one of the following values: - - ``"config"`` - Create fresh component instances using the configurations of the - components in this pipeline. When applied to a trained pipeline, - the clone does **not** have the original's learned parameters. This - is the default clone method. - ``"pipeline-config"`` - Round-trip the entire pipeline through :meth:`get_config` and - :meth:`from_config`. + Get the input wirings for a node. + """ + node = self.node(node) + edges = self._edges.get(node.name, {}) + return {name: self.node(src) for (name, src) in edges.items()} - Args: - how: - The mechanism to use for cloning the pipeline. + def clone(self) -> Pipeline: + """ + Clone the pipeline, **without** its trained parameters. Returns: A new pipeline with the same components and wiring, but fresh instances created by round-tripping the configuration. """ - from .builder import PipelineBuilder - - if how == "pipeline-config": - return self.from_config(self._config) - elif how != "config": # pragma: nocover - raise NotImplementedError("only 'config' cloning is currently supported") - - clone = PipelineBuilder() - - for node in self.nodes(): - match node: - case InputNode(name, types=types): - if types is None: - types = set[type]() - clone.create_input(name, *types) - case LiteralNode(name, value): - clone.literal(value, name=name) - case ComponentInstanceNode(name, comp): - config = None - if isinstance(comp, Component): - config = comp.config - comp = comp.__class__ # type: ignore - clone.add_component(name, comp, config) # type: ignore - case _: # pragma: nocover - raise RuntimeError(f"invalid node {node}") - - for n, t in self._aliases.items(): - clone.alias(n, t.name) - - for node in self.nodes(): - match node: - case ComponentNode(name, connections=cxns): - cn = clone.node(name) - clone.connect(cn, **{wt: clone.node(wn) for (wt, wn) in cxns.items()}) - case _: - pass - - if self._default: - clone.default_component(self._default.name) - - return clone.build() + return self.from_config(self._config) @property def config_hash(self) -> str: diff --git a/lenskit/lenskit/pipeline/builder.py b/lenskit/lenskit/pipeline/builder.py index f601d9880..8e07f183a 100644 --- a/lenskit/lenskit/pipeline/builder.py +++ b/lenskit/lenskit/pipeline/builder.py @@ -7,8 +7,9 @@ import typing import warnings +from copy import deepcopy from types import UnionType -from uuid import NAMESPACE_URL, uuid4, uuid5 +from uuid import NAMESPACE_URL, uuid5 from typing_extensions import Any, Literal, Self, TypeVar, cast, overload @@ -83,21 +84,20 @@ class PipelineBuilder: """ _nodes: dict[str, Node[Any]] + _edges: dict[str, dict[str, str]] _aliases: dict[str, Node[Any]] _default_connections: dict[str, str] _components: dict[str, PipelineFunction[Any] | Component[Any]] _default: str | None = None - _anon_nodes: set[str] - "Track generated node names." def __init__(self, name: str | None = None, version: str | None = None): self.name = name self.version = version self._nodes = {} + self._edges = {} self._aliases = {} self._default_connections = {} self._components = {} - self._anon_nodes = set() def meta(self, *, include_hash: bool = True) -> config.PipelineMeta: """ @@ -201,8 +201,8 @@ def literal(self, value: T, *, name: str | None = None) -> LiteralNode[T]: :meth:`save_config`. """ if name is None: - name = str(uuid4()) - self._anon_nodes.add(name) + lit = config.PipelineLiteral.represent(value) + name = str(uuid5(NAMESPACE_LITERAL_DATA, lit.model_dump_json())) node = LiteralNode(name, value, types=set([type(value)])) self._nodes[name] = node return node @@ -378,14 +378,18 @@ def connect(self, obj: str | Node[Any], **inputs: Node[Any] | str | object): if not isinstance(node, ComponentNode): raise TypeError(f"only component nodes can be wired, not {node}") + edges = self._edges.get(node.name, None) + if edges is None: + self._edges[node.name] = edges = {} + for k, n in inputs.items(): if isinstance(n, Node): n = cast(Node[Any], n) self._check_member_node(n) - node.connections[k] = n.name + edges[k] = n.name else: lit = self.literal(n) - node.connections[k] = lit.name + edges[k] = lit.name def clone(self) -> PipelineBuilder: """ @@ -416,7 +420,8 @@ def clone(self) -> PipelineBuilder: for node in self.nodes(): match node: - case ComponentNode(name, connections=wiring): + case ComponentNode(name): + wiring = self._edges.get(name, {}) cn = clone.node(name) clone.connect(cn, **{wn: clone.node(wt) for (wn, wt) in wiring.items()}) case _: @@ -447,46 +452,25 @@ def build_config(self, *, include_hash: bool = True) -> PipelineConfig: meta = self.meta(include_hash=False) cfg = PipelineConfig(meta=meta) - # FIXME: don't mutate + edges = deepcopy(self._edges) for node in self._nodes.values(): if isinstance(node, ComponentNode): + c_ins = edges[node.name] for iname in node.inputs.keys(): - if iname not in node.connections and iname in self._default_connections: - node.connections[iname] = self._default_connections[iname] - - # We map anonymous nodes to hash-based names for stability. If we ever - # allow anonymous components, this will need to be adjusted to maintain - # component ordering, but it works for now since only literals can be - # anonymous. First handle the anonymous nodes, so we have that mapping: - remapped: dict[str, str] = {} - for an in self._anon_nodes: - node = self._nodes.get(an, None) - match node: - case None: - # skip nodes that no longer exist - continue - case LiteralNode(name, value): - lit = config.PipelineLiteral.represent(value) - sname = str(uuid5(NAMESPACE_LITERAL_DATA, lit.model_dump_json())) - _log.debug("renamed anonymous node %s to %s", name, sname) - remapped[name] = sname - cfg.literals[sname] = lit - case _: - # the pipeline only generates anonymous literal nodes right now - raise RuntimeError(f"unexpected anonymous node {node}") + if iname not in c_ins and iname in self._default_connections: + c_ins[iname] = self._default_connections[iname] # Now we go over all named nodes and add them to the config: for node in self.nodes(): - if node.name in remapped: - continue - match node: case InputNode(): cfg.inputs.append(config.PipelineInput.from_node(node)) case LiteralNode(name, value): cfg.literals[name] = config.PipelineLiteral.represent(value) case ComponentNode(name): - cfg.components[name] = config.PipelineComponent.from_node(node, remapped) + c_cfg = config.PipelineComponent.from_node(node) + c_cfg.inputs = edges.get(name, {}).copy() + cfg.components[name] = c_cfg case _: # pragma: nocover raise RuntimeError(f"invalid node {node}") @@ -570,11 +554,8 @@ def from_config(cls, config: object) -> Self: # pass 3: wiring for name, comp in cfg.components.items(): - if isinstance(comp.inputs, dict): - inputs = {n: builder.node(t) for (n, t) in comp.inputs.items()} - builder.connect(name, **inputs) - elif not comp.code.startswith("@"): - raise PipelineError(f"component {name} inputs must be dict, not list") + inputs = {n: builder.node(t) for (n, t) in comp.inputs.items()} + builder.connect(name, **inputs) # pass 4: aliases for n, t in cfg.aliases.items(): @@ -651,10 +632,10 @@ def build(self) -> Pipeline: def _instantiate(self, node: Node[ND]) -> Node[ND]: match node: - case ComponentConstructorNode(name, constructor, config, connections=cxns): + case ComponentConstructorNode(name, constructor, config): _log.debug("instantiating component", component=constructor) instance = constructor(config) - return ComponentInstanceNode(name, instance, cxns) + return ComponentInstanceNode(name, instance) case _: return node diff --git a/lenskit/lenskit/pipeline/config.py b/lenskit/lenskit/pipeline/config.py index 187c9d8a3..03f7bf7ef 100644 --- a/lenskit/lenskit/pipeline/config.py +++ b/lenskit/lenskit/pipeline/config.py @@ -109,17 +109,14 @@ class PipelineComponent(BaseModel): with its default constructor parameters. """ - inputs: dict[str, str] | list[str] = Field(default_factory=dict) + inputs: dict[str, str] = Field(default_factory=dict) """ The component's input wirings, mapping input names to node names. For certain meta-nodes, it is specified as a list instead of a dict. """ @classmethod - def from_node(cls, node: ComponentNode[Any], mapping: dict[str, str] | None = None) -> Self: - if mapping is None: - mapping = {} - + def from_node(cls, node: ComponentNode[Any]) -> Self: match node: case ComponentInstanceNode(_name, comp): config = None @@ -136,11 +133,7 @@ def from_node(cls, node: ComponentNode[Any], mapping: dict[str, str] | None = No code = f"{ctype.__module__}:{ctype.__qualname__}" - return cls( - code=code, - config=config, - inputs={n: mapping.get(t, t) for (n, t) in node.connections.items()}, - ) + return cls(code=code, config=config) class PipelineLiteral(BaseModel): diff --git a/lenskit/lenskit/pipeline/nodes.py b/lenskit/lenskit/pipeline/nodes.py index 23c4ce21c..7d51f5b54 100644 --- a/lenskit/lenskit/pipeline/nodes.py +++ b/lenskit/lenskit/pipeline/nodes.py @@ -86,14 +86,8 @@ class ComponentNode(Node[ND], Generic[ND]): Internal """ - __match_args__ = ("name", "connections") - - connections: dict[str, str] - "The component's input connections." - def __init__(self, name: str): super().__init__(name) - self.connections = {} @staticmethod def create( @@ -118,7 +112,7 @@ def inputs(self) -> dict[str, type | None]: # pragma: nocover class ComponentConstructorNode(ComponentNode[ND], Generic[ND]): - __match_args__ = ("name", "constructor", "config", "connections") + __match_args__ = ("name", "constructor", "config") constructor: ComponentConstructor[Any, ND] config: object | None @@ -135,7 +129,7 @@ def inputs(self): class ComponentInstanceNode(ComponentNode[ND], Generic[ND]): - __match_args__ = ("name", "component", "connections") + __match_args__ = ("name", "component") component: Component[ND] | PipelineFunction[ND] @@ -143,11 +137,9 @@ def __init__( self, name: str, component: Component[ND] | PipelineFunction[ND], - connections: dict[str, str] | None = None, ): super().__init__(name) self.component = component - self.connections = connections or {} if rt := component_return_type(component): self.types = {rt} diff --git a/lenskit/lenskit/pipeline/runner.py b/lenskit/lenskit/pipeline/runner.py index bc9381645..f4465cc4e 100644 --- a/lenskit/lenskit/pipeline/runner.py +++ b/lenskit/lenskit/pipeline/runner.py @@ -88,8 +88,8 @@ def _run_node(self, node: Node[Any], required: bool) -> None: self.state[name] = value case InputNode(name, types=types): self._inject_input(name, types, required) - case ComponentInstanceNode(name, comp, wiring): - self._run_component(name, comp, wiring, required) + case ComponentInstanceNode(name, comp): + self._run_component(name, comp, required) case _: # pragma: nocover raise PipelineError(f"invalid node {node}") @@ -108,13 +108,13 @@ def _run_component( self, name: str, comp: PipelineFunction[Any], - wiring: dict[str, str], required: bool, ) -> None: in_data = {} log = self.log.bind(node=name) trace(log, "processing inputs") inputs = component_inputs(comp, warn_on_missing=False) + wiring = self.pipe.node_input_connections(name) for iname, itype in inputs.items(): ilog = log.bind(input_name=iname, input_type=itype) trace(ilog, "resolving input") diff --git a/lenskit/tests/pipeline/test_save_load.py b/lenskit/tests/pipeline/test_save_load.py index a1f1a2d21..328b2e588 100644 --- a/lenskit/tests/pipeline/test_save_load.py +++ b/lenskit/tests/pipeline/test_save_load.py @@ -131,7 +131,7 @@ def test_round_trip_single_node(): r2 = p2.node("return") assert isinstance(r2, ComponentInstanceNode) assert r2.component is msg_ident - assert r2.connections == {"msg": "msg"} + assert p2._edges["return"] == {"msg": "msg"} p2 = p2.build() assert p2.run("return", msg="foo") == "foo" @@ -153,7 +153,7 @@ def test_configurable_component(): assert isinstance(r2, ComponentInstanceNode) assert isinstance(r2.component, Prefixer) assert r2.component is not pfx - assert r2.connections == {"msg": "msg"} + assert p2._edges["prefix"] == {"msg": "msg"} p2 = p2.build() assert p2.run("prefix", msg="HACKEM MUCHE") == "scroll named HACKEM MUCHE" @@ -268,7 +268,7 @@ def test_alias_node(): pipe = pipe.build() assert pipe.run("result", a=5, b=7) == 17 - p2 = pipe.clone("pipeline-config") + p2 = pipe.clone() assert p2.run("result", a=5, b=7) == 17 @@ -283,7 +283,7 @@ def test_literal(): assert pipe.run(msg="HACKEM MUCHE") == "hello, HACKEM MUCHE" print(pipe.config.model_dump_json(indent=2)) - p2 = pipe.clone("pipeline-config") + p2 = pipe.clone() assert p2.run(msg="FOOBIE BLETCH") == "hello, FOOBIE BLETCH" @@ -294,12 +294,12 @@ def test_literal_array(): pipe.add_component("add", add, x=np.arange(10), y=a) pipe.default_component("add") + print("pipeline:", pipe.build_config().model_dump_json(indent=2)) pipe = pipe.build() res = pipe.run(a=5) assert np.all(res == np.arange(5, 15)) - print(pipe.config.model_dump_json(indent=2)) - p2 = pipe.clone("pipeline-config") + p2 = pipe.clone() assert np.all(p2.run(a=5) == np.arange(5, 15)) From 0c4e2d8bdf021b98b2116e6207572e26451893be Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 16 Jan 2025 12:19:42 -0500 Subject: [PATCH 32/37] test for cycles at build time --- lenskit/lenskit/pipeline/builder.py | 17 +++++++++++++++++ lenskit/tests/pipeline/test_pipeline.py | 1 - 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/lenskit/lenskit/pipeline/builder.py b/lenskit/lenskit/pipeline/builder.py index 8e07f183a..308aa5ca0 100644 --- a/lenskit/lenskit/pipeline/builder.py +++ b/lenskit/lenskit/pipeline/builder.py @@ -8,6 +8,7 @@ import typing import warnings from copy import deepcopy +from graphlib import CycleError, TopologicalSorter from types import UnionType from uuid import NAMESPACE_URL, uuid5 @@ -391,6 +392,20 @@ def connect(self, obj: str | Node[Any], **inputs: Node[Any] | str | object): lit = self.literal(n) edges[k] = lit.name + def validate(self): + """ + Check the built pipeline for errors. + """ + + # Check for cycles + graph = {n: set(w.values()) for (n, w) in self._edges.items()} + print(graph) + ts = TopologicalSorter(graph) + try: + ts.prepare() + except CycleError as e: + raise PipelineError("pipeline has cycles") from e + def clone(self) -> PipelineBuilder: """ Clone the pipeline builder. The resulting builder starts as a copy of @@ -449,6 +464,8 @@ def build_config(self, *, include_hash: bool = True) -> PipelineConfig: inputs) cannot be serialized, and this method will fail if they are present in the pipeline. """ + self.validate() + meta = self.meta(include_hash=False) cfg = PipelineConfig(meta=meta) diff --git a/lenskit/tests/pipeline/test_pipeline.py b/lenskit/tests/pipeline/test_pipeline.py index f17ece2ef..0f0aa9394 100644 --- a/lenskit/tests/pipeline/test_pipeline.py +++ b/lenskit/tests/pipeline/test_pipeline.py @@ -239,7 +239,6 @@ def add(x: int, y: int) -> int: assert pipe.run(nd, a=3, b=7) == 6 -@mark.xfail(reason="cycle detection not yet implemented") def test_cycle(): pipe = PipelineBuilder() b = pipe.create_input("b", int) From a1b62816a495de72030ac867d9574aa31edf3a3b Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 16 Jan 2025 12:20:23 -0500 Subject: [PATCH 33/37] clean up component warnings --- lenskit/lenskit/basic/history.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/lenskit/lenskit/basic/history.py b/lenskit/lenskit/basic/history.py index 962e190d7..0b8e3b4ff 100644 --- a/lenskit/lenskit/basic/history.py +++ b/lenskit/lenskit/basic/history.py @@ -29,6 +29,7 @@ class UserTrainingHistoryLookup(Component[ItemList], Trainable): Caller """ + config: None training_data_: Dataset @override @@ -72,6 +73,7 @@ class KnownRatingScorer(Component[ItemList], Trainable): in the query as the source of score data. """ + config: None score: Literal["rating", "indicator"] | None source: Literal["training", "query"] From a943644e1ca88ccf03e7848c619a28ffc6cac44b Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 16 Jan 2025 12:23:35 -0500 Subject: [PATCH 34/37] sort aliases and inputs for deterministic configuration --- lenskit/lenskit/pipeline/builder.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/lenskit/lenskit/pipeline/builder.py b/lenskit/lenskit/pipeline/builder.py index 308aa5ca0..8648c0acf 100644 --- a/lenskit/lenskit/pipeline/builder.py +++ b/lenskit/lenskit/pipeline/builder.py @@ -486,12 +486,12 @@ def build_config(self, *, include_hash: bool = True) -> PipelineConfig: cfg.literals[name] = config.PipelineLiteral.represent(value) case ComponentNode(name): c_cfg = config.PipelineComponent.from_node(node) - c_cfg.inputs = edges.get(name, {}).copy() + c_cfg.inputs = dict(sorted(edges.get(name, {}).items(), key=lambda kv: kv[0])) cfg.components[name] = c_cfg case _: # pragma: nocover raise RuntimeError(f"invalid node {node}") - cfg.aliases = {a: t.name for (a, t) in self._aliases.items()} + cfg.aliases = {a: t.name for (a, t) in sorted(self._aliases.items(), key=lambda kv: kv[0])} if self._default: cfg.default = self._default @@ -584,7 +584,9 @@ def from_config(cls, config: object) -> Self: h2 = builder.config_hash() if h2 != cfg.meta.hash: _log.warning("loaded pipeline does not match hash") - warnings.warn("loaded pipeline config does not match hash", PipelineWarning) + warnings.warn( + "loaded pipeline config does not match hash", PipelineWarning, stacklevel=2 + ) return builder From 735c745f602af7495f8e75e44584dfc842844973 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 16 Jan 2025 12:25:40 -0500 Subject: [PATCH 35/37] edge case on missing edgs --- lenskit/lenskit/pipeline/builder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/lenskit/lenskit/pipeline/builder.py b/lenskit/lenskit/pipeline/builder.py index 8648c0acf..39c07c9e8 100644 --- a/lenskit/lenskit/pipeline/builder.py +++ b/lenskit/lenskit/pipeline/builder.py @@ -472,7 +472,9 @@ def build_config(self, *, include_hash: bool = True) -> PipelineConfig: edges = deepcopy(self._edges) for node in self._nodes.values(): if isinstance(node, ComponentNode): - c_ins = edges[node.name] + c_ins = edges.get(node.name, None) + if c_ins is None: + edges[node.name] = c_ins = {} for iname in node.inputs.keys(): if iname not in c_ins and iname in self._default_connections: c_ins[iname] = self._default_connections[iname] From 84a20fec3f026d20eeb80cb886c9bde2a77e7ccb Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 16 Jan 2025 12:32:18 -0500 Subject: [PATCH 36/37] adsd builder.clone tests --- lenskit/lenskit/pipeline/builder.py | 2 + lenskit/tests/pipeline/test_pipeline_clone.py | 95 +++++++++++-------- 2 files changed, 57 insertions(+), 40 deletions(-) diff --git a/lenskit/lenskit/pipeline/builder.py b/lenskit/lenskit/pipeline/builder.py index 39c07c9e8..28d623fec 100644 --- a/lenskit/lenskit/pipeline/builder.py +++ b/lenskit/lenskit/pipeline/builder.py @@ -445,6 +445,8 @@ def clone(self) -> PipelineBuilder: for n, t in self._default_connections.items(): clone.default_connection(n, clone.node(t)) + clone._default = self._default + return clone def build_config(self, *, include_hash: bool = True) -> PipelineConfig: diff --git a/lenskit/tests/pipeline/test_pipeline_clone.py b/lenskit/tests/pipeline/test_pipeline_clone.py index c66d0ec3e..e34b877aa 100644 --- a/lenskit/tests/pipeline/test_pipeline_clone.py +++ b/lenskit/tests/pipeline/test_pipeline_clone.py @@ -6,9 +6,11 @@ # pyright: strict from dataclasses import dataclass +from typing import Literal -from lenskit.pipeline import PipelineBuilder -from lenskit.pipeline.components import Component +from pytest import mark + +from lenskit.pipeline import Component, Pipeline, PipelineBuilder from lenskit.pipeline.nodes import ComponentInstanceNode @@ -35,17 +37,26 @@ def exclaim(msg: str) -> str: return msg + "!" -def test_pipeline_clone(): - pipe = PipelineBuilder() - msg = pipe.create_input("msg", str) - pipe.add_component("prefix", Prefixer, PrefixConfig(prefix="scroll named "), msg=msg) - pipe.default_component("prefix") +def _clone(builder: PipelineBuilder, pipe: Pipeline, what: Literal["pipe", "builder"]) -> Pipeline: + match what: + case "pipe": + return pipe.clone() + case "builder": + return builder.clone().build() - pipe = pipe.build() + +@mark.parametrize("what", ["pipe", "builder"]) +def test_pipeline_clone(what: Literal["pipe", "builder"]): + builder = PipelineBuilder() + msg = builder.create_input("msg", str) + builder.add_component("prefix", Prefixer, PrefixConfig(prefix="scroll named "), msg=msg) + builder.default_component("prefix") + + pipe = builder.build() assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH" comp = pipe.node("prefix").component # type: ignore - p2 = pipe.clone() + p2 = _clone(builder, pipe, what) n2 = p2.node("prefix") assert isinstance(n2, ComponentInstanceNode) assert isinstance(n2.component, Prefixer) @@ -55,65 +66,69 @@ def test_pipeline_clone(): assert p2.run(msg="HACKEM MUCHE") == "scroll named HACKEM MUCHE" -def test_pipeline_clone_with_function(): +@mark.parametrize("what", ["pipe", "builder"]) +def test_pipeline_clone_with_function(what: Literal["pipe", "builder"]): comp = Prefixer(prefix="scroll named ") - pipe = PipelineBuilder() - msg = pipe.create_input("msg", str) - pfx = pipe.add_component("prefix", comp, msg=msg) - pipe.add_component("exclaim", exclaim, msg=pfx) - pipe.default_component("exclaim") + builder = PipelineBuilder() + msg = builder.create_input("msg", str) + pfx = builder.add_component("prefix", comp, msg=msg) + builder.add_component("exclaim", exclaim, msg=pfx) + builder.default_component("exclaim") - pipe = pipe.build() + pipe = builder.build() assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH!" - p2 = pipe.clone() + p2 = _clone(builder, pipe, what) assert p2.run(msg="HACKEM MUCHE") == "scroll named HACKEM MUCHE!" -def test_pipeline_clone_with_nonconfig_class(): +@mark.parametrize("what", ["pipe", "builder"]) +def test_pipeline_clone_with_nonconfig_class(what: Literal["pipe", "builder"]): comp = Prefixer(prefix="scroll named ") - pipe = PipelineBuilder() - msg = pipe.create_input("msg", str) - pfx = pipe.add_component("prefix", comp, msg=msg) - pipe.add_component("question", Question(), msg=pfx) - pipe.default_component("question") + builder = PipelineBuilder() + msg = builder.create_input("msg", str) + pfx = builder.add_component("prefix", comp, msg=msg) + builder.add_component("question", Question(), msg=pfx) + builder.default_component("question") - pipe = pipe.build() + pipe = builder.build() assert pipe.run(msg="FOOBIE BLETCH") == "scroll named FOOBIE BLETCH?" - p2 = pipe.clone() + p2 = _clone(builder, pipe, what) assert p2.run(msg="HACKEM MUCHE") == "scroll named HACKEM MUCHE?" -def test_clone_defaults(): - pipe = PipelineBuilder() - msg = pipe.create_input("msg", str) - pipe.default_connection("msg", msg) - pipe.add_component("return", exclaim) - pipe.default_component("return") +@mark.parametrize("what", ["pipe", "builder"]) +def test_clone_defaults(what: Literal["pipe", "builder"]): + builder = PipelineBuilder() + msg = builder.create_input("msg", str) + builder.default_connection("msg", msg) + builder.add_component("return", exclaim) + builder.default_component("return") - pipe = pipe.build() + pipe = builder.build() assert pipe.run(msg="hello") == "hello!" - p2 = pipe.clone() + p2 = _clone(builder, pipe, what) assert p2.run(msg="hello") == "hello!" -def test_clone_alias(): - pipe = PipelineBuilder() - msg = pipe.create_input("msg", str) - excl = pipe.add_component("exclaim", exclaim, msg=msg) - pipe.alias("return", excl) +@mark.parametrize("what", ["pipe", "builder"]) +def test_clone_alias(what: Literal["pipe", "builder"]): + builder = PipelineBuilder() + msg = builder.create_input("msg", str) + excl = builder.add_component("exclaim", exclaim, msg=msg) + builder.alias("return", excl) - pipe = pipe.build() + pipe = builder.build() assert pipe.run("return", msg="hello") == "hello!" - p2 = pipe.clone() + p2 = _clone(builder, pipe, what) assert p2.run("return", msg="hello") == "hello!" From 6eab439afb01817efc3c7b8da2b697ac96764540 Mon Sep 17 00:00:00 2001 From: Michael Ekstrand Date: Thu, 16 Jan 2025 12:35:22 -0500 Subject: [PATCH 37/37] remove stray print --- lenskit/lenskit/pipeline/builder.py | 1 - 1 file changed, 1 deletion(-) diff --git a/lenskit/lenskit/pipeline/builder.py b/lenskit/lenskit/pipeline/builder.py index 28d623fec..253f53174 100644 --- a/lenskit/lenskit/pipeline/builder.py +++ b/lenskit/lenskit/pipeline/builder.py @@ -399,7 +399,6 @@ def validate(self): # Check for cycles graph = {n: set(w.values()) for (n, w) in self._edges.items()} - print(graph) ts = TopologicalSorter(graph) try: ts.prepare()