From 124ef9569714953ced81dd2e2d2a2512d8a842d2 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Sat, 2 Dec 2023 15:27:12 -0500 Subject: [PATCH 1/2] equal changes --- docs/changelog.md | 5 ++++ docs/reference/egglog-translation.md | 6 ++-- docs/reference/python-integration.md | 3 +- python/egglog/declarations.py | 3 -- python/egglog/egraph.py | 42 ++++++++++++++++++---------- python/egglog/examples/bool.py | 2 +- python/egglog/examples/lambda_.py | 8 +++--- python/egglog/examples/resolution.py | 2 +- python/tests/test_high_level.py | 26 ++++++++++++++++- 9 files changed, 68 insertions(+), 29 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index 48c8467a..de70319c 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,11 @@ _This project uses semantic versioning_ ## UNRELEASED +## 5.0.0 (UNRELEASED) + +- Move egglog `!=` function to be called with `ne(x).to(y)` instead of `x != y` so that user defined expressions + can + ## 4.0.1 (2023-11-27) - Fix keyword args for `__init__` methods (#96)[https://github.com/metadsl/egglog-python/pull/96]. diff --git a/docs/reference/egglog-translation.md b/docs/reference/egglog-translation.md index 927fb4c9..81e2be12 100644 --- a/docs/reference/egglog-translation.md +++ b/docs/reference/egglog-translation.md @@ -52,14 +52,14 @@ Rational(1, 2) / Rational(2, 1) ### `!=` Operator -The `!=` function in egglog works on any two types with the same sort. In Python, this is supported by overloading the `__ne__` operator, which is done by default in all builtin and custom sorts: +The `!=` function in egglog works on any two types with the same sort. In Python, this is mapped to the `ne` function: ```{code-cell} python # egg: (!= 10 2) -i64(10) != i64(2) +ne(i64(10)).to(i64(2)) ``` -This is checked statically, based on the `__ne__` definition in `Expr`, so that only sorts that have the same sort can be compared. +This is a two part function so that we can statically check both sides are the same type. ## Declaring Sorts diff --git a/docs/reference/python-integration.md b/docs/reference/python-integration.md index 7ec02440..66e90b43 100644 --- a/docs/reference/python-integration.md +++ b/docs/reference/python-integration.md @@ -208,6 +208,7 @@ Most of the Python special dunder (= "double under") methods are supported as we - `__le__` - `__eq__` - `__ne__` +- `__ne__` - `__gt__` - `__ge__` - `__add__` @@ -231,7 +232,7 @@ Most of the Python special dunder (= "double under") methods are supported as we - `__setitem__` - `__delitem__` -Currently `__divmod__` is not supported, since it returns multiple results and `__ne__` will shadow the builtin `!=` egglog operator. +Currently `__divmod__` is not supported, since it returns multiple results. Also these methods are currently used in the runtime class and cannot be overridden currently, although we could change this if the need arises: diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index 13a26fe0..ea96bae4 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -770,9 +770,6 @@ def pretty(self, context: PrettyContext, parens: bool = True, **kwargs) -> str: if self in context.names: return context.names[self] ref, args = self.callable, [a.expr for a in self.args] - # Special case != since it doesn't have a decl - if isinstance(ref, MethodRef) and ref.method_name == "__ne__": - return f"{args[0].pretty(context)} != {args[1].pretty(context)}" function_decl = context.mod_decls.get_function_decl(ref) # Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default n_defaults = 0 diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index d7b5c214..9bc28cdb 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -54,6 +54,8 @@ "rewrite", "birewrite", "eq", + "ne", + "_ne", "panic", "let", "delete", @@ -258,8 +260,6 @@ def _class( # noqa: PLR0912 unextractable=unextractable, ) - # Register != as a method so we can print it as a string - self._mod_decls._decl.register_callable_ref(MethodRef(cls_name, "__ne__"), "!=") return RuntimeClass(self._mod_decls, cls_name) # We seperate the function and method overloads to make it simpler to know if we are modifying a function or method, @@ -1152,22 +1152,10 @@ class Expr(metaclass=_ExprMetaclass): Expression base class, which adds suport for != to all expression types. """ - def __ne__(self: EXPR, other_expr: EXPR) -> Unit: # type: ignore[override, empty-body] - """ - Compare whether to expressions are not equal. - - :param self: The expression to compare. - :param other_expr: The other expression to compare to, which must be of the same type. - :meta public: - """ + def __ne__(self, other: NoReturn) -> NoReturn: # type: ignore[override, empty-body] ... def __eq__(self, other: NoReturn) -> NoReturn: # type: ignore[override, empty-body] - """ - Equality is currently not supported. - - We only add this method so that if you try to use it MyPy will warn you. - """ ... @@ -1184,6 +1172,14 @@ def __init__(self) -> None: ... +@BUILTINS.function(egg_fn="!=") +def _ne(l: Expr, r: Expr) -> Unit: # type: ignore[empty-body] + """ + Translates to the != egg function. + """ + ... + + @dataclass(frozen=True) class Ruleset: name: str @@ -1492,6 +1488,11 @@ def eq(expr: EXPR) -> _EqBuilder[EXPR]: return _EqBuilder(expr) +def ne(expr: EXPR) -> _NeBuilder[EXPR]: + """Check if the given expression is not equal to the given value.""" + return _NeBuilder(expr) + + def panic(message: str) -> Action: """Raise an error with the given message.""" return Panic(message) @@ -1596,6 +1597,17 @@ def __str__(self) -> str: return f"eq({self.expr})" +@dataclass +class _NeBuilder(Generic[EXPR]): + expr: EXPR + + def to(self, expr: EXPR) -> Unit: + return _ne(self.expr, expr) + + def __str__(self) -> str: + return f"ne({self.expr})" + + @dataclass class _SetBuilder(Generic[EXPR]): lhs: Expr diff --git a/python/egglog/examples/bool.py b/python/egglog/examples/bool.py index e628d393..c8bb6a40 100644 --- a/python/egglog/examples/bool.py +++ b/python/egglog/examples/bool.py @@ -13,7 +13,7 @@ egraph.check(eq(T & T).to(T)) egraph.check(eq(T & F).to(F)) egraph.check(eq(T | F).to(T)) -egraph.check((T | F) != F) +egraph.check(ne(T | F).to(F)) egraph.check(eq(i64(1).bool_lt(2)).to(T)) egraph.check(eq(i64(2).bool_lt(1)).to(F)) diff --git a/python/egglog/examples/lambda_.py b/python/egglog/examples/lambda_.py index 8d610768..7d2916bd 100644 --- a/python/egglog/examples/lambda_.py +++ b/python/egglog/examples/lambda_.py @@ -117,7 +117,7 @@ def freer(t: Term) -> StringSet: union(t.eval()).with_(Val(i1 + i2)) ), rule(eq(t).to(t1 == t2), eq(t1.eval()).to(t2.eval())).then(union(t.eval()).with_(Val.TRUE)), - rule(eq(t).to(t1 == t2), eq(t1.eval()).to(v1), eq(t2.eval()).to(v2), v1 != v2).then( + rule(eq(t).to(t1 == t2), eq(t1.eval()).to(v1), eq(t2.eval()).to(v2), ne(v1).to(v2)).then( union(t.eval()).with_(Val.FALSE) ), rule(eq(v).to(t.eval())).then(union(t).with_(Term.val(v))), @@ -154,12 +154,12 @@ def freer(t: Term) -> StringSet: # let-var-same rewrite(let_(x, t, Term.var(x))).to(t), # let-var-diff - rewrite(let_(x, t, Term.var(y))).to(Term.var(y), x != y), + rewrite(let_(x, t, Term.var(y))).to(Term.var(y), ne(x).to(y)), # let-lam-same rewrite(let_(x, t, lam(x, t1))).to(lam(x, t1)), # let-lam-diff - rewrite(let_(x, t, lam(y, t1))).to(lam(y, let_(x, t, t1)), x != y, eq(fv).to(freer(t)), fv.not_contains(y)), - rule(eq(t).to(let_(x, t1, lam(y, t2))), x != y, eq(fv).to(freer(t1)), fv.contains(y)).then( + rewrite(let_(x, t, lam(y, t1))).to(lam(y, let_(x, t, t1)), ne(x).to(y), eq(fv).to(freer(t)), fv.not_contains(y)), + rule(eq(t).to(let_(x, t1, lam(y, t2))), ne(x).to(y), eq(fv).to(freer(t1)), fv.contains(y)).then( union(t).with_(lam(t.v(), let_(x, t1, let_(y, Term.var(t.v()), t2)))) ), ) diff --git a/python/egglog/examples/resolution.py b/python/egglog/examples/resolution.py index baf91e36..ba6a1031 100644 --- a/python/egglog/examples/resolution.py +++ b/python/egglog/examples/resolution.py @@ -78,7 +78,7 @@ def pred(x: i64Like) -> Boolean: # type: ignore[empty-body] union(~p0 | (~p1 | (p2 | F))).with_(T), ) egraph.run(10) -egraph.check(T != F) +egraph.check(ne(T).to(F)) egraph.check(eq(p0).to(F)) egraph.check(eq(p2).to(F)) egraph diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index eccaf32e..081aaa3c 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -359,9 +359,33 @@ def test_f64_negation() -> None: def test_not_equals(): egraph = EGraph() - egraph.check(i64(10) != i64(2)) + egraph.check(ne(i64(10)).to(i64(2))) +def test_custom_equality(): + egraph = EGraph() + + @egraph.class_ + class Boolean(Expr): + def __init__(self, value: BoolLike) -> None: + ... + + def __eq__(self, other: Boolean) -> Boolean: # type: ignore[override] + ... + + def __ne__(self, other: Boolean) -> Boolean: # type: ignore[override] + ... + + egraph.register(rewrite(Boolean(True) == Boolean(True)).to(Boolean(False))) + egraph.register(rewrite(Boolean(True) != Boolean(True)).to(Boolean(True))) + + should_be_true = Boolean(True) == Boolean(True) + should_be_false = Boolean(True) != Boolean(True) + egraph.register(should_be_true, should_be_false) + egraph.run(10) + egraph.check(eq(should_be_true).to(Boolean(False))) + egraph.check(eq(should_be_false).to(Boolean(True))) + class TestMutate: def test_setitem_defaults(self): egraph = EGraph() From b3ba8e0eea1c14a6e9a52cff7d6ea60e1f8defe9 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 27 Dec 2023 11:41:38 -0500 Subject: [PATCH 2/2] Fix equality changes --- python/egglog/declarations.py | 2 +- python/egglog/egraph.py | 26 +++++++++++++--------- python/egglog/exp/array_api.py | 6 ++--- python/egglog/exp/array_api_program_gen.py | 3 ++- python/egglog/exp/program_gen.py | 18 +++++++-------- python/egglog/runtime.py | 16 ++++++------- python/tests/test_array_api.py | 1 - test-data/unit/check-high-level.test | 4 ++-- 8 files changed, 40 insertions(+), 36 deletions(-) diff --git a/python/egglog/declarations.py b/python/egglog/declarations.py index ea96bae4..fc950b9c 100644 --- a/python/egglog/declarations.py +++ b/python/egglog/declarations.py @@ -269,7 +269,7 @@ def get_egg_fn(self, ref: CallableRef) -> str: return decls.get_egg_fn(ref) except KeyError: pass - raise KeyError(f"Callable ref {ref} not found") + raise KeyError(f"Callable ref {ref!r} not found") def get_egg_sort(self, ref: JustTypeRef) -> str: for decls in self.all_decls: diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 9bc28cdb..66b45bf6 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -55,7 +55,6 @@ "birewrite", "eq", "ne", - "_ne", "panic", "let", "delete", @@ -663,6 +662,8 @@ def __post_init__(self, modules: list[Module]) -> None: msg = "Builtins already initialized" raise RuntimeError(msg) _BUILTIN_DECLS = self._mod_decls._decl + # Register != operator + _BUILTIN_DECLS.register_callable_ref(FunctionRef("!="), "!=") def _process_commands(self, cmds: Iterable[bindings._Command]) -> None: """ @@ -1172,14 +1173,6 @@ def __init__(self) -> None: ... -@BUILTINS.function(egg_fn="!=") -def _ne(l: Expr, r: Expr) -> Unit: # type: ignore[empty-body] - """ - Translates to the != egg function. - """ - ... - - @dataclass(frozen=True) class Ruleset: name: str @@ -1602,7 +1595,20 @@ class _NeBuilder(Generic[EXPR]): expr: EXPR def to(self, expr: EXPR) -> Unit: - return _ne(self.expr, expr) + l_expr = cast(RuntimeExpr, self.expr) + return cast( + Unit, + RuntimeExpr( + BUILTINS._mod_decls, + TypedExprDecl( + JustTypeRef("Unit"), + CallDecl( + FunctionRef("!="), + (l_expr.__egg_typed_expr__, convert_to_same_type(expr, l_expr).__egg_typed_expr__), + ), + ), + ), + ) def __str__(self) -> str: return f"ne({self.expr})" diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index 7c8351b0..a8e933a1 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -278,7 +278,7 @@ def __bool__(self) -> bool: @array_api_module.register def _int(i: i64, j: i64, r: Boolean, o: Int): yield rewrite(Int(i) == Int(i)).to(TRUE) - yield rule(eq(r).to(Int(i) == Int(j)), i != j).then(union(r).with_(FALSE)) + yield rule(eq(r).to(Int(i) == Int(j)), ne(i).to(j)).then(union(r).with_(FALSE)) yield rewrite(Int(i) >= Int(i)).to(TRUE) yield rule(eq(r).to(Int(i) >= Int(j)), i > j).then(union(r).with_(TRUE)) @@ -666,7 +666,7 @@ def _tuple_value( # Includes rewrite(TupleValue.EMPTY.includes(v)).to(FALSE), rewrite(TupleValue(v).includes(v)).to(TRUE), - rewrite(TupleValue(v).includes(v2)).to(FALSE, v != v2), + rewrite(TupleValue(v).includes(v2)).to(FALSE, ne(v).to(v2)), rewrite((ti + ti2).includes(v)).to(ti.includes(v) | ti2.includes(v)), ] @@ -1503,7 +1503,7 @@ def _assume_value_one_of(x: NDArray, v: Value, vs: TupleValue, idx: TupleInt): def _ndarray_value_isfinite(arr: NDArray, x: Value, xs: TupleValue, i: Int, f: f64, b: Boolean): yield rewrite(Value.int(i).isfinite()).to(TRUE) yield rewrite(Value.bool(b).isfinite()).to(TRUE) - yield rewrite(Value.float(Float(f)).isfinite()).to(TRUE, f != f64(math.nan)) + yield rewrite(Value.float(Float(f)).isfinite()).to(TRUE, ne(f).to(f64(math.nan))) # a sum of an array is finite if all the values are finite yield rewrite(isfinite(sum(arr))).to(NDArray.scalar(Value.bool(arr.index(ALL_INDICES).isfinite()))) diff --git a/python/egglog/exp/array_api_program_gen.py b/python/egglog/exp/array_api_program_gen.py index 58f661aa..56467c5d 100644 --- a/python/egglog/exp/array_api_program_gen.py +++ b/python/egglog/exp/array_api_program_gen.py @@ -126,7 +126,8 @@ def _float_program(f: Float, g: Float, f64_: f64, i: Int, r: Rational): yield rewrite(float_program(f * g)).to(Program("(") + float_program(f) + " * " + float_program(g) + ")") yield rewrite(float_program(f / g)).to(Program("(") + float_program(f) + " / " + float_program(g) + ")") yield rewrite(float_program(Float.rational(r))).to( - Program("float(") + Program(r.numer.to_string()) + " / " + Program(r.denom.to_string()) + ")", r.denom != i64(1) + Program("float(") + Program(r.numer.to_string()) + " / " + Program(r.denom.to_string()) + ")", + ne(r.denom).to(i64(1)), ) yield rewrite(float_program(Float.rational(r))).to( Program("float(") + Program(r.numer.to_string()) + ")", eq(r.denom).to(i64(1)) diff --git a/python/egglog/exp/program_gen.py b/python/egglog/exp/program_gen.py index 0a8f88a9..19f5a9a7 100644 --- a/python/egglog/exp/program_gen.py +++ b/python/egglog/exp/program_gen.py @@ -182,7 +182,7 @@ def _compile( yield rule( stmt, p.compile(i), - p1.parent != p, + ne(p1.parent).to(p), eq(s1).to(p1.expr), ).then( set_(p.statements).to(join(s1, "\n")), @@ -214,7 +214,7 @@ def _compile( # Otherwise, if its not equal to either input, its not an identifier yield rule(program_add, eq(p.expr).to(p1.expr), eq(b).to(p1.is_identifer)).then(set_(p.is_identifer).to(b)) yield rule(program_add, eq(p.expr).to(p2.expr), eq(b).to(p2.is_identifer)).then(set_(p.is_identifer).to(b)) - yield rule(program_add, p.expr != p1.expr, p.expr != p2.expr).then(set_(p.is_identifer).to(Bool(False))) + yield rule(program_add, ne(p.expr).to(p1.expr), ne(p.expr).to(p2.expr)).then(set_(p.is_identifer).to(Bool(False))) # Set parent of p1 yield rule(program_add, p.compile(i)).then( @@ -228,7 +228,7 @@ def _compile( yield rule(program_add, p.compile(i), p1.next_sym).then(set_(p2.parent).to(p)) # Compile p2, if p1 parent not equal, but p2 parent equal - yield rule(program_add, p.compile(i), p1.parent != p, eq(p2.parent).to(p)).then(p2.compile(i)) + yield rule(program_add, p.compile(i), ne(p1.parent).to(p), eq(p2.parent).to(p)).then(p2.compile(i)) # Compile p2, if p1 parent eqal yield rule(program_add, p.compile(i2), eq(p1.parent).to(p), eq(i).to(p1.next_sym), eq(p2.parent).to(p)).then( @@ -259,8 +259,8 @@ def _compile( yield rule( program_add, p.compile(i), - p1.parent != p, - p2.parent != p, + ne(p1.parent).to(p), + ne(p2.parent).to(p), ).then( set_(p.statements).to(String("")), set_(p.next_sym).to(i), @@ -269,7 +269,7 @@ def _compile( yield rule( program_add, eq(p1.parent).to(p), - p2.parent != p, + ne(p2.parent).to(p), eq(s1).to(p1.statements), eq(i).to(p1.next_sym), ).then( @@ -280,7 +280,7 @@ def _compile( yield rule( program_add, eq(p2.parent).to(p), - p1.parent != p, + ne(p1.parent).to(p), eq(s2).to(p2.statements), eq(i).to(p2.next_sym), ).then( @@ -319,7 +319,7 @@ def _compile( # 1. b. If p1 parent is not p, then just use assign as statement, next sym of i yield rule( program_assign, - p1.parent != p, + ne(p1.parent).to(p), p.compile(i), eq(s2).to(p1.expr), eq(p1.is_identifer).to(Bool(False)), @@ -347,7 +347,7 @@ def _compile( # 1. b. If p1 parent is not p, then just use assign as statement, next sym of i yield rule( program_assign, - p1.parent != p, + ne(p1.parent).to(p), p.compile(i), eq(s2).to(p1.expr), eq(p1.is_identifer).to(Bool(True)), diff --git a/python/egglog/runtime.py b/python/egglog/runtime.py index 355e2c38..c0b0bd51 100644 --- a/python/egglog/runtime.py +++ b/python/egglog/runtime.py @@ -435,15 +435,13 @@ def __post_init__(self) -> None: self.__egg_callable_ref__ = PropertyRef(self.class_name, self.__egg_method_name__) else: self.__egg_callable_ref__ = MethodRef(self.class_name, self.__egg_method_name__) - # Special case for __ne__ which does not have a normal function defintion since - # it relies of type parameters - if self.__egg_method_name__ == "__ne__": - self.__egg_fn_decl__ = None - else: - try: - self.__egg_fn_decl__ = self.__egg_self__.__egg_decls__.get_function_decl(self.__egg_callable_ref__) - except KeyError as e: - raise AttributeError(f"Class {self.class_name} does not have method {self.__egg_method_name__}") from e + try: + self.__egg_fn_decl__ = self.__egg_self__.__egg_decls__.get_function_decl(self.__egg_callable_ref__) + except KeyError as e: + msg = f"Class {self.class_name} does not have method {self.__egg_method_name__}" + if self.__egg_method_name__ == "__ne__": + msg += ". Did you mean to use the ne(...).to(...)?" + raise AttributeError(msg) from e def __call__(self, *args: object, **kwargs) -> RuntimeExpr | None: args = (self.__egg_self__, *args) diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index c87630ce..2af2d762 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -1,4 +1,3 @@ -# import pytest from egglog.exp.array_api import * from egglog.exp.array_api_numba import array_api_numba_module from egglog.exp.array_api_program_gen import * diff --git a/test-data/unit/check-high-level.test b/test-data/unit/check-high-level.test index 8f051263..c52e57ad 100644 --- a/test-data/unit/check-high-level.test +++ b/test-data/unit/check-high-level.test @@ -2,9 +2,9 @@ from egglog import * _ = i64(0) == i64(0) # E: Unsupported operand types for == ("i64" and "i64") -[case notEqAllowed] +[case notEqNotAllowed] from egglog import * -_ = i64(0) != i64(0) # type: Unit +_ = i64(0) != i64(0) # E: Unsupported operand types for != ("i64" and "i64") [case eqToAllowed] from egglog import *