From df0fd37364787b7f8f5ec01c4a563a1e9357083c Mon Sep 17 00:00:00 2001 From: Imogen Date: Fri, 5 Jan 2024 15:22:53 +0100 Subject: [PATCH] split Encodable defn into seperate specializations for Instances/BattleData and Solutions --- algobattle/problem.py | 20 ++++++------ algobattle/program.py | 4 ++- algobattle/util.py | 71 ++++++++++++++++++++++++------------------- tests/test_battles.py | 8 +++-- 4 files changed, 59 insertions(+), 44 deletions(-) diff --git a/algobattle/problem.py b/algobattle/problem.py index 28ee5dad..c121ff6b 100644 --- a/algobattle/problem.py +++ b/algobattle/problem.py @@ -38,7 +38,9 @@ ) from algobattle.util import ( + EncodableBase, EncodableModel, + EncodableModelBase, Role, Encodable, import_file_as_module, @@ -70,12 +72,12 @@ def validate_instance(self) -> None: P = ParamSpec("P") -class Solution(Encodable, Generic[InstanceT], ABC): +class Solution(EncodableBase, Generic[InstanceT], ABC): """A proposed solution for an instance of this problem.""" @classmethod @abstractmethod - def decode(cls, source: Path, max_size: int, role: Role, instance: InstanceT | None = None) -> Self: # noqa: D102 + def decode(cls, source: Path, max_size: int, role: Role, instance: InstanceT) -> Self: # noqa: D102 raise NotImplementedError def validate_solution(self, instance: InstanceT, role: Role) -> None: @@ -507,7 +509,7 @@ def __getattr__(self, __name: str) -> AttributeReference: SolutionRef = AttributeReferenceMaker("solution") -class InstanceSolutionModel(EncodableModel): +class InstanceSolutionModel(EncodableModelBase): """Base class for Instance and solution models.""" @classmethod @@ -550,19 +552,17 @@ def _validate_with_self(cls, model_type: ModelType) -> bool: return False -class InstanceModel(Instance, InstanceSolutionModel, ABC): +class InstanceModel(InstanceSolutionModel, EncodableModel, Instance, ABC): """An instance that can easily be parsed to/from a json file.""" pass -class SolutionModel(Solution[InstanceT], InstanceSolutionModel, ABC): +class SolutionModel(InstanceSolutionModel, Solution[InstanceT], ABC): """A solution that can easily be parsed to/from a json file.""" @classmethod - def decode(cls, source: Path, max_size: int, role: Role, instance: InstanceT | None = None) -> Self: + def decode(cls, source: Path, max_size: int, role: Role, instance: InstanceT) -> Self: """Uses pydantic to create a python object from a `.json` file.""" - context: dict[str, Any] = {"max_size": max_size, "role": role} - if instance is not None: - context["instance"] = instance - return cls._decode(source, **context) + context: dict[str, Any] = {"max_size": max_size, "role": role, "instance": instance} + return cls._decode(cls, source, **context) diff --git a/algobattle/program.py b/algobattle/program.py index 9788dae1..5d3b3e67 100644 --- a/algobattle/program.py +++ b/algobattle/program.py @@ -544,7 +544,9 @@ async def run( ) if self.problem.with_solution: try: - solution = self.problem.solution_cls.decode(io.output / "solution", max_size, self.role) + solution = self.problem.solution_cls.decode( + io.output / "solution", max_size, self.role, instance + ) except EncodingError: raise except Exception as e: diff --git a/algobattle/util.py b/algobattle/util.py index f093972d..dd45d221 100644 --- a/algobattle/util.py +++ b/algobattle/util.py @@ -29,36 +29,18 @@ class Role(StrEnum): solver = "solver" -T = TypeVar("T") - - class BaseModel(PydandticBaseModel): """Base class for all pydantic models.""" model_config = ConfigDict(extra="forbid", from_attributes=True, hide_input_in_errors=True) -class Encodable(ABC): - """Represents data that docker containers can interact with.""" - - @classmethod - @abstractmethod - def decode(cls, source: Path, max_size: int, role: Role) -> Self: - """Decodes the data found at the given path into a python object. - - Args: - source: Path to data that can be used to construct an instance of this class. May either point to a folder - or a single file. The expected type of path should be consistent with the result of :meth:`.encode`. - max_size: Maximum size the current battle allows. - role: Role of the program that generated this data. +T = TypeVar("T") +ModelT = TypeVar("ModelT", bound=BaseModel) - Raises: - EncodingError: If the data cannot be decoded into an instance. - Returns: - The decoded object. - """ - raise NotImplementedError +class EncodableBase(ABC): + """Base for Encodable and Solution.""" @abstractmethod def encode(self, target: Path, role: Role) -> None: @@ -89,27 +71,45 @@ def io_schema(cls) -> str | None: return None -class EncodableModel(BaseModel, Encodable, ABC): - """Problem data that can easily be encoded into and decoded from json files.""" +class Encodable(EncodableBase, ABC): + """Represents data that docker containers can interact with.""" @classmethod - def _decode(cls, source: Path, **context: Any) -> Self: + @abstractmethod + def decode(cls, source: Path, max_size: int, role: Role) -> Self: + """Decodes the data found at the given path into a python object. + + Args: + source: Path to data that can be used to construct an instance of this class. May either point to a folder + or a single file. The expected type of path should be consistent with the result of :meth:`.encode`. + max_size: Maximum size the current battle allows. + role: Role of the program that generated this data. + + Raises: + EncodingError: If the data cannot be decoded into an instance. + + Returns: + The decoded object. + """ + raise NotImplementedError + + +class EncodableModelBase(BaseModel, EncodableBase, ABC): + """Base class for EncodableModel and SolutionModel.""" + + @staticmethod + def _decode(model_cls: type[ModelT], source: Path, **context: Any) -> ModelT: """Internal method used by .decode to let Solutions also accept the corresponding instance.""" if not source.with_suffix(".json").is_file(): raise EncodingError("The json file does not exist.") try: with open(source.with_suffix(".json"), "r") as f: - return cls.model_validate_json(f.read(), context=context) + return model_cls.model_validate_json(f.read(), context=context) except PydanticValidationError as e: raise EncodingError("Json data does not fit the schema.", detail=str(e)) except Exception as e: raise EncodingError("Unknown error while decoding the data.", detail=str(e)) - @classmethod - def decode(cls, source: Path, max_size: int, role: Role) -> Self: - """Uses pydantic to create a python object from a `.json` file.""" - return cls._decode(source, max_size=max_size, role=role) - def encode(self, target: Path, role: Role) -> None: """Uses pydantic to create a json representation of the object at the targeted file.""" try: @@ -124,6 +124,15 @@ def io_schema(cls) -> str: return json.dumps(cls.model_json_schema(), indent=4) +class EncodableModel(EncodableModelBase, ABC): + """Problem data that can easily be encoded into and decoded from json files.""" + + @classmethod + def decode(cls, source: Path, max_size: int, role: Role) -> Self: + """Uses pydantic to create a python object from a `.json` file.""" + return cls._decode(cls, source, max_size=max_size, role=role) + + @dataclass class RunningTimer: """Basic data holding info on a currently running timer.""" diff --git a/tests/test_battles.py b/tests/test_battles.py index 4e60c2fe..93115ffd 100644 --- a/tests/test_battles.py +++ b/tests/test_battles.py @@ -353,23 +353,27 @@ def test_encode_instance(self) -> None: self.assertEqual(decoded, first if num == 0 else second) def test_encode_witness(self) -> None: + instance = self.history.history[0].generator.instance + assert isinstance(instance, TestInstance) first = self.history.history[0].generator.solution second = self.history.history[1].generator.solution with TempDir() as target: for num, folder, should_exist in self._encode_attr(target, "gen_sols"): self.assertEqual(folder.joinpath("generator_solution.json").exists(), should_exist) if should_exist: - decoded = TestSolution.decode(folder / "generator_solution.json", 25, Role.generator) + decoded = TestSolution.decode(folder / "generator_solution.json", 25, Role.generator, instance) self.assertEqual(decoded, first if num == 0 else second) def test_encode_solution(self) -> None: + instance = self.history.history[0].generator.instance + assert isinstance(instance, TestInstance) first = cast(SolverResult, self.history.history[0].solver).solution second = cast(SolverResult, self.history.history[1].solver).solution with TempDir() as target: for num, folder, should_exist in self._encode_attr(target, "sol_sols"): self.assertEqual(folder.joinpath("solver_solution.json").exists(), should_exist) if should_exist: - decoded = TestSolution.decode(folder / "solver_solution.json", 25, Role.generator) + decoded = TestSolution.decode(folder / "solver_solution.json", 25, Role.generator, instance) self.assertEqual(decoded, first if num == 0 else second)