diff --git a/docs/pipeline.rst b/docs/pipeline.rst index f519d977b..89604d5be 100644 --- a/docs/pipeline.rst +++ b/docs/pipeline.rst @@ -172,7 +172,7 @@ method takes two types of inputs: altered scores). If no components are specified, it is the same as specifying the last - component added to the pipeline. + component that was added to the pipeline. * Keyword arguments specifying the values for the pipeline's inputs, as defined by calls to :meth:`create_input`. diff --git a/lenskit/lenskit/pipeline/__init__.py b/lenskit/lenskit/pipeline/__init__.py index 15bdcd192..1f3a20373 100644 --- a/lenskit/lenskit/pipeline/__init__.py +++ b/lenskit/lenskit/pipeline/__init__.py @@ -15,7 +15,7 @@ import warnings from types import FunctionType from typing import Literal, cast -from uuid import uuid4 +from uuid import NAMESPACE_URL, uuid4, uuid5 from typing_extensions import Any, Self, TypeAlias, TypeVar, overload @@ -29,7 +29,14 @@ TrainableComponent, instantiate_component, ) -from .config import PipelineComponent, PipelineConfig, PipelineInput, PipelineMeta, hash_config +from .config import ( + PipelineComponent, + PipelineConfig, + PipelineInput, + PipelineLiteral, + PipelineMeta, + hash_config, +) from .nodes import ND, ComponentNode, FallbackNode, InputNode, LiteralNode, Node from .state import PipelineState @@ -54,7 +61,9 @@ T3 = TypeVar("T3") T4 = TypeVar("T4") T5 = TypeVar("T5") -CloneMethod: TypeAlias = Literal["config"] +CloneMethod: TypeAlias = Literal["config", "pipeline-config"] + +NAMESPACE_LITERAL_DATA = uuid5(NAMESPACE_URL, "https://ns.lenskit.org/literal-data/") class PipelineError(Exception): @@ -108,6 +117,9 @@ class Pipeline: _defaults: dict[str, Node[Any]] _components: dict[str, Component[Any]] _hash: str | None = None + _last: Node[Any] | None = None + _anon_nodes: set[str] + "Track generated node names." def __init__(self, name: str | None = None, version: str | None = None): self.name = name @@ -116,22 +128,19 @@ def __init__(self, name: str | None = None, version: str | None = None): self._aliases = {} self._defaults = {} self._components = {} + self._anon_nodes = set() self._clear_caches() - def meta(self, *, include_hash: bool | None = None) -> PipelineMeta: + def meta(self, *, include_hash: bool = True) -> PipelineMeta: """ Get the metadata (name, version, hash, etc.) for this pipeline without returning the whole config. Args: include_hash: - Whether to include a configuration hash in the metadata. If - ``None`` (the default), the metadata includes a hash if there - are no :meth:`literal` nodes in the pipeline. + Whether to include a configuration hash in the metadata. """ meta = PipelineMeta(name=self.name, version=self.version) - if include_hash is None: - include_hash = not any(isinstance(n, LiteralNode) for n in self.nodes) if include_hash: meta.hash = self.config_hash() return meta @@ -208,7 +217,7 @@ def create_input(self, name: str, *types: type[T] | None) -> Node[T]: self._clear_caches() return node - def literal(self, value: T) -> LiteralNode[T]: + def literal(self, value: T, *, name: str | None = None) -> LiteralNode[T]: """ Create a literal node (a node with a fixed value). @@ -216,7 +225,9 @@ def literal(self, value: T) -> LiteralNode[T]: Literal nodes cannot be serialized witih :meth:`get_config` or :meth:`save_config`. """ - name = str(uuid4()) + if name is None: + name = str(uuid4()) + self._anon_nodes.add(name) node = LiteralNode(name, value, types=set([type(value)])) self._nodes[name] = node self._clear_caches() @@ -296,6 +307,7 @@ def add_component( self.connect(node, **inputs) self._clear_caches() + self._last = node return node def replace_component( @@ -427,11 +439,16 @@ def clone(self, how: CloneMethod = "config") -> Pipeline: Clone the pipeline, optionally including trained parameters. The ``how`` parameter controls how the pipeline is cloned, and what is - available in the clone pipeline. Currently only ``"config"`` is - supported, which creates fresh component instances using the - configurations of the components in this pipeline. When applied to a - trained pipeline, the clone does **not** have the original's learned - parameters. + available in the clone pipeline. It can be one of the following values: + + ``"config"`` + Create fresh component instances using the configurations of the + components in this pipeline. When applied to a trained pipeline, + the clone does **not** have the original's learned parameters. This + is the default clone method. + ``"pipeline-config"`` + Round-trip the entire pipeline through :meth:`get_config` and + :meth:`from_config`. Args: how: @@ -441,10 +458,14 @@ def clone(self, how: CloneMethod = "config") -> Pipeline: A new pipeline with the same components and wiring, but fresh instances created by round-tripping the configuration. """ - if how != "config": # pragma: nocover + if how == "pipeline-config": + cfg = self.get_config() + return self.from_config(cfg) + elif how != "config": # pragma: nocover raise NotImplementedError("only 'config' cloning is currently supported") clone = Pipeline() + for node in self.nodes: match node: case InputNode(name, types=types): @@ -495,17 +516,44 @@ def get_config(self, *, include_hash: bool = True) -> PipelineConfig: """ meta = self.meta(include_hash=False) config = PipelineConfig(meta=meta) + + # We map anonymous nodes to hash-based names for stability. If we ever + # allow anonymous components, this will need to be adjusted to maintain + # component ordering, but it works for now since only literals can be + # anonymous. First handle the anonymous nodes, so we have that mapping: + remapped: dict[str, str] = {} + for an in self._anon_nodes: + node = self._nodes.get(an, None) + match node: + case None: + # skip nodes that no longer exist + continue + case LiteralNode(name, value): + cfg = PipelineLiteral.represent(value) + sname = str(uuid5(NAMESPACE_LITERAL_DATA, cfg.model_dump_json())) + _log.debug("renamed anonymous node %s to %s", name, sname) + remapped[name] = sname + config.literals[sname] = cfg + case _: + # the pipeline only generates anonymous literal nodes right now + raise RuntimeError(f"unexpected anonymous node {node}") + + # Now we go over all named nodes and add them to the config: for node in self.nodes: + if node.name in remapped: + continue + match node: case InputNode(): config.inputs.append(PipelineInput.from_node(node)) - case LiteralNode(): - raise RuntimeError("literal nodes cannot be serialized to config") + case LiteralNode(name, value): + config.literals[name] = PipelineLiteral.represent(value) case ComponentNode(name): - config.components[name] = PipelineComponent.from_node(node) + config.components[name] = PipelineComponent.from_node(node, remapped) case FallbackNode(name, alternatives): config.components[name] = PipelineComponent( - code="@use-first-of", inputs=[n.name for n in alternatives] + code="@use-first-of", + inputs=[remapped.get(n.name, n.name) for n in alternatives], ) case _: # pragma: nocover raise RuntimeError(f"invalid node {node}") @@ -523,12 +571,16 @@ def config_hash(self) -> str: Get a hash of the pipeline's configuration to uniquely identify it for logging, version control, or other purposes. - The precise algorithm to compute the hash is not guaranteed, except that - the same configuration with the same version of LensKit and its - dependencies will produce the same hash. In LensKit 2024.1, the - configuration hash is computed by computing the JSON serialization of - the pipeline configuration *without* a hash returning the hex-encoded - SHA256 hash of that configuration. + The hash format and algorithm are not guaranteed, but is stable within a + LensKit version. For the same version of LensKit and component code, + the same configuration will produce the same hash, so long as there are + no literal nodes. Literal nodes will *usually* hash consistently, but + since literals other than basic JSON values are hashed by pickling, hash + stability depends on the stability of the pickle bytestream. + + In LensKit 2024.1, the configuration hash is computed by computing the + JSON serialization of the pipeline configuration *without* a hash and + returning the hex-encoded SHA256 hash of that configuration. """ if self._hash is None: # get the config *without* a hash @@ -550,7 +602,11 @@ def from_config(cls, config: object) -> Self: # that nodes are available before they are wired (since `connect` can # introduce out-of-order dependencies). - # pass 1: add components + # pass 1: add literals + for name, data in cfg.literals.items(): + pipe.literal(data.decode(), name=name) + + # pass 2: add components to_wire: list[PipelineComponent] = [] for name, comp in cfg.components.items(): if comp.code.startswith("@"): @@ -561,7 +617,7 @@ def from_config(cls, config: object) -> Self: pipe.add_component(name, obj) to_wire.append(comp) - # pass 2: add meta nodes + # pass 3: add meta nodes for name, comp in cfg.components.items(): if comp.code == "@use-first-of": if not isinstance(comp.inputs, list): @@ -570,7 +626,7 @@ def from_config(cls, config: object) -> Self: elif comp.code.startswith("@"): raise PipelineError(f"unsupported meta-component {comp.code}") - # pass 3: wiring + # pass 4: wiring for name, comp in cfg.components.items(): if isinstance(comp.inputs, dict): inputs = {n: pipe.node(t) for (n, t) in comp.inputs.items()} @@ -578,11 +634,11 @@ def from_config(cls, config: object) -> Self: elif not comp.code.startswith("@"): raise PipelineError(f"component {name} inputs must be dict, not list") - # pass 4: aliases + # pass 5: aliases for n, t in cfg.aliases.items(): pipe.alias(n, t) - # pass 5: defaults + # pass 6: defaults for n, t in cfg.defaults.items(): pipe.set_default(n, pipe.node(t)) @@ -641,9 +697,6 @@ def run(self, *nodes: str | Node[Any], **kwargs: object) -> object: components. See :ref:`pipeline-execution` for details of the pipeline execution model. - .. todo:: - Add cycle detection. - Args: nodes: The component(s) to run. @@ -656,6 +709,8 @@ def run(self, *nodes: str | Node[Any], **kwargs: object) -> object: are returned in a tuple. Raises: + PipelineError: + when there is a pipeline configuration error (e.g. a cycle). ValueError: when one or more required inputs are missing. TypeError: @@ -664,7 +719,9 @@ def run(self, *nodes: str | Node[Any], **kwargs: object) -> object: exceptions thrown by components are passed through. """ if not nodes: - nodes = (self._last_node(),) + if self._last is None: # pragma: nocover + raise PipelineError("pipeline has no components") + nodes = (self._last,) state = self.run_all(*nodes, **kwargs) results = [state[self.node(n).name] for n in nodes] @@ -719,11 +776,6 @@ def run_all(self, *nodes: str | Node[Any], **kwargs: object) -> PipelineState: meta=self.meta(), ) - def _last_node(self) -> Node[object]: - if not self._nodes: - raise RuntimeError("pipeline is empty") - return list(self._nodes.values())[-1] - def _check_available_name(self, name: str) -> None: if name in self._nodes or name in self._aliases: raise ValueError(f"pipeline already has node {name}") diff --git a/lenskit/lenskit/pipeline/config.py b/lenskit/lenskit/pipeline/config.py index 5adaebedd..073375ddc 100644 --- a/lenskit/lenskit/pipeline/config.py +++ b/lenskit/lenskit/pipeline/config.py @@ -5,11 +5,14 @@ # pyright: strict from __future__ import annotations +import base64 +import pickle from collections import OrderedDict from hashlib import sha256 from types import FunctionType +from typing import Literal -from pydantic import BaseModel, Field +from pydantic import BaseModel, Field, JsonValue, ValidationError from typing_extensions import Any, Optional, Self from .components import ConfigurableComponent @@ -34,6 +37,8 @@ class PipelineConfig(BaseModel): "Pipeline components, with their configurations and wiring." aliases: dict[str, str] = Field(default_factory=dict) "Pipeline node aliases." + literals: dict[str, PipelineLiteral] = Field(default_factory=dict) + "Literals" class PipelineMeta(BaseModel): @@ -92,7 +97,10 @@ class PipelineComponent(BaseModel): """ @classmethod - def from_node(cls, node: ComponentNode[Any]) -> Self: + def from_node(cls, node: ComponentNode[Any], mapping: dict[str, str] | None = None) -> Self: + if mapping is None: + mapping = {} + comp = node.component if isinstance(comp, FunctionType): ctype = comp @@ -103,7 +111,38 @@ def from_node(cls, node: ComponentNode[Any]) -> Self: config = comp.get_config() if isinstance(comp, ConfigurableComponent) else None - return cls(code=code, config=config, inputs=node.connections) + return cls( + code=code, + config=config, + inputs={n: mapping.get(t, t) for (n, t) in node.connections.items()}, + ) + + +class PipelineLiteral(BaseModel): + """ + Literal nodes represented in the pipeline. + """ + + encoding: Literal["json", "base85"] + value: JsonValue + + @classmethod + def represent(cls, data: Any) -> Self: + try: + return cls(encoding="json", value=data) + except ValidationError: + # data is not basic JSON values, so let's pickle it + dbytes = pickle.dumps(data) + return cls(encoding="base85", value=base64.b85encode(dbytes).decode("ascii")) + + def decode(self) -> Any: + "Decode the represented literal." + match self.encoding: + case "json": + return self.value + case "base85": + assert isinstance(self.value, str) + return pickle.loads(base64.b85decode(self.value)) def hash_config(config: BaseModel) -> str: diff --git a/lenskit/tests/pipeline/test_pipeline.py b/lenskit/tests/pipeline/test_pipeline.py index dab8a8a46..5c439e78f 100644 --- a/lenskit/tests/pipeline/test_pipeline.py +++ b/lenskit/tests/pipeline/test_pipeline.py @@ -8,13 +8,15 @@ from typing import Any from uuid import UUID +import numpy as np from typing_extensions import assert_type -from pytest import fail, raises +from pytest import fail, raises, warns from lenskit.data import Dataset, Vocabulary from lenskit.pipeline import InputNode, Node, Pipeline, PipelineError from lenskit.pipeline.components import TrainableComponent +from lenskit.pipeline.types import TypecheckWarning def test_init_empty(): @@ -616,3 +618,24 @@ def get_params(self) -> dict[str, object]: def load_params(self, params: dict[str, Any]) -> None: self.items = params["items"] + + +def test_pipeline_component_default(): + """ + Test that the last *component* is last. It also exercises the warning logic + for missing component types. + """ + pipe = Pipeline() + a = pipe.create_input("a", int) + + def add(x, y): # type: ignore + return x + y # type: ignore + + with warns(TypecheckWarning): + pipe.add_component("add", add, x=np.arange(10), y=a) # type: ignore + + # the component runs + assert np.all(pipe.run("add", a=5) == np.arange(5, 15)) + + # the component is the default + assert np.all(pipe.run(a=5) == np.arange(5, 15)) diff --git a/lenskit/tests/pipeline/test_save_load.py b/lenskit/tests/pipeline/test_save_load.py index 18b8ba7c5..2ea2d2fcb 100644 --- a/lenskit/tests/pipeline/test_save_load.py +++ b/lenskit/tests/pipeline/test_save_load.py @@ -1,6 +1,7 @@ import logging from types import NoneType +import numpy as np from typing_extensions import assert_type from pytest import fail, warns @@ -32,7 +33,7 @@ def double(x: int) -> int: return x * 2 -def add(x: int, y: int) -> int: +def add(x: int | np.ndarray, y: int) -> int | np.ndarray: return x + y @@ -222,7 +223,7 @@ def test_hash_validate(): def test_alias_input(): - "alias an input node" + "just an input node and an alias" pipe = Pipeline() user = pipe.create_input("user", int, str) @@ -245,6 +246,45 @@ def test_alias_node(): assert pipe.run("result", a=5, b=7) == 17 - cfg = pipe.get_config() - p2 = Pipeline.from_config(cfg) + p2 = pipe.clone("pipeline-config") assert p2.run("result", a=5, b=7) == 17 + + +def test_literal(): + pipe = Pipeline("literal-prefix") + msg = pipe.create_input("msg", str) + + pipe.add_component("prefix", msg_prefix, prefix=pipe.literal("hello, "), msg=msg) + + assert pipe.run(msg="HACKEM MUCHE") == "hello, HACKEM MUCHE" + + print(pipe.get_config().model_dump_json(indent=2)) + p2 = pipe.clone("pipeline-config") + assert p2.run(msg="FOOBIE BLETCH") == "hello, FOOBIE BLETCH" + + +def test_literal_array(): + pipe = Pipeline("literal-add-array") + a = pipe.create_input("a", int) + + pipe.add_component("add", add, x=np.arange(10), y=a) + + res = pipe.run(a=5) + assert np.all(res == np.arange(5, 15)) + + print(pipe.get_config().model_dump_json(indent=2)) + p2 = pipe.clone("pipeline-config") + assert np.all(p2.run(a=5) == np.arange(5, 15)) + + +def test_stable_with_literals(): + "test that two identical pipelines have the same hash, even with literals" + p1 = Pipeline("literal-add-array") + a = p1.create_input("a", int) + p1.add_component("add", add, x=np.arange(10), y=a) + + p2 = Pipeline("literal-add-array") + a = p2.create_input("a", int) + p2.add_component("add", add, x=np.arange(10), y=a) + + assert p1.config_hash() == p2.config_hash()