-
Notifications
You must be signed in to change notification settings - Fork 64
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add PipelineState to get entire state
- Loading branch information
1 parent
967655d
commit d058a3d
Showing
4 changed files
with
215 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
# pyright: strict | ||
from collections.abc import Mapping | ||
from typing import Any, Iterator | ||
|
||
|
||
class PipelineState(Mapping[str, Any]): | ||
""" | ||
Full results of running a pipeline. A pipeline state is a dictionary | ||
mapping node names to their results; it is implemented as a separate class | ||
instead of directly using a dictionary to allow data to be looked up by node | ||
aliases in addition to original node names (and to be read-only). | ||
Client code will generally not construct this class directly. | ||
Args: | ||
state: | ||
The pipeline state to wrap. The state object stores a reference to | ||
this dictionary. | ||
aliases: | ||
Dictionary of node aliases. | ||
default: | ||
The name of the default node (whose data should be returned by | ||
:attr:`default` ). | ||
""" | ||
|
||
_state: dict[str, Any] | ||
_aliases: dict[str, str] | ||
_default: str | None = None | ||
|
||
def __init__( | ||
self, | ||
state: dict[str, Any] | None = None, | ||
aliases: dict[str, str] | None = None, | ||
default: str | None = None, | ||
) -> None: | ||
self._state = state if state is not None else {} | ||
self._aliases = aliases if aliases is not None else {} | ||
self._default = default | ||
if default is not None and default not in self: | ||
raise ValueError("default node is not in state or aliases") | ||
|
||
@property | ||
def default(self) -> Any: | ||
""" | ||
Return the data from of the default node (typically the last node run). | ||
Returns: | ||
The data associated with the default node. | ||
Raises: | ||
ValueError: if there is no specified default node. | ||
""" | ||
if self._default is not None: | ||
return self[self._default] | ||
else: | ||
raise ValueError("pipeline state has no default value") | ||
|
||
@property | ||
def default_node(self) -> str | None: | ||
"Return the name of the default node (typically the last node run)." | ||
return self._default | ||
|
||
def __len__(self): | ||
return len(self._state) | ||
|
||
def __contains__(self, key: object) -> bool: | ||
if key in self._state: | ||
return True | ||
if key in self._aliases: | ||
return self._aliases[key] in self | ||
else: | ||
return False | ||
|
||
def __getitem__(self, key: str) -> Any: | ||
if key in self._state: | ||
return self._state[key] | ||
elif key in self._aliases: | ||
return self[self._aliases[key]] | ||
else: | ||
raise KeyError(f"pipeline node <{key}>") | ||
|
||
def __iter__(self) -> Iterator[str]: | ||
return iter(self._state) | ||
|
||
def __str__(self) -> str: | ||
return f"<PipelineState with {len(self)} nodes>" | ||
|
||
def __repr__(self) -> str: | ||
return f"<PipelineState with nodes {set(self._state.keys())}>" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from pytest import raises | ||
|
||
from lenskit.pipeline import PipelineState | ||
|
||
|
||
def test_empty(): | ||
state = PipelineState() | ||
assert len(state) == 0 | ||
assert not state | ||
assert "scroll" not in state | ||
|
||
with raises(KeyError): | ||
state["scroll"] | ||
|
||
|
||
def test_single_value(): | ||
state = PipelineState({"scroll": "HACKEM MUCHE"}) | ||
assert len(state) == 1 | ||
assert state | ||
assert "scroll" in state | ||
assert state["scroll"] == "HACKEM MUCHE" | ||
|
||
|
||
def test_alias(): | ||
state = PipelineState({"scroll": "HACKEM MUCHE"}, {"book": "scroll"}) | ||
assert len(state) == 1 | ||
assert state | ||
assert "scroll" in state | ||
assert "book" in state | ||
assert state["book"] == "HACKEM MUCHE" | ||
|
||
|
||
def test_alias_missing(): | ||
state = PipelineState({"scroll": "HACKEM MUCHE"}, {"book": "manuscript"}) | ||
assert len(state) == 1 | ||
assert state | ||
assert "scroll" in state | ||
assert "book" not in state |