Skip to content

Commit

Permalink
move reference validation code to problem module
Browse files Browse the repository at this point in the history
  • Loading branch information
ImogenBits committed Sep 30, 2023
1 parent 7e0bbec commit e3a6a08
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 183 deletions.
2 changes: 1 addition & 1 deletion algobattle/battle.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,7 @@ class Config(BaseModel):
:meth:`Battle.run` method with its fields set accordingly.
"""

type: str
type: Any
"""Type of battle that will be used."""

@classmethod
Expand Down
175 changes: 169 additions & 6 deletions algobattle/problem.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Module defining the Problem and Solution base classes and related objects."""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import wraps
from importlib.metadata import entry_points
from inspect import Parameter, Signature, signature
from itertools import chain
from pathlib import Path
from typing import (
Expand All @@ -17,12 +19,21 @@
Generic,
TypeVar,
overload,
cast,
get_args,
)
from math import inf, isnan
from annotated_types import GroupedMetadata

from pydantic import (
GetCoreSchemaHandler,
ValidationInfo,
)
from pydantic_core import CoreSchema
from pydantic_core.core_schema import with_info_after_validator_function

from algobattle.util import (
EncodableModel,
InstanceSolutionModel,
Role,
Encodable,
import_file_as_module,
Expand Down Expand Up @@ -363,19 +374,171 @@ def available(cls) -> set[str]:


AnyProblem = Problem[Any, Any]
ModelType = Literal["instance", "solution"]
ModelReference = ModelType | Literal["self"]


@dataclass(frozen=True, slots=True)
class AttributeReference:
"""Creates a reference to the attribute of a model to be used in validaton schemas."""

model: ModelReference
attribute: str

def get_value(self, info: ValidationInfo) -> Any | None:
"""Returns the referenced value from the correct object in the info context.
If the correct object is not in the context or doesn't have the referenced attribute it returns None.
"""
if info.context is None or self.model not in info.context:
return None
model = info.context[self.model]
if hasattr(model, self.attribute):
return getattr(model, self.attribute)
else:
return None

def __str__(self) -> str:
return f"{self.model}.{self.attribute}"

def needs_self(self, model_type: Literal["instance", "solution"]) -> bool:
"""Checks if an attribute reference needs a reference to the current model in order to be resolved."""
if self.model == "self":
return True
else:
return self.model == model_type


NoInfoAttrValidatorFunction = Callable[[Any, Any], Any]
GeneralAttrValidatorFunction = Callable[[Any, Any, ValidationInfo], Any]
AttrValidatorFunction = NoInfoAttrValidatorFunction | GeneralAttrValidatorFunction


def count_positional_params(sig: Signature) -> int:
"""Counts the number of positional parameters in a signature."""
return sum(1 for param in sig.parameters.values() if can_be_positional(param))


def can_be_positional(param: Parameter) -> bool:
"""Checks whether a parameter is positional."""
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)


def is_info_validator(validator: AttrValidatorFunction) -> bool:
"""Helper method to discriminate the union."""
match count_positional_params(signature(validator)):
case 2:
return False
case 3:
return True
case _:
raise TypeError


@dataclass(frozen=True, slots=True)
class AttributeReferenceValidator:
"""An AfterValidator that can resolve a reference to a model attribute and pass it to the validator function.
Using this with a reference to an attribute in the model it is defined may significantly impact performance.
"""

func: AttrValidatorFunction
attribute: AttributeReference

class InstanceModel(Instance, EncodableModel, InstanceSolutionModel, ABC):
def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
schema = handler(source_type)
info_arg = is_info_validator(self.func)
if info_arg:
func = cast(GeneralAttrValidatorFunction, self.func)

def wrapper(value: Any, info: ValidationInfo) -> Any:
attribute_val = self.attribute.get_value(info)
if attribute_val is None:
return value
return func(value, attribute_val, info)

else:
func = cast(NoInfoAttrValidatorFunction, self.func)

def wrapper(value: Any, info: ValidationInfo) -> Any:
attribute_val = self.attribute.get_value(info)
if attribute_val is None:
return value
return func(value, attribute_val)

return with_info_after_validator_function(wrapper, schema=schema)

def needs_self(self, model_type: ModelType) -> bool:
"""Checks if the validator needs a reference to the current model in order to work fully."""
if self.attribute.model == "self":
return True
else:
return self.attribute.model == model_type


@dataclass
class AttributeReferenceMaker:
"""Helper class to easily create attribute references."""

_attr_ref_maker_model: ModelReference

def __getattr__(self, __name: str) -> AttributeReference:
return AttributeReference(self._attr_ref_maker_model, __name)


SelfRef = AttributeReferenceMaker("self")
InstanceRef = AttributeReferenceMaker("instance")
SolutionRef = AttributeReferenceMaker("solution")


class InstanceSolutionModel(EncodableModel):
"""Base class for Instance and solution models."""

@classmethod
def model_validate( # noqa: D102
cls,
obj: Any,
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: dict[str, Any] | None = None,
) -> Self:
model = super().model_validate(obj, strict=strict, from_attributes=from_attributes, context=context)
model_type = "instance" if issubclass(cls, InstanceModel) else "solution"
if cls._validate_with_self(model_type):
context = (context or {}) | {"self": model, model_type: model}
model = super().model_validate(obj, context=context)
return model

@classmethod
def _annotation_needs_self(cls, annotation: object, model_type: ModelType) -> bool:
if isinstance(annotation, AttributeReferenceValidator):
return annotation.needs_self(model_type)
if isinstance(annotation, GroupedMetadata):
return any(cls._annotation_needs_self(e, model_type) for e in annotation)
return any(cls._annotation_needs_self(e, model_type) for e in get_args(annotation))

@classmethod
def _validate_with_self(cls, model_type: ModelType) -> bool:
# info.annotation contains the type and any nested metadata, info.metadata the top level metadata
# we can use _annotation_needs_self for all of them, so we iterate over all fields and see if any of them
# either have an annotation or metadata we need to parse with a self reference
for info in cls.model_fields.values():
values = chain((info.annotation,), info.metadata)
if any(cls._annotation_needs_self(value, model_type) for value in values):
return True
return False


class InstanceModel(Instance, InstanceSolutionModel, ABC):
"""An instance that can easily be parsed to/from a json file."""

_algobattle_model_type: ClassVar[Literal["instance"]] = "instance"
pass


class SolutionModel(Solution[InstanceT], EncodableModel, InstanceSolutionModel, ABC):
class SolutionModel(Solution[InstanceT], InstanceSolutionModel, ABC):
"""A solution that can easily be parsed to/from a json file."""

_algobattle_model_type: ClassVar[Literal["solution"]] = "solution"

@classmethod
def decode(cls, source: Path, max_size: int, role: Role, instance: InstanceT | None = None) -> Self:
"""Uses pydantic to create a python object from a `.json` file."""
Expand Down
6 changes: 1 addition & 5 deletions algobattle/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,11 @@
from algobattle.problem import (
InstanceModel,
SolutionModel,
)
from algobattle.util import (
BaseModel,
Role,
AttributeReference,
AttributeReferenceValidator,
InstanceRef,
ValidationError,
)
from algobattle.util import BaseModel, Role, ValidationError

__all__ = (
"u64",
Expand Down
Loading

0 comments on commit e3a6a08

Please sign in to comment.