Skip to content

Commit

Permalink
Merge branch 'feature/really-hash'
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Aug 9, 2024
2 parents fe17ebd + 1b50b56 commit 37c9995
Show file tree
Hide file tree
Showing 5 changed files with 204 additions and 50 deletions.
2 changes: 1 addition & 1 deletion docs/pipeline.rst
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ method takes two types of inputs:
altered scores).

If no components are specified, it is the same as specifying the last
component added to the pipeline.
component that was added to the pipeline.

* Keyword arguments specifying the values for the pipeline's inputs, as defined by
calls to :meth:`create_input`.
Expand Down
134 changes: 93 additions & 41 deletions lenskit/lenskit/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import warnings
from types import FunctionType
from typing import Literal, cast
from uuid import uuid4
from uuid import NAMESPACE_URL, uuid4, uuid5

from typing_extensions import Any, Self, TypeAlias, TypeVar, overload

Expand All @@ -29,7 +29,14 @@
TrainableComponent,
instantiate_component,
)
from .config import PipelineComponent, PipelineConfig, PipelineInput, PipelineMeta, hash_config
from .config import (
PipelineComponent,
PipelineConfig,
PipelineInput,
PipelineLiteral,
PipelineMeta,
hash_config,
)
from .nodes import ND, ComponentNode, FallbackNode, InputNode, LiteralNode, Node
from .state import PipelineState

Expand All @@ -54,7 +61,9 @@
T3 = TypeVar("T3")
T4 = TypeVar("T4")
T5 = TypeVar("T5")
CloneMethod: TypeAlias = Literal["config"]
CloneMethod: TypeAlias = Literal["config", "pipeline-config"]

NAMESPACE_LITERAL_DATA = uuid5(NAMESPACE_URL, "https://ns.lenskit.org/literal-data/")


class PipelineError(Exception):
Expand Down Expand Up @@ -108,6 +117,9 @@ class Pipeline:
_defaults: dict[str, Node[Any]]
_components: dict[str, Component[Any]]
_hash: str | None = None
_last: Node[Any] | None = None
_anon_nodes: set[str]
"Track generated node names."

def __init__(self, name: str | None = None, version: str | None = None):
self.name = name
Expand All @@ -116,22 +128,19 @@ def __init__(self, name: str | None = None, version: str | None = None):
self._aliases = {}
self._defaults = {}
self._components = {}
self._anon_nodes = set()
self._clear_caches()

def meta(self, *, include_hash: bool | None = None) -> PipelineMeta:
def meta(self, *, include_hash: bool = True) -> PipelineMeta:
"""
Get the metadata (name, version, hash, etc.) for this pipeline without
returning the whole config.
Args:
include_hash:
Whether to include a configuration hash in the metadata. If
``None`` (the default), the metadata includes a hash if there
are no :meth:`literal` nodes in the pipeline.
Whether to include a configuration hash in the metadata.
"""
meta = PipelineMeta(name=self.name, version=self.version)
if include_hash is None:
include_hash = not any(isinstance(n, LiteralNode) for n in self.nodes)
if include_hash:
meta.hash = self.config_hash()
return meta
Expand Down Expand Up @@ -208,15 +217,17 @@ def create_input(self, name: str, *types: type[T] | None) -> Node[T]:
self._clear_caches()
return node

def literal(self, value: T) -> LiteralNode[T]:
def literal(self, value: T, *, name: str | None = None) -> LiteralNode[T]:
"""
Create a literal node (a node with a fixed value).
.. note::
Literal nodes cannot be serialized witih :meth:`get_config` or
:meth:`save_config`.
"""
name = str(uuid4())
if name is None:
name = str(uuid4())
self._anon_nodes.add(name)
node = LiteralNode(name, value, types=set([type(value)]))
self._nodes[name] = node
self._clear_caches()
Expand Down Expand Up @@ -296,6 +307,7 @@ def add_component(
self.connect(node, **inputs)

self._clear_caches()
self._last = node
return node

def replace_component(
Expand Down Expand Up @@ -427,11 +439,16 @@ def clone(self, how: CloneMethod = "config") -> Pipeline:
Clone the pipeline, optionally including trained parameters.
The ``how`` parameter controls how the pipeline is cloned, and what is
available in the clone pipeline. Currently only ``"config"`` is
supported, which creates fresh component instances using the
configurations of the components in this pipeline. When applied to a
trained pipeline, the clone does **not** have the original's learned
parameters.
available in the clone pipeline. It can be one of the following values:
``"config"``
Create fresh component instances using the configurations of the
components in this pipeline. When applied to a trained pipeline,
the clone does **not** have the original's learned parameters. This
is the default clone method.
``"pipeline-config"``
Round-trip the entire pipeline through :meth:`get_config` and
:meth:`from_config`.
Args:
how:
Expand All @@ -441,10 +458,14 @@ def clone(self, how: CloneMethod = "config") -> Pipeline:
A new pipeline with the same components and wiring, but fresh
instances created by round-tripping the configuration.
"""
if how != "config": # pragma: nocover
if how == "pipeline-config":
cfg = self.get_config()
return self.from_config(cfg)
elif how != "config": # pragma: nocover
raise NotImplementedError("only 'config' cloning is currently supported")

clone = Pipeline()

for node in self.nodes:
match node:
case InputNode(name, types=types):
Expand Down Expand Up @@ -495,17 +516,44 @@ def get_config(self, *, include_hash: bool = True) -> PipelineConfig:
"""
meta = self.meta(include_hash=False)
config = PipelineConfig(meta=meta)

# We map anonymous nodes to hash-based names for stability. If we ever
# allow anonymous components, this will need to be adjusted to maintain
# component ordering, but it works for now since only literals can be
# anonymous. First handle the anonymous nodes, so we have that mapping:
remapped: dict[str, str] = {}
for an in self._anon_nodes:
node = self._nodes.get(an, None)
match node:
case None:
# skip nodes that no longer exist
continue
case LiteralNode(name, value):
cfg = PipelineLiteral.represent(value)
sname = str(uuid5(NAMESPACE_LITERAL_DATA, cfg.model_dump_json()))
_log.debug("renamed anonymous node %s to %s", name, sname)
remapped[name] = sname
config.literals[sname] = cfg
case _:
# the pipeline only generates anonymous literal nodes right now
raise RuntimeError(f"unexpected anonymous node {node}")

# Now we go over all named nodes and add them to the config:
for node in self.nodes:
if node.name in remapped:
continue

match node:
case InputNode():
config.inputs.append(PipelineInput.from_node(node))
case LiteralNode():
raise RuntimeError("literal nodes cannot be serialized to config")
case LiteralNode(name, value):
config.literals[name] = PipelineLiteral.represent(value)
case ComponentNode(name):
config.components[name] = PipelineComponent.from_node(node)
config.components[name] = PipelineComponent.from_node(node, remapped)
case FallbackNode(name, alternatives):
config.components[name] = PipelineComponent(
code="@use-first-of", inputs=[n.name for n in alternatives]
code="@use-first-of",
inputs=[remapped.get(n.name, n.name) for n in alternatives],
)
case _: # pragma: nocover
raise RuntimeError(f"invalid node {node}")
Expand All @@ -523,12 +571,16 @@ def config_hash(self) -> str:
Get a hash of the pipeline's configuration to uniquely identify it for
logging, version control, or other purposes.
The precise algorithm to compute the hash is not guaranteed, except that
the same configuration with the same version of LensKit and its
dependencies will produce the same hash. In LensKit 2024.1, the
configuration hash is computed by computing the JSON serialization of
the pipeline configuration *without* a hash returning the hex-encoded
SHA256 hash of that configuration.
The hash format and algorithm are not guaranteed, but is stable within a
LensKit version. For the same version of LensKit and component code,
the same configuration will produce the same hash, so long as there are
no literal nodes. Literal nodes will *usually* hash consistently, but
since literals other than basic JSON values are hashed by pickling, hash
stability depends on the stability of the pickle bytestream.
In LensKit 2024.1, the configuration hash is computed by computing the
JSON serialization of the pipeline configuration *without* a hash and
returning the hex-encoded SHA256 hash of that configuration.
"""
if self._hash is None:
# get the config *without* a hash
Expand All @@ -550,7 +602,11 @@ def from_config(cls, config: object) -> Self:
# that nodes are available before they are wired (since `connect` can
# introduce out-of-order dependencies).

# pass 1: add components
# pass 1: add literals
for name, data in cfg.literals.items():
pipe.literal(data.decode(), name=name)

# pass 2: add components
to_wire: list[PipelineComponent] = []
for name, comp in cfg.components.items():
if comp.code.startswith("@"):
Expand All @@ -561,7 +617,7 @@ def from_config(cls, config: object) -> Self:
pipe.add_component(name, obj)
to_wire.append(comp)

# pass 2: add meta nodes
# pass 3: add meta nodes
for name, comp in cfg.components.items():
if comp.code == "@use-first-of":
if not isinstance(comp.inputs, list):
Expand All @@ -570,19 +626,19 @@ def from_config(cls, config: object) -> Self:
elif comp.code.startswith("@"):
raise PipelineError(f"unsupported meta-component {comp.code}")

# pass 3: wiring
# pass 4: wiring
for name, comp in cfg.components.items():
if isinstance(comp.inputs, dict):
inputs = {n: pipe.node(t) for (n, t) in comp.inputs.items()}
pipe.connect(name, **inputs)
elif not comp.code.startswith("@"):
raise PipelineError(f"component {name} inputs must be dict, not list")

# pass 4: aliases
# pass 5: aliases
for n, t in cfg.aliases.items():
pipe.alias(n, t)

# pass 5: defaults
# pass 6: defaults
for n, t in cfg.defaults.items():
pipe.set_default(n, pipe.node(t))

Expand Down Expand Up @@ -641,9 +697,6 @@ def run(self, *nodes: str | Node[Any], **kwargs: object) -> object:
components. See :ref:`pipeline-execution` for details of the pipeline
execution model.
.. todo::
Add cycle detection.
Args:
nodes:
The component(s) to run.
Expand All @@ -656,6 +709,8 @@ def run(self, *nodes: str | Node[Any], **kwargs: object) -> object:
are returned in a tuple.
Raises:
PipelineError:
when there is a pipeline configuration error (e.g. a cycle).
ValueError:
when one or more required inputs are missing.
TypeError:
Expand All @@ -664,7 +719,9 @@ def run(self, *nodes: str | Node[Any], **kwargs: object) -> object:
exceptions thrown by components are passed through.
"""
if not nodes:
nodes = (self._last_node(),)
if self._last is None: # pragma: nocover
raise PipelineError("pipeline has no components")
nodes = (self._last,)
state = self.run_all(*nodes, **kwargs)
results = [state[self.node(n).name] for n in nodes]

Expand Down Expand Up @@ -719,11 +776,6 @@ def run_all(self, *nodes: str | Node[Any], **kwargs: object) -> PipelineState:
meta=self.meta(),
)

def _last_node(self) -> Node[object]:
if not self._nodes:
raise RuntimeError("pipeline is empty")
return list(self._nodes.values())[-1]

def _check_available_name(self, name: str) -> None:
if name in self._nodes or name in self._aliases:
raise ValueError(f"pipeline already has node {name}")
Expand Down
45 changes: 42 additions & 3 deletions lenskit/lenskit/pipeline/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,14 @@
# pyright: strict
from __future__ import annotations

import base64
import pickle
from collections import OrderedDict
from hashlib import sha256
from types import FunctionType
from typing import Literal

from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, JsonValue, ValidationError
from typing_extensions import Any, Optional, Self

from .components import ConfigurableComponent
Expand All @@ -34,6 +37,8 @@ class PipelineConfig(BaseModel):
"Pipeline components, with their configurations and wiring."
aliases: dict[str, str] = Field(default_factory=dict)
"Pipeline node aliases."
literals: dict[str, PipelineLiteral] = Field(default_factory=dict)
"Literals"


class PipelineMeta(BaseModel):
Expand Down Expand Up @@ -92,7 +97,10 @@ class PipelineComponent(BaseModel):
"""

@classmethod
def from_node(cls, node: ComponentNode[Any]) -> Self:
def from_node(cls, node: ComponentNode[Any], mapping: dict[str, str] | None = None) -> Self:
if mapping is None:
mapping = {}

comp = node.component
if isinstance(comp, FunctionType):
ctype = comp
Expand All @@ -103,7 +111,38 @@ def from_node(cls, node: ComponentNode[Any]) -> Self:

config = comp.get_config() if isinstance(comp, ConfigurableComponent) else None

return cls(code=code, config=config, inputs=node.connections)
return cls(
code=code,
config=config,
inputs={n: mapping.get(t, t) for (n, t) in node.connections.items()},
)


class PipelineLiteral(BaseModel):
"""
Literal nodes represented in the pipeline.
"""

encoding: Literal["json", "base85"]
value: JsonValue

@classmethod
def represent(cls, data: Any) -> Self:
try:
return cls(encoding="json", value=data)
except ValidationError:
# data is not basic JSON values, so let's pickle it
dbytes = pickle.dumps(data)
return cls(encoding="base85", value=base64.b85encode(dbytes).decode("ascii"))

def decode(self) -> Any:
"Decode the represented literal."
match self.encoding:
case "json":
return self.value
case "base85":
assert isinstance(self.value, str)
return pickle.loads(base64.b85decode(self.value))


def hash_config(config: BaseModel) -> str:
Expand Down
Loading

0 comments on commit 37c9995

Please sign in to comment.