Skip to content

Commit

Permalink
Merge pull request #122 from ImogenBits/util_scripts
Browse files Browse the repository at this point in the history
Modernize command line interaction
  • Loading branch information
Benezivas authored Sep 20, 2023
2 parents cf6ebb4 + 7a3a901 commit baf89ba
Show file tree
Hide file tree
Showing 79 changed files with 3,357 additions and 2,052 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,5 @@ dist
site
docs/src/pairsum_solver/target
docs/src/pairsum_solver/Cargo.lock
.results
.project
61 changes: 0 additions & 61 deletions algobattle.ps1

This file was deleted.

75 changes: 64 additions & 11 deletions algobattle/battle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from abc import abstractmethod
from inspect import isclass
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Expand All @@ -22,9 +23,16 @@
TypeVar,
)

from pydantic import Field, GetCoreSchemaHandler
from pydantic import (
ConfigDict,
Field,
GetCoreSchemaHandler,
ValidationError,
ValidationInfo,
ValidatorFunctionWrapHandler,
)
from pydantic_core import CoreSchema
from pydantic_core.core_schema import tagged_union_schema
from pydantic_core.core_schema import tagged_union_schema, general_wrap_validator_function

from algobattle.program import (
Generator,
Expand Down Expand Up @@ -242,20 +250,63 @@ def __get_pydantic_core_schema__(cls, source: Type, handler: GetCoreSchemaHandle
return handler(source)
except NameError:
return handler(source)

match len(Battle._battle_types):
case 0:
return handler(source)
subclass_schema = handler(source)
case 1:
return handler(next(iter(Battle._battle_types.values())))
subclass_schema = handler(next(iter(Battle._battle_types.values())))
case _:
return tagged_union_schema(
subclass_schema = tagged_union_schema(
choices={
battle.Config.model_fields["type"].default: battle.Config.__pydantic_core_schema__
for battle in Battle._battle_types.values()
},
discriminator="type",
)

# we want to validate into the actual battle type's config, so we need to treat them as a tagged union
# but if we're initializing a project the type might not be installed yet, so we want to also parse
# into an unspecified dummy object. This wrap validator will efficiently and transparently act as a tagged
# union when ignore_uninstalled is not set. If it is set it catches only the error of a missing tag, other
# errors are passed through
def check_installed(val: object, handler: ValidatorFunctionWrapHandler, info: ValidationInfo) -> object:
try:
return handler(val)
except ValidationError as e:
union_err = next(filter(lambda err: err["type"] == "union_tag_invalid", e.errors()), None)
if union_err is None:
raise
if info.context is not None and info.context.get("ignore_uninstalled", False):
if info.config is not None:
settings: dict[str, Any] = {
"strict": info.config.get("strict", None),
"from_attributes": info.config.get("from_attributes"),
}
else:
settings = {}
return Battle.FallbackConfig.model_validate(val, context=info.context, **settings)
else:
passed = union_err["input"]["type"]
installed = ", ".join(b.name() for b in Battle._battle_types.values())
raise ValueError(
f"The specified battle type '{passed}' is not installed. Installed types are: {installed}"
)

return general_wrap_validator_function(check_installed, subclass_schema)

class FallbackConfig(Config):
"""Fallback config object to parse into if the proper battle typ isn't installed and we're ignoring installs."""

type: str

model_config = ConfigDict(extra="allow")

if TYPE_CHECKING:
# to hint that we're gonna fill this with arbitrary data belonging to some supposed battle type
def __getattr__(self, __attr: str) -> Any:
...

class UiData(BaseModel):
"""Object containing custom diplay data.
Expand All @@ -280,11 +331,12 @@ def load_entrypoints(cls) -> None:
if not (isclass(battle) and issubclass(battle, Battle)):
raise ValueError(f"Entrypoint {entrypoint.name} targets something other than a Battle type")

def __init_subclass__(cls) -> None:
@classmethod
def __pydantic_init_subclass__(cls, **kwargs: Any) -> None:
if cls.name() not in Battle._battle_types:
Battle._battle_types[cls.name()] = cls
Battle.Config.model_rebuild(force=True)
return super().__init_subclass__()
return super().__pydantic_init_subclass__(**kwargs)

@abstractmethod
def score(self) -> float:
Expand Down Expand Up @@ -367,10 +419,11 @@ async def run_battle(self, fight: FightHandler, config: Config, min_size: int, u
base_increment = 0
alive = True
reached = 0
self.results.append(0)
cap = config.maximum_size
current = min_size
while alive:
ui.update_battle_data(self.UiData(reached=self.results + [reached], cap=cap))
ui.update_battle_data(self.UiData(reached=self.results, cap=cap))
result = await fight.run(current)
score = result.score
if score < config.minimum_score:
Expand All @@ -384,7 +437,7 @@ async def run_battle(self, fight: FightHandler, config: Config, min_size: int, u
alive = True
elif current > reached and alive:
# We solved an instance of bigger size than before
reached = current
self.results[-1] = reached = current

if current + 1 > cap:
alive = False
Expand All @@ -396,7 +449,7 @@ async def run_battle(self, fight: FightHandler, config: Config, min_size: int, u
# We have failed at this value of n already, reset the step size!
current -= base_increment**config.exponent - 1
base_increment = 1
self.results.append(reached)
self.results[-1] = reached

def score(self) -> float:
"""Averages the highest instance size reached in each round."""
Expand All @@ -416,7 +469,7 @@ class Config(Battle.Config):

type: Literal["Averaged"] = "Averaged"

instance_size: int = 10
instance_size: int = 25
"""Instance size that will be fought at."""
num_fights: int = 10
"""Number of iterations in each round."""
Expand Down
Loading

0 comments on commit baf89ba

Please sign in to comment.