Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor assertion chains to be valid assertions #1580

Merged
merged 1 commit into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 0 additions & 21 deletions betty/functools.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,27 +78,6 @@ def slice_to_range(indices: slice, iterable: Sized) -> Iterable[int]:
return range(start, stop, step)


class _Result(Generic[T]):
def __init__(self, value: T | None, _error: BaseException | None = None):
assert not _error or value is None
self._value = value
self._error = _error

@property
def value(self) -> T:
if self._error:
raise self._error
return cast(T, self._value)

def map(self, f: Callable[[T], U]) -> _Result[U]:
if self._error:
return cast(_Result[U], self)
try:
return _Result(f(self.value))
except Exception as e:
return _Result(None, e)


def filter_suppress(
raising_filter: Callable[[T], Any],
exception_type: type[BaseException],
Expand Down
147 changes: 18 additions & 129 deletions betty/serde/load.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from __future__ import annotations

from dataclasses import dataclass
from pathlib import Path
from typing import (
Iterator,
Expand All @@ -19,7 +20,6 @@
TypeAlias,
)

from betty.functools import _Result
from betty.locale import LocaleNotFoundError, get_data, Str
from betty.model import (
Entity,
Expand All @@ -30,7 +30,6 @@
)
from betty.serde.dump import DumpType, DumpTypeT, Void
from betty.serde.error import SerdeError, SerdeErrorCollection
from betty.warnings import deprecated

if TYPE_CHECKING:
from betty.app.extension import Extension
Expand Down Expand Up @@ -87,6 +86,9 @@ class AssertionChain(Generic[_AssertionValueT, _AssertionReturnT]):
value and, if the assertions pass, return an output value. Each chain may be (re)used as many
times as needed.

Assertion chains are assertions themselves: you can use a chain wherever you can use a 'plain'
assertion.

Assertions chains are `monads <https://en.wikipedia.org/wiki/Monad_(functional_programming)>`_.
While uncommon in Python, this allows us to create these chains in a type-safe way, and tools
like mypy can confirm that all assertions in any given chain are compatible with each other.
Expand All @@ -96,108 +98,34 @@ def __init__(self, _assertion: Assertion[_AssertionValueT, _AssertionReturnT]):
self._assertion = _assertion

def extend(
self, _assertion: Assertion[_AssertionReturnT, _AssertionsExtendReturnT]
self, assertion: Assertion[_AssertionReturnT, _AssertionsExtendReturnT]
) -> AssertionChain[_AssertionValueT, _AssertionsExtendReturnT]:
"""
Extend the chain with the given assertion.
"""
return _AssertionChainExtension(_assertion, self)
return AssertionChain(lambda value: assertion(self(value)))

def __or__(
self, _assertion: Assertion[_AssertionReturnT, _AssertionsExtendReturnT]
) -> AssertionChain[_AssertionValueT, _AssertionsExtendReturnT]:
return self.extend(_assertion)

def __call__(self, value: _AssertionValueT) -> _Result[_AssertionReturnT]:
def __call__(self, value: _AssertionValueT) -> _AssertionReturnT:
"""
Invoke the chain with a value.

This method may be called more than once.
"""
return _Result(value).map(self._assertion)

@property
def assertion(self) -> Assertion[_AssertionValueT, _AssertionReturnT]:
"""
The assertion for this chain.
"""
return lambda value: self(value).value


class _AssertionChainExtension(
AssertionChain[_AssertionValueT, _AssertionReturnT],
Generic[_AssertionValueT, _AssertionReturnT],
):
def __init__(
self,
assertion_extension: Assertion[
_AssertionsIntermediateValueReturnT, _AssertionReturnT
],
assertion_chain: AssertionChain[
_AssertionValueT, _AssertionsIntermediateValueReturnT
],
):
super().__init__(
lambda value: assertion_chain(value).map(assertion_extension).value
)


@deprecated(
"This class is deprecated as of Betty 0.3.8, and will be removed in Betty 0.4.x. Instead, use :py:class:`betty.serde.load.AssertionChain`."
)
class Assertions( # noqa: D101
AssertionChain[_AssertionValueT, _AssertionReturnT],
Generic[_AssertionValueT, _AssertionReturnT],
):
pass


AssertionType: TypeAlias = (
AssertionChain[_AssertionValueT, _AssertionReturnT]
| Assertion[_AssertionValueT, _AssertionReturnT]
)
return self._assertion(value)


@dataclass(frozen=True)
class _Field(Generic[_AssertionValueT, _AssertionReturnT]):
@overload
def __init__(
self,
name: str,
assertion: AssertionChain[_AssertionValueT, _AssertionReturnT] | None = None,
):
pass

@overload
def __init__(
self,
name: str,
assertion: Assertion[_AssertionValueT, _AssertionReturnT] | None = None,
):
pass

def __init__(
self,
name: str,
assertion: AssertionChain[_AssertionValueT, _AssertionReturnT]
| Assertion[_AssertionValueT, _AssertionReturnT]
| None = None,
):
self._name = name
self._assertion = (
assertion
if assertion is None or isinstance(assertion, AssertionChain)
else AssertionChain(assertion)
)

@property
def name(self) -> str:
return self._name

@property
def assertion(self) -> AssertionChain[_AssertionValueT, _AssertionReturnT] | None:
return self._assertion
name: str
assertion: Assertion[_AssertionValueT, _AssertionReturnT] | None = None


@dataclass(frozen=True)
class RequiredField(
Generic[_AssertionValueT, _AssertionReturnT],
_Field[_AssertionValueT, _AssertionReturnT],
Expand All @@ -209,6 +137,7 @@ class RequiredField(
pass # pragma: no cover


@dataclass(frozen=True)
class OptionalField(
Generic[_AssertionValueT, _AssertionReturnT],
_Field[_AssertionValueT, _AssertionReturnT],
Expand Down Expand Up @@ -387,34 +316,8 @@ def _assert_dict(value: Any) -> dict[str, Any]:

return _assert_dict

def assert_assertions(
self, assertions: AssertionType[_AssertionValueT, _AssertionReturnT]
) -> Assertion[_AssertionValueT, _AssertionReturnT]:
"""
Assert that an assertions chain passes, and return the chain's output.
"""

def _assert_assertions(value: _AssertionValueT) -> _AssertionReturnT:
if isinstance(assertions, AssertionChain):
return assertions(value).value
return assertions(value)

return _assert_assertions

@overload
def assert_sequence(
self, item_assertion: AssertionChain[Any, _AssertionReturnT]
) -> Assertion[Any, MutableSequence[_AssertionReturnT]]:
pass

@overload
def assert_sequence(
self, item_assertion: Assertion[Any, _AssertionReturnT]
) -> Assertion[Any, MutableSequence[_AssertionReturnT]]:
pass

def assert_sequence(
self, item_assertion: AssertionType[Any, _AssertionReturnT]
) -> Assertion[Any, MutableSequence[_AssertionReturnT]]:
"""
Assert that a value is a sequence and that all item values are of the given type.
Expand All @@ -426,26 +329,14 @@ def _assert_sequence(value: Any) -> MutableSequence[_AssertionReturnT]:
with SerdeErrorCollection().assert_valid() as errors:
for value_item_index, value_item_value in enumerate(list_value):
with errors.catch(Str.plain(value_item_index)):
sequence.append(
self.assert_assertions(item_assertion)(value_item_value)
)
sequence.append(item_assertion(value_item_value))
return sequence

return _assert_sequence

@overload
def assert_mapping(
self, item_assertion: AssertionChain[Any, _AssertionReturnT]
) -> Assertion[Any, MutableMapping[str, _AssertionReturnT]]:
pass

@overload
def assert_mapping(
self, item_assertion: Assertion[Any, _AssertionReturnT]
) -> Assertion[Any, MutableMapping[str, _AssertionReturnT]]:
pass

def assert_mapping(self, item_assertion):
"""
Assert that a value is a key-value mapping and assert that all item values are of the given type.
"""
Expand All @@ -456,9 +347,7 @@ def _assert_mapping(value: Any) -> MutableMapping[str, _AssertionReturnT]:
with SerdeErrorCollection().assert_valid() as errors:
for value_item_key, value_item_value in dict_value.items():
with errors.catch(Str.plain(value_item_key)):
mapping[value_item_key] = self.assert_assertions(
item_assertion
)(value_item_value)
mapping[value_item_key] = item_assertion(value_item_value)
return mapping

return _assert_mapping
Expand All @@ -476,9 +365,9 @@ def _assert_fields(value: Any) -> MutableMapping[str, Any]:
with errors.catch(Str.plain(field.name)):
if field.name in value_dict:
if field.assertion:
mapping[field.name] = self.assert_assertions(
field.assertion
)(value_dict[field.name])
mapping[field.name] = field.assertion(
value_dict[field.name]
)
elif isinstance(field, RequiredField):
raise AssertionFailed(Str._("This field is required."))
return mapping
Expand Down
8 changes: 4 additions & 4 deletions betty/tests/serde/test_load.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,21 +26,21 @@
class TestAssertionChain:
async def test___call__(self) -> None:
sut = AssertionChain[int, int](lambda value: value)
assert sut(123).value == 123
assert sut(123) == 123

async def test___or__(self) -> None:
sut = AssertionChain[int, int](lambda value: value)
sut |= lambda value: 2 * value
assert sut.assertion(123) == 246
assert sut(123) == 246

async def test_assertion(self) -> None:
sut = AssertionChain[int, int](lambda value: value)
assert sut.assertion(123) == 123
assert sut(123) == 123

async def test_extend(self) -> None:
sut = AssertionChain[int, int](lambda value: value)
sut = sut.extend(lambda value: 2 * value)
assert sut.assertion(123) == 246
assert sut(123) == 246


def _always_valid(value: _T) -> _T:
Expand Down
Loading