Skip to content

Commit

Permalink
Follow generics example to handle type arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
jayqi committed Mar 5, 2024
1 parent 765263f commit f52f25a
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 151 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ known-first-party = ["sortedcontainers_pydantic"]
force-sort-within-sections = true

[tool.mypy]
files = ["*.py"]
files = ["sortedcontainers_pydantic.py"]
ignore_missing_imports = true

[tool.pytest.ini_options]
Expand Down
224 changes: 103 additions & 121 deletions sortedcontainers_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,192 +1,174 @@
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Iterable, Tuple
from typing import Any, Iterable, Set, Tuple, get_args

from pydantic import (
GetCoreSchemaHandler,
GetJsonSchemaHandler,
)
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema
import sortedcontainers

if TYPE_CHECKING:
from sortedcontainers.sorteddict import _VT, _OrderT # pragma: no cover

__version__ = "1.0.0"


class SortedDict(sortedcontainers.SortedDict):
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
"""
Returns pydantic_core.CoreSchema that defines how Pydantic should validate or
Returns pydantic_core.CoreSchema that defines how Pydantic should validate and
serialize this class.
- Validating from JSON: Parse as a dict and then pass to SortedDict constructor
- Validating from JSON: Validate as an iterable and pass to SortedList
constructor
- Validating from Python:
- If it's already a SortedDict, do nothing
- If it can be parsed as a dict, do so and then pass to SortedDict
constructor
- If it can be parsed as a list of two-tuples, do so and then pass to
SortedDict constructor
- Serialization: Convert to a dict
- If it's already a SortedList, do nothing
- If it's an iterable, pass to SortedList constructor
- Serialization: Convert to a list
"""

def validate_from_mapping(value: Mapping) -> SortedDict:
return SortedDict(value)

from_mapping_schema = core_schema.chain_schema(
[
core_schema.dict_schema(),
core_schema.no_info_plain_validator_function(validate_from_mapping),
]
# Schema for when the input is already an instance of this class
instance_schema = core_schema.is_instance_schema(cls)

# Get schema for Iterable type based on source type has arguments
args = get_args(source_type)
if args:
mapping_t_schema = handler.generate_schema(Mapping[*args]) # type: ignore
iterable_of_pairs_t_schema = handler.generate_schema(Iterable[Tuple[*args]]) # type: ignore
else:
mapping_t_schema = handler.generate_schema(Mapping)
iterable_of_pairs_t_schema = handler.generate_schema(Iterable[Tuple[Any, Any]])

# Schema for when the input is a mapping
from_mapping_schema = core_schema.no_info_after_validator_function(
function=cls, schema=mapping_t_schema
)

def validate_from_iterable_of_pairs(
value: Iterable[Tuple["_OrderT", "_VT"]],
) -> SortedDict:
return SortedDict(value)
# Schema for when the input is an iterable of pairs
from_iterable_of_pairs_schema = core_schema.no_info_after_validator_function(
function=cls, schema=iterable_of_pairs_t_schema
)

from_iterable_of_pairs_schema = core_schema.chain_schema(
# Union of the two schemas
python_schema = core_schema.union_schema(
[
core_schema.list_schema(
items_schema=core_schema.tuple_schema(
[core_schema.any_schema(), core_schema.any_schema()],
)
),
core_schema.no_info_plain_validator_function(validate_from_iterable_of_pairs),
instance_schema,
from_mapping_schema,
from_iterable_of_pairs_schema,
]
)

# Serializer that converts an instance to a dict
as_dict_serializer = core_schema.plain_serializer_function_ser_schema(dict)

return core_schema.json_or_python_schema(
json_schema=from_mapping_schema,
python_schema=core_schema.union_schema(
[
# check if it's an instance first before doing any further work
core_schema.is_instance_schema(SortedDict),
from_mapping_schema,
from_iterable_of_pairs_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: dict(instance)
),
python_schema=python_schema,
serialization=as_dict_serializer,
)

@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
"""Returns the JSON schema for this class. Uses the same schema as a normal
dict.
"""
return handler(core_schema.dict_schema())


class SortedList(sortedcontainers.SortedList):
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
"""
Returns pydantic_core.CoreSchema that defines how Pydantic should validate or
Returns pydantic_core.CoreSchema that defines how Pydantic should validate and
serialize this class.
- Validating from JSON: Parse as a list and then pass to SortedList constructor
- Validating from JSON: Validate as an iterable and pass to SortedList
constructor
- Validating from Python:
- If it's already a SortedList, do nothing
- If it can be parsed as a list, do so and then pass to SortedList
constructor
- If it's an iterable, pass to SortedList constructor
- Serialization: Convert to a list
"""
# Schema for when the input is already an instance of this class
instance_schema = core_schema.is_instance_schema(cls)

# Get schema for Iterable type based on source type has arguments
args = get_args(source_type)
if args:
iterable_t_schema = handler.generate_schema(Iterable[*args]) # type: ignore
else:
iterable_t_schema = handler.generate_schema(Iterable)

# Schema for when the input is an iterable
from_iterable_schema = core_schema.no_info_after_validator_function(
function=cls, schema=iterable_t_schema
)

def validate_from_iterable(value: Iterable) -> SortedList:
return SortedList(value)

from_iterable_schema = core_schema.chain_schema(
# Union of the two schemas
python_schema = core_schema.union_schema(
[
core_schema.list_schema(),
core_schema.no_info_plain_validator_function(validate_from_iterable),
instance_schema,
from_iterable_schema,
]
)

# Serializer that converts an instance to a list
as_list_serializer = core_schema.plain_serializer_function_ser_schema(list)

return core_schema.json_or_python_schema(
json_schema=from_iterable_schema,
python_schema=core_schema.union_schema(
[
# check if it's an instance first before doing any further work
core_schema.is_instance_schema(SortedList),
from_iterable_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: list(instance)
),
python_schema=python_schema,
serialization=as_list_serializer,
)

@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
"""Returns the JSON schema for this class. Uses the same schema as a normal
list.
"""
return handler(core_schema.list_schema())


class SortedSet(sortedcontainers.SortedSet):
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
cls, source_type: Any, handler: GetCoreSchemaHandler
) -> core_schema.CoreSchema:
"""
Returns pydantic_core.CoreSchema that defines how Pydantic should validate or
Returns pydantic_core.CoreSchema that defines how Pydantic should validate and
serialize this class.
- Validating from JSON: Parse as a set and then pass to SortedSet constructor
- Validating from JSON: Validate as an iterable and pass to SortedSet
constructor
- Validating from Python:
- If it's already a SortedSet, do nothing
- If it can be parsed as a set, do so and then pass to SortedSet constructor
- Serialization: Convert to a set
- If it's a set, parse as a set and pass to SortedSet constructor
- If it's an iterable, pass to SortedSet constructor
- Serialization: Convert to a list
"""
# Schema for when the input is already an instance of this class
instance_schema = core_schema.is_instance_schema(cls)

# Get schema for Iterable type based on source type has arguments
args = get_args(source_type)
if args:
set_t_schema = handler.generate_schema(Set[*args]) # type: ignore
iterable_t_schema = handler.generate_schema(Iterable[*args]) # type: ignore
else:
set_t_schema = handler.generate_schema(Set)
iterable_t_schema = handler.generate_schema(Iterable)

# Schema for when the input is a set
from_set_schema = core_schema.no_info_after_validator_function(
function=cls, schema=set_t_schema
)

def validate_from_iterable(value: Iterable) -> SortedSet:
return SortedSet(value)
# Schema for when the input is an iterable
from_iterable_schema = core_schema.no_info_after_validator_function(
function=cls, schema=iterable_t_schema
)

from_iterable_schema = core_schema.chain_schema(
# Union of the two schemas
python_schema = core_schema.union_schema(
[
core_schema.set_schema(),
core_schema.no_info_plain_validator_function(validate_from_iterable),
instance_schema,
from_set_schema,
from_iterable_schema,
]
)

# Serializer that converts an instance to a set
as_set_serializer = core_schema.plain_serializer_function_ser_schema(set)

return core_schema.json_or_python_schema(
json_schema=from_iterable_schema,
python_schema=core_schema.union_schema(
[
# check if it's an instance first before doing any further work
core_schema.is_instance_schema(SortedSet),
from_iterable_schema,
]
),
serialization=core_schema.plain_serializer_function_ser_schema(
lambda instance: set(instance)
),
json_schema=from_set_schema,
python_schema=python_schema,
serialization=as_set_serializer,
)

@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
"""Returns the JSON schema for this class. Uses the same schema as a normal
set.
"""
return handler(core_schema.set_schema())
Loading

0 comments on commit f52f25a

Please sign in to comment.