Skip to content

Commit

Permalink
Merge pull request #128 from ImogenBits/float_types
Browse files Browse the repository at this point in the history
Simple, safe float comparisons
  • Loading branch information
Benezivas authored Sep 29, 2023
2 parents 95f0b40 + 61bbdc7 commit a559a52
Show file tree
Hide file tree
Showing 5 changed files with 249 additions and 11 deletions.
4 changes: 2 additions & 2 deletions algobattle/battle.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
ValidatorFunctionWrapHandler,
)
from pydantic_core import CoreSchema
from pydantic_core.core_schema import tagged_union_schema, general_wrap_validator_function
from pydantic_core.core_schema import tagged_union_schema, with_info_wrap_validator_function

from algobattle.program import (
Generator,
Expand Down Expand Up @@ -295,7 +295,7 @@ def check_installed(val: object, handler: ValidatorFunctionWrapHandler, info: Va
f"The specified battle type '{passed}' is not installed. Installed types are: {installed}"
)

return general_wrap_validator_function(check_installed, subclass_schema)
return with_info_wrap_validator_function(check_installed, subclass_schema)

class FallbackConfig(Config):
"""Fallback config object to parse into if the proper battle typ isn't installed and we're ignoring installs."""
Expand Down
81 changes: 80 additions & 1 deletion algobattle/types.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,18 @@
"""Utility types used to easily define Problems."""
from dataclasses import dataclass
from typing import Annotated, Any, Collection, Iterator, TypeVar, Generic, TypedDict, overload
from sys import float_info
from typing import (
Annotated,
Any,
ClassVar,
Collection,
Iterator,
Literal,
TypeVar,
Generic,
TypedDict,
overload,
)
import annotated_types as at
from annotated_types import (
BaseMetadata,
Expand Down Expand Up @@ -58,6 +70,8 @@
"EdgeWeights",
"VertexWeights",
"AlgobattleContext",
"LaxComp",
"lax_comp",
)


Expand Down Expand Up @@ -472,3 +486,68 @@ class VertexWeights(DirectedGraph, BaseModel, Generic[Weight]):
"""Mixin for graphs with weighted vertices."""

vertex_weights: Annotated[list[Weight], SizeLen]


@dataclass(frozen=True, slots=True)
class LaxComp:
"""Helper class to make forgiving float comparisons easy.
When comparing floats for equality there often are frustrating edge cases introduced by its imprecisions. This can
lead to matches not being decided by which team generates better instances, but by who can craft the most finnicky
floating point values. This class lets you easily sidestep these problems.
It implements comparison operations by adding a small epsilon that covers an allowable range of imprecision. The
solving team will receive twice the epsilon that the generating team was given. This means that the generator cannot
try to exploit imprecision issues since the solver has a bigger tolerance to play with.
!!! example "Usage"
```py
LaxComp(some_val ** 2, role) <= comparison_val
```
"""

value: float
"""The value that can be relaxed in the comparison."""
role: Role
"""Role of the program whose output is currently being validated."""

relative_epsilon: ClassVar[float] = 128 * float_info.epsilon
absolute_epsilon: ClassVar[float] = float_info.min

def __eq__(self, other: object, /) -> bool:
if isinstance(other, (float, int, bool)):
other = float(other)
diff = abs(self.value - other)
norm = min(abs(self.value) + abs(other), float_info.max)
factor = 1 if self.role == Role.generator else 2
return diff <= factor * max(self.absolute_epsilon, norm * self.relative_epsilon)
else:
return NotImplemented

def __le__(self, other: float, /) -> bool:
return self.value <= other or self == other

def __ge__(self, other: float, /) -> bool:
return self.value >= other or self == other


def lax_comp(value: float, cmp: Literal["<=", "==", ">="], other: float, role: Role) -> bool:
"""Helper function to explicitly use the `LaxComp` comparison mechanism.
Args:
value: First value to compare.
cmp: Comparison to perform, one of "<=", "==", or ">=".
other: Other value to compare.
role: Role of the program the values are being validated for.
Returns:
Result of the comparison.
"""
val = LaxComp(value, role)
match cmp:
case "<=":
return val <= other
case "==":
return val == other
case ">=":
return val >= other
7 changes: 3 additions & 4 deletions algobattle/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,12 @@
from pydantic import (
ConfigDict,
BaseModel as PydandticBaseModel,
Extra,
GetCoreSchemaHandler,
ValidationError as PydanticValidationError,
ValidationInfo,
)
from pydantic_core import CoreSchema
from pydantic_core.core_schema import general_after_validator_function
from pydantic_core.core_schema import with_info_after_validator_function


class Role(StrEnum):
Expand All @@ -47,7 +46,7 @@ class Role(StrEnum):
class BaseModel(PydandticBaseModel):
"""Base class for all pydantic models."""

model_config = ConfigDict(extra=Extra.forbid, from_attributes=True)
model_config = ConfigDict(extra="forbid", from_attributes=True)


class InstanceSolutionModel(BaseModel):
Expand Down Expand Up @@ -168,7 +167,7 @@ def wrapper(value: Any, info: ValidationInfo) -> Any:
return value
return func(value, attribute_val)

return general_after_validator_function(wrapper, schema=schema)
return with_info_after_validator_function(wrapper, schema=schema)

def needs_self(self, model_type: ModelType) -> bool:
"""Checks if the validator needs a reference to the current model in order to work fully."""
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ classifiers = [
]
dependencies = [
"docker~=6.1.3",
"pydantic~=2.3.0",
"pydantic~=2.4.0",
"anyio~=4.0.0",
"typer[all]~=0.9.0",
"typing-extensions~=4.8.0",
Expand Down
166 changes: 163 additions & 3 deletions tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
"""Tests for pydantic parsing types."""
from typing import Annotated
from typing import Annotated, Any
from unittest import TestCase, main
from unittest.util import safe_repr

from pydantic import ValidationError

from algobattle.problem import InstanceModel
from algobattle.util import AttributeReference, SelfRef
from algobattle.types import Ge, Interval, SizeIndex, UniqueItems
from algobattle.util import AttributeReference, Role, SelfRef
from algobattle.types import Ge, Interval, LaxComp, SizeIndex, UniqueItems


class ModelCreationTests(TestCase):
Expand Down Expand Up @@ -167,5 +168,164 @@ def test_schema(self):
self.assertIn("uniqueItems", schema["properties"]["array"])


class LaxCompTests(TestCase):
"""Tests for the LaxComp helper."""

def assertNotGreaterEqual(self, a: Any, b: Any, msg: str | None = None) -> None:
if msg is None:
msg = f"{safe_repr(a)} greater than or equal to {safe_repr(b)}"
self.assertFalse(a >= b, msg)

def assertNotLessEqual(self, a: Any, b: Any, msg: str | None = None) -> None:
if msg is None:
msg = f"{safe_repr(a)} less than or equal to {safe_repr(b)}"
self.assertFalse(a <= b, msg)

@classmethod
def setUpClass(cls) -> None:
LaxComp.absolute_epsilon = 1
LaxComp.relative_epsilon = 0.1

def test_equal_strict(self) -> None:
self.assertEqual(LaxComp(0, Role.generator), 0)
self.assertEqual(LaxComp(0, Role.solver), 0)

def test_equal_small(self) -> None:
self.assertEqual(LaxComp(0, Role.generator), 0.5)
self.assertEqual(LaxComp(0, Role.solver), 0.5)

def test_equal_medium(self) -> None:
self.assertNotEqual(LaxComp(0, Role.generator), 1.5)
self.assertEqual(LaxComp(0, Role.solver), 1.5)

def test_equal_big(self) -> None:
self.assertNotEqual(LaxComp(0, Role.generator), 2.5)
self.assertNotEqual(LaxComp(0, Role.solver), 2.5)

def test_equal_rel_strict(self) -> None:
self.assertEqual(LaxComp(100, Role.generator), 100)
self.assertEqual(LaxComp(100, Role.solver), 100)

def test_equal_rel_small(self) -> None:
self.assertEqual(LaxComp(100, Role.generator), 110)
self.assertEqual(LaxComp(100, Role.solver), 110)

def test_equal_rel_medium(self) -> None:
self.assertNotEqual(LaxComp(100, Role.generator), 130)
self.assertEqual(LaxComp(100, Role.solver), 130)

def test_equal_rel_big(self) -> None:
self.assertNotEqual(LaxComp(100, Role.generator), 160)
self.assertNotEqual(LaxComp(100, Role.solver), 160)

def test_greater_equal_greater(self) -> None:
self.assertGreaterEqual(LaxComp(1, Role.generator), 0)
self.assertGreaterEqual(LaxComp(1, Role.solver), 0)
self.assertGreaterEqual(1, LaxComp(0, Role.generator))
self.assertGreaterEqual(1, LaxComp(0, Role.solver))

def test_greater_equal_strict(self) -> None:
self.assertGreaterEqual(LaxComp(0, Role.generator), 0)
self.assertGreaterEqual(LaxComp(0, Role.solver), 0)
self.assertGreaterEqual(0, LaxComp(0, Role.generator))
self.assertGreaterEqual(0, LaxComp(0, Role.solver))

def test_greater_equal_small(self) -> None:
self.assertGreaterEqual(LaxComp(0, Role.generator), 0.5)
self.assertGreaterEqual(LaxComp(0, Role.solver), 0.5)
self.assertGreaterEqual(0, LaxComp(0.5, Role.generator))
self.assertGreaterEqual(0, LaxComp(0.5, Role.solver))

def test_greater_equal_medium(self) -> None:
self.assertNotGreaterEqual(LaxComp(0, Role.generator), 1.5)
self.assertGreaterEqual(LaxComp(0, Role.solver), 1.5)
self.assertNotGreaterEqual(0, LaxComp(1.5, Role.generator))
self.assertGreaterEqual(0, LaxComp(1.5, Role.solver))

def test_greater_equal_big(self) -> None:
self.assertNotGreaterEqual(LaxComp(0, Role.generator), 2.5)
self.assertNotGreaterEqual(LaxComp(0, Role.solver), 2.5)
self.assertNotGreaterEqual(0, LaxComp(2.5, Role.generator))
self.assertNotGreaterEqual(0, LaxComp(2.5, Role.solver))

def test_greater_equal_rel_strict(self) -> None:
self.assertGreaterEqual(LaxComp(100, Role.generator), 100)
self.assertGreaterEqual(LaxComp(100, Role.solver), 100)
self.assertGreaterEqual(100, LaxComp(100, Role.generator))
self.assertGreaterEqual(100, LaxComp(100, Role.solver))

def test_greater_equal_rel_small(self) -> None:
self.assertGreaterEqual(LaxComp(100, Role.generator), 110)
self.assertGreaterEqual(LaxComp(100, Role.solver), 110)
self.assertGreaterEqual(100, LaxComp(110, Role.generator))
self.assertGreaterEqual(100, LaxComp(110, Role.solver))

def test_greater_equal_rel_medium(self) -> None:
self.assertNotGreaterEqual(LaxComp(100, Role.generator), 130)
self.assertGreaterEqual(LaxComp(100, Role.solver), 130)
self.assertNotGreaterEqual(100, LaxComp(130, Role.generator))
self.assertGreaterEqual(100, LaxComp(130, Role.solver))

def test_greater_equal_rel_big(self) -> None:
self.assertNotGreaterEqual(LaxComp(100, Role.generator), 160)
self.assertNotGreaterEqual(LaxComp(100, Role.solver), 160)
self.assertNotGreaterEqual(100, LaxComp(160, Role.generator))
self.assertNotGreaterEqual(100, LaxComp(160, Role.solver))

def test_less_equal_less(self) -> None:
self.assertLessEqual(LaxComp(0, Role.generator), 1)
self.assertLessEqual(LaxComp(0, Role.solver), 1)
self.assertLessEqual(0, LaxComp(1, Role.generator))
self.assertLessEqual(0, LaxComp(1, Role.solver))

def test_less_equal_strict(self) -> None:
self.assertLessEqual(LaxComp(0, Role.generator), 0)
self.assertLessEqual(LaxComp(0, Role.solver), 0)
self.assertLessEqual(0, LaxComp(0, Role.generator))
self.assertLessEqual(0, LaxComp(0, Role.solver))

def test_less_equal_small(self) -> None:
self.assertLessEqual(LaxComp(0.5, Role.generator), 0)
self.assertLessEqual(LaxComp(0.5, Role.solver), 0)
self.assertLessEqual(0.5, LaxComp(0, Role.generator))
self.assertLessEqual(0.5, LaxComp(0, Role.solver))

def test_less_equal_medium(self) -> None:
self.assertNotLessEqual(LaxComp(1.5, Role.generator), 0)
self.assertLessEqual(LaxComp(1.5, Role.solver), 0)
self.assertNotLessEqual(1.5, LaxComp(0, Role.generator))
self.assertLessEqual(1.5, LaxComp(0, Role.solver))

def test_less_equal_big(self) -> None:
self.assertNotLessEqual(LaxComp(2.5, Role.generator), 0)
self.assertNotLessEqual(LaxComp(2.5, Role.solver), 0)
self.assertNotLessEqual(2.5, LaxComp(0, Role.generator))
self.assertNotLessEqual(2.5, LaxComp(0, Role.solver))

def test_less_equal_rel_strict(self) -> None:
self.assertLessEqual(LaxComp(100, Role.generator), 100)
self.assertLessEqual(LaxComp(100, Role.solver), 100)
self.assertLessEqual(100, LaxComp(100, Role.generator))
self.assertLessEqual(100, LaxComp(100, Role.solver))

def test_less_equal_rel_small(self) -> None:
self.assertLessEqual(LaxComp(110, Role.generator), 100)
self.assertLessEqual(LaxComp(110, Role.solver), 100)
self.assertLessEqual(110, LaxComp(100, Role.generator))
self.assertLessEqual(110, LaxComp(100, Role.solver))

def test_less_equal_rel_medium(self) -> None:
self.assertNotLessEqual(LaxComp(130, Role.generator), 100)
self.assertLessEqual(LaxComp(130, Role.solver), 100)
self.assertNotLessEqual(130, LaxComp(100, Role.generator))
self.assertLessEqual(130, LaxComp(100, Role.solver))

def test_less_equal_rel_big(self) -> None:
self.assertNotLessEqual(LaxComp(160, Role.generator), 100)
self.assertNotLessEqual(LaxComp(160, Role.solver), 100)
self.assertNotLessEqual(160, LaxComp(100, Role.generator))
self.assertNotLessEqual(160, LaxComp(100, Role.solver))


if __name__ == "__main__":
main()

0 comments on commit a559a52

Please sign in to comment.