Skip to content

Commit

Permalink
implement fallback nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Aug 1, 2024
1 parent 2767d7b commit 17d1f5e
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 9 deletions.
62 changes: 54 additions & 8 deletions lenskit/lenskit/pipeline/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,21 @@ class InputNode(Node[ND], Generic[ND]):
"""


class FallbackNode(Node[ND], Generic[ND]):
"""
Node for trying several nodes in turn.
"""

__match_args__ = ("name", "options")

options: list[Node[ND | None]]
"The nodes that can possibly fulfil this node."

def __init__(self, name: str, options: list[Node[ND | None]]):
super().__init__(name)
self.options = options


class LiteralNode(Node[ND], Generic[ND]):
__match_args__ = ("name", "value")
value: ND
Expand Down Expand Up @@ -322,7 +337,10 @@ def use_first_of(self, name: str, *nodes: Node[T | None]) -> Node[T]:
in missing scores for individual items that A did not score. A specific
itemwise fallback component is needed for such an operation.
"""
raise NotImplementedError()
node = FallbackNode(name, list(nodes))
self._nodes[name] = node
self._clear_caches()
return node

def connect(self, obj: str | Node[Any], **inputs: Node[Any] | str | object):
"""
Expand Down Expand Up @@ -430,18 +448,20 @@ def run(self, *nodes: str | Node[Any], **kwargs: object) -> object:
if not ret:
ret = [self._last_node()]

# set up a stack of nodes to look at
# set up a stack of nodes to look at (with their required/optional status)
# we traverse the graph with this
needed = list(reversed(ret))
needed = [(r, True) for r in reversed(ret)]

# the main loop — keep resolving pipeline nodes until we're done
while needed:
node = needed[-1]
node, required = needed[-1]
if node.name in state:
# the node is computed, we're done
needed.pop()
continue

_log.debug("processing node %s (required=%s)", node, required)

match node:
case LiteralNode(name, value):
# literal nodes are ready to put on the state
Expand All @@ -450,7 +470,7 @@ def run(self, *nodes: str | Node[Any], **kwargs: object) -> object:
case ComponentNode(name, comp, inputs, wiring):
# check that (1) the node is fully wired, and (2) its inputs are all computed
ready = True
for k in inputs.keys():
for k, it in inputs.items():
if k in wiring:
wired = wiring[k]
elif k in self._defaults:
Expand All @@ -464,7 +484,12 @@ def run(self, *nodes: str | Node[Any], **kwargs: object) -> object:
ready = False
# it is fine to queue the same node twice — it will
# be quickly skipped the second time
needed.append(wired)
if it is None:
required = True
else:
required = not isinstance(None, it)
_log.debug("%s: queueing input %s (type %s)", node, k, it)
needed.append((wired, required))

if ready:
_log.debug("running %s (%s)", node, comp)
Expand All @@ -488,15 +513,36 @@ def run(self, *nodes: str | Node[Any], **kwargs: object) -> object:
try:
val = kwargs[name]
except KeyError:
raise RuntimeError(f"input {name} not specified")
if required:
raise RuntimeError(f"input {name} not specified")
else:
val = None

if types and not is_compatible_data(val, *types):
if required and types and not is_compatible_data(val, *types):
raise TypeError(
f"invalid data for input {name} (expected {types}, got {type(val)})"
)
state[name] = val
needed.pop()

case FallbackNode(name, options):
status = "failed"
for opt in options:
if opt.name not in state:
# try to get this item
needed.append((opt, False))
status = "pending"
break
elif state[opt.name] is not None:
# we have a value
state[name] = state[opt.name]
status = "fulfilled"
needed.pop()
break

if status == "failed":
raise RuntimeError(f"no alternative for {node} was fulfilled")

case _:
raise RuntimeError(f"invalid node {node}")

Expand Down
3 changes: 2 additions & 1 deletion lenskit/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,7 +333,8 @@ def add(x: int, y: int) -> int:
fb = pipe.use_first_of("fill-operand", b, nn)
na = pipe.add_component("add", add, x=nd, y=fb)

assert pipe.run(na, a=3) == 0
# 3 * 2 + -3 = 3
assert pipe.run(na, a=3) == 3


def test_fallback_only_run_if_needed():
Expand Down

0 comments on commit 17d1f5e

Please sign in to comment.