-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Follow generics example to handle type arguments
- Loading branch information
Showing
3 changed files
with
250 additions
and
151 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,192 +1,174 @@ | ||
from collections.abc import Mapping | ||
from typing import TYPE_CHECKING, Any, Iterable, Tuple | ||
from typing import Any, Iterable, Set, Tuple, get_args | ||
|
||
from pydantic import ( | ||
GetCoreSchemaHandler, | ||
GetJsonSchemaHandler, | ||
) | ||
from pydantic.json_schema import JsonSchemaValue | ||
from pydantic_core import core_schema | ||
import sortedcontainers | ||
|
||
if TYPE_CHECKING: | ||
from sortedcontainers.sorteddict import _VT, _OrderT # pragma: no cover | ||
|
||
__version__ = "1.0.0" | ||
|
||
|
||
class SortedDict(sortedcontainers.SortedDict): | ||
@classmethod | ||
def __get_pydantic_core_schema__( | ||
cls, | ||
_source_type: Any, | ||
_handler: GetCoreSchemaHandler, | ||
cls, source_type: Any, handler: GetCoreSchemaHandler | ||
) -> core_schema.CoreSchema: | ||
""" | ||
Returns pydantic_core.CoreSchema that defines how Pydantic should validate or | ||
Returns pydantic_core.CoreSchema that defines how Pydantic should validate and | ||
serialize this class. | ||
- Validating from JSON: Parse as a dict and then pass to SortedDict constructor | ||
- Validating from JSON: Validate as an iterable and pass to SortedList | ||
constructor | ||
- Validating from Python: | ||
- If it's already a SortedDict, do nothing | ||
- If it can be parsed as a dict, do so and then pass to SortedDict | ||
constructor | ||
- If it can be parsed as a list of two-tuples, do so and then pass to | ||
SortedDict constructor | ||
- Serialization: Convert to a dict | ||
- If it's already a SortedList, do nothing | ||
- If it's an iterable, pass to SortedList constructor | ||
- Serialization: Convert to a list | ||
""" | ||
|
||
def validate_from_mapping(value: Mapping) -> SortedDict: | ||
return SortedDict(value) | ||
|
||
from_mapping_schema = core_schema.chain_schema( | ||
[ | ||
core_schema.dict_schema(), | ||
core_schema.no_info_plain_validator_function(validate_from_mapping), | ||
] | ||
# Schema for when the input is already an instance of this class | ||
instance_schema = core_schema.is_instance_schema(cls) | ||
|
||
# Get schema for Iterable type based on source type has arguments | ||
args = get_args(source_type) | ||
if args: | ||
mapping_t_schema = handler.generate_schema(Mapping[*args]) # type: ignore | ||
iterable_of_pairs_t_schema = handler.generate_schema(Iterable[Tuple[*args]]) # type: ignore | ||
else: | ||
mapping_t_schema = handler.generate_schema(Mapping) | ||
iterable_of_pairs_t_schema = handler.generate_schema(Iterable[Tuple[Any, Any]]) | ||
|
||
# Schema for when the input is a mapping | ||
from_mapping_schema = core_schema.no_info_after_validator_function( | ||
function=cls, schema=mapping_t_schema | ||
) | ||
|
||
def validate_from_iterable_of_pairs( | ||
value: Iterable[Tuple["_OrderT", "_VT"]], | ||
) -> SortedDict: | ||
return SortedDict(value) | ||
# Schema for when the input is an iterable of pairs | ||
from_iterable_of_pairs_schema = core_schema.no_info_after_validator_function( | ||
function=cls, schema=iterable_of_pairs_t_schema | ||
) | ||
|
||
from_iterable_of_pairs_schema = core_schema.chain_schema( | ||
# Union of the two schemas | ||
python_schema = core_schema.union_schema( | ||
[ | ||
core_schema.list_schema( | ||
items_schema=core_schema.tuple_schema( | ||
[core_schema.any_schema(), core_schema.any_schema()], | ||
) | ||
), | ||
core_schema.no_info_plain_validator_function(validate_from_iterable_of_pairs), | ||
instance_schema, | ||
from_mapping_schema, | ||
from_iterable_of_pairs_schema, | ||
] | ||
) | ||
|
||
# Serializer that converts an instance to a dict | ||
as_dict_serializer = core_schema.plain_serializer_function_ser_schema(dict) | ||
|
||
return core_schema.json_or_python_schema( | ||
json_schema=from_mapping_schema, | ||
python_schema=core_schema.union_schema( | ||
[ | ||
# check if it's an instance first before doing any further work | ||
core_schema.is_instance_schema(SortedDict), | ||
from_mapping_schema, | ||
from_iterable_of_pairs_schema, | ||
] | ||
), | ||
serialization=core_schema.plain_serializer_function_ser_schema( | ||
lambda instance: dict(instance) | ||
), | ||
python_schema=python_schema, | ||
serialization=as_dict_serializer, | ||
) | ||
|
||
@classmethod | ||
def __get_pydantic_json_schema__( | ||
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler | ||
) -> JsonSchemaValue: | ||
"""Returns the JSON schema for this class. Uses the same schema as a normal | ||
dict. | ||
""" | ||
return handler(core_schema.dict_schema()) | ||
|
||
|
||
class SortedList(sortedcontainers.SortedList): | ||
@classmethod | ||
def __get_pydantic_core_schema__( | ||
cls, | ||
_source_type: Any, | ||
_handler: GetCoreSchemaHandler, | ||
cls, source_type: Any, handler: GetCoreSchemaHandler | ||
) -> core_schema.CoreSchema: | ||
""" | ||
Returns pydantic_core.CoreSchema that defines how Pydantic should validate or | ||
Returns pydantic_core.CoreSchema that defines how Pydantic should validate and | ||
serialize this class. | ||
- Validating from JSON: Parse as a list and then pass to SortedList constructor | ||
- Validating from JSON: Validate as an iterable and pass to SortedList | ||
constructor | ||
- Validating from Python: | ||
- If it's already a SortedList, do nothing | ||
- If it can be parsed as a list, do so and then pass to SortedList | ||
constructor | ||
- If it's an iterable, pass to SortedList constructor | ||
- Serialization: Convert to a list | ||
""" | ||
# Schema for when the input is already an instance of this class | ||
instance_schema = core_schema.is_instance_schema(cls) | ||
|
||
# Get schema for Iterable type based on source type has arguments | ||
args = get_args(source_type) | ||
if args: | ||
iterable_t_schema = handler.generate_schema(Iterable[*args]) # type: ignore | ||
else: | ||
iterable_t_schema = handler.generate_schema(Iterable) | ||
|
||
# Schema for when the input is an iterable | ||
from_iterable_schema = core_schema.no_info_after_validator_function( | ||
function=cls, schema=iterable_t_schema | ||
) | ||
|
||
def validate_from_iterable(value: Iterable) -> SortedList: | ||
return SortedList(value) | ||
|
||
from_iterable_schema = core_schema.chain_schema( | ||
# Union of the two schemas | ||
python_schema = core_schema.union_schema( | ||
[ | ||
core_schema.list_schema(), | ||
core_schema.no_info_plain_validator_function(validate_from_iterable), | ||
instance_schema, | ||
from_iterable_schema, | ||
] | ||
) | ||
|
||
# Serializer that converts an instance to a list | ||
as_list_serializer = core_schema.plain_serializer_function_ser_schema(list) | ||
|
||
return core_schema.json_or_python_schema( | ||
json_schema=from_iterable_schema, | ||
python_schema=core_schema.union_schema( | ||
[ | ||
# check if it's an instance first before doing any further work | ||
core_schema.is_instance_schema(SortedList), | ||
from_iterable_schema, | ||
] | ||
), | ||
serialization=core_schema.plain_serializer_function_ser_schema( | ||
lambda instance: list(instance) | ||
), | ||
python_schema=python_schema, | ||
serialization=as_list_serializer, | ||
) | ||
|
||
@classmethod | ||
def __get_pydantic_json_schema__( | ||
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler | ||
) -> JsonSchemaValue: | ||
"""Returns the JSON schema for this class. Uses the same schema as a normal | ||
list. | ||
""" | ||
return handler(core_schema.list_schema()) | ||
|
||
|
||
class SortedSet(sortedcontainers.SortedSet): | ||
@classmethod | ||
def __get_pydantic_core_schema__( | ||
cls, | ||
_source_type: Any, | ||
_handler: GetCoreSchemaHandler, | ||
cls, source_type: Any, handler: GetCoreSchemaHandler | ||
) -> core_schema.CoreSchema: | ||
""" | ||
Returns pydantic_core.CoreSchema that defines how Pydantic should validate or | ||
Returns pydantic_core.CoreSchema that defines how Pydantic should validate and | ||
serialize this class. | ||
- Validating from JSON: Parse as a set and then pass to SortedSet constructor | ||
- Validating from JSON: Validate as an iterable and pass to SortedSet | ||
constructor | ||
- Validating from Python: | ||
- If it's already a SortedSet, do nothing | ||
- If it can be parsed as a set, do so and then pass to SortedSet constructor | ||
- Serialization: Convert to a set | ||
- If it's a set, parse as a set and pass to SortedSet constructor | ||
- If it's an iterable, pass to SortedSet constructor | ||
- Serialization: Convert to a list | ||
""" | ||
# Schema for when the input is already an instance of this class | ||
instance_schema = core_schema.is_instance_schema(cls) | ||
|
||
# Get schema for Iterable type based on source type has arguments | ||
args = get_args(source_type) | ||
if args: | ||
set_t_schema = handler.generate_schema(Set[*args]) # type: ignore | ||
iterable_t_schema = handler.generate_schema(Iterable[*args]) # type: ignore | ||
else: | ||
set_t_schema = handler.generate_schema(Set) | ||
iterable_t_schema = handler.generate_schema(Iterable) | ||
|
||
# Schema for when the input is a set | ||
from_set_schema = core_schema.no_info_after_validator_function( | ||
function=cls, schema=set_t_schema | ||
) | ||
|
||
def validate_from_iterable(value: Iterable) -> SortedSet: | ||
return SortedSet(value) | ||
# Schema for when the input is an iterable | ||
from_iterable_schema = core_schema.no_info_after_validator_function( | ||
function=cls, schema=iterable_t_schema | ||
) | ||
|
||
from_iterable_schema = core_schema.chain_schema( | ||
# Union of the two schemas | ||
python_schema = core_schema.union_schema( | ||
[ | ||
core_schema.set_schema(), | ||
core_schema.no_info_plain_validator_function(validate_from_iterable), | ||
instance_schema, | ||
from_set_schema, | ||
from_iterable_schema, | ||
] | ||
) | ||
|
||
# Serializer that converts an instance to a set | ||
as_set_serializer = core_schema.plain_serializer_function_ser_schema(set) | ||
|
||
return core_schema.json_or_python_schema( | ||
json_schema=from_iterable_schema, | ||
python_schema=core_schema.union_schema( | ||
[ | ||
# check if it's an instance first before doing any further work | ||
core_schema.is_instance_schema(SortedSet), | ||
from_iterable_schema, | ||
] | ||
), | ||
serialization=core_schema.plain_serializer_function_ser_schema( | ||
lambda instance: set(instance) | ||
), | ||
json_schema=from_set_schema, | ||
python_schema=python_schema, | ||
serialization=as_set_serializer, | ||
) | ||
|
||
@classmethod | ||
def __get_pydantic_json_schema__( | ||
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler | ||
) -> JsonSchemaValue: | ||
"""Returns the JSON schema for this class. Uses the same schema as a normal | ||
set. | ||
""" | ||
return handler(core_schema.set_schema()) |
Oops, something went wrong.