Skip to content

Commit

Permalink
Add PipelineState to get entire state
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Aug 7, 2024
1 parent 967655d commit d058a3d
Show file tree
Hide file tree
Showing 4 changed files with 215 additions and 4 deletions.
48 changes: 44 additions & 4 deletions lenskit/lenskit/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
TrainableComponent,
)
from .nodes import ND, ComponentNode, FallbackNode, InputNode, LiteralNode, Node
from .state import PipelineState

__all__ = [
"Pipeline",
Expand Down Expand Up @@ -460,18 +461,57 @@ def run(self, *nodes: str | Node[Any], **kwargs: object) -> object:
other:
exceptions thrown by components are passed through.
"""
from .runner import PipelineRunner

runner = PipelineRunner(self, kwargs)
if not nodes:
nodes = (self._last_node(),)
results = [runner.run(self.node(n)) for n in nodes]
state = self.run_all(*nodes, **kwargs)
results = [state[self.node(n).name] for n in nodes]

if len(results) > 1:
return tuple(results)
else:
return results[0]

def run_all(self, *nodes: str | Node[Any], **kwargs: object) -> PipelineState:
"""
Run all nodes in the pipeline, or all nodes required to fulfill the
requested node, and return a mapping with the full pipeline state (the
data attached to each node). This is useful in cases where client code
needs to be able to inspect the data at arbitrary steps of the pipeline.
It differs from :meth:`run` in two ways:
1. It returns the data from all nodes as a mapping (dictionary-like
object), not just the specified nodes as a tuple.
2. If no nodes are specified, it runs *all* nodes instead of only the
last node. This has the consequence of running nodes that are not
required to fulfill the last node (such scenarios typically result
from using :meth:`use_first_of`).
Args:
nodes:
The nodes to run, as positional arguments (if no nodes are
specified, this method runs all nodes).
kwargs:
The inputs.
Returns:
The full pipeline state, with :attr:`~PipelineState.default` set to
the last node specified (either the last node in `nodes`, or the
last node added to the pipeline).
"""
from .runner import PipelineRunner

runner = PipelineRunner(self, kwargs)
node_list = [self.node(n) for n in nodes]
if not node_list:
node_list = self.nodes

last = None
for node in node_list:
runner.run(node)
last = node.name

return PipelineState(runner.state, {a: t.name for (a, t) in self._aliases.items()}, last)

def _last_node(self) -> Node[object]:
if not self._nodes:
raise RuntimeError("pipeline is empty")
Expand Down
89 changes: 89 additions & 0 deletions lenskit/lenskit/pipeline/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
# pyright: strict
from collections.abc import Mapping
from typing import Any, Iterator


class PipelineState(Mapping[str, Any]):
"""
Full results of running a pipeline. A pipeline state is a dictionary
mapping node names to their results; it is implemented as a separate class
instead of directly using a dictionary to allow data to be looked up by node
aliases in addition to original node names (and to be read-only).
Client code will generally not construct this class directly.
Args:
state:
The pipeline state to wrap. The state object stores a reference to
this dictionary.
aliases:
Dictionary of node aliases.
default:
The name of the default node (whose data should be returned by
:attr:`default` ).
"""

_state: dict[str, Any]
_aliases: dict[str, str]
_default: str | None = None

def __init__(
self,
state: dict[str, Any] | None = None,
aliases: dict[str, str] | None = None,
default: str | None = None,
) -> None:
self._state = state if state is not None else {}
self._aliases = aliases if aliases is not None else {}
self._default = default
if default is not None and default not in self:
raise ValueError("default node is not in state or aliases")

@property
def default(self) -> Any:
"""
Return the data from of the default node (typically the last node run).
Returns:
The data associated with the default node.
Raises:
ValueError: if there is no specified default node.
"""
if self._default is not None:
return self[self._default]
else:
raise ValueError("pipeline state has no default value")

@property
def default_node(self) -> str | None:
"Return the name of the default node (typically the last node run)."
return self._default

def __len__(self):
return len(self._state)

def __contains__(self, key: object) -> bool:
if key in self._state:
return True
if key in self._aliases:
return self._aliases[key] in self
else:
return False

def __getitem__(self, key: str) -> Any:
if key in self._state:
return self._state[key]
elif key in self._aliases:
return self[self._aliases[key]]
else:
raise KeyError(f"pipeline node <{key}>")

def __iter__(self) -> Iterator[str]:
return iter(self._state)

def __str__(self) -> str:
return f"<PipelineState with {len(self)} nodes>"

def __repr__(self) -> str:
return f"<PipelineState with nodes {set(self._state.keys())}>"
44 changes: 44 additions & 0 deletions lenskit/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,6 +353,50 @@ def add(x: int, y: int) -> int:
assert pipe.run("result", a=1, b=7) == 9


def test_run_all():
pipe = Pipeline()
a = pipe.create_input("a", int)
b = pipe.create_input("b", int)

def double(x: int) -> int:
return x * 2

def add(x: int, y: int) -> int:
return x + y

nd = pipe.add_component("double", double, x=a)
na = pipe.add_component("add", add, x=nd, y=b)

pipe.alias("result", na)

state = pipe.run_all(a=1, b=7)
assert state["double"] == 2
assert state["add"] == 9
assert state["result"] == 9


def test_run_all_limit():
pipe = Pipeline()
a = pipe.create_input("a", int)
b = pipe.create_input("b", int)

def double(x: int) -> int:
return x * 2

def add(x: int, y: int) -> int:
return x + y

nd = pipe.add_component("double", double, x=a)
na = pipe.add_component("add", add, x=nd, y=b)

pipe.alias("result", na)

state = pipe.run_all("double", a=1, b=7)
assert state["double"] == 2
assert "add" not in state
assert "result" not in state


def test_connect_literal():
pipe = Pipeline()
a = pipe.create_input("a", int)
Expand Down
38 changes: 38 additions & 0 deletions lenskit/tests/test_pipeline_state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from pytest import raises

from lenskit.pipeline import PipelineState


def test_empty():
state = PipelineState()
assert len(state) == 0
assert not state
assert "scroll" not in state

with raises(KeyError):
state["scroll"]


def test_single_value():
state = PipelineState({"scroll": "HACKEM MUCHE"})
assert len(state) == 1
assert state
assert "scroll" in state
assert state["scroll"] == "HACKEM MUCHE"


def test_alias():
state = PipelineState({"scroll": "HACKEM MUCHE"}, {"book": "scroll"})
assert len(state) == 1
assert state
assert "scroll" in state
assert "book" in state
assert state["book"] == "HACKEM MUCHE"


def test_alias_missing():
state = PipelineState({"scroll": "HACKEM MUCHE"}, {"book": "manuscript"})
assert len(state) == 1
assert state
assert "scroll" in state
assert "book" not in state

0 comments on commit d058a3d

Please sign in to comment.