Skip to content

Commit

Permalink
clean up component replacement tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mdekstrand committed Aug 1, 2024
1 parent 6655b8b commit 2767d7b
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 5 deletions.
18 changes: 15 additions & 3 deletions lenskit/lenskit/pipeline/_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,17 @@ def replace_component(
inputs), but any connections that use the old component to supply an
input will use the new component instead.
"""
raise NotImplementedError()
if isinstance(name, Node):
name = name.name

node = ComponentNode(name, obj)
self._nodes[name] = node
self._components[name] = obj

self.connect(node, **inputs)

self._clear_caches()
return node

def use_first_of(self, name: str, *nodes: Node[T | None]) -> Node[T]:
"""
Expand Down Expand Up @@ -424,6 +434,7 @@ def run(self, *nodes: str | Node[Any], **kwargs: object) -> object:
# we traverse the graph with this
needed = list(reversed(ret))

# the main loop — keep resolving pipeline nodes until we're done
while needed:
node = needed[-1]
if node.name in state:
Expand Down Expand Up @@ -456,7 +467,7 @@ def run(self, *nodes: str | Node[Any], **kwargs: object) -> object:
needed.append(wired)

if ready:
_log.debug("running %s", node)
_log.debug("running %s (%s)", node, comp)
# if the node is ready to compute (all inputs in state), we run it.
args = {}
for n in inputs.keys():
Expand Down Expand Up @@ -485,6 +496,7 @@ def run(self, *nodes: str | Node[Any], **kwargs: object) -> object:
)
state[name] = val
needed.pop()

case _:
raise RuntimeError(f"invalid node {node}")

Expand All @@ -505,7 +517,7 @@ def _check_available_name(self, name: str) -> None:
def _check_member_node(self, node: Node[Any]) -> None:
nw = self._nodes.get(node.name)
if nw is not node:
raise ValueError(f"node {node} not in graph")
raise RuntimeError(f"node {node} not in pipeline")

def _clear_caches(self):
pass
8 changes: 6 additions & 2 deletions lenskit/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,9 +176,13 @@ def add(x: int, y: int) -> int:

nt = pipe.replace_component("double", triple, x=a)

assert pipe.run(a=1, b=7) == 9
assert pipe.run(na, a=3, b=7) == 13
# run through the end
assert pipe.run(a=1, b=7) == 10
assert pipe.run(na, a=3, b=7) == 16
# run only the first component
assert pipe.run(nt, a=3, b=7) == 9

# old node should be missing!
with raises(RuntimeError, match="not in pipeline"):
pipe.run(nd, a=3, b=7)

Expand Down

0 comments on commit 2767d7b

Please sign in to comment.