Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move != to function instead of method #106

Merged
merged 2 commits into from
Jan 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@ _This project uses semantic versioning_

## UNRELEASED

## 5.0.0 (UNRELEASED)

- Move egglog `!=` function to be called with `ne(x).to(y)` instead of `x != y` so that user defined expressions
can

## 4.0.1 (2023-11-27)

- Fix keyword args for `__init__` methods (#96)[https://github.com/metadsl/egglog-python/pull/96].
Expand Down
6 changes: 3 additions & 3 deletions docs/reference/egglog-translation.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,14 @@ Rational(1, 2) / Rational(2, 1)

### `!=` Operator

The `!=` function in egglog works on any two types with the same sort. In Python, this is supported by overloading the `__ne__` operator, which is done by default in all builtin and custom sorts:
The `!=` function in egglog works on any two types with the same sort. In Python, this is mapped to the `ne` function:

```{code-cell} python
# egg: (!= 10 2)
i64(10) != i64(2)
ne(i64(10)).to(i64(2))
```

This is checked statically, based on the `__ne__` definition in `Expr`, so that only sorts that have the same sort can be compared.
This is a two part function so that we can statically check both sides are the same type.

## Declaring Sorts

Expand Down
3 changes: 2 additions & 1 deletion docs/reference/python-integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,7 @@ Most of the Python special dunder (= "double under") methods are supported as we
- `__le__`
- `__eq__`
- `__ne__`
- `__ne__`
- `__gt__`
- `__ge__`
- `__add__`
Expand All @@ -231,7 +232,7 @@ Most of the Python special dunder (= "double under") methods are supported as we
- `__setitem__`
- `__delitem__`

Currently `__divmod__` is not supported, since it returns multiple results and `__ne__` will shadow the builtin `!=` egglog operator.
Currently `__divmod__` is not supported, since it returns multiple results.

Also these methods are currently used in the runtime class and cannot be overridden currently, although we could change this
if the need arises:
Expand Down
5 changes: 1 addition & 4 deletions python/egglog/declarations.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ def get_egg_fn(self, ref: CallableRef) -> str:
return decls.get_egg_fn(ref)
except KeyError:
pass
raise KeyError(f"Callable ref {ref} not found")
raise KeyError(f"Callable ref {ref!r} not found")

def get_egg_sort(self, ref: JustTypeRef) -> str:
for decls in self.all_decls:
Expand Down Expand Up @@ -770,9 +770,6 @@ def pretty(self, context: PrettyContext, parens: bool = True, **kwargs) -> str:
if self in context.names:
return context.names[self]
ref, args = self.callable, [a.expr for a in self.args]
# Special case != since it doesn't have a decl
if isinstance(ref, MethodRef) and ref.method_name == "__ne__":
return f"{args[0].pretty(context)} != {args[1].pretty(context)}"
function_decl = context.mod_decls.get_function_decl(ref)
# Determine how many of the last arguments are defaults, by iterating from the end and comparing the arg with the default
n_defaults = 0
Expand Down
48 changes: 33 additions & 15 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
"rewrite",
"birewrite",
"eq",
"ne",
"panic",
"let",
"delete",
Expand Down Expand Up @@ -258,8 +259,6 @@ def _class( # noqa: PLR0912
unextractable=unextractable,
)

# Register != as a method so we can print it as a string
self._mod_decls._decl.register_callable_ref(MethodRef(cls_name, "__ne__"), "!=")
return RuntimeClass(self._mod_decls, cls_name)

# We seperate the function and method overloads to make it simpler to know if we are modifying a function or method,
Expand Down Expand Up @@ -663,6 +662,8 @@ def __post_init__(self, modules: list[Module]) -> None:
msg = "Builtins already initialized"
raise RuntimeError(msg)
_BUILTIN_DECLS = self._mod_decls._decl
# Register != operator
_BUILTIN_DECLS.register_callable_ref(FunctionRef("!="), "!=")

def _process_commands(self, cmds: Iterable[bindings._Command]) -> None:
"""
Expand Down Expand Up @@ -1152,22 +1153,10 @@ class Expr(metaclass=_ExprMetaclass):
Expression base class, which adds suport for != to all expression types.
"""

def __ne__(self: EXPR, other_expr: EXPR) -> Unit: # type: ignore[override, empty-body]
"""
Compare whether to expressions are not equal.

:param self: The expression to compare.
:param other_expr: The other expression to compare to, which must be of the same type.
:meta public:
"""
def __ne__(self, other: NoReturn) -> NoReturn: # type: ignore[override, empty-body]
...

def __eq__(self, other: NoReturn) -> NoReturn: # type: ignore[override, empty-body]
"""
Equality is currently not supported.

We only add this method so that if you try to use it MyPy will warn you.
"""
...


Expand Down Expand Up @@ -1492,6 +1481,11 @@ def eq(expr: EXPR) -> _EqBuilder[EXPR]:
return _EqBuilder(expr)


def ne(expr: EXPR) -> _NeBuilder[EXPR]:
"""Check if the given expression is not equal to the given value."""
return _NeBuilder(expr)


def panic(message: str) -> Action:
"""Raise an error with the given message."""
return Panic(message)
Expand Down Expand Up @@ -1596,6 +1590,30 @@ def __str__(self) -> str:
return f"eq({self.expr})"


@dataclass
class _NeBuilder(Generic[EXPR]):
expr: EXPR

def to(self, expr: EXPR) -> Unit:
l_expr = cast(RuntimeExpr, self.expr)
return cast(
Unit,
RuntimeExpr(
BUILTINS._mod_decls,
TypedExprDecl(
JustTypeRef("Unit"),
CallDecl(
FunctionRef("!="),
(l_expr.__egg_typed_expr__, convert_to_same_type(expr, l_expr).__egg_typed_expr__),
),
),
),
)

def __str__(self) -> str:
return f"ne({self.expr})"


@dataclass
class _SetBuilder(Generic[EXPR]):
lhs: Expr
Expand Down
2 changes: 1 addition & 1 deletion python/egglog/examples/bool.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
egraph.check(eq(T & T).to(T))
egraph.check(eq(T & F).to(F))
egraph.check(eq(T | F).to(T))
egraph.check((T | F) != F)
egraph.check(ne(T | F).to(F))

egraph.check(eq(i64(1).bool_lt(2)).to(T))
egraph.check(eq(i64(2).bool_lt(1)).to(F))
Expand Down
8 changes: 4 additions & 4 deletions python/egglog/examples/lambda_.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def freer(t: Term) -> StringSet:
union(t.eval()).with_(Val(i1 + i2))
),
rule(eq(t).to(t1 == t2), eq(t1.eval()).to(t2.eval())).then(union(t.eval()).with_(Val.TRUE)),
rule(eq(t).to(t1 == t2), eq(t1.eval()).to(v1), eq(t2.eval()).to(v2), v1 != v2).then(
rule(eq(t).to(t1 == t2), eq(t1.eval()).to(v1), eq(t2.eval()).to(v2), ne(v1).to(v2)).then(
union(t.eval()).with_(Val.FALSE)
),
rule(eq(v).to(t.eval())).then(union(t).with_(Term.val(v))),
Expand Down Expand Up @@ -154,12 +154,12 @@ def freer(t: Term) -> StringSet:
# let-var-same
rewrite(let_(x, t, Term.var(x))).to(t),
# let-var-diff
rewrite(let_(x, t, Term.var(y))).to(Term.var(y), x != y),
rewrite(let_(x, t, Term.var(y))).to(Term.var(y), ne(x).to(y)),
# let-lam-same
rewrite(let_(x, t, lam(x, t1))).to(lam(x, t1)),
# let-lam-diff
rewrite(let_(x, t, lam(y, t1))).to(lam(y, let_(x, t, t1)), x != y, eq(fv).to(freer(t)), fv.not_contains(y)),
rule(eq(t).to(let_(x, t1, lam(y, t2))), x != y, eq(fv).to(freer(t1)), fv.contains(y)).then(
rewrite(let_(x, t, lam(y, t1))).to(lam(y, let_(x, t, t1)), ne(x).to(y), eq(fv).to(freer(t)), fv.not_contains(y)),
rule(eq(t).to(let_(x, t1, lam(y, t2))), ne(x).to(y), eq(fv).to(freer(t1)), fv.contains(y)).then(
union(t).with_(lam(t.v(), let_(x, t1, let_(y, Term.var(t.v()), t2))))
),
)
Expand Down
2 changes: 1 addition & 1 deletion python/egglog/examples/resolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def pred(x: i64Like) -> Boolean: # type: ignore[empty-body]
union(~p0 | (~p1 | (p2 | F))).with_(T),
)
egraph.run(10)
egraph.check(T != F)
egraph.check(ne(T).to(F))
egraph.check(eq(p0).to(F))
egraph.check(eq(p2).to(F))
egraph
6 changes: 3 additions & 3 deletions python/egglog/exp/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,7 +278,7 @@ def __bool__(self) -> bool:
@array_api_module.register
def _int(i: i64, j: i64, r: Boolean, o: Int):
yield rewrite(Int(i) == Int(i)).to(TRUE)
yield rule(eq(r).to(Int(i) == Int(j)), i != j).then(union(r).with_(FALSE))
yield rule(eq(r).to(Int(i) == Int(j)), ne(i).to(j)).then(union(r).with_(FALSE))

yield rewrite(Int(i) >= Int(i)).to(TRUE)
yield rule(eq(r).to(Int(i) >= Int(j)), i > j).then(union(r).with_(TRUE))
Expand Down Expand Up @@ -666,7 +666,7 @@ def _tuple_value(
# Includes
rewrite(TupleValue.EMPTY.includes(v)).to(FALSE),
rewrite(TupleValue(v).includes(v)).to(TRUE),
rewrite(TupleValue(v).includes(v2)).to(FALSE, v != v2),
rewrite(TupleValue(v).includes(v2)).to(FALSE, ne(v).to(v2)),
rewrite((ti + ti2).includes(v)).to(ti.includes(v) | ti2.includes(v)),
]

Expand Down Expand Up @@ -1503,7 +1503,7 @@ def _assume_value_one_of(x: NDArray, v: Value, vs: TupleValue, idx: TupleInt):
def _ndarray_value_isfinite(arr: NDArray, x: Value, xs: TupleValue, i: Int, f: f64, b: Boolean):
yield rewrite(Value.int(i).isfinite()).to(TRUE)
yield rewrite(Value.bool(b).isfinite()).to(TRUE)
yield rewrite(Value.float(Float(f)).isfinite()).to(TRUE, f != f64(math.nan))
yield rewrite(Value.float(Float(f)).isfinite()).to(TRUE, ne(f).to(f64(math.nan)))

# a sum of an array is finite if all the values are finite
yield rewrite(isfinite(sum(arr))).to(NDArray.scalar(Value.bool(arr.index(ALL_INDICES).isfinite())))
Expand Down
3 changes: 2 additions & 1 deletion python/egglog/exp/array_api_program_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,8 @@ def _float_program(f: Float, g: Float, f64_: f64, i: Int, r: Rational):
yield rewrite(float_program(f * g)).to(Program("(") + float_program(f) + " * " + float_program(g) + ")")
yield rewrite(float_program(f / g)).to(Program("(") + float_program(f) + " / " + float_program(g) + ")")
yield rewrite(float_program(Float.rational(r))).to(
Program("float(") + Program(r.numer.to_string()) + " / " + Program(r.denom.to_string()) + ")", r.denom != i64(1)
Program("float(") + Program(r.numer.to_string()) + " / " + Program(r.denom.to_string()) + ")",
ne(r.denom).to(i64(1)),
)
yield rewrite(float_program(Float.rational(r))).to(
Program("float(") + Program(r.numer.to_string()) + ")", eq(r.denom).to(i64(1))
Expand Down
18 changes: 9 additions & 9 deletions python/egglog/exp/program_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ def _compile(
yield rule(
stmt,
p.compile(i),
p1.parent != p,
ne(p1.parent).to(p),
eq(s1).to(p1.expr),
).then(
set_(p.statements).to(join(s1, "\n")),
Expand Down Expand Up @@ -214,7 +214,7 @@ def _compile(
# Otherwise, if its not equal to either input, its not an identifier
yield rule(program_add, eq(p.expr).to(p1.expr), eq(b).to(p1.is_identifer)).then(set_(p.is_identifer).to(b))
yield rule(program_add, eq(p.expr).to(p2.expr), eq(b).to(p2.is_identifer)).then(set_(p.is_identifer).to(b))
yield rule(program_add, p.expr != p1.expr, p.expr != p2.expr).then(set_(p.is_identifer).to(Bool(False)))
yield rule(program_add, ne(p.expr).to(p1.expr), ne(p.expr).to(p2.expr)).then(set_(p.is_identifer).to(Bool(False)))

# Set parent of p1
yield rule(program_add, p.compile(i)).then(
Expand All @@ -228,7 +228,7 @@ def _compile(
yield rule(program_add, p.compile(i), p1.next_sym).then(set_(p2.parent).to(p))

# Compile p2, if p1 parent not equal, but p2 parent equal
yield rule(program_add, p.compile(i), p1.parent != p, eq(p2.parent).to(p)).then(p2.compile(i))
yield rule(program_add, p.compile(i), ne(p1.parent).to(p), eq(p2.parent).to(p)).then(p2.compile(i))

# Compile p2, if p1 parent eqal
yield rule(program_add, p.compile(i2), eq(p1.parent).to(p), eq(i).to(p1.next_sym), eq(p2.parent).to(p)).then(
Expand Down Expand Up @@ -259,8 +259,8 @@ def _compile(
yield rule(
program_add,
p.compile(i),
p1.parent != p,
p2.parent != p,
ne(p1.parent).to(p),
ne(p2.parent).to(p),
).then(
set_(p.statements).to(String("")),
set_(p.next_sym).to(i),
Expand All @@ -269,7 +269,7 @@ def _compile(
yield rule(
program_add,
eq(p1.parent).to(p),
p2.parent != p,
ne(p2.parent).to(p),
eq(s1).to(p1.statements),
eq(i).to(p1.next_sym),
).then(
Expand All @@ -280,7 +280,7 @@ def _compile(
yield rule(
program_add,
eq(p2.parent).to(p),
p1.parent != p,
ne(p1.parent).to(p),
eq(s2).to(p2.statements),
eq(i).to(p2.next_sym),
).then(
Expand Down Expand Up @@ -319,7 +319,7 @@ def _compile(
# 1. b. If p1 parent is not p, then just use assign as statement, next sym of i
yield rule(
program_assign,
p1.parent != p,
ne(p1.parent).to(p),
p.compile(i),
eq(s2).to(p1.expr),
eq(p1.is_identifer).to(Bool(False)),
Expand Down Expand Up @@ -347,7 +347,7 @@ def _compile(
# 1. b. If p1 parent is not p, then just use assign as statement, next sym of i
yield rule(
program_assign,
p1.parent != p,
ne(p1.parent).to(p),
p.compile(i),
eq(s2).to(p1.expr),
eq(p1.is_identifer).to(Bool(True)),
Expand Down
16 changes: 7 additions & 9 deletions python/egglog/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,15 +435,13 @@ def __post_init__(self) -> None:
self.__egg_callable_ref__ = PropertyRef(self.class_name, self.__egg_method_name__)
else:
self.__egg_callable_ref__ = MethodRef(self.class_name, self.__egg_method_name__)
# Special case for __ne__ which does not have a normal function defintion since
# it relies of type parameters
if self.__egg_method_name__ == "__ne__":
self.__egg_fn_decl__ = None
else:
try:
self.__egg_fn_decl__ = self.__egg_self__.__egg_decls__.get_function_decl(self.__egg_callable_ref__)
except KeyError as e:
raise AttributeError(f"Class {self.class_name} does not have method {self.__egg_method_name__}") from e
try:
self.__egg_fn_decl__ = self.__egg_self__.__egg_decls__.get_function_decl(self.__egg_callable_ref__)
except KeyError as e:
msg = f"Class {self.class_name} does not have method {self.__egg_method_name__}"
if self.__egg_method_name__ == "__ne__":
msg += ". Did you mean to use the ne(...).to(...)?"
raise AttributeError(msg) from e

def __call__(self, *args: object, **kwargs) -> RuntimeExpr | None:
args = (self.__egg_self__, *args)
Expand Down
1 change: 0 additions & 1 deletion python/tests/test_array_api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# import pytest
from egglog.exp.array_api import *
from egglog.exp.array_api_numba import array_api_numba_module
from egglog.exp.array_api_program_gen import *
Expand Down
26 changes: 25 additions & 1 deletion python/tests/test_high_level.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,9 +359,33 @@ def test_f64_negation() -> None:

def test_not_equals():
egraph = EGraph()
egraph.check(i64(10) != i64(2))
egraph.check(ne(i64(10)).to(i64(2)))


def test_custom_equality():
egraph = EGraph()

@egraph.class_
class Boolean(Expr):
def __init__(self, value: BoolLike) -> None:
...

def __eq__(self, other: Boolean) -> Boolean: # type: ignore[override]
...

def __ne__(self, other: Boolean) -> Boolean: # type: ignore[override]
...

egraph.register(rewrite(Boolean(True) == Boolean(True)).to(Boolean(False)))
egraph.register(rewrite(Boolean(True) != Boolean(True)).to(Boolean(True)))

should_be_true = Boolean(True) == Boolean(True)
should_be_false = Boolean(True) != Boolean(True)
egraph.register(should_be_true, should_be_false)
egraph.run(10)
egraph.check(eq(should_be_true).to(Boolean(False)))
egraph.check(eq(should_be_false).to(Boolean(True)))

class TestMutate:
def test_setitem_defaults(self):
egraph = EGraph()
Expand Down
4 changes: 2 additions & 2 deletions test-data/unit/check-high-level.test
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from egglog import *
_ = i64(0) == i64(0) # E: Unsupported operand types for == ("i64" and "i64")

[case notEqAllowed]
[case notEqNotAllowed]
from egglog import *
_ = i64(0) != i64(0) # type: Unit
_ = i64(0) != i64(0) # E: Unsupported operand types for != ("i64" and "i64")

[case eqToAllowed]
from egglog import *
Expand Down
Loading