Skip to content

Commit

Permalink
rewrite pipeline run for readability
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Aug 2, 2024
1 parent 6e52078 commit 89f75ed
Show file tree
Hide file tree
Showing 4 changed files with 210 additions and 128 deletions.
165 changes: 41 additions & 124 deletions lenskit/lenskit/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from typing_extensions import Any, LiteralString, TypeVar, overload

from lenskit.data import Dataset
from lenskit.pipeline.types import is_compatible_data

from .components import Component, ConfigurableComponent, TrainableComponent
from .nodes import ND, ComponentNode, FallbackNode, InputNode, LiteralNode, Node
Expand Down Expand Up @@ -104,7 +103,7 @@ def node(self, node: str | Node[Any]) -> Node[object]:
else:
raise KeyError(f"node {node}")

def create_input(self, name: LiteralString, *types: type[T]) -> Node[T]:
def create_input(self, name: LiteralString, *types: type[T] | None) -> Node[T]:
"""
Create an input node for the pipeline. Pipelines expect their inputs to
be provided when they are run.
Expand All @@ -127,7 +126,7 @@ def create_input(self, name: LiteralString, *types: type[T]) -> Node[T]:
"""
self._check_available_name(name)

node = InputNode[Any](name, types=set(types))
node = InputNode[Any](name, types=set((t if t is not None else type[None]) for t in types))
self._nodes[name] = node
self._clear_caches()
return node
Expand Down Expand Up @@ -158,6 +157,12 @@ def set_default(self, name: LiteralString, node: Node[Any] | object) -> None:
self._defaults[name] = node
self._clear_caches()

def get_default(self, name: str) -> Node[Any] | None:
"""
Get the default wiring for an input name.
"""
return self._defaults.get(name, None)

def alias(self, alias: str, node: Node[Any] | str) -> None:
"""
Create an alias for a node. After aliasing, the node can be retrieved
Expand Down Expand Up @@ -236,10 +241,12 @@ def replace_component(
def use_first_of(self, name: str, *nodes: Node[T | None]) -> Node[T]:
"""
Create a new node whose value is the first defined (not ``None``) value
of the specified nodes. This is used for things like filling in optional
pipeline inputs. For example, if you want the pipeline to take candidate
items through an ``items`` input, but look them up from the user's history
and the training data if ``items`` is not supplied, you would do:
of the specified nodes. If a node is an input node and its value is not
supplied, it is treated as ``None`` in this case instead of failing the
run. This method is used for things like filling in optional pipeline
inputs. For example, if you want the pipeline to take candidate items
through an ``items`` input, but look them up from the user's history and
the training data if ``items`` is not supplied, you would do:
.. code:: python
Expand All @@ -258,11 +265,23 @@ def use_first_of(self, name: str, *nodes: Node[T | None]) -> Node[T]:
.. note::
This method does *not* implement item-level fallbacks, only fallbacks at
the level of entire results. That is, you can use it to use component A
as a fallback for B if B returns ``None``, but it will not use B to fill
in missing scores for individual items that A did not score. A specific
itemwise fallback component is needed for such an operation.
This method does not distinguish between an input being unspecified and
explicitly specified as ``None``.
.. note::
This method does *not* implement item-level fallbacks, only
fallbacks at the level of entire results. That is, you can use it
to use component A as a fallback for B if B returns ``None``, but it
will not use B to fill in missing scores for individual items that A
did not score. A specific itemwise fallback component is needed for
such an operation.
Args:
name:
The name of the node.
nodes:
The nodes to try, in order, to satisfy this node.
"""
node = FallbackNode(name, list(nodes))
self._nodes[name] = node
Expand Down Expand Up @@ -364,119 +383,17 @@ def run(self, *nodes: str | Node[Any], **kwargs: object) -> object:
other:
exceptions thrown by components are passed through.
"""
state: dict[str, Any] = {}

ret: list[Node[Any]] | Node[Any] = [self.node(n) for n in nodes]
_log.debug(
"starting run of pipeline with %d nodes, want %s",
len(self._nodes),
[n.name for n in ret],
)
if not ret:
ret = [self._last_node()]

# set up a stack of nodes to look at (with their required/optional status)
# we traverse the graph with this
needed = [(r, True) for r in reversed(ret)]

# the main loop — keep resolving pipeline nodes until we're done
while needed:
node, required = needed[-1]
if node.name in state:
# the node is computed, we're done
needed.pop()
continue

_log.debug("processing node %s (required=%s)", node, required)

match node:
case LiteralNode(name, value):
# literal nodes are ready to put on the state
state[name] = value
needed.pop()
case ComponentNode(name, comp, inputs, wiring):
# check that (1) the node is fully wired, and (2) its inputs are all computed
ready = True
for k, it in inputs.items():
if k in wiring:
wired = wiring[k]
elif k in self._defaults:
wired = self._defaults[k]
else:
raise RuntimeError(f"input {k} for {node} not connected")
wired = self.node(wired)

if wired.name not in state:
# input value not available, queue it up
ready = False
# it is fine to queue the same node twice — it will
# be quickly skipped the second time
if it is None:
required = True
else:
required = not isinstance(None, it)
_log.debug("%s: queueing input %s (type %s)", node, k, it)
needed.append((wired, required))

if ready:
_log.debug("running %s (%s)", node, comp)
# if the node is ready to compute (all inputs in state), we run it.
args = {}
for n in inputs.keys():
if n in wiring:
args[n] = state[wiring[n]]
elif n in self._defaults:
args[n] = state[self._defaults[n].name]
else: # pragma: nocover
raise AssertionError("missing input not caught earlier")
state[name] = comp(**args)
needed.pop()

# fallthrough: the node is not ready, and we have pushed its
# inputs onto the stack. The inputs may be re-pushed, so this
# will never be the last node on the stack at this point

case InputNode(name, types=types):
try:
val = kwargs[name]
except KeyError:
if required:
raise RuntimeError(f"input {name} not specified")
else:
val = None

if required and types and not is_compatible_data(val, *types):
raise TypeError(
f"invalid data for input {name} (expected {types}, got {type(val)})"
)
state[name] = val
needed.pop()

case FallbackNode(name, options):
status = "failed"
for opt in options:
if opt.name not in state:
# try to get this item
needed.append((opt, False))
status = "pending"
break
elif state[opt.name] is not None:
# we have a value
state[name] = state[opt.name]
status = "fulfilled"
needed.pop()
break

if status == "failed":
raise RuntimeError(f"no alternative for {node} was fulfilled")

case _:
raise RuntimeError(f"invalid node {node}")

if len(ret) > 1:
return tuple(state[r.name] for r in ret)
from .runner import PipelineRunner

runner = PipelineRunner(self, kwargs)
if not nodes:
nodes = (self._last_node(),)
results = [runner.run(self.node(n)) for n in nodes]

if len(results) > 1:
return tuple(results)
else:
return state[ret[0].name]
return results[0]

def _last_node(self) -> Node[object]:
if not self._nodes:
Expand Down
14 changes: 10 additions & 4 deletions lenskit/lenskit/pipeline/nodes.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
# This file is part of LensKit.
# Copyright (C) 2018-2023 Boise State University
# Copyright (C) 2023-2024 Drexel University
# Licensed under the MIT license, see LICENSE.md for details.
# SPDX-License-Identifier: MIT

# pyright: strict

import warnings
Expand Down Expand Up @@ -45,14 +51,14 @@ class FallbackNode(Node[ND], Generic[ND]):
Node for trying several nodes in turn.
"""

__match_args__ = ("name", "options")
__match_args__ = ("name", "alternatives")

options: list[Node[ND | None]]
alternatives: list[Node[ND | None]]
"The nodes that can possibly fulfil this node."

def __init__(self, name: str, options: list[Node[ND | None]]):
def __init__(self, name: str, alternatives: list[Node[ND | None]]):
super().__init__(name)
self.options = options
self.alternatives = alternatives


class LiteralNode(Node[ND], Generic[ND]):
Expand Down
133 changes: 133 additions & 0 deletions lenskit/lenskit/pipeline/runner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# This file is part of LensKit.
# Copyright (C) 2018-2023 Boise State University
# Copyright (C) 2023-2024 Drexel University
# Licensed under the MIT license, see LICENSE.md for details.
# SPDX-License-Identifier: MIT

"""
Pipeline runner logic.
"""

# pyright: strict
import logging
from typing import Any, Literal, TypeAlias

from . import Pipeline
from .components import Component
from .nodes import ComponentNode, FallbackNode, InputNode, LiteralNode, Node
from .types import is_compatible_data

_log = logging.getLogger(__name__)
State: TypeAlias = Literal["pending", "in-progress", "finished", "failed"]


class PipelineRunner:
"""
Node status and results for a single pipeline run.
This class operates recursively; pipelines should never be so deep that
recursion fails.
"""

pipe: Pipeline
inputs: dict[str, Any]
status: dict[str, State]
state: dict[str, Any]

def __init__(self, pipe: Pipeline, inputs: dict[str, Any]):
self.pipe = pipe
self.inputs = inputs
self.status = {n.name: "pending" for n in pipe.nodes}
self.state = {}

def run(self, node: Node[Any], *, required: bool = True) -> Any:
"""
Run the pipleline to obtain the results of a node.
"""
status = self.status[node.name]
if status == "finished":
return self.state[node.name]
elif status == "in-progress":
raise RuntimeError(f"pipeline cycle encountered at {node}")
elif status == "failed":
raise RuntimeError(f"{node} previously failed")

_log.debug("processing node %s", node)
self.status[node.name] = "in-progress"
try:
self._run_node(node, required)
self.status[node.name] = "finished"
except Exception as e:
self.status[node.name] = "failed"
raise e

return self.state[node.name]

def _run_node(self, node: Node[Any], required: bool) -> None:
match node:
case LiteralNode(name, value):
self.state[name] = value
case InputNode(name, types=types):
self._inject_input(name, types, required)
case ComponentNode(name, comp, inputs, wiring):
self._run_component(name, comp, inputs, wiring)
case FallbackNode(name, alts):
self._run_fallback(name, alts)
case _:
raise RuntimeError(f"invalid node {node}")

def _inject_input(self, name: str, types: set[type] | None, required: bool) -> None:
val = self.inputs.get(name, None)
if val is None and required and types and not is_compatible_data(None, *types):
raise RuntimeError(f"input {name} not specified")

if val is not None and types and not is_compatible_data(val, *types):
raise TypeError(f"invalid data for input {name} (expected {types}, got {type(val)})")

self.state[name] = val

def _run_component(
self,
name: str,
comp: Component[Any],
inputs: dict[str, type | None],
wiring: dict[str, str],
) -> None:
in_data = {}
_log.debug("processing inputs for component %s", name)
for iname, itype in inputs.items():
src = wiring.get(iname, None)
if src is not None:
snode = self.pipe.node(src)
else:
snode = self.pipe.get_default(iname)

if snode is None:
ival = None
else:
if itype:
required = not is_compatible_data(None, itype)
else:
required = False
ival = self.run(snode, required=required)

if itype and not is_compatible_data(ival, itype):
raise TypeError(
f"input {iname} for component {name}"
f" has invalid type {type(ival)} (expected {itype})"
)

in_data[iname] = ival

_log.debug("running component %s", name)
self.state[name] = comp(**in_data)

def _run_fallback(self, name: str, alternatives: list[Node[Any]]) -> None:
for alt in alternatives:
val = self.run(alt, required=False)
if val is not None:
self.state[name] = val
return

# got this far, no alternatives
raise RuntimeError(f"no alternative for {name} returned data")
Loading

0 comments on commit 89f75ed

Please sign in to comment.