Skip to content

Commit

Permalink
remove type vars from Problem
Browse files Browse the repository at this point in the history
  • Loading branch information
ImogenBits committed Sep 30, 2023
1 parent e3a6a08 commit 6678af3
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 26 deletions.
4 changes: 2 additions & 2 deletions algobattle/battle.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
ProgramUi,
Solver,
)
from algobattle.problem import AnyProblem
from algobattle.problem import Problem
from algobattle.util import Encodable, ExceptionInfo, BaseModel


Expand Down Expand Up @@ -106,7 +106,7 @@ async def inner(self: "FightHandler", *args: P.args, **kwargs: P.kwargs) -> Figh
class FightHandler:
"""Helper class to run fights of a given battle."""

problem: AnyProblem
problem: Problem
generator: Generator
solver: Solver
battle: "Battle"
Expand Down
6 changes: 3 additions & 3 deletions algobattle/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from algobattle.battle import Battle, FightHandler, FightUi, BattleUi, Iterated
from algobattle.program import ProgramConfigView, ProgramUi, Matchup, TeamHandler, BuildUi
from algobattle.problem import InstanceT, Problem, SolutionT
from algobattle.problem import Problem
from algobattle.util import (
ExceptionInfo,
Role,
Expand Down Expand Up @@ -74,7 +74,7 @@ async def _run_battle(
battle: Battle,
matchup: Matchup,
config: "AlgobattleConfig",
problem: Problem[InstanceT, SolutionT],
problem: Problem,
cpus: list[str | None],
ui: "Ui",
limiter: CapacityLimiter,
Expand Down Expand Up @@ -634,7 +634,7 @@ def check_problem_defined(self) -> Self:
return self

@cached_property
def problem(self) -> Problem[Any, Any]:
def problem(self) -> Problem:
"""The problem this config uses."""
return Problem.load(self.match.problem, self.problems)

Expand Down
31 changes: 21 additions & 10 deletions algobattle/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,7 @@ class DynamicProblemInfo(Protocol):
location: Path


class Problem(Generic[InstanceT, SolutionT]):
class Problem:
"""The definition of a problem."""

@overload
Expand Down Expand Up @@ -291,40 +291,52 @@ def __init__(
self._problems[name] = self

__slots__ = ("name", "instance_cls", "solution_cls", "min_size", "with_solution", "score_function", "test_instance")
_problems: "ClassVar[dict[str, AnyProblem]]" = {}
_problems: ClassVar[dict[str, Self]] = {}

@overload
def score(self, instance: InstanceT, *, solution: SolutionT) -> float:
def score(self, instance: InstanceT, *, solution: Solution[InstanceT]) -> float:
...

@overload
def score(self, instance: InstanceT, *, generator_solution: SolutionT, solver_solution: SolutionT) -> float:
def score(
self, instance: InstanceT, *, generator_solution: Solution[InstanceT], solver_solution: Solution[InstanceT]
) -> float:
...

def score(
self,
instance: InstanceT,
instance: Instance,
*,
solution: SolutionT | None = None,
generator_solution: SolutionT | None = None,
solver_solution: SolutionT | None = None,
) -> float:
"""Helper function to call self.score_function with easier to use overloads."""
if self.with_solution:
if solution is not None or generator_solution is None or solver_solution is None:
if not (
isinstance(instance, self.instance_cls)
and isinstance(generator_solution, self.solution_cls)
and isinstance(solver_solution, self.solution_cls)
and solution is None
):
raise TypeError
if TYPE_CHECKING:
assert isinstance(self.score_function, ScoreFunctionWithSol)
return self.score_function(instance, generator_solution=generator_solution, solver_solution=solver_solution)
else:
if solution is None or generator_solution is not None or solver_solution is not None:
if not (
isinstance(instance, self.instance_cls)
and isinstance(solution, self.solution_cls)
and generator_solution is None
and solver_solution is None
):
raise TypeError
if TYPE_CHECKING:
assert isinstance(self.score_function, ScoreFunctionNoSol)
return self.score_function(instance, solution=solution)

@classmethod
def load_file(cls, name: str, file: Path) -> "AnyProblem":
def load_file(cls, name: str, file: Path) -> Self:
"""Loads the problem from the specified file."""
existing_problems = cls._problems.copy()
import_file_as_module(file, "__algobattle_problem__")
Expand All @@ -335,7 +347,7 @@ def load_file(cls, name: str, file: Path) -> "AnyProblem":
return cls._problems[name]

@classmethod
def load(cls, name: str, dynamic: Mapping[str, DynamicProblemInfo]) -> "AnyProblem":
def load(cls, name: str, dynamic: Mapping[str, DynamicProblemInfo]) -> Self:
"""Loads the problem with the given name.
Args:
Expand Down Expand Up @@ -373,7 +385,6 @@ def available(cls) -> set[str]:
return set(chain(cls._problems.keys(), (e.name for e in entry_points(group="algobattle.problem"))))


AnyProblem = Problem[Any, Any]
ModelType = Literal["instance", "solution"]
ModelReference = ModelType | Literal["self"]

Expand Down
21 changes: 10 additions & 11 deletions algobattle/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,15 @@
Role,
BaseModel,
)
from algobattle.problem import AnyProblem, Instance, Solution


AnySolution = Solution[Instance]
from algobattle.problem import Problem, Instance, Solution


_client_var: DockerClient | None = None


T = TypeVar("T")
_I = TypeVar("_I")
_S = TypeVar("_S")


def client() -> DockerClient:
Expand Down Expand Up @@ -186,14 +185,14 @@ class GeneratorResult(ProgramResult):
"""Result of a single generator execution."""

instance: Instance | None = None
solution: AnySolution | None = None
solution: Solution[Instance] | None = None


@dataclass
class SolverResult(ProgramResult):
"""Result of a single solver execution."""

solution: AnySolution | None = None
solution: Solution[Instance] | None = None


@dataclass
Expand All @@ -210,7 +209,7 @@ class Program(ABC):

id: str
"""The id of the Docker image."""
problem: AnyProblem
problem: Problem
"""The problem this program generates/solves."""
config: ProgramConfigView
"""Config settings used for this program."""
Expand Down Expand Up @@ -248,7 +247,7 @@ async def build(
cls,
path: Path,
*,
problem: AnyProblem,
problem: Problem,
config: ProgramConfigView,
team_name: str | None = None,
) -> Self:
Expand Down Expand Up @@ -601,7 +600,7 @@ def _encode_input(self, input: Path, max_size: int, instance: Instance | None) -
assert instance is not None
instance.encode(input / "instance", self.role)

def _parse_output(self, output: Path, max_size: int, instance: Instance | None) -> AnySolution:
def _parse_output(self, output: Path, max_size: int, instance: Instance | None) -> Solution[Instance]:
assert instance is not None
try:
solution = self.problem.solution_cls.decode(output / "solution", max_size, self.role, instance)
Expand Down Expand Up @@ -729,7 +728,7 @@ async def build(
cls,
name: str,
info: _TeamInfo,
problem: AnyProblem,
problem: Problem,
config: ProgramConfigView,
ui: BuildUi,
) -> "Team":
Expand Down Expand Up @@ -825,7 +824,7 @@ class TeamHandler:
async def build(
cls,
infos: Mapping[str, _TeamInfo],
problem: AnyProblem,
problem: Problem,
config: ProgramConfigView,
ui: BuildUi,
) -> Self:
Expand Down

0 comments on commit 6678af3

Please sign in to comment.