diff --git a/pyproject.toml b/pyproject.toml index 10465de..196bb4e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -54,7 +54,7 @@ known-first-party = ["sortedcontainers_pydantic"] force-sort-within-sections = true [tool.mypy] -files = ["*.py"] +files = ["sortedcontainers_pydantic.py"] ignore_missing_imports = true [tool.pytest.ini_options] diff --git a/sortedcontainers_pydantic.py b/sortedcontainers_pydantic.py index a23ccb8..de00895 100644 --- a/sortedcontainers_pydantic.py +++ b/sortedcontainers_pydantic.py @@ -1,192 +1,174 @@ from collections.abc import Mapping -from typing import TYPE_CHECKING, Any, Iterable, Tuple +from typing import Any, Iterable, Set, Tuple, get_args from pydantic import ( GetCoreSchemaHandler, - GetJsonSchemaHandler, ) -from pydantic.json_schema import JsonSchemaValue from pydantic_core import core_schema import sortedcontainers -if TYPE_CHECKING: - from sortedcontainers.sorteddict import _VT, _OrderT # pragma: no cover - __version__ = "1.0.0" class SortedDict(sortedcontainers.SortedDict): @classmethod def __get_pydantic_core_schema__( - cls, - _source_type: Any, - _handler: GetCoreSchemaHandler, + cls, source_type: Any, handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: """ - Returns pydantic_core.CoreSchema that defines how Pydantic should validate or + Returns pydantic_core.CoreSchema that defines how Pydantic should validate and serialize this class. - - Validating from JSON: Parse as a dict and then pass to SortedDict constructor + - Validating from JSON: Validate as an iterable and pass to SortedList + constructor - Validating from Python: - - If it's already a SortedDict, do nothing - - If it can be parsed as a dict, do so and then pass to SortedDict - constructor - - If it can be parsed as a list of two-tuples, do so and then pass to - SortedDict constructor - - Serialization: Convert to a dict + - If it's already a SortedList, do nothing + - If it's an iterable, pass to SortedList constructor + - Serialization: Convert to a list """ - - def validate_from_mapping(value: Mapping) -> SortedDict: - return SortedDict(value) - - from_mapping_schema = core_schema.chain_schema( - [ - core_schema.dict_schema(), - core_schema.no_info_plain_validator_function(validate_from_mapping), - ] + # Schema for when the input is already an instance of this class + instance_schema = core_schema.is_instance_schema(cls) + + # Get schema for Iterable type based on source type has arguments + args = get_args(source_type) + if args: + mapping_t_schema = handler.generate_schema(Mapping[*args]) # type: ignore + iterable_of_pairs_t_schema = handler.generate_schema(Iterable[Tuple[*args]]) # type: ignore + else: + mapping_t_schema = handler.generate_schema(Mapping) + iterable_of_pairs_t_schema = handler.generate_schema(Iterable[Tuple[Any, Any]]) + + # Schema for when the input is a mapping + from_mapping_schema = core_schema.no_info_after_validator_function( + function=cls, schema=mapping_t_schema ) - def validate_from_iterable_of_pairs( - value: Iterable[Tuple["_OrderT", "_VT"]], - ) -> SortedDict: - return SortedDict(value) + # Schema for when the input is an iterable of pairs + from_iterable_of_pairs_schema = core_schema.no_info_after_validator_function( + function=cls, schema=iterable_of_pairs_t_schema + ) - from_iterable_of_pairs_schema = core_schema.chain_schema( + # Union of the two schemas + python_schema = core_schema.union_schema( [ - core_schema.list_schema( - items_schema=core_schema.tuple_schema( - [core_schema.any_schema(), core_schema.any_schema()], - ) - ), - core_schema.no_info_plain_validator_function(validate_from_iterable_of_pairs), + instance_schema, + from_mapping_schema, + from_iterable_of_pairs_schema, ] ) + # Serializer that converts an instance to a dict + as_dict_serializer = core_schema.plain_serializer_function_ser_schema(dict) + return core_schema.json_or_python_schema( json_schema=from_mapping_schema, - python_schema=core_schema.union_schema( - [ - # check if it's an instance first before doing any further work - core_schema.is_instance_schema(SortedDict), - from_mapping_schema, - from_iterable_of_pairs_schema, - ] - ), - serialization=core_schema.plain_serializer_function_ser_schema( - lambda instance: dict(instance) - ), + python_schema=python_schema, + serialization=as_dict_serializer, ) - @classmethod - def __get_pydantic_json_schema__( - cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler - ) -> JsonSchemaValue: - """Returns the JSON schema for this class. Uses the same schema as a normal - dict. - """ - return handler(core_schema.dict_schema()) - class SortedList(sortedcontainers.SortedList): @classmethod def __get_pydantic_core_schema__( - cls, - _source_type: Any, - _handler: GetCoreSchemaHandler, + cls, source_type: Any, handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: """ - Returns pydantic_core.CoreSchema that defines how Pydantic should validate or + Returns pydantic_core.CoreSchema that defines how Pydantic should validate and serialize this class. - - Validating from JSON: Parse as a list and then pass to SortedList constructor + - Validating from JSON: Validate as an iterable and pass to SortedList + constructor - Validating from Python: - If it's already a SortedList, do nothing - - If it can be parsed as a list, do so and then pass to SortedList - constructor + - If it's an iterable, pass to SortedList constructor - Serialization: Convert to a list """ + # Schema for when the input is already an instance of this class + instance_schema = core_schema.is_instance_schema(cls) + + # Get schema for Iterable type based on source type has arguments + args = get_args(source_type) + if args: + iterable_t_schema = handler.generate_schema(Iterable[*args]) # type: ignore + else: + iterable_t_schema = handler.generate_schema(Iterable) + + # Schema for when the input is an iterable + from_iterable_schema = core_schema.no_info_after_validator_function( + function=cls, schema=iterable_t_schema + ) - def validate_from_iterable(value: Iterable) -> SortedList: - return SortedList(value) - - from_iterable_schema = core_schema.chain_schema( + # Union of the two schemas + python_schema = core_schema.union_schema( [ - core_schema.list_schema(), - core_schema.no_info_plain_validator_function(validate_from_iterable), + instance_schema, + from_iterable_schema, ] ) + # Serializer that converts an instance to a list + as_list_serializer = core_schema.plain_serializer_function_ser_schema(list) + return core_schema.json_or_python_schema( json_schema=from_iterable_schema, - python_schema=core_schema.union_schema( - [ - # check if it's an instance first before doing any further work - core_schema.is_instance_schema(SortedList), - from_iterable_schema, - ] - ), - serialization=core_schema.plain_serializer_function_ser_schema( - lambda instance: list(instance) - ), + python_schema=python_schema, + serialization=as_list_serializer, ) - @classmethod - def __get_pydantic_json_schema__( - cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler - ) -> JsonSchemaValue: - """Returns the JSON schema for this class. Uses the same schema as a normal - list. - """ - return handler(core_schema.list_schema()) - class SortedSet(sortedcontainers.SortedSet): @classmethod def __get_pydantic_core_schema__( - cls, - _source_type: Any, - _handler: GetCoreSchemaHandler, + cls, source_type: Any, handler: GetCoreSchemaHandler ) -> core_schema.CoreSchema: """ - Returns pydantic_core.CoreSchema that defines how Pydantic should validate or + Returns pydantic_core.CoreSchema that defines how Pydantic should validate and serialize this class. - - Validating from JSON: Parse as a set and then pass to SortedSet constructor + - Validating from JSON: Validate as an iterable and pass to SortedSet + constructor - Validating from Python: - If it's already a SortedSet, do nothing - - If it can be parsed as a set, do so and then pass to SortedSet constructor - - Serialization: Convert to a set + - If it's a set, parse as a set and pass to SortedSet constructor + - If it's an iterable, pass to SortedSet constructor + - Serialization: Convert to a list """ + # Schema for when the input is already an instance of this class + instance_schema = core_schema.is_instance_schema(cls) + + # Get schema for Iterable type based on source type has arguments + args = get_args(source_type) + if args: + set_t_schema = handler.generate_schema(Set[*args]) # type: ignore + iterable_t_schema = handler.generate_schema(Iterable[*args]) # type: ignore + else: + set_t_schema = handler.generate_schema(Set) + iterable_t_schema = handler.generate_schema(Iterable) + + # Schema for when the input is a set + from_set_schema = core_schema.no_info_after_validator_function( + function=cls, schema=set_t_schema + ) - def validate_from_iterable(value: Iterable) -> SortedSet: - return SortedSet(value) + # Schema for when the input is an iterable + from_iterable_schema = core_schema.no_info_after_validator_function( + function=cls, schema=iterable_t_schema + ) - from_iterable_schema = core_schema.chain_schema( + # Union of the two schemas + python_schema = core_schema.union_schema( [ - core_schema.set_schema(), - core_schema.no_info_plain_validator_function(validate_from_iterable), + instance_schema, + from_set_schema, + from_iterable_schema, ] ) + # Serializer that converts an instance to a set + as_set_serializer = core_schema.plain_serializer_function_ser_schema(set) + return core_schema.json_or_python_schema( - json_schema=from_iterable_schema, - python_schema=core_schema.union_schema( - [ - # check if it's an instance first before doing any further work - core_schema.is_instance_schema(SortedSet), - from_iterable_schema, - ] - ), - serialization=core_schema.plain_serializer_function_ser_schema( - lambda instance: set(instance) - ), + json_schema=from_set_schema, + python_schema=python_schema, + serialization=as_set_serializer, ) - - @classmethod - def __get_pydantic_json_schema__( - cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler - ) -> JsonSchemaValue: - """Returns the JSON schema for this class. Uses the same schema as a normal - set. - """ - return handler(core_schema.set_schema()) diff --git a/tests.py b/tests.py index 0a2d4ce..da94fa5 100644 --- a/tests.py +++ b/tests.py @@ -1,49 +1,166 @@ -from pydantic import TypeAdapter +from typing import Callable, Dict, Iterable, List, Set + +from pydantic import BaseModel, TypeAdapter import sortedcontainers import sortedcontainers_pydantic +class ReusableIterable: + """Utility iterable class that can be used multiple times without exhausting.""" + + def __init__(self, iterable_factory: Callable[[], Iterable]): + self.iterable_factory = iterable_factory + + def __iter__(self): + return iter(self.iterable_factory()) + + def test_sorted_dict(): ta = TypeAdapter(sortedcontainers_pydantic.SortedDict) expected = sortedcontainers.SortedDict({"c": 1, "a": 2, "b": 3}) - assert ta.validate_python(expected) == expected - assert ta.validate_python({"c": 1, "a": 2, "b": 3}) == expected - assert ta.validate_python({"b": 3, "c": 1, "a": 2}) == expected - assert ta.validate_python([("c", 1), ("a", 2), ("b", 3)]) == expected - assert ta.validate_python([("b", 3), ("c", 1), ("a", 2)]) == expected - assert ta.validate_python(pair for pair in [("c", 1), ("a", 2), ("b", 3)]) == expected - - assert ta.json_schema() == TypeAdapter(dict).json_schema() + cases = [ + expected, + {"c": 1, "a": 2, "b": 3}, + [("c", 1), ("a", 2), ("b", 3)], + [("b", 3), ("c", 1), ("a", 2)], + (("c", 1), ("a", 2), ("b", 3)), + {("c", 1), ("a", 2), ("b", 3)}, + ReusableIterable(lambda: [("c", 1), ("a", 2), ("b", 3)]), + ] + + for annotation in ( + sortedcontainers_pydantic.SortedDict, + sortedcontainers_pydantic.SortedDict[str, int], + ): + ta = TypeAdapter(annotation) + for case in cases: + print(f"annotation: {annotation}, case: {case}") + actual = ta.validate_python(case) + assert isinstance(actual, sortedcontainers.SortedDict) + assert actual == expected + + assert ta.dump_json(expected).decode() == '{"a":2,"b":3,"c":1}' + + assert ( + TypeAdapter(sortedcontainers_pydantic.SortedDict).json_schema() + == TypeAdapter(dict).json_schema() + ) + assert ( + TypeAdapter(sortedcontainers_pydantic.SortedDict[str, int]).json_schema() + == TypeAdapter(Dict[str, int]).json_schema() + ) + + class MyModel(BaseModel): + sorted_dict: sortedcontainers_pydantic.SortedDict + + class MyModelWithArg(BaseModel): + sorted_dict: sortedcontainers_pydantic.SortedDict[str, int] + + for model in (MyModel, MyModelWithArg): + for case in cases: + print(f"model: {model}, case: {case}") + instance = MyModel(sorted_dict=case) + assert isinstance(instance.sorted_dict, sortedcontainers.SortedDict) + assert instance.sorted_dict == expected + assert instance.model_dump_json() == '{"sorted_dict":{"a":2,"b":3,"c":1}}' def test_sorted_list(): - ta = TypeAdapter(sortedcontainers_pydantic.SortedList) - expected = sortedcontainers.SortedList([3, 2, 1]) - assert ta.validate_python(expected) == expected - assert ta.validate_python([3, 2, 1]) == expected - assert ta.validate_python((3, 2, 1)) == expected - assert ta.validate_python({3, 2, 1}) == expected - assert ta.validate_python(i for i in (3, 2, 1)) == expected - assert ta.validate_python([2, 3, 1]) == expected - - assert ta.json_schema() == TypeAdapter(list).json_schema() + cases = [ + expected, + [3, 2, 1], + (3, 2, 1), + {3, 2, 1}, + range(3, 0, -1), + [2, 3, 1], + ] + + for annotation in ( + sortedcontainers_pydantic.SortedList, + sortedcontainers_pydantic.SortedList[int], + ): + ta = TypeAdapter(annotation) + for case in cases: + print(f"annotation: {annotation}, case: {case}") + actual = ta.validate_python(case) + assert isinstance(actual, sortedcontainers.SortedList) + assert actual == expected + + assert ta.dump_json(expected).decode() == "[1,2,3]" + + assert ( + TypeAdapter(sortedcontainers_pydantic.SortedList).json_schema() + == TypeAdapter(list).json_schema() + ) + assert ( + TypeAdapter(sortedcontainers_pydantic.SortedList[int]).json_schema() + == TypeAdapter(List[int]).json_schema() + ) + + class MyModel(BaseModel): + sorted_list: sortedcontainers_pydantic.SortedList + + class MyModelWithArg(BaseModel): + sorted_list: sortedcontainers_pydantic.SortedList[int] + + for model in (MyModel, MyModelWithArg): + for case in cases: + print(f"model: {model}, case: {case}") + instance = MyModel(sorted_list=case) + assert isinstance(instance.sorted_list, sortedcontainers.SortedList) + assert instance.sorted_list == expected + assert instance.model_dump_json() == '{"sorted_list":[1,2,3]}' def test_sorted_set(): - ta = TypeAdapter(sortedcontainers_pydantic.SortedSet) - expected = sortedcontainers.SortedSet([3, 2, 1]) - assert ta.validate_python(expected) == expected - assert ta.validate_python([3, 2, 1]) == expected - assert ta.validate_python((3, 2, 1)) == expected - assert ta.validate_python({3, 2, 1}) == expected - assert ta.validate_python(i for i in (3, 2, 1)) == expected - assert ta.validate_python([2, 3, 1]) == expected - - assert ta.json_schema() == TypeAdapter(set).json_schema() + cases = [ + expected, + [3, 2, 1], + (3, 2, 1), + {3, 2, 1}, + range(3, 0, -1), + [2, 3, 1], + ] + + for annotation in ( + sortedcontainers_pydantic.SortedSet, + sortedcontainers_pydantic.SortedSet[int], + ): + ta = TypeAdapter(annotation) + for case in cases: + print(f"annotation: {annotation}, case: {case}") + actual = ta.validate_python(case) + assert isinstance(actual, sortedcontainers.SortedSet) + assert actual == expected + + assert ta.dump_json(expected).decode() == "[1,2,3]" + + assert ( + TypeAdapter(sortedcontainers_pydantic.SortedSet).json_schema() + == TypeAdapter(set).json_schema() + ) + assert ( + TypeAdapter(sortedcontainers_pydantic.SortedSet[int]).json_schema() + == TypeAdapter(Set[int]).json_schema() + ) + + class MyModel(BaseModel): + sorted_set: sortedcontainers_pydantic.SortedSet + + class MyModelWithArg(BaseModel): + sorted_set: sortedcontainers_pydantic.SortedSet[int] + + for model in (MyModel, MyModelWithArg): + for case in cases: + print(f"model: {model}, case: {case}") + instance = MyModel(sorted_set=case) + assert isinstance(instance.sorted_set, sortedcontainers.SortedSet) + assert instance.sorted_set == expected + assert instance.model_dump_json() == '{"sorted_set":[1,2,3]}'