From e3a6a08e3528d8a104d815e99b28e02fd53177d2 Mon Sep 17 00:00:00 2001 From: Imogen Date: Sat, 30 Sep 2023 17:23:47 +0200 Subject: [PATCH] move reference validation code to problem module --- algobattle/battle.py | 2 +- algobattle/problem.py | 175 ++++++++++++++++++++++++++++++++++++++++-- algobattle/types.py | 6 +- algobattle/util.py | 170 +--------------------------------------- tests/test_types.py | 4 +- 5 files changed, 174 insertions(+), 183 deletions(-) diff --git a/algobattle/battle.py b/algobattle/battle.py index 2cd86a3e..eb380b14 100644 --- a/algobattle/battle.py +++ b/algobattle/battle.py @@ -237,7 +237,7 @@ class Config(BaseModel): :meth:`Battle.run` method with its fields set accordingly. """ - type: str + type: Any """Type of battle that will be used.""" @classmethod diff --git a/algobattle/problem.py b/algobattle/problem.py index 94543cfc..8074cd8c 100644 --- a/algobattle/problem.py +++ b/algobattle/problem.py @@ -1,7 +1,9 @@ """Module defining the Problem and Solution base classes and related objects.""" from abc import ABC, abstractmethod +from dataclasses import dataclass from functools import wraps from importlib.metadata import entry_points +from inspect import Parameter, Signature, signature from itertools import chain from pathlib import Path from typing import ( @@ -17,12 +19,21 @@ Generic, TypeVar, overload, + cast, + get_args, ) from math import inf, isnan +from annotated_types import GroupedMetadata + +from pydantic import ( + GetCoreSchemaHandler, + ValidationInfo, +) +from pydantic_core import CoreSchema +from pydantic_core.core_schema import with_info_after_validator_function from algobattle.util import ( EncodableModel, - InstanceSolutionModel, Role, Encodable, import_file_as_module, @@ -363,19 +374,171 @@ def available(cls) -> set[str]: AnyProblem = Problem[Any, Any] +ModelType = Literal["instance", "solution"] +ModelReference = ModelType | Literal["self"] + + +@dataclass(frozen=True, slots=True) +class AttributeReference: + """Creates a reference to the attribute of a model to be used in validaton schemas.""" + + model: ModelReference + attribute: str + + def get_value(self, info: ValidationInfo) -> Any | None: + """Returns the referenced value from the correct object in the info context. + + If the correct object is not in the context or doesn't have the referenced attribute it returns None. + """ + if info.context is None or self.model not in info.context: + return None + model = info.context[self.model] + if hasattr(model, self.attribute): + return getattr(model, self.attribute) + else: + return None + + def __str__(self) -> str: + return f"{self.model}.{self.attribute}" + + def needs_self(self, model_type: Literal["instance", "solution"]) -> bool: + """Checks if an attribute reference needs a reference to the current model in order to be resolved.""" + if self.model == "self": + return True + else: + return self.model == model_type + + +NoInfoAttrValidatorFunction = Callable[[Any, Any], Any] +GeneralAttrValidatorFunction = Callable[[Any, Any, ValidationInfo], Any] +AttrValidatorFunction = NoInfoAttrValidatorFunction | GeneralAttrValidatorFunction + + +def count_positional_params(sig: Signature) -> int: + """Counts the number of positional parameters in a signature.""" + return sum(1 for param in sig.parameters.values() if can_be_positional(param)) + + +def can_be_positional(param: Parameter) -> bool: + """Checks whether a parameter is positional.""" + return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) + + +def is_info_validator(validator: AttrValidatorFunction) -> bool: + """Helper method to discriminate the union.""" + match count_positional_params(signature(validator)): + case 2: + return False + case 3: + return True + case _: + raise TypeError + + +@dataclass(frozen=True, slots=True) +class AttributeReferenceValidator: + """An AfterValidator that can resolve a reference to a model attribute and pass it to the validator function. + + Using this with a reference to an attribute in the model it is defined may significantly impact performance. + """ + func: AttrValidatorFunction + attribute: AttributeReference -class InstanceModel(Instance, EncodableModel, InstanceSolutionModel, ABC): + def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: + schema = handler(source_type) + info_arg = is_info_validator(self.func) + if info_arg: + func = cast(GeneralAttrValidatorFunction, self.func) + + def wrapper(value: Any, info: ValidationInfo) -> Any: + attribute_val = self.attribute.get_value(info) + if attribute_val is None: + return value + return func(value, attribute_val, info) + + else: + func = cast(NoInfoAttrValidatorFunction, self.func) + + def wrapper(value: Any, info: ValidationInfo) -> Any: + attribute_val = self.attribute.get_value(info) + if attribute_val is None: + return value + return func(value, attribute_val) + + return with_info_after_validator_function(wrapper, schema=schema) + + def needs_self(self, model_type: ModelType) -> bool: + """Checks if the validator needs a reference to the current model in order to work fully.""" + if self.attribute.model == "self": + return True + else: + return self.attribute.model == model_type + + +@dataclass +class AttributeReferenceMaker: + """Helper class to easily create attribute references.""" + + _attr_ref_maker_model: ModelReference + + def __getattr__(self, __name: str) -> AttributeReference: + return AttributeReference(self._attr_ref_maker_model, __name) + + +SelfRef = AttributeReferenceMaker("self") +InstanceRef = AttributeReferenceMaker("instance") +SolutionRef = AttributeReferenceMaker("solution") + + +class InstanceSolutionModel(EncodableModel): + """Base class for Instance and solution models.""" + + @classmethod + def model_validate( # noqa: D102 + cls, + obj: Any, + *, + strict: bool | None = None, + from_attributes: bool | None = None, + context: dict[str, Any] | None = None, + ) -> Self: + model = super().model_validate(obj, strict=strict, from_attributes=from_attributes, context=context) + model_type = "instance" if issubclass(cls, InstanceModel) else "solution" + if cls._validate_with_self(model_type): + context = (context or {}) | {"self": model, model_type: model} + model = super().model_validate(obj, context=context) + return model + + @classmethod + def _annotation_needs_self(cls, annotation: object, model_type: ModelType) -> bool: + if isinstance(annotation, AttributeReferenceValidator): + return annotation.needs_self(model_type) + if isinstance(annotation, GroupedMetadata): + return any(cls._annotation_needs_self(e, model_type) for e in annotation) + return any(cls._annotation_needs_self(e, model_type) for e in get_args(annotation)) + + @classmethod + def _validate_with_self(cls, model_type: ModelType) -> bool: + # info.annotation contains the type and any nested metadata, info.metadata the top level metadata + # we can use _annotation_needs_self for all of them, so we iterate over all fields and see if any of them + # either have an annotation or metadata we need to parse with a self reference + for info in cls.model_fields.values(): + values = chain((info.annotation,), info.metadata) + if any(cls._annotation_needs_self(value, model_type) for value in values): + return True + return False + + +class InstanceModel(Instance, InstanceSolutionModel, ABC): """An instance that can easily be parsed to/from a json file.""" - _algobattle_model_type: ClassVar[Literal["instance"]] = "instance" + pass -class SolutionModel(Solution[InstanceT], EncodableModel, InstanceSolutionModel, ABC): +class SolutionModel(Solution[InstanceT], InstanceSolutionModel, ABC): """A solution that can easily be parsed to/from a json file.""" - _algobattle_model_type: ClassVar[Literal["solution"]] = "solution" - @classmethod def decode(cls, source: Path, max_size: int, role: Role, instance: InstanceT | None = None) -> Self: """Uses pydantic to create a python object from a `.json` file.""" diff --git a/algobattle/types.py b/algobattle/types.py index e402ca55..02605a92 100644 --- a/algobattle/types.py +++ b/algobattle/types.py @@ -34,15 +34,11 @@ from algobattle.problem import ( InstanceModel, SolutionModel, -) -from algobattle.util import ( - BaseModel, - Role, AttributeReference, AttributeReferenceValidator, InstanceRef, - ValidationError, ) +from algobattle.util import BaseModel, Role, ValidationError __all__ = ( "u64", diff --git a/algobattle/util.py b/algobattle/util.py index 0e0a625f..6e7c2c39 100644 --- a/algobattle/util.py +++ b/algobattle/util.py @@ -7,26 +7,19 @@ from datetime import datetime from enum import StrEnum from importlib.util import module_from_spec, spec_from_file_location -from inspect import Parameter, Signature, signature -from itertools import chain import json from pathlib import Path import sys from tempfile import TemporaryDirectory from traceback import format_exception from types import ModuleType -from typing import Any, Callable, ClassVar, Iterable, Literal, LiteralString, TypeVar, Self, cast, get_args -from annotated_types import GroupedMetadata +from typing import Any, Iterable, LiteralString, TypeVar, Self from pydantic import ( ConfigDict, BaseModel as PydandticBaseModel, - GetCoreSchemaHandler, ValidationError as PydanticValidationError, - ValidationInfo, ) -from pydantic_core import CoreSchema -from pydantic_core.core_schema import with_info_after_validator_function class Role(StrEnum): @@ -39,163 +32,12 @@ class Role(StrEnum): T = TypeVar("T") -ModelType = Literal["instance", "solution", "other"] -ModelReference = ModelType | Literal["self"] - - class BaseModel(PydandticBaseModel): """Base class for all pydantic models.""" model_config = ConfigDict(extra="forbid", from_attributes=True) -class InstanceSolutionModel(BaseModel): - """Base class for Instance and solution models.""" - - _algobattle_model_type: ClassVar[ModelType] = "other" - - @classmethod - def model_validate( # noqa: D102 - cls, - obj: Any, - *, - strict: bool | None = None, - from_attributes: bool | None = None, - context: dict[str, Any] | None = None, - ) -> Self: - model = super().model_validate(obj, strict=strict, from_attributes=from_attributes, context=context) - if cls._validate_with_self(cls._algobattle_model_type): - context = (context or {}) | {"self": model, cls._algobattle_model_type: model} - model = super().model_validate(obj, context=context) - return model - - @classmethod - def _annotation_needs_self(cls, annotation: object, model_type: ModelType) -> bool: - if isinstance(annotation, AttributeReferenceValidator): - return annotation.needs_self(model_type) - if isinstance(annotation, GroupedMetadata): - return any(cls._annotation_needs_self(e, model_type) for e in annotation) - return any(cls._annotation_needs_self(e, model_type) for e in get_args(annotation)) - - @classmethod - def _validate_with_self(cls, model_type: ModelType) -> bool: - # info.annotation contains the type and any nested metadata, info.metadata the top level metadata - # we can use _annotation_needs_self for all of them, so we iterate over all fields and see if any of them - # either have an annotation or metadata we need to parse with a self reference - for info in cls.model_fields.values(): - values = chain((info.annotation,), info.metadata) - if any(cls._annotation_needs_self(value, model_type) for value in values): - return True - return False - - -@dataclass(frozen=True, slots=True) -class AttributeReference: - """Creates a reference to the attribute of a model to be used in validaton schemas.""" - - model: ModelReference - attribute: str - - def get_value(self, info: ValidationInfo) -> Any | None: - """Returns the referenced value from the correct object in the info context. - - If the correct object is not in the context or doesn't have the referenced attribute it returns None. - """ - if info.context is None or self.model not in info.context: - return None - model = info.context[self.model] - if hasattr(model, self.attribute): - return getattr(model, self.attribute) - else: - return None - - def __str__(self) -> str: - return f"{self.model}.{self.attribute}" - - def needs_self(self, model_type: Literal["instance", "solution"]) -> bool: - """Checks if an attribute reference needs a reference to the current model in order to be resolved.""" - if self.model == "self": - return True - else: - return self.model == model_type - - -NoInfoAttrValidatorFunction = Callable[[Any, Any], Any] -GeneralAttrValidatorFunction = Callable[[Any, Any, ValidationInfo], Any] -AttrValidatorFunction = NoInfoAttrValidatorFunction | GeneralAttrValidatorFunction - - -def is_info_validator(validator: AttrValidatorFunction) -> bool: - """Helper method to discriminate the union.""" - match count_positional_params(signature(validator)): - case 2: - return False - case 3: - return True - case _: - raise TypeError - - -@dataclass(frozen=True, slots=True) -class AttributeReferenceValidator: - """An AfterValidator that can resolve a reference to a model attribute and pass it to the validator function. - - Using this with a reference to an attribute in the model it is defined may significantly impact performance. - """ - - func: AttrValidatorFunction - attribute: AttributeReference - - def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema: - schema = handler(source_type) - info_arg = is_info_validator(self.func) - if info_arg: - func = cast(GeneralAttrValidatorFunction, self.func) - - def wrapper(value: Any, info: ValidationInfo) -> Any: - attribute_val = self.attribute.get_value(info) - if attribute_val is None: - return value - return func(value, attribute_val, info) - - else: - func = cast(NoInfoAttrValidatorFunction, self.func) - - def wrapper(value: Any, info: ValidationInfo) -> Any: - attribute_val = self.attribute.get_value(info) - if attribute_val is None: - return value - return func(value, attribute_val) - - return with_info_after_validator_function(wrapper, schema=schema) - - def needs_self(self, model_type: ModelType) -> bool: - """Checks if the validator needs a reference to the current model in order to work fully.""" - if self.attribute.model == "self": - return True - else: - return self.attribute.model == model_type - - -@dataclass -class AttributeReferenceMaker: - """Helper class to easily create attribute references.""" - - _attr_ref_maker_model: ModelReference - - def __getattr__(self, __name: str) -> AttributeReference: - return AttributeReference(self._attr_ref_maker_model, __name) - - -SelfRef = AttributeReferenceMaker("self") - - -InstanceRef = AttributeReferenceMaker("instance") - - -SolutionRef = AttributeReferenceMaker("solution") - - class Encodable(ABC): """Represents data that docker containers can interact with.""" @@ -373,16 +215,6 @@ def from_exception(cls, error: Exception) -> Self: ) -def count_positional_params(sig: Signature) -> int: - """Counts the number of positional parameters in a signature.""" - return sum(1 for param in sig.parameters.values() if can_be_positional(param)) - - -def can_be_positional(param: Parameter) -> bool: - """Checks whether a parameter is positional.""" - return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD) - - class TempDir(TemporaryDirectory): """Python's `TemporaryDirectory` but with a contextmanager returning a Path.""" diff --git a/tests/test_types.py b/tests/test_types.py index 0716d598..26507683 100644 --- a/tests/test_types.py +++ b/tests/test_types.py @@ -5,8 +5,8 @@ from pydantic import ValidationError -from algobattle.problem import InstanceModel -from algobattle.util import AttributeReference, Role, SelfRef +from algobattle.problem import InstanceModel, AttributeReference, SelfRef +from algobattle.util import Role from algobattle.types import Ge, Interval, LaxComp, SizeIndex, UniqueItems