diff --git a/algobattle/battle.py b/algobattle/battle.py index eb380b14..a167d06a 100644 --- a/algobattle/battle.py +++ b/algobattle/battle.py @@ -42,7 +42,7 @@ ProgramUi, Solver, ) -from algobattle.problem import AnyProblem +from algobattle.problem import Problem from algobattle.util import Encodable, ExceptionInfo, BaseModel @@ -106,7 +106,7 @@ async def inner(self: "FightHandler", *args: P.args, **kwargs: P.kwargs) -> Figh class FightHandler: """Helper class to run fights of a given battle.""" - problem: AnyProblem + problem: Problem generator: Generator solver: Solver battle: "Battle" diff --git a/algobattle/match.py b/algobattle/match.py index 3f9a50b9..ef3b5aba 100644 --- a/algobattle/match.py +++ b/algobattle/match.py @@ -29,7 +29,7 @@ from algobattle.battle import Battle, FightHandler, FightUi, BattleUi, Iterated from algobattle.program import ProgramConfigView, ProgramUi, Matchup, TeamHandler, BuildUi -from algobattle.problem import InstanceT, Problem, SolutionT +from algobattle.problem import Problem from algobattle.util import ( ExceptionInfo, Role, @@ -74,7 +74,7 @@ async def _run_battle( battle: Battle, matchup: Matchup, config: "AlgobattleConfig", - problem: Problem[InstanceT, SolutionT], + problem: Problem, cpus: list[str | None], ui: "Ui", limiter: CapacityLimiter, @@ -634,7 +634,7 @@ def check_problem_defined(self) -> Self: return self @cached_property - def problem(self) -> Problem[Any, Any]: + def problem(self) -> Problem: """The problem this config uses.""" return Problem.load(self.match.problem, self.problems) diff --git a/algobattle/problem.py b/algobattle/problem.py index 8074cd8c..e0f60386 100644 --- a/algobattle/problem.py +++ b/algobattle/problem.py @@ -221,7 +221,7 @@ class DynamicProblemInfo(Protocol): location: Path -class Problem(Generic[InstanceT, SolutionT]): +class Problem: """The definition of a problem.""" @overload @@ -291,19 +291,21 @@ def __init__( self._problems[name] = self __slots__ = ("name", "instance_cls", "solution_cls", "min_size", "with_solution", "score_function", "test_instance") - _problems: "ClassVar[dict[str, AnyProblem]]" = {} + _problems: ClassVar[dict[str, Self]] = {} @overload - def score(self, instance: InstanceT, *, solution: SolutionT) -> float: + def score(self, instance: InstanceT, *, solution: Solution[InstanceT]) -> float: ... @overload - def score(self, instance: InstanceT, *, generator_solution: SolutionT, solver_solution: SolutionT) -> float: + def score( + self, instance: InstanceT, *, generator_solution: Solution[InstanceT], solver_solution: Solution[InstanceT] + ) -> float: ... def score( self, - instance: InstanceT, + instance: Instance, *, solution: SolutionT | None = None, generator_solution: SolutionT | None = None, @@ -311,20 +313,30 @@ def score( ) -> float: """Helper function to call self.score_function with easier to use overloads.""" if self.with_solution: - if solution is not None or generator_solution is None or solver_solution is None: + if not ( + isinstance(instance, self.instance_cls) + and isinstance(generator_solution, self.solution_cls) + and isinstance(solver_solution, self.solution_cls) + and solution is None + ): raise TypeError if TYPE_CHECKING: assert isinstance(self.score_function, ScoreFunctionWithSol) return self.score_function(instance, generator_solution=generator_solution, solver_solution=solver_solution) else: - if solution is None or generator_solution is not None or solver_solution is not None: + if not ( + isinstance(instance, self.instance_cls) + and isinstance(solution, self.solution_cls) + and generator_solution is None + and solver_solution is None + ): raise TypeError if TYPE_CHECKING: assert isinstance(self.score_function, ScoreFunctionNoSol) return self.score_function(instance, solution=solution) @classmethod - def load_file(cls, name: str, file: Path) -> "AnyProblem": + def load_file(cls, name: str, file: Path) -> Self: """Loads the problem from the specified file.""" existing_problems = cls._problems.copy() import_file_as_module(file, "__algobattle_problem__") @@ -335,7 +347,7 @@ def load_file(cls, name: str, file: Path) -> "AnyProblem": return cls._problems[name] @classmethod - def load(cls, name: str, dynamic: Mapping[str, DynamicProblemInfo]) -> "AnyProblem": + def load(cls, name: str, dynamic: Mapping[str, DynamicProblemInfo]) -> Self: """Loads the problem with the given name. Args: @@ -373,7 +385,6 @@ def available(cls) -> set[str]: return set(chain(cls._problems.keys(), (e.name for e in entry_points(group="algobattle.problem")))) -AnyProblem = Problem[Any, Any] ModelType = Literal["instance", "solution"] ModelReference = ModelType | Literal["self"] diff --git a/algobattle/program.py b/algobattle/program.py index 964d0d1d..395e36a9 100644 --- a/algobattle/program.py +++ b/algobattle/program.py @@ -39,16 +39,15 @@ Role, BaseModel, ) -from algobattle.problem import AnyProblem, Instance, Solution - - -AnySolution = Solution[Instance] +from algobattle.problem import Problem, Instance, Solution _client_var: DockerClient | None = None T = TypeVar("T") +_I = TypeVar("_I") +_S = TypeVar("_S") def client() -> DockerClient: @@ -186,14 +185,14 @@ class GeneratorResult(ProgramResult): """Result of a single generator execution.""" instance: Instance | None = None - solution: AnySolution | None = None + solution: Solution[Instance] | None = None @dataclass class SolverResult(ProgramResult): """Result of a single solver execution.""" - solution: AnySolution | None = None + solution: Solution[Instance] | None = None @dataclass @@ -210,7 +209,7 @@ class Program(ABC): id: str """The id of the Docker image.""" - problem: AnyProblem + problem: Problem """The problem this program generates/solves.""" config: ProgramConfigView """Config settings used for this program.""" @@ -248,7 +247,7 @@ async def build( cls, path: Path, *, - problem: AnyProblem, + problem: Problem, config: ProgramConfigView, team_name: str | None = None, ) -> Self: @@ -601,7 +600,7 @@ def _encode_input(self, input: Path, max_size: int, instance: Instance | None) - assert instance is not None instance.encode(input / "instance", self.role) - def _parse_output(self, output: Path, max_size: int, instance: Instance | None) -> AnySolution: + def _parse_output(self, output: Path, max_size: int, instance: Instance | None) -> Solution[Instance]: assert instance is not None try: solution = self.problem.solution_cls.decode(output / "solution", max_size, self.role, instance) @@ -729,7 +728,7 @@ async def build( cls, name: str, info: _TeamInfo, - problem: AnyProblem, + problem: Problem, config: ProgramConfigView, ui: BuildUi, ) -> "Team": @@ -825,7 +824,7 @@ class TeamHandler: async def build( cls, infos: Mapping[str, _TeamInfo], - problem: AnyProblem, + problem: Problem, config: ProgramConfigView, ui: BuildUi, ) -> Self: