Skip to content

Commit

Permalink
Merge pull request #129 from ImogenBits/cleanup
Browse files Browse the repository at this point in the history
Cleanup
  • Loading branch information
Benezivas authored Oct 1, 2023
2 parents 7e0bbec + 4c64437 commit 766f837
Show file tree
Hide file tree
Showing 7 changed files with 207 additions and 237 deletions.
6 changes: 3 additions & 3 deletions algobattle/battle.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
ProgramUi,
Solver,
)
from algobattle.problem import AnyProblem
from algobattle.problem import Problem
from algobattle.util import Encodable, ExceptionInfo, BaseModel


Expand Down Expand Up @@ -106,7 +106,7 @@ async def inner(self: "FightHandler", *args: P.args, **kwargs: P.kwargs) -> Figh
class FightHandler:
"""Helper class to run fights of a given battle."""

problem: AnyProblem
problem: Problem
generator: Generator
solver: Solver
battle: "Battle"
Expand Down Expand Up @@ -237,7 +237,7 @@ class Config(BaseModel):
:meth:`Battle.run` method with its fields set accordingly.
"""

type: str
type: Any
"""Type of battle that will be used."""

@classmethod
Expand Down
6 changes: 3 additions & 3 deletions algobattle/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from algobattle.battle import Battle, FightHandler, FightUi, BattleUi, Iterated
from algobattle.program import ProgramConfigView, ProgramUi, Matchup, TeamHandler, BuildUi
from algobattle.problem import InstanceT, Problem, SolutionT
from algobattle.problem import Problem
from algobattle.util import (
ExceptionInfo,
Role,
Expand Down Expand Up @@ -74,7 +74,7 @@ async def _run_battle(
battle: Battle,
matchup: Matchup,
config: "AlgobattleConfig",
problem: Problem[InstanceT, SolutionT],
problem: Problem,
cpus: list[str | None],
ui: "Ui",
limiter: CapacityLimiter,
Expand Down Expand Up @@ -634,7 +634,7 @@ def check_problem_defined(self) -> Self:
return self

@cached_property
def problem(self) -> Problem[Any, Any]:
def problem(self) -> Problem:
"""The problem this config uses."""
return Problem.load(self.match.problem, self.problems)

Expand Down
206 changes: 190 additions & 16 deletions algobattle/problem.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""Module defining the Problem and Solution base classes and related objects."""
from abc import ABC, abstractmethod
from dataclasses import dataclass
from functools import wraps
from importlib.metadata import entry_points
from inspect import Parameter, Signature, signature
from itertools import chain
from pathlib import Path
from typing import (
Expand All @@ -17,12 +19,21 @@
Generic,
TypeVar,
overload,
cast,
get_args,
)
from math import inf, isnan
from annotated_types import GroupedMetadata

from pydantic import (
GetCoreSchemaHandler,
ValidationInfo,
)
from pydantic_core import CoreSchema
from pydantic_core.core_schema import with_info_after_validator_function

from algobattle.util import (
EncodableModel,
InstanceSolutionModel,
Role,
Encodable,
import_file_as_module,
Expand Down Expand Up @@ -210,7 +221,7 @@ class DynamicProblemInfo(Protocol):
location: Path


class Problem(Generic[InstanceT, SolutionT]):
class Problem:
"""The definition of a problem."""

@overload
Expand Down Expand Up @@ -280,40 +291,52 @@ def __init__(
self._problems[name] = self

__slots__ = ("name", "instance_cls", "solution_cls", "min_size", "with_solution", "score_function", "test_instance")
_problems: "ClassVar[dict[str, AnyProblem]]" = {}
_problems: ClassVar[dict[str, Self]] = {}

@overload
def score(self, instance: InstanceT, *, solution: SolutionT) -> float:
def score(self, instance: InstanceT, *, solution: Solution[InstanceT]) -> float:
...

@overload
def score(self, instance: InstanceT, *, generator_solution: SolutionT, solver_solution: SolutionT) -> float:
def score(
self, instance: InstanceT, *, generator_solution: Solution[InstanceT], solver_solution: Solution[InstanceT]
) -> float:
...

def score(
self,
instance: InstanceT,
instance: Instance,
*,
solution: SolutionT | None = None,
generator_solution: SolutionT | None = None,
solver_solution: SolutionT | None = None,
) -> float:
"""Helper function to call self.score_function with easier to use overloads."""
if self.with_solution:
if solution is not None or generator_solution is None or solver_solution is None:
if not (
isinstance(instance, self.instance_cls)
and isinstance(generator_solution, self.solution_cls)
and isinstance(solver_solution, self.solution_cls)
and solution is None
):
raise TypeError
if TYPE_CHECKING:
assert isinstance(self.score_function, ScoreFunctionWithSol)
return self.score_function(instance, generator_solution=generator_solution, solver_solution=solver_solution)
else:
if solution is None or generator_solution is not None or solver_solution is not None:
if not (
isinstance(instance, self.instance_cls)
and isinstance(solution, self.solution_cls)
and generator_solution is None
and solver_solution is None
):
raise TypeError
if TYPE_CHECKING:
assert isinstance(self.score_function, ScoreFunctionNoSol)
return self.score_function(instance, solution=solution)

@classmethod
def load_file(cls, name: str, file: Path) -> "AnyProblem":
def load_file(cls, name: str, file: Path) -> Self:
"""Loads the problem from the specified file."""
existing_problems = cls._problems.copy()
import_file_as_module(file, "__algobattle_problem__")
Expand All @@ -324,7 +347,7 @@ def load_file(cls, name: str, file: Path) -> "AnyProblem":
return cls._problems[name]

@classmethod
def load(cls, name: str, dynamic: Mapping[str, DynamicProblemInfo]) -> "AnyProblem":
def load(cls, name: str, dynamic: Mapping[str, DynamicProblemInfo]) -> Self:
"""Loads the problem with the given name.
Args:
Expand Down Expand Up @@ -362,20 +385,171 @@ def available(cls) -> set[str]:
return set(chain(cls._problems.keys(), (e.name for e in entry_points(group="algobattle.problem"))))


AnyProblem = Problem[Any, Any]
ModelType = Literal["instance", "solution"]
ModelReference = ModelType | Literal["self"]


@dataclass(frozen=True, slots=True)
class AttributeReference:
"""Creates a reference to the attribute of a model to be used in validaton schemas."""

model: ModelReference
attribute: str

def get_value(self, info: ValidationInfo) -> Any | None:
"""Returns the referenced value from the correct object in the info context.
If the correct object is not in the context or doesn't have the referenced attribute it returns None.
"""
if info.context is None or self.model not in info.context:
return None
model = info.context[self.model]
if hasattr(model, self.attribute):
return getattr(model, self.attribute)
else:
return None

def __str__(self) -> str:
return f"{self.model}.{self.attribute}"

def needs_self(self, model_type: Literal["instance", "solution"]) -> bool:
"""Checks if an attribute reference needs a reference to the current model in order to be resolved."""
if self.model == "self":
return True
else:
return self.model == model_type


NoInfoAttrValidatorFunction = Callable[[Any, Any], Any]
GeneralAttrValidatorFunction = Callable[[Any, Any, ValidationInfo], Any]
AttrValidatorFunction = NoInfoAttrValidatorFunction | GeneralAttrValidatorFunction


def count_positional_params(sig: Signature) -> int:
"""Counts the number of positional parameters in a signature."""
return sum(1 for param in sig.parameters.values() if can_be_positional(param))


def can_be_positional(param: Parameter) -> bool:
"""Checks whether a parameter is positional."""
return param.kind in (Parameter.POSITIONAL_ONLY, Parameter.POSITIONAL_OR_KEYWORD)


def is_info_validator(validator: AttrValidatorFunction) -> bool:
"""Helper method to discriminate the union."""
match count_positional_params(signature(validator)):
case 2:
return False
case 3:
return True
case _:
raise TypeError


@dataclass(frozen=True, slots=True)
class AttributeReferenceValidator:
"""An AfterValidator that can resolve a reference to a model attribute and pass it to the validator function.
Using this with a reference to an attribute in the model it is defined may significantly impact performance.
"""

class InstanceModel(Instance, EncodableModel, InstanceSolutionModel, ABC):
func: AttrValidatorFunction
attribute: AttributeReference

def __get_pydantic_core_schema__(self, source_type: Any, handler: GetCoreSchemaHandler) -> CoreSchema:
schema = handler(source_type)
info_arg = is_info_validator(self.func)
if info_arg:
func = cast(GeneralAttrValidatorFunction, self.func)

def wrapper(value: Any, info: ValidationInfo) -> Any:
attribute_val = self.attribute.get_value(info)
if attribute_val is None:
return value
return func(value, attribute_val, info)

else:
func = cast(NoInfoAttrValidatorFunction, self.func)

def wrapper(value: Any, info: ValidationInfo) -> Any:
attribute_val = self.attribute.get_value(info)
if attribute_val is None:
return value
return func(value, attribute_val)

return with_info_after_validator_function(wrapper, schema=schema)

def needs_self(self, model_type: ModelType) -> bool:
"""Checks if the validator needs a reference to the current model in order to work fully."""
if self.attribute.model == "self":
return True
else:
return self.attribute.model == model_type


@dataclass
class AttributeReferenceMaker:
"""Helper class to easily create attribute references."""

_attr_ref_maker_model: ModelReference

def __getattr__(self, __name: str) -> AttributeReference:
return AttributeReference(self._attr_ref_maker_model, __name)


SelfRef = AttributeReferenceMaker("self")
InstanceRef = AttributeReferenceMaker("instance")
SolutionRef = AttributeReferenceMaker("solution")


class InstanceSolutionModel(EncodableModel):
"""Base class for Instance and solution models."""

@classmethod
def model_validate( # noqa: D102
cls,
obj: Any,
*,
strict: bool | None = None,
from_attributes: bool | None = None,
context: dict[str, Any] | None = None,
) -> Self:
model = super().model_validate(obj, strict=strict, from_attributes=from_attributes, context=context)
model_type = "instance" if issubclass(cls, InstanceModel) else "solution"
if cls._validate_with_self(model_type):
context = (context or {}) | {"self": model, model_type: model}
model = super().model_validate(obj, context=context)
return model

@classmethod
def _annotation_needs_self(cls, annotation: object, model_type: ModelType) -> bool:
if isinstance(annotation, AttributeReferenceValidator):
return annotation.needs_self(model_type)
if isinstance(annotation, GroupedMetadata):
return any(cls._annotation_needs_self(e, model_type) for e in annotation)
return any(cls._annotation_needs_self(e, model_type) for e in get_args(annotation))

@classmethod
def _validate_with_self(cls, model_type: ModelType) -> bool:
# info.annotation contains the type and any nested metadata, info.metadata the top level metadata
# we can use _annotation_needs_self for all of them, so we iterate over all fields and see if any of them
# either have an annotation or metadata we need to parse with a self reference
for info in cls.model_fields.values():
values = chain((info.annotation,), info.metadata)
if any(cls._annotation_needs_self(value, model_type) for value in values):
return True
return False


class InstanceModel(Instance, InstanceSolutionModel, ABC):
"""An instance that can easily be parsed to/from a json file."""

_algobattle_model_type: ClassVar[Literal["instance"]] = "instance"
pass


class SolutionModel(Solution[InstanceT], EncodableModel, InstanceSolutionModel, ABC):
class SolutionModel(Solution[InstanceT], InstanceSolutionModel, ABC):
"""A solution that can easily be parsed to/from a json file."""

_algobattle_model_type: ClassVar[Literal["solution"]] = "solution"

@classmethod
def decode(cls, source: Path, max_size: int, role: Role, instance: InstanceT | None = None) -> Self:
"""Uses pydantic to create a python object from a `.json` file."""
Expand Down
Loading

0 comments on commit 766f837

Please sign in to comment.