Skip to content

Commit

Permalink
Merge pull request #161 from ImogenBits/bugs
Browse files Browse the repository at this point in the history
Small bug fixes
  • Loading branch information
Benezivas authored Jan 7, 2024
2 parents 068d85a + df0fd37 commit aefbb6b
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 47 deletions.
2 changes: 1 addition & 1 deletion algobattle/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ def run_match(
result = Match(config=config)
try:
with CliUi(result, config) if ui else EmptyUi() as ui_obj:
run_async_fn(result.run, config, ui_obj)
run_async_fn(result.run, ui_obj)
except DockerNotRunning:
console.print("[error]Could not connect to the Docker Daemon.[/] Is Docker running?")
save = False
Expand Down
22 changes: 11 additions & 11 deletions algobattle/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@
)

from algobattle.util import (
EncodableBase,
EncodableModel,
EncodableModelBase,
Role,
Encodable,
import_file_as_module,
Expand Down Expand Up @@ -70,12 +72,12 @@ def validate_instance(self) -> None:
P = ParamSpec("P")


class Solution(Encodable, Generic[InstanceT], ABC):
class Solution(EncodableBase, Generic[InstanceT], ABC):
"""A proposed solution for an instance of this problem."""

@classmethod
@abstractmethod
def decode(cls, source: Path, max_size: int, role: Role, instance: InstanceT | None = None) -> Self: # noqa: D102
def decode(cls, source: Path, max_size: int, role: Role, instance: InstanceT) -> Self: # noqa: D102
raise NotImplementedError

def validate_solution(self, instance: InstanceT, role: Role) -> None:
Expand Down Expand Up @@ -174,7 +176,7 @@ def __call__(self, instance: _I, *, solution: _S) -> float:


@overload
def default_score(instance: Instance, *, solution: Solution[Instance]) -> float:
def default_score(instance: Instance, *, solution: Solution[Any]) -> float:
...


Expand Down Expand Up @@ -507,7 +509,7 @@ def __getattr__(self, __name: str) -> AttributeReference:
SolutionRef = AttributeReferenceMaker("solution")


class InstanceSolutionModel(EncodableModel):
class InstanceSolutionModel(EncodableModelBase):
"""Base class for Instance and solution models."""

@classmethod
Expand Down Expand Up @@ -550,19 +552,17 @@ def _validate_with_self(cls, model_type: ModelType) -> bool:
return False


class InstanceModel(Instance, InstanceSolutionModel, ABC):
class InstanceModel(InstanceSolutionModel, EncodableModel, Instance, ABC):
"""An instance that can easily be parsed to/from a json file."""

pass


class SolutionModel(Solution[InstanceT], InstanceSolutionModel, ABC):
class SolutionModel(InstanceSolutionModel, Solution[InstanceT], ABC):
"""A solution that can easily be parsed to/from a json file."""

@classmethod
def decode(cls, source: Path, max_size: int, role: Role, instance: InstanceT | None = None) -> Self:
def decode(cls, source: Path, max_size: int, role: Role, instance: InstanceT) -> Self:
"""Uses pydantic to create a python object from a `.json` file."""
context: dict[str, Any] = {"max_size": max_size, "role": role}
if instance is not None:
context["instance"] = instance
return cls._decode(source, **context)
context: dict[str, Any] = {"max_size": max_size, "role": role, "instance": instance}
return cls._decode(cls, source, **context)
4 changes: 3 additions & 1 deletion algobattle/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -544,7 +544,9 @@ async def run(
)
if self.problem.with_solution:
try:
solution = self.problem.solution_cls.decode(io.output / "solution", max_size, self.role)
solution = self.problem.solution_cls.decode(
io.output / "solution", max_size, self.role, instance
)
except EncodingError:
raise
except Exception as e:
Expand Down
2 changes: 1 addition & 1 deletion algobattle/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,7 +346,7 @@ def In(attribute: AttributeReference) -> AttributeReferenceValidator:
"""Specifies that the value should be `in` some collection."""

def validator(val: Any, attr: Any) -> Any:
if not (val in attr):
if val not in attr:
raise ValueError(f"Value is not contained in collection {attribute}.")
return val

Expand Down
71 changes: 40 additions & 31 deletions algobattle/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,36 +29,18 @@ class Role(StrEnum):
solver = "solver"


T = TypeVar("T")


class BaseModel(PydandticBaseModel):
"""Base class for all pydantic models."""

model_config = ConfigDict(extra="forbid", from_attributes=True, hide_input_in_errors=True)


class Encodable(ABC):
"""Represents data that docker containers can interact with."""

@classmethod
@abstractmethod
def decode(cls, source: Path, max_size: int, role: Role) -> Self:
"""Decodes the data found at the given path into a python object.
Args:
source: Path to data that can be used to construct an instance of this class. May either point to a folder
or a single file. The expected type of path should be consistent with the result of :meth:`.encode`.
max_size: Maximum size the current battle allows.
role: Role of the program that generated this data.
T = TypeVar("T")
ModelT = TypeVar("ModelT", bound=BaseModel)

Raises:
EncodingError: If the data cannot be decoded into an instance.

Returns:
The decoded object.
"""
raise NotImplementedError
class EncodableBase(ABC):
"""Base for Encodable and Solution."""

@abstractmethod
def encode(self, target: Path, role: Role) -> None:
Expand Down Expand Up @@ -89,27 +71,45 @@ def io_schema(cls) -> str | None:
return None


class EncodableModel(BaseModel, Encodable, ABC):
"""Problem data that can easily be encoded into and decoded from json files."""
class Encodable(EncodableBase, ABC):
"""Represents data that docker containers can interact with."""

@classmethod
def _decode(cls, source: Path, **context: Any) -> Self:
@abstractmethod
def decode(cls, source: Path, max_size: int, role: Role) -> Self:
"""Decodes the data found at the given path into a python object.
Args:
source: Path to data that can be used to construct an instance of this class. May either point to a folder
or a single file. The expected type of path should be consistent with the result of :meth:`.encode`.
max_size: Maximum size the current battle allows.
role: Role of the program that generated this data.
Raises:
EncodingError: If the data cannot be decoded into an instance.
Returns:
The decoded object.
"""
raise NotImplementedError


class EncodableModelBase(BaseModel, EncodableBase, ABC):
"""Base class for EncodableModel and SolutionModel."""

@staticmethod
def _decode(model_cls: type[ModelT], source: Path, **context: Any) -> ModelT:
"""Internal method used by .decode to let Solutions also accept the corresponding instance."""
if not source.with_suffix(".json").is_file():
raise EncodingError("The json file does not exist.")
try:
with open(source.with_suffix(".json"), "r") as f:
return cls.model_validate_json(f.read(), context=context)
return model_cls.model_validate_json(f.read(), context=context)
except PydanticValidationError as e:
raise EncodingError("Json data does not fit the schema.", detail=str(e))
except Exception as e:
raise EncodingError("Unknown error while decoding the data.", detail=str(e))

@classmethod
def decode(cls, source: Path, max_size: int, role: Role) -> Self:
"""Uses pydantic to create a python object from a `.json` file."""
return cls._decode(source, max_size=max_size, role=role)

def encode(self, target: Path, role: Role) -> None:
"""Uses pydantic to create a json representation of the object at the targeted file."""
try:
Expand All @@ -124,6 +124,15 @@ def io_schema(cls) -> str:
return json.dumps(cls.model_json_schema(), indent=4)


class EncodableModel(EncodableModelBase, ABC):
"""Problem data that can easily be encoded into and decoded from json files."""

@classmethod
def decode(cls, source: Path, max_size: int, role: Role) -> Self:
"""Uses pydantic to create a python object from a `.json` file."""
return cls._decode(cls, source, max_size=max_size, role=role)


@dataclass
class RunningTimer:
"""Basic data holding info on a currently running timer."""
Expand Down
8 changes: 6 additions & 2 deletions tests/test_battles.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,23 +353,27 @@ def test_encode_instance(self) -> None:
self.assertEqual(decoded, first if num == 0 else second)

def test_encode_witness(self) -> None:
instance = self.history.history[0].generator.instance
assert isinstance(instance, TestInstance)
first = self.history.history[0].generator.solution
second = self.history.history[1].generator.solution
with TempDir() as target:
for num, folder, should_exist in self._encode_attr(target, "gen_sols"):
self.assertEqual(folder.joinpath("generator_solution.json").exists(), should_exist)
if should_exist:
decoded = TestSolution.decode(folder / "generator_solution.json", 25, Role.generator)
decoded = TestSolution.decode(folder / "generator_solution.json", 25, Role.generator, instance)
self.assertEqual(decoded, first if num == 0 else second)

def test_encode_solution(self) -> None:
instance = self.history.history[0].generator.instance
assert isinstance(instance, TestInstance)
first = cast(SolverResult, self.history.history[0].solver).solution
second = cast(SolverResult, self.history.history[1].solver).solution
with TempDir() as target:
for num, folder, should_exist in self._encode_attr(target, "sol_sols"):
self.assertEqual(folder.joinpath("solver_solution.json").exists(), should_exist)
if should_exist:
decoded = TestSolution.decode(folder / "solver_solution.json", 25, Role.generator)
decoded = TestSolution.decode(folder / "solver_solution.json", 25, Role.generator, instance)
self.assertEqual(decoded, first if num == 0 else second)


Expand Down

0 comments on commit aefbb6b

Please sign in to comment.