Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make pipeline fallback work when nodes fail transitively #464

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 19 additions & 13 deletions lenskit/lenskit/pipeline/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,23 +257,21 @@ def use_first_of(self, name: str, *nodes: Node[T | None]) -> Node[T]:

.. code:: python

pipe = Pipeline()
# allow candidate items to be optionally specified
items = pipe.create_input('items', list[EntityId], None)
# find candidates from the training data (optional)
lookup_candidates = pipe.add_component(
'select-candidates',
UnratedTrainingItemsCandidateSelector(),
pipe = Pipeline() # allow candidate items to be optionally specified
items = pipe.create_input('items', list[EntityId], None) # find
candidates from the training data (optional) lookup_candidates =
pipe.add_component(
'select-candidates', UnratedTrainingItemsCandidateSelector(),
user=history,
)
# if the client provided items as a pipeline input, use those; otherwise
# use the candidate selector we just configured.
candidates = pipe.use_first_of('candidates', items, lookup_candidates)
) # if the client provided items as a pipeline input, use those;
otherwise # use the candidate selector we just configured.
candidates = pipe.use_first_of('candidates', items,
lookup_candidates)

.. note::

This method does not distinguish between an input being unspecified and
explicitly specified as ``None``.
This method does not distinguish between an input being unspecified
and explicitly specified as ``None``.

.. note::

Expand All @@ -284,6 +282,14 @@ def use_first_of(self, name: str, *nodes: Node[T | None]) -> Node[T]:
did not score. A specific itemwise fallback component is needed for
such an operation.

.. note::

If one of the fallback elements is a component ``A`` that depends on
another component or input ``B``, and ``B`` is missing or returns
``None`` such that ``A`` would usually fail, then ``A`` will be
skipped and the fallback will move on to the next node. This works
with arbitrarily-deep transitive chains.

Args:
name:
The name of the node.
Expand Down
23 changes: 17 additions & 6 deletions lenskit/lenskit/pipeline/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,13 @@ def run(self, node: Node[Any], *, required: bool = True) -> Any:
self.status[node.name] = "failed"
raise e

return self.state[node.name]
try:
return self.state[node.name]
except KeyError as e:
if required:
raise e
else:
return None

def _run_node(self, node: Node[Any], required: bool) -> None:
match node:
Expand All @@ -71,7 +77,7 @@ def _run_node(self, node: Node[Any], required: bool) -> None:
case InputNode(name, types=types):
self._inject_input(name, types, required)
case ComponentNode(name, comp, inputs, wiring):
self._run_component(name, comp, inputs, wiring)
self._run_component(name, comp, inputs, wiring, required)
case FallbackNode(name, alts):
self._run_fallback(name, alts)
case _: # pragma: nocover
Expand All @@ -93,6 +99,7 @@ def _run_component(
comp: Component[Any],
inputs: dict[str, type | None],
wiring: dict[str, str],
required: bool,
) -> None:
in_data = {}
_log.debug("processing inputs for component %s", name)
Expand All @@ -106,11 +113,15 @@ def _run_component(
if snode is None:
ival = None
else:
if itype:
required = not is_compatible_data(None, itype)
if required and itype:
ireq = not is_compatible_data(None, itype)
else:
required = False
ival = self.run(snode, required=required)
ireq = False
ival = self.run(snode, required=ireq)

# bail out if we're trying to satisfy a non-required dependency
if ival is None and itype and not is_compatible_data(None, itype) and not required:
return None

if itype and not is_compatible_data(ival, itype):
raise TypeError(
Expand Down
62 changes: 62 additions & 0 deletions lenskit/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,68 @@ def add(x: int, y: int) -> int:
pipe.run(na, a=3)


def test_fallback_transitive():
"test that a fallback works if a dependency's dependency fails"
pipe = Pipeline()
ia = pipe.create_input("a", int)
ib = pipe.create_input("b", int)

def double(x: int) -> int:
return 2 * x

# two components, each with a different input
c1 = pipe.add_component("double-a", double, x=ia)
c2 = pipe.add_component("double-b", double, x=ib)
# use the first that succeeds
c = pipe.use_first_of("result", c1, c2)

# omitting the first input should result in the second component
assert pipe.run(c, b=17) == 34


def test_fallback_transitive_deeper():
"deeper transitive fallback test"
pipe = Pipeline()
a = pipe.create_input("a", int)
b = pipe.create_input("b", int)

def negative(x: int) -> int:
return -x

def double(x: int) -> int:
return x * 2

nd = pipe.add_component("double", double, x=a)
nn = pipe.add_component("negate", negative, x=nd)
nr = pipe.use_first_of("fill-operand", nn, b)

assert pipe.run(nr, b=8) == 8


def test_fallback_transitive_nodefail():
"deeper transitive fallback test"
pipe = Pipeline()
a = pipe.create_input("a", int)
b = pipe.create_input("b", int)

def negative(x: int) -> int | None:
# make this return None in some cases to trigger failure
if x >= 0:
return -x
else:
return None

def double(x: int) -> int:
return x * 2

nd = pipe.add_component("double", double, x=a)
nn = pipe.add_component("negate", negative, x=nd)
nr = pipe.use_first_of("fill-operand", nn, b)

assert pipe.run(nr, a=2, b=8) == -4
assert pipe.run(nr, a=-7, b=8) == 8


def test_train(ml_ds: Dataset):
pipe = Pipeline()
item = pipe.create_input("item", int)
Expand Down
Loading