From 137ccee6a41fb046b9a079527b6bf732928b42fe Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Thu, 14 Sep 2023 01:28:21 -0400 Subject: [PATCH 01/23] Other failed experiment --- python/tests/test_high_level.py | 97 ++++++++++++++++++++++++++++++++- 1 file changed, 95 insertions(+), 2 deletions(-) diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 19095e15..624474db 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -8,6 +8,7 @@ import pytest from egglog import * +from egglog.builtins import count_matches from egglog.declarations import ( CallDecl, FunctionRef, @@ -434,7 +435,7 @@ def __radd__(self, other: Math) -> Math: @pytest.mark.xfail(reason="https://github.com/egraphs-good/egglog/issues/229") def test_imperative(): - egraph = EGraph(seminaive=False) + egraph = EGraph() @egraph.function(merge=lambda old, new: join(old, new), default=String("")) def statements() -> String: @@ -505,6 +506,98 @@ def _rules(s: String, y_expr: String, z_expr: String, x: Math, i: i64, y: Math, y = egraph.let("y", Math(2) * (Math.var("x") + Math(3))) - egraph.run(3) + egraph.run(10) + egraph.check(eq(y.expr).to(String("_1"))) + egraph.check(eq(statements()).to(String("_0 = x + 3\n_1 = 2 * _0\n"))) + + +@pytest.mark.xfail(reason="applies rules too many times b/c keeps matching") +def test_imperative_stable(): + # More stable version of imperative, which uses idempotent merge function + egraph = EGraph() + + @egraph.function(merge=lambda old, new: new) + def statements() -> String: + ... + + egraph.register(set_(statements()).to(String(""))) + + @egraph.function(merge=lambda old, new: old + new, default=i64(0)) + def gensym() -> i64: + ... + + @egraph.class_ + class Math(Expr): + @egraph.method(egg_fn="Num") + def __init__(self, value: i64Like) -> None: + ... + + @egraph.method(egg_fn="Var") + @classmethod + def var(cls, v: StringLike) -> Math: + ... + + @egraph.method(egg_fn="Add") + def __add__(self, other: Math) -> Math: + ... + + @egraph.method(egg_fn="Mul") + def __mul__(self, other: Math) -> Math: + ... + + @egraph.method(egg_fn="expr") # type: ignore[misc] + @property + def expr(self) -> String: + ... + + @egraph.register + def _rules( + s: String, + y_expr: String, + z_expr: String, + old_statements: String, + x: Math, + i: i64, + y: Math, + z: Math, + old_gensym: i64, + ): + gensym_var = join("_", gensym().to_string()) + yield rule( + eq(x).to(Math.var(s)), + ).then( + set_(x.expr).to(s), + ) + + yield rule( + eq(x).to(Math(i)), + ).then( + set_(x.expr).to(i.to_string()), + ) + + yield rule( + eq(x).to(y + z), + eq(y_expr).to(y.expr), + eq(z_expr).to(z.expr), + eq(old_statements).to(statements()), + ).then( + set_(x.expr).to(gensym_var), + set_(statements()).to(join(old_statements, gensym_var, " = ", y_expr, " + ", z_expr, "\n")), + set_(gensym()).to(i64(1)), + ) + yield rule( + eq(x).to(y * z), + eq(y_expr).to(y.expr), + eq(z_expr).to(z.expr), + eq(old_statements).to(statements()), + ).then( + set_(x.expr).to(gensym_var), + set_(statements()).to(join(old_statements, gensym_var, " = ", y_expr, " * ", z_expr, "\n")), + set_(gensym()).to(i64(1)), + ) + + y = egraph.let("y", Math(2) * (Math.var("x") + Math(3))) + + egraph.run(10) egraph.check(eq(y.expr).to(String("_1"))) egraph.check(eq(statements()).to(String("_0 = x + 3\n_1 = 2 * _0\n"))) From 37b31fdf2c672fdcfd0b9f649e676bd72cc84e8b Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 20 Sep 2023 15:50:12 -0400 Subject: [PATCH 02/23] Tmp --- docs/changelog.md | 2 ++ docs/reference/egglog-translation.md | 15 -------- docs/reference/python-integration.md | 52 ++++++++++++++++++++++++++-- python/egglog/builtins.py | 9 +++-- python/egglog/egraph.py | 44 +++++++++++++++++------ 5 files changed, 92 insertions(+), 30 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index 6df7aa9d..5ee3364e 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -16,6 +16,8 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea - Add ability to inline leaves $n$ times instead of just once for visualization [#48](https://github.com/metadsl/egglog-python/pull/48) - Add `Relation` and `PrintOverallStatistics` low level commands [#46](https://github.com/metadsl/egglog-python/pull/46) - Adds `count-matches` and `replace` string commands [#46](https://github.com/metadsl/egglog-python/pull/46) +- Adds ability for custom user defined types in a union for proper static typing with conversions +- Adds `py_eval` function to `EGraph` as a helper to eval Python code. ### Bug fixes diff --git a/docs/reference/egglog-translation.md b/docs/reference/egglog-translation.md index 25877037..6a2c73a3 100644 --- a/docs/reference/egglog-translation.md +++ b/docs/reference/egglog-translation.md @@ -245,21 +245,6 @@ As shown above, we can also use the `@classmethod` and `@property` decorators to Note that reflected methods (i.e. `__radd__`) are handled as a special case. If defined, they won't create their own egglog functions. Instead, whenever a reflected method is called, we will try to find the corresponding non-reflected method and call that instead. -#### Custom Type Promotion - -Similar to how an `int` can be automatically upcasted to an `i64`, we also support registering conversion to your custom types. For example: - -```{code-cell} python -converter(int, Math, Math) -converter(str, Math, Math.var) - -Math(2) + 30 + "x" -# equal to -Math(2) + Math(i64(30)) + Math.var(String("x")) -``` - -Regstering a conversion from A to B will also register all transitively reachable conversions from A to B. - ### Declarations In egglog, the `(declare ...)` command is syntactic sugar for a nullary function. In Python, these can be declare either as class variables or with the toplevel `egraph.constant` function: diff --git a/docs/reference/python-integration.md b/docs/reference/python-integration.md index 3fa8cd59..6a4fda09 100644 --- a/docs/reference/python-integration.md +++ b/docs/reference/python-integration.md @@ -87,11 +87,59 @@ locals_expr = egraph.save_object(locals()) globals_expr = egraph.save_object(globals()) # Need `one` to map to the expression for `1` not the Python object of the expression amended_globals = globals_expr.dict_update(PyObject.from_string("one"), one) -evalled = py_eval("one + 2", locals_expr, amended_globals) +evalled = py_eval("my_add(one, 2)", locals_expr, amended_globals) assert egraph.load_object(egraph.extract(evalled)) == 3 ``` -This is a bit subtle at the moment, and we plan on adding an easier wrapper to eval arbitrary Python code in the future. +### Simpler Eval + +Instead of using the above low level primitive for evaling, there is a higher level wrapper function, `egraph.eval_fn`. + +It takes in a Python function and converts it to a function of PyObjects, by using `py_eval` +under the hood. + +The above code code be re-written like this: + +```{code-cell} python +def my_add(a, b): + return a + b + +evalled = egraph.eval_fn(lambda a: my_add(a, 2))(one) +assert egraph.load_object(egraph.extract(evalled)) == 3 +``` + +#### Custom Type Promotion + +Similar to how an `int` can be automatically upcasted to an `i64`, we also support registering conversion to your custom types. For example: + +```{code-cell} python +converter(i64, Math, Math) +converter(String, Math, Math.var) + +Math(2) + i64(30) + String("x") +# equal to +Math(2) + Math(i64(30)) + Math.var(String("x")) +``` + +Regstering a conversion from A to B will also register all transitively reachable conversions from A to B, so you can also use: + +```{code-cell} python +Math(2) + 30 + "x" +``` + +If you want to have this work with the static type checker, you can define your own `Union` type, which MUST include +have the Egglog class as the first item in the union. For example, in this case you could then define: + +```{code-cell} python +from typing import Union +MathLike = Union[Math, i64Like, StringLike] + +@egraph.function +def some_math_fn(x: MathLike) -> MathLike: + ... + +some_math_fn(10) +``` ## "Preserved" methods diff --git a/python/egglog/builtins.py b/python/egglog/builtins.py index 39d3765f..d64aadbe 100644 --- a/python/egglog/builtins.py +++ b/python/egglog/builtins.py @@ -27,7 +27,7 @@ ] -StringLike = Union[str, "String"] +StringLike = Union["String", str] @BUILTINS.class_ @@ -48,7 +48,7 @@ def join(*strings: StringLike) -> String: # type: ignore[empty-body] converter(str, String, String) # The types which can be convertered into an i64 -i64Like = Union[int, "i64"] +i64Like = Union["i64", int] @BUILTINS.class_(egg_sort="i64") @@ -159,7 +159,7 @@ def count_matches(s: StringLike, pattern: StringLike) -> i64: # type: ignore[em ... -f64Like = Union[float, "f64"] +f64Like = Union["f64", float] @BUILTINS.class_(egg_sort="f64") @@ -318,6 +318,9 @@ def __sub__(self, other: Set[T]) -> Set[T]: # type: ignore[empty-body] def __and__(self, other: Set[T]) -> Set[T]: # type: ignore[empty-body] ... + # def peek(self) -> T: + # ... + @BUILTINS.class_(egg_sort="Rational") class Rational(Expr): diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 5c1d59db..16626584 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -18,6 +18,7 @@ Literal, NoReturn, Optional, + Protocol, TypeVar, Union, cast, @@ -93,6 +94,11 @@ ALWAYS_MUTATES_SELF = {"__setitem__", "__delitem__"} +class PyObjectFunction(Protocol): + def __call__(self, *__args: PyObject) -> PyObject: + ... + + @dataclass class _BaseModule(ABC): """ @@ -490,17 +496,10 @@ def _resolve_type_annotation( ) -> TypeOrVarRef: if isinstance(tp, TypeVar): return ClassTypeVarRef(cls_typevars.index(tp)) - # If there is a union, it should be of a literal and another type to allow type promotion + # If there is a union, then we assume the first item is the type we want, and the others are types that can be converted to that type. if get_origin(tp) == Union: - args = get_args(tp) - if len(args) != 2: - raise TypeError("Union types are only supported for type promotion") - fst, snd = args - if fst in {int, str, float}: - return self._resolve_type_annotation(snd, cls_typevars, cls_type_and_name) - if snd in {int, str, float}: - return self._resolve_type_annotation(fst, cls_typevars, cls_type_and_name) - raise TypeError("Union types are only supported for type promotion") + first, *_rest = get_args(tp) + return self._resolve_type_annotation(first, cls_typevars, cls_type_and_name) # If this is the type for the class, use the class name if cls_type_and_name and tp == cls_type_and_name[0]: @@ -879,6 +878,31 @@ def load_object(self, obj: PyObject) -> object: expr = typed_expr_decl.to_egg(self._mod_decls) return self._egraph.load_object(expr) + def eval_fn(self, fn: Callable) -> PyObjectFunction: + """ + Takes a python callable and maps it to a callable which takes + and returns PyObjects. + + It translates it to a call which uses `py_eval` to call the function, passing in the + args as locals, and using the globals from function. + """ + from .builtins import py_eval + + fn_globals = self.save_object(fn.__globals__) + fn_locals = self.save_object({"__fn": fn}) + + def inner(*__args: PyObject, __fn_locals=fn_locals) -> PyObject: + new_kvs: list[PyObject] = [] + eval_str = "__fn(" + for i, arg in enumerate(__args): + new_kvs.append(self.save_object(f"__arg_{i}")) + new_kvs.append(arg) + eval_str += f"__arg_{i}, " + eval_str += ")" + return py_eval(eval_str, fn_locals.dict_update(new_kvs), fn_globals) + + return inner + @classmethod def current(cls) -> EGraph: """ From 6d767b34a5f170f7d85e3f925229204077ab1fd7 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 20 Sep 2023 15:55:01 -0400 Subject: [PATCH 03/23] Fixes --- python/egglog/egraph.py | 4 ++-- python/egglog/examples/lambda.py | 1 - python/egglog/examples/ndarrays.py | 1 - python/egglog/exp/array_api.py | 1 - python/tests/test_high_level.py | 2 -- 5 files changed, 2 insertions(+), 7 deletions(-) diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 16626584..fb32f040 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -677,7 +677,7 @@ class EGraph(_BaseModule): _decl_stack: list[Declarations] = field(default_factory=list, repr=False) _token: Optional[Token[EGraph]] = None - def __post_init__(self, modules: list[Module], seminaive) -> None: + def __post_init__(self, modules: list[Module], seminaive: bool) -> None: # type: ignore super().__post_init__(modules) self._egraph = bindings.EGraph(seminaive=seminaive) for m in self._flatted_deps: @@ -899,7 +899,7 @@ def inner(*__args: PyObject, __fn_locals=fn_locals) -> PyObject: new_kvs.append(arg) eval_str += f"__arg_{i}, " eval_str += ")" - return py_eval(eval_str, fn_locals.dict_update(new_kvs), fn_globals) + return py_eval(eval_str, fn_locals.dict_update(*new_kvs), fn_globals) return inner diff --git a/python/egglog/examples/lambda.py b/python/egglog/examples/lambda.py index a5888b12..fd671517 100644 --- a/python/egglog/examples/lambda.py +++ b/python/egglog/examples/lambda.py @@ -2,7 +2,6 @@ Lambda Calculus =============== """ -# mypy: disable-error-code=empty-body from __future__ import annotations from typing import Callable, ClassVar diff --git a/python/egglog/examples/ndarrays.py b/python/egglog/examples/ndarrays.py index e1a1c610..599611fd 100644 --- a/python/egglog/examples/ndarrays.py +++ b/python/egglog/examples/ndarrays.py @@ -4,7 +4,6 @@ Example of building NDarray in the vein of Mathemetics of Arrays. """ -# mypy: disable-error-code=empty-body from __future__ import annotations from egglog import * diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index f5c5acec..9876a66d 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -1,4 +1,3 @@ -# mypy: disable-error-code=empty-body from __future__ import annotations import itertools diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 624474db..1cea7ca5 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -1,4 +1,3 @@ -# mypy: disable-error-code="empty-body" from __future__ import annotations import importlib @@ -8,7 +7,6 @@ import pytest from egglog import * -from egglog.builtins import count_matches from egglog.declarations import ( CallDecl, FunctionRef, From ae4cd198652bd3d88853582e299429f1589da4dc Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 27 Sep 2023 15:24:08 -0400 Subject: [PATCH 04/23] First sort of working version but much too slow --- pyproject.toml | 2 +- python/egglog/egraph.py | 27 +++-- python/egglog/examples/lambda.py | 2 + python/egglog/examples/ndarrays.py | 2 + python/egglog/exp/array_api.py | 2 + python/egglog/exp/program_gen.py | 156 +++++++++++++++++++++++++++++ python/tests/test_high_level.py | 93 +++++++++++++++++ python/tests/test_program_gen.py | 67 +++++++++++++ 8 files changed, 342 insertions(+), 9 deletions(-) create mode 100644 python/egglog/exp/program_gen.py create mode 100644 python/tests/test_program_gen.py diff --git a/pyproject.toml b/pyproject.toml index badc204b..a4510256 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ check_untyped_defs = true strict_equality = true warn_unused_configs = true allow_redefinition = true -enable_incomplete_feature = ["Unpack", "TypeVarTuple"] +# enable_incomplete_feature = ["Unpack", "TypeVarTuple"] exclude = ["__snapshots__", "_build", "^conftest.py$"] [tool.maturin] diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index fb32f040..825a4fd6 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -7,8 +7,7 @@ from dataclasses import InitVar, dataclass, field from inspect import Parameter, currentframe, signature from types import FunctionType -from typing import _GenericAlias # type: ignore[attr-defined] -from typing import ( +from typing import ( # type: ignore[attr-defined] TYPE_CHECKING, Any, Callable, @@ -19,8 +18,10 @@ NoReturn, Optional, Protocol, + TypedDict, TypeVar, Union, + _GenericAlias, cast, get_type_hints, overload, @@ -28,7 +29,7 @@ import graphviz from egglog.declarations import REFLECTED_BINARY_METHODS, Declarations -from typing_extensions import ParamSpec, get_args, get_origin +from typing_extensions import ParamSpec, Unpack, get_args, get_origin from . import bindings from .declarations import * @@ -664,6 +665,12 @@ def save_object(self, obj: object) -> PyObject: return cast("PyObject", RuntimeExpr(self._mod_decls, typed_expr_decl)) +class GraphvizKwargs(TypedDict, total=False): + max_functions: Optional[int] + max_calls_per_function: Optional[int] + n_inline_leaves: int + + @dataclass class EGraph(_BaseModule): """ @@ -695,7 +702,7 @@ def _repr_mimebundle_(self, *args, **kwargs): return {"image/svg+xml": self.graphviz().pipe(format="svg", quiet=True, encoding="utf-8")} - def graphviz(self, **kwargs) -> graphviz.Source: + def graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source: return graphviz.Source(self._egraph.to_graphviz_string(**kwargs)) def _repr_html_(self) -> str: @@ -707,13 +714,17 @@ def _repr_html_(self) -> str: """ return self.graphviz().pipe(format="svg", quiet=True).decode() - def display(self, **kwargs): + def display(self, **kwargs: Unpack[GraphvizKwargs]): """ Displays the e-graph in the notebook. """ - from IPython.display import SVG, display + graphviz = self.graphviz(**kwargs) + if hasattr(__builtins__, "__IPYTHON__"): + from IPython.display import SVG, display - display(SVG(self.graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8"))) + display(SVG(graphviz.pipe(format="svg", quiet=True, encoding="utf-8"))) + else: + graphviz.view() @overload def simplify(self, expr: EXPR, limit: int, /, *until: Fact, ruleset: Optional[Ruleset] = None) -> EXPR: @@ -1127,7 +1138,7 @@ def __str__(self) -> str: def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Set: egg_call = self._call.__egg_typed_expr__.expr.to_egg(mod_decls) if not isinstance(egg_call, bindings.Call): - raise ValueError(f"Can only create a call with a call for the lhs, got {self._call}") + raise ValueError(f"Can only create a set with a call for the lhs, got {self._call}") return bindings.Set( egg_call.name, egg_call.args, diff --git a/python/egglog/examples/lambda.py b/python/egglog/examples/lambda.py index fd671517..a89e87ea 100644 --- a/python/egglog/examples/lambda.py +++ b/python/egglog/examples/lambda.py @@ -1,3 +1,5 @@ +# mypy: disable-error-code="empty-body" + """ Lambda Calculus =============== diff --git a/python/egglog/examples/ndarrays.py b/python/egglog/examples/ndarrays.py index 599611fd..4f5b9780 100644 --- a/python/egglog/examples/ndarrays.py +++ b/python/egglog/examples/ndarrays.py @@ -1,3 +1,5 @@ +# mypy: disable-error-code="empty-body" + """ N-Dimensional Arrays ==================== diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index 9876a66d..286dbc8f 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -1,3 +1,5 @@ +# mypy: disable-error-code="empty-body" + from __future__ import annotations import itertools diff --git a/python/egglog/exp/program_gen.py b/python/egglog/exp/program_gen.py new file mode 100644 index 00000000..d9ba6558 --- /dev/null +++ b/python/egglog/exp/program_gen.py @@ -0,0 +1,156 @@ +# mypy: disable-error-code="empty-body" +""" +Builds up imperative string expressions from a functional expression. +""" +from __future__ import annotations + +from typing import Union + +from egglog import * + +program_gen_module = Module() + +ProgramLike = Union["Program", StringLike] + + +@program_gen_module.class_ +class Program(Expr): + """ + Semanticallly represents an expression with a number of ordered statements that it depends on to run. + + The expression and statements are all represented as strings. + """ + + def __init__(self, expr: StringLike) -> None: + """ + Create a program based on a string expression. + """ + ... + + def __add__(self, other: ProgramLike) -> Program: + """ + Concats the strings of the two expressions and also the statements. + """ + ... + + def statement(self, statement: ProgramLike) -> Program: + """ + Uses the expression of the statement and adds that as a statement to the program. + """ + ... + + def assign(self) -> Program: + """ + Returns a new program with the expression assigned to a gensym. + """ + ... + + +converter(String, Program, Program) + + +@program_gen_module.class_ +class Compiler(Expr): + def __init__( + self, + # The mapping from programs to their compiled of the expressions + compiled_programs: Map[Program, Program] = Map[Program, Program].empty(), + # The next gensym counter + sym_counter: i64Like = i64(0), + # The cumulative list of statements seperated by newlines, all stored as the expression of a program + compiled_statements: Program = Program(""), + # The compiled expression from the last `compile` call + compiled_expr: Program = Program(""), + ) -> None: + ... + + def compile(self, program: ProgramLike) -> Compiler: + ... + + @property + def expr(self) -> Program: + ... + + def set_expr(self, expr: Program) -> Compiler: + ... + + def add_statement(self, statements: Program) -> Compiler: + ... + + @property + def string(self) -> String: + ... + + def added_sym(self) -> Compiler: + ... + + @property + def next_sym(self) -> Program: + ... + + +@program_gen_module.register +def _compile( + s: String, + s1: String, + s2: String, + p: Program, + p1: Program, + p2: Program, + c: Compiler, + statements: Program, + expr: Program, + i: i64, + m: Map[Program, Program], +): + # Combining two strings is just joining them + yield rewrite(Program(s1) + Program(s2)).to(Program(join(s1, s2))) + + compiler = Compiler(m, i, statements, expr) + # Compiling a program that is already in the compiled programs is a no-op, but the expression is updated + yield rewrite(compiler.compile(p)).to(compiler.set_expr(m[p]), m.contains(p)) + # Compiling a string just gives that string + program_expr = Program(s) + yield rewrite(compiler.compile(program_expr)).to( + Compiler(m.insert(program_expr, program_expr), i, statements, program_expr), m.not_contains(program_expr) + ) + # Compiling a statement means that we should use the expression of the statement as a statement and use the expression + # of the underlying program + program_statement = p.statement(p1) + p_compiled = compiler.compile(p) + p1_compiled = p_compiled.compile(p1) + yield rewrite(compiler.compile(program_statement)).to( + p1_compiled.add_statement(p1_compiled.expr).set_expr(p_compiled.expr), m.not_contains(program_statement) + ) + + # Compiling an addition is the same as compiling one, then the other, then setting the expression as the addition + # of the two + program_add = p1 + p2 + p1_compiled = compiler.compile(p1) + p2_compiled = p1_compiled.compile(p2) + yield rewrite(compiler.compile(program_add)).to( + p2_compiled.set_expr(p1_compiled.expr + p2_compiled.expr), m.not_contains(program_add) + ) + + # Compiling an assign is the same as compiling the expression, adding an assign statement, then setting the + # expression as the gensym + program_assign = p.assign() + p_compiled = compiler.compile(p) + yield rewrite(compiler.compile(program_assign)).to( + p_compiled.add_statement(p_compiled.next_sym + " = " + p_compiled.expr) + .set_expr(p_compiled.next_sym) + .added_sym(), + m.not_contains(program_assign), + ) + + yield rewrite(compiler.set_expr(p)).to(Compiler(m, i, statements, p)) + yield rewrite(compiler.add_statement(p)).to(Compiler(m, i, statements + "\n" + p, expr), m.not_contains(p)) + yield rewrite(compiler.add_statement(p)).to(compiler, m.contains(p)) + yield rewrite(compiler.expr).to(expr) + yield rewrite(compiler.added_sym()).to(Compiler(m, i + 1, statements, expr)) + yield rewrite(compiler.next_sym).to(Program(join("_", i.to_string()))) + + # Set `to_string` to the compiled statements added to the compiled expression + yield rule( + eq(c).to(Compiler(m, i, Program(s1), Program(s2))), + ).then(set_(c.string).to(join(s1, "\n", s2, "\n"))) diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 1cea7ca5..41b11ba3 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="empty-body" from __future__ import annotations import importlib @@ -599,3 +600,95 @@ def _rules( egraph.run(10) egraph.check(eq(y.expr).to(String("_1"))) egraph.check(eq(statements()).to(String("_0 = x + 3\n_1 = 2 * _0\n"))) + + +def test_imperative_python(): + # Tries implementing the same functionality but with a PyObject + # More stable version of imperative, which uses idempotent merge function + egraph = EGraph() + + @egraph.function(merge=lambda old, new: new) + def statements() -> String: + ... + + egraph.register(set_(statements()).to(String(""))) + + @egraph.function(merge=lambda old, new: old + new, default=i64(0)) + def gensym() -> i64: + ... + + @egraph.class_ + class Math(Expr): + @egraph.method(egg_fn="Num") + def __init__(self, value: i64Like) -> None: + ... + + @egraph.method(egg_fn="Var") + @classmethod + def var(cls, v: StringLike) -> Math: + ... + + @egraph.method(egg_fn="Add") + def __add__(self, other: Math) -> Math: + ... + + @egraph.method(egg_fn="Mul") + def __mul__(self, other: Math) -> Math: + ... + + @egraph.method(egg_fn="expr") # type: ignore[misc] + @property + def expr(self) -> String: + ... + + @egraph.register + def _rules( + s: String, + y_expr: String, + z_expr: String, + old_statements: String, + x: Math, + i: i64, + y: Math, + z: Math, + old_gensym: i64, + ): + gensym_var = join("_", gensym().to_string()) + yield rule( + eq(x).to(Math.var(s)), + ).then( + set_(x.expr).to(s), + ) + + yield rule( + eq(x).to(Math(i)), + ).then( + set_(x.expr).to(i.to_string()), + ) + + yield rule( + eq(x).to(y + z), + eq(y_expr).to(y.expr), + eq(z_expr).to(z.expr), + eq(old_statements).to(statements()), + ).then( + set_(x.expr).to(gensym_var), + set_(statements()).to(join(old_statements, gensym_var, " = ", y_expr, " + ", z_expr, "\n")), + set_(gensym()).to(i64(1)), + ) + yield rule( + eq(x).to(y * z), + eq(y_expr).to(y.expr), + eq(z_expr).to(z.expr), + eq(old_statements).to(statements()), + ).then( + set_(x.expr).to(gensym_var), + set_(statements()).to(join(old_statements, gensym_var, " = ", y_expr, " * ", z_expr, "\n")), + set_(gensym()).to(i64(1)), + ) + + y = egraph.let("y", Math(2) * (Math.var("x") + Math(3))) + + egraph.run(10) + egraph.check(eq(y.expr).to(String("_1"))) + egraph.check(eq(statements()).to(String("_0 = x + 3\n_1 = 2 * _0\n"))) diff --git a/python/tests/test_program_gen.py b/python/tests/test_program_gen.py new file mode 100644 index 00000000..f406a0ed --- /dev/null +++ b/python/tests/test_program_gen.py @@ -0,0 +1,67 @@ +# mypy: disable-error-code="empty-body" +from __future__ import annotations + +from egglog import * +from egglog.exp.program_gen import * + + +def test_to_string(snapshot_py) -> None: + egraph = EGraph([program_gen_module]) + + @egraph.class_ + class Math(Expr): + def __init__(self, value: i64Like) -> None: + ... + + @classmethod + def var(cls, v: StringLike) -> Math: + ... + + def __add__(self, other: Math) -> Math: + ... + + def __mul__(self, other: Math) -> Math: + ... + + def __neg__(self) -> Math: + ... + + @property + def program(self) -> Program: + ... + + @egraph.function + def assume_pos(x: Math) -> Math: + ... + + @egraph.register + def _rules( + s: String, + y_expr: String, + z_expr: String, + old_statements: String, + x: Math, + i: i64, + y: Math, + z: Math, + old_gensym: i64, + ): + yield rewrite(Math.var(s).program).to(Program(s)) + yield rewrite(Math(i).program).to(Program(i.to_string())) + yield rewrite((y + z).program).to((y.program + " + " + z.program).assign()) + yield rewrite((y * z).program).to((y.program + " * " + z.program).assign()) + yield rewrite((-y).program).to(Program("-(") + y.program + ")") + assigned_x = x.program.assign() + yield rewrite(assume_pos(x).program).to(assigned_x.statement(Program("assert ") + assigned_x + " > 0")) + + first = assume_pos(-Math.var("x")) + Math(3) + y = egraph.let("y", Math(2) * first + Math(0) + first) + compiled = Compiler().compile(y.program) + egraph.register(compiled) + egraph.run(100) + egraph.display(max_calls_per_function=40, n_inline_leaves=2) + assert egraph.load_object(egraph.extract(PyObject.from_string(compiled.string))) == snapshot_py + + # egraph.run(10) + # egraph.check(eq(y.expr).to(String("_1"))) + # egraph.check(eq(y.statements).to(String("_0 = x + -3\n_1 = 2 * _0\n"))) From 0e02ffcc40a4e0113bd267a798b9b58f560c9bcd Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 27 Sep 2023 19:51:58 -0400 Subject: [PATCH 05/23] Change fact to take any expression not just unit --- python/egglog/egraph.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 825a4fd6..803a5ab3 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -1502,7 +1502,7 @@ def _action_like(action_like: ActionLike) -> Action: return action_like -FactLike = Union[Fact, Unit] +FactLike = Union[Fact, Expr] def _fact_likes(fact_likes: Iterable[FactLike]) -> tuple[Fact, ...]: From 23064285d839212772cc16a9d4272356e9f6a477 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 27 Sep 2023 19:53:34 -0400 Subject: [PATCH 06/23] Add first working version --- python/egglog/exp/program_gen.py | 138 ++++++++++-------- .../test_program_gen/test_to_string.py | 5 + python/tests/test_program_gen.py | 21 ++- 3 files changed, 93 insertions(+), 71 deletions(-) create mode 100644 python/tests/__snapshots__/test_program_gen/test_to_string.py diff --git a/python/egglog/exp/program_gen.py b/python/egglog/exp/program_gen.py index d9ba6558..d6fb3f3d 100644 --- a/python/egglog/exp/program_gen.py +++ b/python/egglog/exp/program_gen.py @@ -45,48 +45,35 @@ def assign(self) -> Program: """ ... - -converter(String, Program, Program) - - -@program_gen_module.class_ -class Compiler(Expr): - def __init__( - self, - # The mapping from programs to their compiled of the expressions - compiled_programs: Map[Program, Program] = Map[Program, Program].empty(), - # The next gensym counter - sym_counter: i64Like = i64(0), - # The cumulative list of statements seperated by newlines, all stored as the expression of a program - compiled_statements: Program = Program(""), - # The compiled expression from the last `compile` call - compiled_expr: Program = Program(""), - ) -> None: - ... - - def compile(self, program: ProgramLike) -> Compiler: - ... - @property - def expr(self) -> Program: - ... - - def set_expr(self, expr: Program) -> Compiler: + def expr(self) -> String: + """ + Returns the expression of the program, if it's been compiled + """ ... - def add_statement(self, statements: Program) -> Compiler: + @property + def statements(self) -> String: + """ + Returns the statements of the program, if it's been compiled + """ ... @property - def string(self) -> String: + def next_sym(self) -> i64: + """ + Returns the next gensym to use. + """ ... - def added_sym(self) -> Compiler: - ... + @program_gen_module.method(default=Unit()) + def compile(self, next_sym: i64 = i64(0)) -> Unit: + """ + Triggers compilation of the program. + """ - @property - def next_sym(self) -> Program: - ... + +converter(String, Program, Program) @program_gen_module.register @@ -94,10 +81,12 @@ def _compile( s: String, s1: String, s2: String, + s3: String, + s4: String, p: Program, p1: Program, p2: Program, - c: Compiler, + # c: Compiler, statements: Program, expr: Program, i: i64, @@ -106,51 +95,72 @@ def _compile( # Combining two strings is just joining them yield rewrite(Program(s1) + Program(s2)).to(Program(join(s1, s2))) - compiler = Compiler(m, i, statements, expr) - # Compiling a program that is already in the compiled programs is a no-op, but the expression is updated - yield rewrite(compiler.compile(p)).to(compiler.set_expr(m[p]), m.contains(p)) # Compiling a string just gives that string program_expr = Program(s) - yield rewrite(compiler.compile(program_expr)).to( - Compiler(m.insert(program_expr, program_expr), i, statements, program_expr), m.not_contains(program_expr) + yield rule(program_expr.compile(i)).then( + set_(program_expr.expr).to(s), + set_(program_expr.statements).to(String("")), + set_(program_expr.next_sym).to(i), ) # Compiling a statement means that we should use the expression of the statement as a statement and use the expression # of the underlying program program_statement = p.statement(p1) - p_compiled = compiler.compile(p) - p1_compiled = p_compiled.compile(p1) - yield rewrite(compiler.compile(program_statement)).to( - p1_compiled.add_statement(p1_compiled.expr).set_expr(p_compiled.expr), m.not_contains(program_statement) + # First compile the expression + yield rule(program_statement.compile(i)).then(p.compile(i)) + # Then, when the expression is compiled, compile the statement, and set the expr of the whole statement + yield rule( + eq(p2).to(program_statement), + eq(i).to(p.next_sym), + eq(s).to(p.expr), + ).then(p1.compile(i), set_(p2.expr).to(s)) + # When both are compiled, add the statements of both + the expr of p1 to the statements of p + yield rule( + eq(p2).to(program_statement), + eq(s1).to(p.statements), + eq(s2).to(p1.statements), + eq(s).to(p1.expr), + eq(i).to(p1.next_sym), + ).then( + set_(p2.statements).to(join(s1, s2, s, "\n")), + set_(p2.next_sym).to(i), ) # Compiling an addition is the same as compiling one, then the other, then setting the expression as the addition # of the two program_add = p1 + p2 - p1_compiled = compiler.compile(p1) - p2_compiled = p1_compiled.compile(p2) - yield rewrite(compiler.compile(program_add)).to( - p2_compiled.set_expr(p1_compiled.expr + p2_compiled.expr), m.not_contains(program_add) + # Compile the first + yield rule(program_add.compile(i)).then(p1.compile(i)) + # Once the first is finished, do the second + yield rule(program_add, eq(i).to(p1.next_sym)).then(p2.compile(i)) + # Once the second is finished, set the the addition to the addition of the two expressions + yield rule( + eq(p).to(program_add), + eq(s1).to(p1.expr), + eq(s2).to(p2.expr), + eq(s3).to(p1.statements), + eq(s4).to(p2.statements), + eq(i).to(p2.next_sym), + ).then( + set_(p.expr).to(join(s1, s2)), + set_(p.statements).to(join(s3, s4)), + set_(p.next_sym).to(i), ) # Compiling an assign is the same as compiling the expression, adding an assign statement, then setting the # expression as the gensym program_assign = p.assign() - p_compiled = compiler.compile(p) - yield rewrite(compiler.compile(program_assign)).to( - p_compiled.add_statement(p_compiled.next_sym + " = " + p_compiled.expr) - .set_expr(p_compiled.next_sym) - .added_sym(), - m.not_contains(program_assign), - ) - - yield rewrite(compiler.set_expr(p)).to(Compiler(m, i, statements, p)) - yield rewrite(compiler.add_statement(p)).to(Compiler(m, i, statements + "\n" + p, expr), m.not_contains(p)) - yield rewrite(compiler.add_statement(p)).to(compiler, m.contains(p)) - yield rewrite(compiler.expr).to(expr) - yield rewrite(compiler.added_sym()).to(Compiler(m, i + 1, statements, expr)) - yield rewrite(compiler.next_sym).to(Program(join("_", i.to_string()))) + # Compile the expression + yield rule(program_assign.compile(i)).then(p.compile(i)) + # Once the expression is compiled, add the assign statement to the statements and set the expr - # Set `to_string` to the compiled statements added to the compiled expression + symbol = join(String("_"), i.to_string()) yield rule( - eq(c).to(Compiler(m, i, Program(s1), Program(s2))), - ).then(set_(c.string).to(join(s1, "\n", s2, "\n"))) + eq(p1).to(program_assign), + eq(s1).to(p.statements), + eq(s2).to(p.expr), + eq(i).to(p.next_sym), + ).then( + set_(p1.statements).to(join(s1, symbol, " = ", s2, "\n")), + set_(p1.expr).to(symbol), + set_(p1.next_sym).to(i + 1), + ) diff --git a/python/tests/__snapshots__/test_program_gen/test_to_string.py b/python/tests/__snapshots__/test_program_gen/test_to_string.py new file mode 100644 index 00000000..78b44912 --- /dev/null +++ b/python/tests/__snapshots__/test_program_gen/test_to_string.py @@ -0,0 +1,5 @@ +_0 = -(x) +_0 = -(x) +assert _0 > 0 +_1 = _0 + x +_1 diff --git a/python/tests/test_program_gen.py b/python/tests/test_program_gen.py index f406a0ed..739374dc 100644 --- a/python/tests/test_program_gen.py +++ b/python/tests/test_program_gen.py @@ -26,6 +26,7 @@ def __mul__(self, other: Math) -> Math: def __neg__(self) -> Math: ... + @egraph.method(cost=1000) @property def program(self) -> Program: ... @@ -54,13 +55,19 @@ def _rules( assigned_x = x.program.assign() yield rewrite(assume_pos(x).program).to(assigned_x.statement(Program("assert ") + assigned_x + " > 0")) - first = assume_pos(-Math.var("x")) + Math(3) - y = egraph.let("y", Math(2) * first + Math(0) + first) - compiled = Compiler().compile(y.program) - egraph.register(compiled) - egraph.run(100) - egraph.display(max_calls_per_function=40, n_inline_leaves=2) - assert egraph.load_object(egraph.extract(PyObject.from_string(compiled.string))) == snapshot_py + first = assume_pos(-Math.var("x")) + Math.var("x") + with egraph: + y = first + egraph.register(y.program) + egraph.run(10) + p = egraph.extract(y.program) + egraph.register(p) + egraph.register(p.compile()) + egraph.run(40) + # egraph.display(n_inline_leaves=1) + e = egraph.load_object(egraph.extract(PyObject.from_string(p.expr))) + stmts = egraph.load_object(egraph.extract(PyObject.from_string(p.statements))) + assert (stmts + e + "\n") == snapshot_py # egraph.run(10) # egraph.check(eq(y.expr).to(String("_1"))) From 8395a0eb12bb45d741fe47a1530862f00fed5e6a Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Thu, 28 Sep 2023 09:53:47 -0400 Subject: [PATCH 07/23] Almost working with parents set --- python/egglog/builtins.py | 3 - python/egglog/exp/program_gen.py | 176 ++++++++++++++---- .../test_program_gen/test_to_string.py | 1 - 3 files changed, 139 insertions(+), 41 deletions(-) diff --git a/python/egglog/builtins.py b/python/egglog/builtins.py index d64aadbe..1db82c3b 100644 --- a/python/egglog/builtins.py +++ b/python/egglog/builtins.py @@ -318,9 +318,6 @@ def __sub__(self, other: Set[T]) -> Set[T]: # type: ignore[empty-body] def __and__(self, other: Set[T]) -> Set[T]: # type: ignore[empty-body] ... - # def peek(self) -> T: - # ... - @BUILTINS.class_(egg_sort="Rational") class Rational(Expr): diff --git a/python/egglog/exp/program_gen.py b/python/egglog/exp/program_gen.py index d6fb3f3d..16678e13 100644 --- a/python/egglog/exp/program_gen.py +++ b/python/egglog/exp/program_gen.py @@ -45,6 +45,12 @@ def assign(self) -> Program: """ ... + def expr_to_statement(self) -> Program: + """ + Returns a new program with the expression as a statement and the new expression empty. + """ + ... + @property def expr(self) -> String: """ @@ -62,7 +68,7 @@ def statements(self) -> String: @property def next_sym(self) -> i64: """ - Returns the next gensym to use. + Returns the next gensym to use. This is set after calling `compile(i)` on a program. """ ... @@ -72,6 +78,16 @@ def compile(self, next_sym: i64 = i64(0)) -> Unit: Triggers compilation of the program. """ + @program_gen_module.method(merge=lambda old, new: old) # type: ignore[misc] + @property + def parent(self) -> Program: + """ + Returns the parent of the program, if it's been compiled into the parent. + + Only keeps the original parent, not any additional ones, so that each set of statements is only added once. + """ + ... + converter(String, Program, Program) @@ -102,65 +118,151 @@ def _compile( set_(program_expr.statements).to(String("")), set_(program_expr.next_sym).to(i), ) - # Compiling a statement means that we should use the expression of the statement as a statement and use the expression - # of the underlying program - program_statement = p.statement(p1) - # First compile the expression - yield rule(program_statement.compile(i)).then(p.compile(i)) - # Then, when the expression is compiled, compile the statement, and set the expr of the whole statement + + ## + # Statement + ## + # Compiling a statement means that we should use the expression of the statement as a statement and use the expression of the first + yield rewrite(p1.statement(p2)).to(p1 + p2.expr_to_statement()) + + ## + # Expr to statement + ## + stmt = p1.expr_to_statement() + # 1. Set parent + yield rule(eq(p).to(stmt), p.compile(i)).then(set_(p1.parent).to(p)) + # 2. Compile p1 if parent set + yield rule(eq(p).to(stmt), p.compile(i), eq(p1.parent).to(stmt)).then(p1.compile(i)) + # 3.a. If parent not set, set statements to expr yield rule( - eq(p2).to(program_statement), - eq(i).to(p.next_sym), - eq(s).to(p.expr), - ).then(p1.compile(i), set_(p2.expr).to(s)) - # When both are compiled, add the statements of both + the expr of p1 to the statements of p + eq(p).to(stmt), + p.compile(i), + p1.parent != p, + eq(s1).to(p1.expr), + ).then( + set_(p.statements).to(join(s1, "\n")), + set_(p.next_sym).to(i), + set_(p.expr).to(String("")), + ) + # 3.b. If parent set, set statements to expr + statements yield rule( - eq(p2).to(program_statement), - eq(s1).to(p.statements), + eq(p).to(stmt), + eq(p1.parent).to(stmt), + eq(s1).to(p1.expr), eq(s2).to(p1.statements), - eq(s).to(p1.expr), eq(i).to(p1.next_sym), ).then( - set_(p2.statements).to(join(s1, s2, s, "\n")), - set_(p2.next_sym).to(i), + set_(p.statements).to(join(s2, s1, "\n")), + set_(p.next_sym).to(i), + set_(p.expr).to(String("")), ) + ## + # Addition + ## + # Compiling an addition is the same as compiling one, then the other, then setting the expression as the addition # of the two program_add = p1 + p2 - # Compile the first - yield rule(program_add.compile(i)).then(p1.compile(i)) - # Once the first is finished, do the second - yield rule(program_add, eq(i).to(p1.next_sym)).then(p2.compile(i)) - # Once the second is finished, set the the addition to the addition of the two expressions + + # Set parents + yield rule(eq(p).to(program_add), p.compile(i)).then(set_(p1.parent).to(p), set_(p2.parent).to(p)) + + # Compile p1, if p1 parent set + yield rule(eq(p).to(program_add), p.compile(i), eq(p1.parent).to(program_add)).then(p1.compile(i)) + + # Compile p2, if p1 parent not set + yield rule(eq(p).to(program_add), p.compile(i), p1.parent != p).then(p2.compile(i)) + + # Compile p2, if p1 parent set + yield rule(eq(p).to(program_add), eq(p1.parent).to(program_add), eq(i).to(p1.next_sym)).then(p2.compile(i)) + + # Set p expr to join of p1 and p2 yield rule( eq(p).to(program_add), eq(s1).to(p1.expr), eq(s2).to(p2.expr), - eq(s3).to(p1.statements), - eq(s4).to(p2.statements), - eq(i).to(p2.next_sym), ).then( set_(p.expr).to(join(s1, s2)), - set_(p.statements).to(join(s3, s4)), + ) + # Set p statements to join and next sym to p2 if both parents set + yield rule( + eq(p).to(program_add), + eq(p1.parent).to(p), + eq(p2.parent).to(p), + eq(s1).to(p1.statements), + eq(s2).to(p2.statements), + eq(i).to(p2.next_sym), + ).then( + set_(p.statements).to(join(s1, s2)), + set_(p.next_sym).to(i), + ) + # Set p statements to empty and next sym to i if neither parents set + yield rule( + eq(p).to(program_add), + p.compile(i), + p1.parent != p, + p2.parent != p, + ).then( + set_(p.statements).to(String("")), + set_(p.next_sym).to(i), + ) + # Set p statements to p1 and next sym to p1 if p1 parent set and p2 parent not set + yield rule( + eq(p).to(program_add), + eq(p1.parent).to(p), + p2.parent != p, + eq(s1).to(p1.statements), + eq(i).to(p1.next_sym), + ).then( + set_(p.statements).to(s1), set_(p.next_sym).to(i), ) + # Set p statements to p2 and next sym to p2 if p2 parent set and p1 parent not set + yield rule( + eq(p).to(program_add), + eq(p2.parent).to(p), + p1.parent != p, + eq(s2).to(p2.statements), + eq(i).to(p2.next_sym), + ).then( + set_(p.statements).to(s2), + set_(p.next_sym).to(i), + ) + + ## + # Assign + ## # Compiling an assign is the same as compiling the expression, adding an assign statement, then setting the # expression as the gensym - program_assign = p.assign() - # Compile the expression - yield rule(program_assign.compile(i)).then(p.compile(i)) - # Once the expression is compiled, add the assign statement to the statements and set the expr + program_assign = p1.assign() + # Set parent + yield rule(eq(p).to(program_assign), p.compile(i)).then(set_(p1.parent).to(p)) + # If parent set, compile the expression + yield rule(eq(p).to(program_assign), p.compile(i), eq(p1.parent).to(program_assign)).then(p1.compile(i)) + # If p1 parent is p, then use statements of p, next sym of p symbol = join(String("_"), i.to_string()) yield rule( - eq(p1).to(program_assign), - eq(s1).to(p.statements), - eq(s2).to(p.expr), - eq(i).to(p.next_sym), + eq(p).to(program_assign), + eq(p1.parent).to(p), + eq(s1).to(p1.statements), + eq(i).to(p1.next_sym), + eq(s2).to(p1.expr), + ).then( + set_(p.statements).to(join(s1, symbol, " = ", s2, "\n")), + set_(p.expr).to(symbol), + set_(p.next_sym).to(i + 1), + ) + # If p1 parent is not p, then just use assign as statement, next sym of i + yield rule( + eq(p).to(program_assign), + p1.parent != p, + p.compile(i), + eq(s2).to(p1.expr), ).then( - set_(p1.statements).to(join(s1, symbol, " = ", s2, "\n")), - set_(p1.expr).to(symbol), - set_(p1.next_sym).to(i + 1), + set_(p.statements).to(join(symbol, " = ", s2, "\n")), + set_(p.expr).to(symbol), + set_(p.next_sym).to(i + 1), ) diff --git a/python/tests/__snapshots__/test_program_gen/test_to_string.py b/python/tests/__snapshots__/test_program_gen/test_to_string.py index 78b44912..c2148732 100644 --- a/python/tests/__snapshots__/test_program_gen/test_to_string.py +++ b/python/tests/__snapshots__/test_program_gen/test_to_string.py @@ -1,5 +1,4 @@ _0 = -(x) -_0 = -(x) assert _0 > 0 _1 = _0 + x _1 From 2e2456927fd37a6aa5dffb718d77b269f5e1d9e4 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Thu, 28 Sep 2023 09:57:14 -0400 Subject: [PATCH 08/23] Real working version! --- python/egglog/exp/program_gen.py | 13 ++++++++----- .../test_program_gen/test_to_string.py | 6 ++++-- python/tests/test_program_gen.py | 8 ++++---- 3 files changed, 16 insertions(+), 11 deletions(-) diff --git a/python/egglog/exp/program_gen.py b/python/egglog/exp/program_gen.py index 16678e13..e711abb2 100644 --- a/python/egglog/exp/program_gen.py +++ b/python/egglog/exp/program_gen.py @@ -165,16 +165,19 @@ def _compile( # of the two program_add = p1 + p2 - # Set parents - yield rule(eq(p).to(program_add), p.compile(i)).then(set_(p1.parent).to(p), set_(p2.parent).to(p)) + # Set parent of p1 + yield rule(eq(p).to(program_add), p.compile(i)).then(set_(p1.parent).to(p)) - # Compile p1, if p1 parent set + # Compile p1, if p1 parent equal yield rule(eq(p).to(program_add), p.compile(i), eq(p1.parent).to(program_add)).then(p1.compile(i)) - # Compile p2, if p1 parent not set + # Set parent of p2, once p1 compiled + yield rule(eq(p).to(program_add), p1.next_sym).then(set_(p2.parent).to(p)) + + # Compile p2, if p1 parent not equal yield rule(eq(p).to(program_add), p.compile(i), p1.parent != p).then(p2.compile(i)) - # Compile p2, if p1 parent set + # Compile p2, if p1 parent eqal yield rule(eq(p).to(program_add), eq(p1.parent).to(program_add), eq(i).to(p1.next_sym)).then(p2.compile(i)) # Set p expr to join of p1 and p2 diff --git a/python/tests/__snapshots__/test_program_gen/test_to_string.py b/python/tests/__snapshots__/test_program_gen/test_to_string.py index c2148732..cc9697b6 100644 --- a/python/tests/__snapshots__/test_program_gen/test_to_string.py +++ b/python/tests/__snapshots__/test_program_gen/test_to_string.py @@ -1,4 +1,6 @@ -_0 = -(x) +_0 = -x assert _0 > 0 _1 = _0 + x -_1 +_2 = _1 + 2 +_3 = _2 + _1 +_3 diff --git a/python/tests/test_program_gen.py b/python/tests/test_program_gen.py index 739374dc..450d93fc 100644 --- a/python/tests/test_program_gen.py +++ b/python/tests/test_program_gen.py @@ -51,19 +51,19 @@ def _rules( yield rewrite(Math(i).program).to(Program(i.to_string())) yield rewrite((y + z).program).to((y.program + " + " + z.program).assign()) yield rewrite((y * z).program).to((y.program + " * " + z.program).assign()) - yield rewrite((-y).program).to(Program("-(") + y.program + ")") + yield rewrite((-y).program).to(Program("-") + y.program) assigned_x = x.program.assign() yield rewrite(assume_pos(x).program).to(assigned_x.statement(Program("assert ") + assigned_x + " > 0")) first = assume_pos(-Math.var("x")) + Math.var("x") with egraph: - y = first + y = first + Math(2) + first egraph.register(y.program) - egraph.run(10) + egraph.run(100) p = egraph.extract(y.program) egraph.register(p) egraph.register(p.compile()) - egraph.run(40) + egraph.run(100) # egraph.display(n_inline_leaves=1) e = egraph.load_object(egraph.extract(PyObject.from_string(p.expr))) stmts = egraph.load_object(egraph.extract(PyObject.from_string(p.statements))) From 698c047e27fe930367bf6bfbe57d1e1e0c6cd7e6 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Sun, 1 Oct 2023 12:37:57 -0400 Subject: [PATCH 09/23] Get function generation working --- python/egglog/exp/array_api.py | 271 ++++-------------- python/egglog/exp/program_gen.py | 37 ++- .../test_program_gen/test_to_string.py | 13 +- python/tests/test_program_gen.py | 27 +- 4 files changed, 114 insertions(+), 234 deletions(-) diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index 286dbc8f..de8f111e 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -15,6 +15,8 @@ from egglog.egraph import Action from egglog.runtime import RuntimeExpr +from .program_gen import * + # Pretend that exprs are numbers b/c sklearn does isinstance checks numbers.Integral.register(RuntimeExpr) @@ -1486,88 +1488,68 @@ def _size(x: NDArray): # Depends on `np` as a global variable. ## -array_api_module_string = Module([array_api_module]) - - -@array_api_module_string.function(merge=lambda old, new: new, default=i64(0)) -def gensym() -> i64: - ... - - -gensym_var = join("_", gensym().to_string()) - - -def add_line(*v: StringLike) -> Action: - return set_(statements()).to(join(" ", *v, "\n")) - - -incr_gensym = set_(gensym()).to(gensym() + 1) - - -@array_api_module_string.function(merge=lambda old, new: join(old, new), default=String("")) -def statements() -> String: - ... +array_api_module_string = Module([array_api_module, program_gen_module]) @array_api_module_string.function() -def ndarray_expr(x: NDArray) -> String: +def ndarray_program(x: NDArray) -> Program: ... @array_api_module_string.function() -def dtype_expr(x: DType) -> String: +def dtype_program(x: DType) -> Program: ... @array_api_module_string.function() -def tuple_int_expr(x: TupleInt) -> String: +def tuple_int_program(x: TupleInt) -> Program: ... @array_api_module_string.function() -def int_expr(x: Int) -> String: +def int_program(x: Int) -> Program: ... @array_api_module_string.function() -def tuple_value_expr(x: TupleValue) -> String: +def tuple_value_program(x: TupleValue) -> Program: ... @array_api_module_string.function() -def value_expr(x: Value) -> String: +def value_program(x: Value) -> Program: ... array_api_module_string.register( - set_(dtype_expr(DType.float64)).to(String("np.float64")), - set_(dtype_expr(DType.int64)).to(String("np.int64")), + union(dtype_program(DType.float64)).with_(Program("np.float64")), + union(dtype_program(DType.int64)).with_(Program("np.int64")), ) @array_api_module_string.function -def bool_expr(x: Bool) -> String: +def bool_program(x: Bool) -> Program: ... array_api_module_string.register( - set_(bool_expr(TRUE)).to(String("True")), - set_(bool_expr(FALSE)).to(String("False")), + union(bool_program(TRUE)).with_(Program("True")), + union(bool_program(FALSE)).with_(Program("False")), ) @array_api_module_string.function -def float_expr(x: Float) -> String: +def float_program(x: Float) -> Program: ... @array_api_module_string.function -def tuple_ndarray_expr(x: TupleNDArray) -> String: +def tuple_ndarray_program(x: TupleNDArray) -> Program: ... @array_api_module_string.function -def optional_dtype_expr(x: OptionalDType) -> String: +def optional_dtype_program(x: OptionalDType) -> Program: ... @@ -1611,232 +1593,97 @@ def _py_expr( optional_dtype_: OptionalDType, ): # Var - yield rule( - eq(x).to(NDArray.var(s)), - ).then( - set_(lhs=ndarray_expr(x)).to(s), - ) + yield rewrite(ndarray_program(NDArray.var(s))).to(Program(s)) # Asssume dtype z_assumed_dtype = copy(z) - assume_dtype(z_assumed_dtype, dtype=dtype) - yield rule( - eq(x).to(z_assumed_dtype), - eq(z_str).to(ndarray_expr(z)), - eq(dtype_str).to(dtype_expr(dtype)), - ).then( - set_(ndarray_expr(x)).to(z_str), - add_line("assert ", z_str, ".dtype == ", dtype_str), + assume_dtype(z_assumed_dtype, dtype) + z_program = ndarray_program(z) + yield rewrite(ndarray_program(z_assumed_dtype)).to( + z_program.statement(Program("assert ") + z_program + ".dtype == " + dtype_program(dtype)) ) - # assume shape z_assumed_shape = copy(z) assume_shape(z_assumed_shape, ti) - yield rule( - eq(x).to(z_assumed_shape), - eq(z_str).to(ndarray_expr(z)), - eq(ti_str).to(tuple_int_expr(ti)), - ).then( - set_(ndarray_expr(x)).to(z_str), - add_line("assert ", z_str, ".shape == ", ti_str), + yield rewrite(ndarray_program(z_assumed_shape)).to( + z_program.statement(Program("assert ") + z_program + ".shape == " + tuple_int_program(ti)) ) # tuple int - yield rule( - eq(ti).to(ti1 + ti2), - eq(ti_str1).to(tuple_int_expr(ti1)), - eq(ti_str2).to(tuple_int_expr(ti2)), - ).then( - set_(tuple_int_expr(ti)).to(join(ti_str1, " + ", ti_str2)), - ) - yield rule( - eq(ti).to(TupleInt(i)), - eq(i_str).to(int_expr(i)), - ).then( - set_(tuple_int_expr(ti)).to(join("(", i_str, ",)")), - ) + yield rewrite(tuple_int_program(ti1 + ti2)).to(tuple_int_program(ti1) + " + " + tuple_int_program(ti2)) + yield rewrite(tuple_int_program(TupleInt(i))).to(Program("(") + int_program(i) + ",)") # Int - yield rule( - eq(i).to(Int(i64_)), - ).then( - set_(int_expr(i)).to(i64_.to_string()), - ) + yield rewrite(int_program(Int(i64_))).to(Program(i64_.to_string())) # assume isfinite z_assumed_isfinite = copy(z) assume_isfinite(z_assumed_isfinite) - yield rule( - eq(x).to(z_assumed_isfinite), - eq(z_str).to(ndarray_expr(z)), - ).then( - set_(ndarray_expr(x)).to(z_str), - add_line("assert np.all(np.isfinite(", z_str, "))"), + yield rewrite(ndarray_program(z_assumed_isfinite)).to( + z_program.statement(Program("assert np.all(np.isfinite(") + z_program + "))") ) # Assume value_one_of z_assumed_value_one_of = copy(z) assume_value_one_of(z_assumed_value_one_of, tv) - yield rule( - eq(x).to(z_assumed_value_one_of), - # not_traversed(x), - eq(z_str).to(ndarray_expr(z)), - eq(tv_str).to(tuple_value_expr(tv)), - ).then( - set_(ndarray_expr(x)).to(z_str), - # traverse(x), - add_line("assert set(", z_str, ".flatten()) == set(", tv_str, ")"), + yield rewrite(ndarray_program(z_assumed_value_one_of)).to( + z_program.statement(Program("assert set(") + z_program + ".flatten()) == set(" + tuple_value_program(tv) + ")") ) - # print(r._to_egg_command(array_api_module_string._mod_decls)) - # yield r + # tuple values - yield rule( - eq(tv).to(tv1 + tv2), - eq(tv1_str).to(tuple_value_expr(tv1)), - eq(tv2_str).to(tuple_value_expr(tv2)), - ).then( - set_(tuple_value_expr(tv)).to(join(tv1_str, " + ", tv2_str)), - ) - yield rule( - eq(tv).to(TupleValue(v)), - eq(v_str).to(value_expr(v)), - ).then( - set_(tuple_value_expr(tv)).to(join("(", v_str, ",)")), - ) + yield rewrite(tuple_value_program(tv1 + tv2)).to(tuple_value_program(tv1) + " + " + tuple_value_program(tv2)) + yield rewrite(tuple_value_program(TupleValue(v))).to(Program("(") + value_program(v) + ",)") # Value - yield rule( - eq(v).to(Value.int(i)), - eq(i_str).to(int_expr(i)), - ).then( - set_(value_expr(v)).to(i_str), - ) - yield rule( - eq(v).to(Value.bool(b)), - eq(b_str).to(bool_expr(b)), - ).then( - set_(value_expr(v)).to(b_str), - ) - yield rule( - eq(v).to(Value.float(f)), - eq(f_str).to(float_expr(f)), - ).then( - set_(value_expr(v)).to(f_str), - ) + yield rewrite(value_program(Value.int(i))).to(int_program(i)) + yield rewrite(value_program(Value.bool(b))).to(bool_program(b)) + yield rewrite(value_program(Value.float(f))).to(float_program(f)) # Float - yield rule( - eq(f).to(Float(f64_)), - ).then( - set_(float_expr(f)).to(f64_.to_string()), - ) + yield rewrite(float_program(Float(f64_))).to(Program(f64_.to_string())) # reshape (don't include copy, since not present in numpy) - yield rule( - eq(x).to(reshape(y, ti, ob)), - eq(y_str).to(ndarray_expr(y)), - eq(ti_str).to(tuple_int_expr(ti)), - ).then( - set_(ndarray_expr(x)).to(gensym_var), - add_line(gensym_var, " = ", y_str, ".reshape(", ti_str, ")"), - incr_gensym, + yield rewrite(ndarray_program(reshape(y, ti, ob))).to( + (ndarray_program(y) + ".reshape(" + tuple_int_program(ti) + ")").assign() ) # astype - yield rule( - eq(x).to(astype(y, dtype)), - eq(y_str).to(ndarray_expr(y)), - eq(dtype_str).to(dtype_expr(dtype)), - ).then( - set_(ndarray_expr(x)).to(gensym_var), - add_line(gensym_var, " = ", y_str, ".astype(", dtype_str, ")"), - incr_gensym, + yield rewrite(ndarray_program(astype(y, dtype))).to( + (ndarray_program(y) + ".astype(" + dtype_program(dtype) + ")").assign() ) # unique_counts(x) => unique(x, return_counts=True) - yield rule( - eq(tnd).to(unique_counts(y)), - eq(y_str).to(ndarray_expr(y)), - ).then( - set_(tuple_ndarray_expr(tnd)).to(gensym_var), - add_line(gensym_var, " = np.unique(", y_str, ", return_counts=True)"), - incr_gensym, + yield rewrite(tuple_ndarray_program(unique_counts(x))).to( + (Program("np.unique(") + ndarray_program(x) + ", return_counts=True)").assign() ) + # Tuple ndarray indexing - yield rule( - eq(x).to(tnd[i]), - eq(tnd_str).to(tuple_ndarray_expr(tnd)), - eq(i_str).to(int_expr(i)), - ).then( - set_(ndarray_expr(x)).to(join(tnd_str, "[", i_str, "]")), - ) + yield rewrite(ndarray_program(tnd[i])).to(tuple_ndarray_program(tnd) + "[" + int_program(i) + "]") # ndarray scalar # TODO: Use dtype and shape and indexing instead? - yield rule( - eq(x).to(NDArray.scalar(v)), - eq(v_str).to(value_expr(v)), - ).then( - set_(ndarray_expr(x)).to(gensym_var), - add_line(gensym_var, " = np.array(", v_str, ")"), - incr_gensym, - ) + # TODO: SPecify dtype? + yield rewrite(ndarray_program(NDArray.scalar(v))).to(Program("np.array(") + value_program(v) + ")") # zeros - yield rule( - eq(x).to(zeros(ti, optional_dtype_, optional_device_)), - eq(ti_str).to(tuple_int_expr(ti)), - eq(dtype_str).to(optional_dtype_expr(optional_dtype_)), - ).then( - set_(ndarray_expr(x)).to(gensym_var), - add_line(gensym_var, " = np.zeros(", ti_str, ", dtype=", dtype_str, ")"), - incr_gensym, + yield rewrite(ndarray_program(zeros(ti, optional_dtype_, optional_device_))).to( + ( + Program("np.zeros(") + tuple_int_program(ti) + ", dtype=" + optional_dtype_program(optional_dtype_) + ")" + ).assign() ) # Optional dtype - yield rule( - eq(optional_dtype_).to(OptionalDType.none), - ).then( - set_(optional_dtype_expr(optional_dtype_)).to(String("None")), - ) - yield rule( - eq(optional_dtype_).to(OptionalDType.some(dtype)), - eq(dtype_str).to(dtype_expr(dtype)), - ).then( - set_(optional_dtype_expr(optional_dtype_)).to(dtype_str), - ) + yield rewrite(optional_dtype_program(OptionalDType.none)).to(Program("None")) + yield rewrite(optional_dtype_program(OptionalDType.some(dtype))).to(dtype_program(dtype)) # unique_values - yield rule( - eq(x).to(unique_values(y)), - eq(y_str).to(ndarray_expr(y)), - ).then( - set_(ndarray_expr(x)).to(gensym_var), - add_line(gensym_var, " = np.unique(", y_str, ")"), - incr_gensym, - ) + yield rewrite(ndarray_program(unique_values(x))).to((Program("np.unique(") + ndarray_program(x) + ")").assign()) # reshape # NDARRAy ops - - yield rule( - eq(x).to(y + z), - eq(y_str).to(ndarray_expr(y)), - eq(z_str).to(ndarray_expr(z)), - ).then( - set_(ndarray_expr(x)).to(gensym_var), - add_line(gensym_var, " = ", y_str, " + ", z_str), - incr_gensym, - ) - - yield rule( - eq(x).to(y / z), - eq(y_str).to(ndarray_expr(y)), - eq(z_str).to(ndarray_expr(z)), - ).then( - set_(ndarray_expr(x)).to(gensym_var), - add_line(gensym_var, " = ", y_str, " / ", z_str), - incr_gensym, - ) + yield rewrite(ndarray_program(x + y)).to((ndarray_program(x) + " + " + ndarray_program(y)).assign()) + yield rewrite(ndarray_program(x - y)).to((ndarray_program(x) + " - " + ndarray_program(y)).assign()) + yield rewrite(ndarray_program(x * y)).to((ndarray_program(x) + " * " + ndarray_program(y)).assign()) + yield rewrite(ndarray_program(x / y)).to((ndarray_program(x) + " / " + ndarray_program(y)).assign()) @array_api_module_string.class_ diff --git a/python/egglog/exp/program_gen.py b/python/egglog/exp/program_gen.py index e711abb2..d174a382 100644 --- a/python/egglog/exp/program_gen.py +++ b/python/egglog/exp/program_gen.py @@ -4,6 +4,7 @@ """ from __future__ import annotations +from turtle import st from typing import Union from egglog import * @@ -45,6 +46,12 @@ def assign(self) -> Program: """ ... + def function_two(self, name: StringLike, arg1: ProgramLike, arg2: ProgramLike) -> Program: + """ + Returns a new program defining a function with two arguments. + """ + ... + def expr_to_statement(self) -> Program: """ Returns a new program with the expression as a statement and the new expression empty. @@ -99,9 +106,11 @@ def _compile( s2: String, s3: String, s4: String, + s5: String, p: Program, p1: Program, p2: Program, + p3: Program, # c: Compiler, statements: Program, expr: Program, @@ -109,7 +118,7 @@ def _compile( m: Map[Program, Program], ): # Combining two strings is just joining them - yield rewrite(Program(s1) + Program(s2)).to(Program(join(s1, s2))) + # yield rewrite(Program(s1) + Program(s2)).to(Program(join(s1, s2))) # Compiling a string just gives that string program_expr = Program(s) @@ -269,3 +278,29 @@ def _compile( set_(p.expr).to(symbol), set_(p.next_sym).to(i + 1), ) + + ## + # Function two + + # When compiling a function, the two args, p2 and p3, should get compiled when we compile p1, and should just be vars. + fn_two = p1.function_two(s1, p2, p3) + # 1. Set parent of p1 + yield rule(eq(p).to(fn_two), fn_two.compile(i)).then(set_(p1.parent).to(p)) + # TODO: Compile vars? + # 2. Compile p1 if parent set + yield rule(eq(p).to(fn_two), p.compile(i), eq(p1.parent).to(fn_two)).then(p1.compile(i)) + # 3. Set statements to function body and the next sym to i + yield rule( + eq(p).to(fn_two), + p.compile(i), + eq(s2).to(p1.expr), + eq(s3).to(p1.statements), + eq(s4).to(p2.expr), + eq(s5).to(p3.expr), + ).then( + set_(p.statements).to( + join("def ", s1, "(", s4, ", ", s5, "):\n ", s3.replace("\n", "\n "), "return ", s2, "\n") + ), + set_(p.next_sym).to(i), + set_(p.expr).to(s1), + ) diff --git a/python/tests/__snapshots__/test_program_gen/test_to_string.py b/python/tests/__snapshots__/test_program_gen/test_to_string.py index cc9697b6..fe466ad7 100644 --- a/python/tests/__snapshots__/test_program_gen/test_to_string.py +++ b/python/tests/__snapshots__/test_program_gen/test_to_string.py @@ -1,6 +1,7 @@ -_0 = -x -assert _0 > 0 -_1 = _0 + x -_2 = _1 + 2 -_3 = _2 + _1 -_3 +def my_fn(x, y): + _0 = -x + assert _0 > 0 + _1 = _0 + y + _2 = _1 + 2 + _3 = _2 + _1 + return _3 diff --git a/python/tests/test_program_gen.py b/python/tests/test_program_gen.py index 450d93fc..94567be8 100644 --- a/python/tests/test_program_gen.py +++ b/python/tests/test_program_gen.py @@ -55,20 +55,17 @@ def _rules( assigned_x = x.program.assign() yield rewrite(assume_pos(x).program).to(assigned_x.statement(Program("assert ") + assigned_x + " > 0")) - first = assume_pos(-Math.var("x")) + Math.var("x") + first = assume_pos(-Math.var("x")) + Math.var("y") + fn = (first + Math(2) + first).program.function_two("my_fn", Math.var("x").program, Math.var("y").program) with egraph: - y = first + Math(2) + first - egraph.register(y.program) - egraph.run(100) - p = egraph.extract(y.program) - egraph.register(p) - egraph.register(p.compile()) - egraph.run(100) + egraph.register(fn) + egraph.run(200) + fn = egraph.extract(fn) + egraph.register(fn) + egraph.register(fn.compile()) + egraph.run(200) # egraph.display(n_inline_leaves=1) - e = egraph.load_object(egraph.extract(PyObject.from_string(p.expr))) - stmts = egraph.load_object(egraph.extract(PyObject.from_string(p.statements))) - assert (stmts + e + "\n") == snapshot_py - - # egraph.run(10) - # egraph.check(eq(y.expr).to(String("_1"))) - # egraph.check(eq(y.statements).to(String("_0 = x + -3\n_1 = 2 * _0\n"))) + expr = egraph.load_object(egraph.extract(PyObject.from_string(fn.expr))) + assert expr == "my_fn" # type: ignore + stmts = egraph.load_object(egraph.extract(PyObject.from_string(fn.statements))) + assert stmts == snapshot_py # type: ignore From cc09fe110545ab3e08fc057b0b5b23e059f55882 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Sun, 1 Oct 2023 13:30:02 -0400 Subject: [PATCH 10/23] Add compilation to python value --- docs/reference/python-integration.md | 9 +- python/egglog/builtins.py | 18 +- python/egglog/exp/program_gen.py | 28 +++ python/tests/test_high_level.py | 272 +-------------------------- python/tests/test_program_gen.py | 100 ++++++---- src/py_object_sort.rs | 89 +++++++++ 6 files changed, 209 insertions(+), 307 deletions(-) diff --git a/docs/reference/python-integration.md b/docs/reference/python-integration.md index 6a4fda09..da61d70d 100644 --- a/docs/reference/python-integration.md +++ b/docs/reference/python-integration.md @@ -72,8 +72,13 @@ egraph.load_object(egraph.extract(PyObject.from_int(1))) We also support evaling arbitrary Python bode, given some locals and globals. This technically allows us to implement any Python method: ```{code-cell} python -empty_dict = egraph.save_object({}) -egraph.load_object(egraph.extract(py_eval("1 + 2", empty_dict, empty_dict))) +egraph.load_object(egraph.extract(py_eval("1 + 2"))) +``` + +Execing Python code is also supported. In this case, the return value will be the updated globals dict, which will be copied first before using. + +```{code-cell} python +egraph.load_object(egraph.extract(py_exec("x = 1 + 2"))) ``` Alongside this, we support a function `dict_update` method, which can allow you to combine some local local egglog expressions alongside, say, the locals and globals of the Python code you are evaling. diff --git a/python/egglog/builtins.py b/python/egglog/builtins.py index 1db82c3b..51b20db5 100644 --- a/python/egglog/builtins.py +++ b/python/egglog/builtins.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="empty-body" """ Builtin sorts and function to egg. """ @@ -24,6 +25,7 @@ "join", "PyObject", "py_eval", + "py_exec", ] @@ -454,8 +456,20 @@ def dict_update(dict, *keys_and_values: PyObject) -> PyObject: # type: ignore[e def from_int(cls, i: i64Like) -> PyObject: # type: ignore[empty-body] ... + @BUILTINS.method(egg_fn="py-dict") + @classmethod + def dict(cls, *keys_and_values: PyObject) -> PyObject: + ... + -# TODO: Maybe move to static method if we implement those? @BUILTINS.function(egg_fn="py-eval") -def py_eval(code: StringLike, locals: PyObject, globals: PyObject) -> PyObject: # type: ignore[empty-body] +def py_eval(code: StringLike, globals: PyObject = PyObject.dict(), locals: PyObject = PyObject.dict()) -> PyObject: # type: ignore[empty-body] + ... + + +@BUILTINS.function(egg_fn="py-exec") +def py_exec(code: StringLike, globals: PyObject = PyObject.dict(), locals: PyObject = PyObject.dict()) -> PyObject: + """ + Copies the locals, execs the Python code, and returns the locals with any updates. + """ ... diff --git a/python/egglog/exp/program_gen.py b/python/egglog/exp/program_gen.py index d174a382..d3a1ff3b 100644 --- a/python/egglog/exp/program_gen.py +++ b/python/egglog/exp/program_gen.py @@ -95,10 +95,38 @@ def parent(self) -> Program: """ ... + @program_gen_module.method(default=Unit()) + def eval_py_object(self, globals: PyObject) -> Unit: + """ + Evaluates the program and saves as the py_object + """ + + @property + def py_object(self) -> PyObject: + """ + Returns the python object of the program, if it's been evaluated. + """ + ... + converter(String, Program, Program) +@program_gen_module.register +def _py_object(p: Program, expr: String, statements: String, g: PyObject): + # When we evaluate a program, we first want to compile to a string + yield rule(p.eval_py_object(g)).then(p.compile()) + # Then we want to evaluate the statements/expr + yield rule(p.eval_py_object(g), eq(p.statements).to(statements), eq(p.expr).to(expr)).then( + set_(p.py_object).to( + py_eval( + "l['___res']", + PyObject.dict(PyObject.from_string("l"), py_exec(join(statements, "\n", "___res = ", expr), g)), + ) + ) + ) + + @program_gen_module.register def _compile( s: String, diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 41b11ba3..fc1fb304 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -312,6 +312,16 @@ def test_eval_local(self): res_simpl = egraph.simplify(res, 1) assert egraph.load_object(res_simpl) == "hithere" + def test_exec(self): + egraph = EGraph() + res = egraph.simplify(py_exec("x = 10"), 1) + assert egraph.load_object(res) == {"x": 10} + + def test_exec_globals(self): + egraph = EGraph() + res = egraph.simplify(py_exec("x = y + 1", egraph.save_object({"y": 10})), 1) + assert egraph.load_object(res) == {"x": 11} + def my_add(a, b): return a + b @@ -430,265 +440,3 @@ def __radd__(self, other: Math) -> Math: JustTypeRef("Math"), CallDecl(MethodRef("Math", "__add__"), (expr_parts(Math(i64(10))), expr_parts(Math(i64(5))))), ) - - -@pytest.mark.xfail(reason="https://github.com/egraphs-good/egglog/issues/229") -def test_imperative(): - egraph = EGraph() - - @egraph.function(merge=lambda old, new: join(old, new), default=String("")) - def statements() -> String: - ... - - @egraph.function(merge=lambda old, new: old + new, default=i64(0)) - def gensym() -> i64: - ... - - gensym_var = join("_", gensym().to_string()) - - @egraph.class_ - class Math(Expr): - @egraph.method(egg_fn="Num") - def __init__(self, value: i64Like) -> None: - ... - - @egraph.method(egg_fn="Var") - @classmethod - def var(cls, v: StringLike) -> Math: - ... - - @egraph.method(egg_fn="Add") - def __add__(self, other: Math) -> Math: - ... - - @egraph.method(egg_fn="Mul") - def __mul__(self, other: Math) -> Math: - ... - - @egraph.method(egg_fn="expr") # type: ignore[misc] - @property - def expr(self) -> String: - ... - - @egraph.register - def _rules(s: String, y_expr: String, z_expr: String, x: Math, i: i64, y: Math, z: Math): - yield rule( - eq(x).to(Math.var(s)), - ).then( - set_(x.expr).to(s), - ) - - yield rule( - eq(x).to(Math(i)), - ).then( - set_(x.expr).to(i.to_string()), - ) - - yield rule( - eq(x).to(y + z), - eq(y_expr).to(y.expr), - eq(z_expr).to(z.expr), - ).then( - set_(x.expr).to(gensym_var), - set_(statements()).to(join(gensym_var, " = ", y_expr, " + ", z_expr, "\n")), - set_(gensym()).to(i64(1)), - ) - yield rule( - eq(x).to(y * z), - eq(y_expr).to(y.expr), - eq(z_expr).to(z.expr), - ).then( - set_(x.expr).to(gensym_var), - set_(statements()).to(join(gensym_var, " = ", y_expr, " * ", z_expr, "\n")), - set_(gensym()).to(i64(1)), - ) - - y = egraph.let("y", Math(2) * (Math.var("x") + Math(3))) - - egraph.run(10) - egraph.check(eq(y.expr).to(String("_1"))) - egraph.check(eq(statements()).to(String("_0 = x + 3\n_1 = 2 * _0\n"))) - - -@pytest.mark.xfail(reason="applies rules too many times b/c keeps matching") -def test_imperative_stable(): - # More stable version of imperative, which uses idempotent merge function - egraph = EGraph() - - @egraph.function(merge=lambda old, new: new) - def statements() -> String: - ... - - egraph.register(set_(statements()).to(String(""))) - - @egraph.function(merge=lambda old, new: old + new, default=i64(0)) - def gensym() -> i64: - ... - - @egraph.class_ - class Math(Expr): - @egraph.method(egg_fn="Num") - def __init__(self, value: i64Like) -> None: - ... - - @egraph.method(egg_fn="Var") - @classmethod - def var(cls, v: StringLike) -> Math: - ... - - @egraph.method(egg_fn="Add") - def __add__(self, other: Math) -> Math: - ... - - @egraph.method(egg_fn="Mul") - def __mul__(self, other: Math) -> Math: - ... - - @egraph.method(egg_fn="expr") # type: ignore[misc] - @property - def expr(self) -> String: - ... - - @egraph.register - def _rules( - s: String, - y_expr: String, - z_expr: String, - old_statements: String, - x: Math, - i: i64, - y: Math, - z: Math, - old_gensym: i64, - ): - gensym_var = join("_", gensym().to_string()) - yield rule( - eq(x).to(Math.var(s)), - ).then( - set_(x.expr).to(s), - ) - - yield rule( - eq(x).to(Math(i)), - ).then( - set_(x.expr).to(i.to_string()), - ) - - yield rule( - eq(x).to(y + z), - eq(y_expr).to(y.expr), - eq(z_expr).to(z.expr), - eq(old_statements).to(statements()), - ).then( - set_(x.expr).to(gensym_var), - set_(statements()).to(join(old_statements, gensym_var, " = ", y_expr, " + ", z_expr, "\n")), - set_(gensym()).to(i64(1)), - ) - yield rule( - eq(x).to(y * z), - eq(y_expr).to(y.expr), - eq(z_expr).to(z.expr), - eq(old_statements).to(statements()), - ).then( - set_(x.expr).to(gensym_var), - set_(statements()).to(join(old_statements, gensym_var, " = ", y_expr, " * ", z_expr, "\n")), - set_(gensym()).to(i64(1)), - ) - - y = egraph.let("y", Math(2) * (Math.var("x") + Math(3))) - - egraph.run(10) - egraph.check(eq(y.expr).to(String("_1"))) - egraph.check(eq(statements()).to(String("_0 = x + 3\n_1 = 2 * _0\n"))) - - -def test_imperative_python(): - # Tries implementing the same functionality but with a PyObject - # More stable version of imperative, which uses idempotent merge function - egraph = EGraph() - - @egraph.function(merge=lambda old, new: new) - def statements() -> String: - ... - - egraph.register(set_(statements()).to(String(""))) - - @egraph.function(merge=lambda old, new: old + new, default=i64(0)) - def gensym() -> i64: - ... - - @egraph.class_ - class Math(Expr): - @egraph.method(egg_fn="Num") - def __init__(self, value: i64Like) -> None: - ... - - @egraph.method(egg_fn="Var") - @classmethod - def var(cls, v: StringLike) -> Math: - ... - - @egraph.method(egg_fn="Add") - def __add__(self, other: Math) -> Math: - ... - - @egraph.method(egg_fn="Mul") - def __mul__(self, other: Math) -> Math: - ... - - @egraph.method(egg_fn="expr") # type: ignore[misc] - @property - def expr(self) -> String: - ... - - @egraph.register - def _rules( - s: String, - y_expr: String, - z_expr: String, - old_statements: String, - x: Math, - i: i64, - y: Math, - z: Math, - old_gensym: i64, - ): - gensym_var = join("_", gensym().to_string()) - yield rule( - eq(x).to(Math.var(s)), - ).then( - set_(x.expr).to(s), - ) - - yield rule( - eq(x).to(Math(i)), - ).then( - set_(x.expr).to(i.to_string()), - ) - - yield rule( - eq(x).to(y + z), - eq(y_expr).to(y.expr), - eq(z_expr).to(z.expr), - eq(old_statements).to(statements()), - ).then( - set_(x.expr).to(gensym_var), - set_(statements()).to(join(old_statements, gensym_var, " = ", y_expr, " + ", z_expr, "\n")), - set_(gensym()).to(i64(1)), - ) - yield rule( - eq(x).to(y * z), - eq(y_expr).to(y.expr), - eq(z_expr).to(z.expr), - eq(old_statements).to(statements()), - ).then( - set_(x.expr).to(gensym_var), - set_(statements()).to(join(old_statements, gensym_var, " = ", y_expr, " * ", z_expr, "\n")), - set_(gensym()).to(i64(1)), - ) - - y = egraph.let("y", Math(2) * (Math.var("x") + Math(3))) - - egraph.run(10) - egraph.check(eq(y.expr).to(String("_1"))) - egraph.check(eq(statements()).to(String("_0 = x + 3\n_1 = 2 * _0\n"))) diff --git a/python/tests/test_program_gen.py b/python/tests/test_program_gen.py index 94567be8..b1dca680 100644 --- a/python/tests/test_program_gen.py +++ b/python/tests/test_program_gen.py @@ -1,60 +1,65 @@ # mypy: disable-error-code="empty-body" from __future__ import annotations +import math + from egglog import * from egglog.exp.program_gen import * +egraph = EGraph([program_gen_module]) -def test_to_string(snapshot_py) -> None: - egraph = EGraph([program_gen_module]) - - @egraph.class_ - class Math(Expr): - def __init__(self, value: i64Like) -> None: - ... - @classmethod - def var(cls, v: StringLike) -> Math: - ... +@egraph.class_ +class Math(Expr): + def __init__(self, value: i64Like) -> None: + ... - def __add__(self, other: Math) -> Math: - ... + @classmethod + def var(cls, v: StringLike) -> Math: + ... - def __mul__(self, other: Math) -> Math: - ... + def __add__(self, other: Math) -> Math: + ... - def __neg__(self) -> Math: - ... + def __mul__(self, other: Math) -> Math: + ... - @egraph.method(cost=1000) - @property - def program(self) -> Program: - ... + def __neg__(self) -> Math: + ... - @egraph.function - def assume_pos(x: Math) -> Math: + @egraph.method(cost=1000) + @property + def program(self) -> Program: ... - @egraph.register - def _rules( - s: String, - y_expr: String, - z_expr: String, - old_statements: String, - x: Math, - i: i64, - y: Math, - z: Math, - old_gensym: i64, - ): - yield rewrite(Math.var(s).program).to(Program(s)) - yield rewrite(Math(i).program).to(Program(i.to_string())) - yield rewrite((y + z).program).to((y.program + " + " + z.program).assign()) - yield rewrite((y * z).program).to((y.program + " * " + z.program).assign()) - yield rewrite((-y).program).to(Program("-") + y.program) - assigned_x = x.program.assign() - yield rewrite(assume_pos(x).program).to(assigned_x.statement(Program("assert ") + assigned_x + " > 0")) +@egraph.function +def assume_pos(x: Math) -> Math: + ... + + +@egraph.register +def _rules( + s: String, + y_expr: String, + z_expr: String, + old_statements: String, + x: Math, + i: i64, + y: Math, + z: Math, + old_gensym: i64, +): + yield rewrite(Math.var(s).program).to(Program(s)) + yield rewrite(Math(i).program).to(Program(i.to_string())) + yield rewrite((y + z).program).to((y.program + " + " + z.program).assign()) + yield rewrite((y * z).program).to((y.program + " * " + z.program).assign()) + yield rewrite((-y).program).to(Program("-") + y.program) + assigned_x = x.program.assign() + yield rewrite(assume_pos(x).program).to(assigned_x.statement(Program("assert ") + assigned_x + " > 0")) + + +def test_to_string(snapshot_py) -> None: first = assume_pos(-Math.var("x")) + Math.var("y") fn = (first + Math(2) + first).program.function_two("my_fn", Math.var("x").program, Math.var("y").program) with egraph: @@ -69,3 +74,16 @@ def _rules( assert expr == "my_fn" # type: ignore stmts = egraph.load_object(egraph.extract(PyObject.from_string(fn.statements))) assert stmts == snapshot_py # type: ignore + + +def test_py_object(): + x = Math.var("x") + y = Math.var("y") + z = Math.var("z") + fn = (x + y + z).program.function_two("my_fn", x.program, y.program) + egraph.register(fn.compile()) + egraph.run(100) + egraph.register(fn.eval_py_object(egraph.save_object({"z": 10}))) + egraph.run(100) + res = egraph.load_object(egraph.extract(fn.py_object)) + assert res(1, 2) == 13 # type: ignore diff --git a/src/py_object_sort.rs b/src/py_object_sort.rs index b71a2cc7..7ec9ccd9 100644 --- a/src/py_object_sort.rs +++ b/src/py_object_sort.rs @@ -90,6 +90,15 @@ impl Sort for PyObjectSort { py_object: self.clone(), string: typeinfo.get_sort(), }); + typeinfo.add_primitive(Exec { + name: "py-exec".into(), + py_object: self.clone(), + string: typeinfo.get_sort(), + }); + typeinfo.add_primitive(Dict { + name: "py-dict".into(), + py_object: self.clone(), + }); typeinfo.add_primitive(DictUpdate { name: "py-dict-update".into(), py_object: self.clone(), @@ -202,6 +211,86 @@ impl PrimitiveLike for Eval { } } +/// Copies the locals, execs the Python string, then returns the copied version of the locals with any updates +/// (py-exec ) +struct Exec { + name: Symbol, + py_object: Arc, + string: Arc, +} + +impl PrimitiveLike for Exec { + fn name(&self) -> Symbol { + self.name + } + + fn accept(&self, types: &[ArcSort]) -> Option { + match types { + [str, locals, globals] + if str.name() == self.string.name() + && locals.name() == self.py_object.name() + && globals.name() == self.py_object.name() => + { + Some(self.py_object.clone()) + } + _ => None, + } + } + + fn apply(&self, values: &[Value]) -> Option { + let code: Symbol = Symbol::load(self.string.as_ref(), &values[0]); + let locals: PyObject = Python::with_gil(|py| { + let (_, globals) = self.py_object.load(&values[1]); + let globals = globals.downcast::(py).unwrap(); + let (_, locals) = self.py_object.load(&values[2]); + let locals = locals.downcast::(py).unwrap().copy().unwrap(); + py.run(code.into(), Some(globals), Some(locals)).unwrap(); + locals.into() + }); + Some(self.py_object.store(locals)) + } +} + +/// (py-dict [ ]*) +struct Dict { + name: Symbol, + py_object: Arc, +} + +impl PrimitiveLike for Dict { + fn name(&self) -> Symbol { + self.name + } + + fn accept(&self, types: &[ArcSort]) -> Option { + // Should have an even number of args + if types.len() % 2 != 0 { + return None; + } + for tp in types.iter() { + // All tps should be object + if tp.name() != self.py_object.name() { + return None; + } + } + Some(self.py_object.clone()) + } + + fn apply(&self, values: &[Value]) -> Option { + let dict: PyObject = Python::with_gil(|py| { + let dict = PyDict::new(py); + // Update the dict with the key-value pairs + for i in values.chunks_exact(2) { + let key = self.py_object.load(&i[0]).1; + let value = self.py_object.load(&i[1]).1; + dict.set_item(key, value).unwrap(); + } + dict.into() + }); + Some(self.py_object.store(dict)) + } +} + /// Supports calling (py-dict-update [ ]*) struct DictUpdate { name: Symbol, From 5dcae1e5fb18347eadd6242626f0faba70f2f5f2 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Sun, 1 Oct 2023 13:47:40 -0400 Subject: [PATCH 11/23] Set default function name --- python/egglog/exp/program_gen.py | 4 ++-- python/tests/test_program_gen.py | 6 ++---- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/python/egglog/exp/program_gen.py b/python/egglog/exp/program_gen.py index d3a1ff3b..c55eee97 100644 --- a/python/egglog/exp/program_gen.py +++ b/python/egglog/exp/program_gen.py @@ -46,7 +46,7 @@ def assign(self) -> Program: """ ... - def function_two(self, name: StringLike, arg1: ProgramLike, arg2: ProgramLike) -> Program: + def function_two(self, arg1: ProgramLike, arg2: ProgramLike, name: StringLike = String("__fn")) -> Program: """ Returns a new program defining a function with two arguments. """ @@ -311,7 +311,7 @@ def _compile( # Function two # When compiling a function, the two args, p2 and p3, should get compiled when we compile p1, and should just be vars. - fn_two = p1.function_two(s1, p2, p3) + fn_two = p1.function_two(p2, p3, s1) # 1. Set parent of p1 yield rule(eq(p).to(fn_two), fn_two.compile(i)).then(set_(p1.parent).to(p)) # TODO: Compile vars? diff --git a/python/tests/test_program_gen.py b/python/tests/test_program_gen.py index b1dca680..d43eceb3 100644 --- a/python/tests/test_program_gen.py +++ b/python/tests/test_program_gen.py @@ -61,7 +61,7 @@ def _rules( def test_to_string(snapshot_py) -> None: first = assume_pos(-Math.var("x")) + Math.var("y") - fn = (first + Math(2) + first).program.function_two("my_fn", Math.var("x").program, Math.var("y").program) + fn = (first + Math(2) + first).program.function_two(Math.var("x").program, Math.var("y").program, "my_fn") with egraph: egraph.register(fn) egraph.run(200) @@ -80,9 +80,7 @@ def test_py_object(): x = Math.var("x") y = Math.var("y") z = Math.var("z") - fn = (x + y + z).program.function_two("my_fn", x.program, y.program) - egraph.register(fn.compile()) - egraph.run(100) + fn = (x + y + z).program.function_two(x.program, y.program) egraph.register(fn.eval_py_object(egraph.save_object({"z": 10}))) egraph.run(100) res = egraph.load_object(egraph.extract(fn.py_object)) From fa8309eda919e9a343ae3e011280f05a579fef87 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Sun, 1 Oct 2023 13:53:20 -0400 Subject: [PATCH 12/23] Start moving array api into new form --- python/egglog/exp/array_api.py | 30 ------------------- .../test_array_api/test_sklearn_lda.py | 6 ++-- .../test_array_api/test_to_source.py | 27 ----------------- python/tests/test_array_api.py | 25 ++++++---------- 4 files changed, 12 insertions(+), 76 deletions(-) delete mode 100644 python/tests/__snapshots__/test_array_api/test_to_source.py diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index de8f111e..02b5716a 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -12,7 +12,6 @@ import numpy as np from egglog import * from egglog.bindings import EggSmolError -from egglog.egraph import Action from egglog.runtime import RuntimeExpr from .program_gen import * @@ -1684,32 +1683,3 @@ def _py_expr( yield rewrite(ndarray_program(x - y)).to((ndarray_program(x) + " - " + ndarray_program(y)).assign()) yield rewrite(ndarray_program(x * y)).to((ndarray_program(x) + " * " + ndarray_program(y)).assign()) yield rewrite(ndarray_program(x / y)).to((ndarray_program(x) + " / " + ndarray_program(y)).assign()) - - -@array_api_module_string.class_ -class FunctionExprTwo(Expr): - """ - Python expression that takes two NDArrays as arguments and returns an NDArray. - """ - - def __init__(self, name: StringLike, res: NDArray, arg_1: NDArray, arg_2: NDArray) -> None: - ... - - @property - def source(self) -> String: - ... - - -fn_ruleset = array_api_module_string.ruleset("fn") - - -@array_api_module_string.register -def _function_expr(name: String, res: NDArray, arg1: String, arg2: String, f: FunctionExprTwo, s: String): - yield rule( - eq(f).to(FunctionExprTwo(name, res, NDArray.var(arg1), NDArray.var(arg2))), - ruleset=fn_ruleset, - ).then( - set_(f.source).to( - join("def ", name, "(", arg1, ", ", arg2, "):\n", statements(), " return ", ndarray_expr(res), "\n") - ), - ) diff --git a/python/tests/__snapshots__/test_array_api/test_sklearn_lda.py b/python/tests/__snapshots__/test_array_api/test_sklearn_lda.py index 8263e212..26c6da6c 100644 --- a/python/tests/__snapshots__/test_array_api/test_sklearn_lda.py +++ b/python/tests/__snapshots__/test_array_api/test_sklearn_lda.py @@ -28,7 +28,7 @@ _NDArray_7 = std(_NDArray_6, OptionalIntOrTuple.int(Int(0))) _NDArray_7[ndarray_index(std(_NDArray_6, OptionalIntOrTuple.int(Int(0))) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0))) _TupleNDArray_1 = svd(sqrt(NDArray.scalar(Value.int(NDArray.scalar(Value.float(Float(1.0))).to_int() / Int(147)))) * (_NDArray_6 / _NDArray_7), FALSE) -_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_int())) +_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(astype(sum(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001)))), DType.int32).to_int())) _NDArray_8 = (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)] / _NDArray_7).T / _TupleNDArray_1[ Int(1) ][IndexKey.slice(_Slice_1)] @@ -49,8 +49,8 @@ Slice( OptionalInt.none, OptionalInt.some( - sum( - astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32) + astype( + sum(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))])), DType.int32 ).to_int() ), ) diff --git a/python/tests/__snapshots__/test_array_api/test_to_source.py b/python/tests/__snapshots__/test_array_api/test_to_source.py deleted file mode 100644 index a867c69f..00000000 --- a/python/tests/__snapshots__/test_array_api/test_to_source.py +++ /dev/null @@ -1,27 +0,0 @@ -def my_fn(X, y): - assert y.dtype == np.int64 - assert X.dtype == np.float64 - assert y.dtype == np.int64 - assert X.dtype == np.float64 - _0 = np.array(150.0) - assert y.shape == (150,) - assert y.shape == (150,) - assert X.shape == (150,) + (4,) - assert X.shape == (150,) + (4,) - assert y.shape == (150,) - assert y.shape == (150,) - assert set(y.flatten()) == set((0,) + (1,) + (2,)) - assert set(y.flatten()) == set((0,) + (1,) + (2,)) - _1 = y.reshape((-1,)) - _1 = y.reshape((-1,)) - _2 = np.unique(_1, return_counts=True) - _2 = np.unique(_1, return_counts=True) - _3 = np.unique(_1) - _4 = _2[1].astype(np.float64) - _5 = _4 / _0 - _6 = np.zeros((3,) + (4,), dtype=np.float64) - _6 = np.zeros((3,) + (4,), dtype=np.float64) - _7 = _5 + X - _8 = _7 + _6 - _8 = _7 + _6 - return _8 diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index 9cd25e55..74249274 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -28,6 +28,8 @@ def test_tuple_value_includes(): def test_to_source(snapshot_py): + import numpy + _NDArray_1 = NDArray.var("X") X_orig = copy(_NDArray_1) assume_dtype(_NDArray_1, DType.float64) @@ -48,25 +50,17 @@ def test_to_source(snapshot_py): OptionalDevice.some(_NDArray_1.device), ) res = _NDArray_4 + _NDArray_1 + _NDArray_5 - + fn = ndarray_program(res).function_two(ndarray_program(X_orig), ndarray_program(Y_orig)) egraph = EGraph([array_api_module_string]) - fn = egraph.let("fn", FunctionExprTwo("my_fn", res, X_orig, Y_orig)) - - egraph.run(20) - # while egraph.run((run())).updated: - # print(egraph.load_object(egraph.extract(PyObject.from_string(statements())))) - # egraph.graphviz().render(view=True) - # egraph.graphviz(n_inline_leaves=3).render("inlined", view=True) + egraph.register(fn.eval_py_object(egraph.save_object({"np": numpy}))) - egraph.run(run(fn_ruleset)) - fn_source = egraph.load_object(egraph.extract(PyObject.from_string(fn.source))) + egraph.run(100) + egraph.display(n_inline_leaves=1) + fn_source = egraph.load_object(egraph.extract(PyObject.from_string(fn.statements))) assert fn_source == snapshot_py - locals_: dict[str, object] = {} - exec(fn_source, {"np": np}, locals_) # type: ignore - fn: object = locals_["my_fn"] -@pytest.mark.xfail(reason="unstable output") +# @pytest.mark.xfail(raises=TODO) def test_sklearn_lda(snapshot_py): from sklearn import config_context from sklearn.discriminant_analysis import LinearDiscriminantAnalysis @@ -88,8 +82,7 @@ def test_sklearn_lda(snapshot_py): with EGraph([array_api_module]) as egraph: egraph.register(X_r2) - egraph.run((run() * 10)) - # egraph.run((run() * 10).saturate()) + egraph.run((run() * 10).saturate()) # egraph.graphviz(n_inline_leaves=3).render("3", view=True) res = egraph.extract(X_r2) From afec7b6f20d61a55e6c101744377ff7e25b321a5 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Sun, 1 Oct 2023 13:54:35 -0400 Subject: [PATCH 13/23] Lint fixes --- pyproject.toml | 2 +- python/egglog/exp/program_gen.py | 1 - python/tests/test_array_api.py | 1 - python/tests/test_program_gen.py | 4 +--- 4 files changed, 2 insertions(+), 6 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index a4510256..28fbd61d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ check_untyped_defs = true strict_equality = true warn_unused_configs = true allow_redefinition = true -# enable_incomplete_feature = ["Unpack", "TypeVarTuple"] +enable_incomplete_feature = ["Unpack"] exclude = ["__snapshots__", "_build", "^conftest.py$"] [tool.maturin] diff --git a/python/egglog/exp/program_gen.py b/python/egglog/exp/program_gen.py index c55eee97..b2410b5c 100644 --- a/python/egglog/exp/program_gen.py +++ b/python/egglog/exp/program_gen.py @@ -4,7 +4,6 @@ """ from __future__ import annotations -from turtle import st from typing import Union from egglog import * diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index 74249274..90404054 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -1,4 +1,3 @@ -import pytest from egglog.exp.array_api import * diff --git a/python/tests/test_program_gen.py b/python/tests/test_program_gen.py index d43eceb3..e7cc001b 100644 --- a/python/tests/test_program_gen.py +++ b/python/tests/test_program_gen.py @@ -1,8 +1,6 @@ # mypy: disable-error-code="empty-body" from __future__ import annotations -import math - from egglog import * from egglog.exp.program_gen import * @@ -27,7 +25,7 @@ def __mul__(self, other: Math) -> Math: def __neg__(self) -> Math: ... - @egraph.method(cost=1000) + @egraph.method(cost=1000) # type: ignore @property def program(self) -> Program: ... From c3d1bb226f570c194785598ab6b00fb0866e2b59 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 4 Oct 2023 07:05:57 -0400 Subject: [PATCH 14/23] Make getting started tutorial clearer --- docs/tutorials/getting-started.ipynb | 1771 ++++++++++++++++++-------- 1 file changed, 1232 insertions(+), 539 deletions(-) diff --git a/docs/tutorials/getting-started.ipynb b/docs/tutorials/getting-started.ipynb index 9931fe86..95396b89 100644 --- a/docs/tutorials/getting-started.ipynb +++ b/docs/tutorials/getting-started.ipynb @@ -1,562 +1,1255 @@ { - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "id": "ffabb623", - "metadata": { - "tags": [] - }, - "source": [ - "# Getting Started - Matrix Multiplication\n", - "\n", - "In this tutorial, you will learn how to:\n", - "\n", - "1. Install `egglog` Python\n", - "2. Create a representation for matrices and some simplification rules for them. This will be based off of the [matrix multiplication example](https://github.com/egraphs-good/egglog/blob/08a6e8f/tests/matrix.egg) in the egglog repository. By using our high level wrapper, we can rely on Python's built in static type checker to check the correctness of your representation.\n", - "3. Try out using our library in an interactive notebook.\n", - "\n", - "## Install egglog Python\n", - "\n", - "First, you will need to have a working Python interpreter. In this tutorial, we will [use `miniconda`](https://docs.conda.io/en/latest/miniconda.html) to create a new Python environment and activate it:\n", - "\n", - "```bash\n", - "$ brew install miniconda\n", - "$ conda create -n egglog-python python=3.11\n", - "$ conda activate egglog-python\n", - "```\n", - "\n", - "Then we want to install `egglog` Python. `egglog` Python can run on any recent Python version, and is tested on 3.8 - 3.11. To install it, run:\n", - "\n", - "```bash\n", - "$ pip install egglog\n", - "```\n", - "\n", - "To test you have installed it correctly, run:\n", - "\n", - "```bash\n", - "$ python -m 'import egglog'\n", - "```\n", - "\n", - "We also want to install `mypy` for static type checking. This is not required, but it will help us write correct representations. To install it, run:\n", - "\n", - "```bash\n", - "$ pip install mypy\n", - "```\n", - "\n", - "## Creating an E-Graph\n", - "\n", - "In this tutorial, we will use [VS Code](https://code.visualstudio.com/) to create file, `matrix.py`, to include our egraph\n", - "and the simplification rules:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "7369b71b", - "metadata": {}, - "outputs": [], - "source": [ - "from __future__ import annotations\n", - "\n", - "from egglog import *\n", - "\n", - "egraph = EGraph()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "814a51c5", - "metadata": {}, - "source": [ - "## Defining Dimensions\n", - "\n", - "We will start by defining a representation for integers, which we will use to represent\n", - "the dimensions of the matrix:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "04fa991a", - "metadata": {}, - "outputs": [], - "source": [ - "@egraph.class_\n", - "class Dim(Expr):\n", - " \"\"\"\n", - " A dimension of a matix.\n", - "\n", - " >>> Dim(3) * Dim.named(\"n\")\n", - " Dim(3) * Dim.named(\"n\")\n", - " \"\"\"\n", - "\n", - " def __init__(self, value: i64Like) -> None:\n", - " ...\n", - "\n", - " @classmethod\n", - " def named(cls, name: StringLike) -> Dim:\n", - " ...\n", - "\n", - " def __mul__(self, other: Dim) -> Dim:\n", - " ..." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "f5098a2b", - "metadata": { - "tags": [] - }, - "source": [ - "As you can see, you must wrap any class with the `egraph.class_` to register\n", - "it with the egraph and be able to use it like a Python class.\n", - "\n", - "### Testing in a notebook\n", - "\n", - "We can try out this by [creating a new notebook](https://code.visualstudio.com/docs/datascience/jupyter-notebooks#_create-or-open-a-jupyter-notebook) which imports this file:\n", - "\n", - "```python\n", - "from matrix import *\n", - "```\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "fd43c7ef", - "metadata": {}, - "source": [ - "We can then create a new `Dim` object:\n" - ] - }, + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "ffabb623", + "metadata": { + "tags": [] + }, + "source": [ + "# Getting Started - Matrix Multiplication\n", + "\n", + "In this tutorial, you will learn how to:\n", + "\n", + "1. Install `egglog` Python\n", + "2. Create a representation for matrices and some simplification rules for them. This will be based off of the [matrix multiplication example](https://github.com/egraphs-good/egglog/blob/08a6e8f/tests/matrix.egg) in the egglog repository. By using our high level wrapper, we can rely on Python's built in static type checker to check the correctness of your representation.\n", + "3. Try out using our library in an interactive notebook.\n", + "\n", + "## Install egglog Python\n", + "\n", + "First, you will need to have a working Python interpreter. In this tutorial, we will [use `miniconda`](https://docs.conda.io/en/latest/miniconda.html) to create a new Python environment and activate it:\n", + "\n", + "```bash\n", + "$ brew install miniconda\n", + "$ conda create -n egglog-python python=3.11\n", + "$ conda activate egglog-python\n", + "```\n", + "\n", + "Then we want to install `egglog` Python. `egglog` Python can run on any recent Python version, and is tested on 3.8 - 3.11. To install it, run:\n", + "\n", + "```bash\n", + "$ pip install egglog\n", + "```\n", + "\n", + "To test you have installed it correctly, run:\n", + "\n", + "```bash\n", + "$ python -m 'import egglog'\n", + "```\n", + "\n", + "We also want to install `mypy` for static type checking. This is not required, but it will help us write correct representations. To install it, run:\n", + "\n", + "```bash\n", + "$ pip install mypy\n", + "```\n", + "\n", + "## Creating an E-Graph\n", + "\n", + "In this tutorial, we will use [VS Code](https://code.visualstudio.com/) to create file, `matrix.py`, to include our egraph\n", + "and the simplification rules:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7369b71b", + "metadata": {}, + "outputs": [], + "source": [ + "from __future__ import annotations\n", + "\n", + "from egglog import *\n", + "\n", + "egraph = EGraph()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "814a51c5", + "metadata": {}, + "source": [ + "## Defining Dimensions\n", + "\n", + "We will start by defining a representation for integers, which we will use to represent\n", + "the dimensions of the matrix:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "04fa991a", + "metadata": {}, + "outputs": [], + "source": [ + "@egraph.class_\n", + "class Dim(Expr):\n", + " \"\"\"\n", + " A dimension of a matix.\n", + "\n", + " >>> Dim(3) * Dim.named(\"n\")\n", + " Dim(3) * Dim.named(\"n\")\n", + " \"\"\"\n", + "\n", + " def __init__(self, value: i64Like) -> None:\n", + " ...\n", + "\n", + " @classmethod\n", + " def named(cls, name: StringLike) -> Dim:\n", + " ...\n", + "\n", + " def __mul__(self, other: Dim) -> Dim:\n", + " ..." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "f5098a2b", + "metadata": { + "tags": [] + }, + "source": [ + "As you can see, you must wrap any class with the `egraph.class_` to register\n", + "it with the egraph and be able to use it like a Python class.\n", + "\n", + "### Testing in a notebook\n", + "\n", + "We can try out this by [creating a new notebook](https://code.visualstudio.com/docs/datascience/jupyter-notebooks#_create-or-open-a-jupyter-notebook) which imports this file:\n", + "\n", + "```python\n", + "from matrix import *\n", + "```\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "fd43c7ef", + "metadata": {}, + "source": [ + "We can then create a new `Dim` object:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b6424530", + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 8, - "id": "b6424530", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(Dim.named(\"x\") * Dim(10)) * Dim(10)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } + "data": { + "text/html": [ + "
(Dim.named("x") * Dim(10)) * Dim(10)\n",
+       "
\n" ], - "source": [ - "x = Dim.named(\"x\")\n", - "ten = Dim(10)\n", - "res = x * ten * ten\n", - "res" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "ef5ebb16", - "metadata": {}, - "source": [ - "We see that the output is not evaluated, it's just a representation of the computation as well as the type. This is because we haven't defined any simplification rules yet.\n", - "\n", - "We can also try to create a dimension from an invalid type, or use it in an invalid way, we get a type error before we even run the code:\n", - "\n", - "```python\n", - "x - ten\n", - "```\n", - "\n", - "![Screenshot of VS Code showing a type error](./screenshot-1.png)\n", - "\n", - "## Dimension Replacements\n", - "\n", - "Now we will register some replacements for our dimensions and see how we can interface with egg to get it\n", - "to execute them.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "b06b1749", - "metadata": {}, - "outputs": [], - "source": [ - "a, b, c = vars_(\"a b c\", Dim)\n", - "i, j = vars_(\"i j\", i64)\n", - "egraph.register(\n", - " rewrite(a * (b * c)).to((a * b) * c),\n", - " rewrite((a * b) * c).to(a * (b * c)),\n", - " rewrite(Dim(i) * Dim(j)).to(Dim(i * j)),\n", - " rewrite(a * b).to(b * a),\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "167722d1-60b8-452a-ae54-6a8df4db5b00", - "metadata": {}, - "source": [ - "You might notice that unlike a traditional term rewriting system, we don't specify any order for these rewrites. They will be executed until the graph is fully saturated, meaning that no new terms are created.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "a4d2c911", - "metadata": {}, - "source": [ - "We can also see how the type checking can help us. If we try to create a rewrite from a `Dim` to an `i64` we see that we get a type error:\n", - "\n", - "![Screenshot of VS Code showing a type error](./screenshot-2.png)\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "76dc1672-dba6-44ab-b9f1-aa01de685fb1", - "metadata": {}, - "source": [ - "### Testing\n", - "\n", - "Going back to the notebook, we can test out the that the rewrites are working.\n", - "We can run some number of iterations and extract out the lowest cost expression which is equivalent to our variable:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "31afa12e-da68-4398-91fa-14523f6c099a", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Dim.named(\"x\") * Dim(100)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{p}{(}\\PY{n}{Dim}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{x}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)} \\PY{o}{*} \\PY{n}{Dim}\\PY{p}{(}\\PY{l+m+mi}{10}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{Dim}\\PY{p}{(}\\PY{l+m+mi}{10}\\PY{p}{)}\n", + "\\end{Verbatim}\n" ], - "source": [ - "egraph.simplify(res, 10)" + "text/plain": [ + "(Dim.named(\"x\") * Dim(10)) * Dim(10)" ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "7e44104c-d87b-441d-a717-92d42aab9d37", - "metadata": {}, - "source": [ - "## Matrix Expressions\n", - "\n", - "Now that we have defined dimensions, we can define matrices as well as some functions on them:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "c5b96cfb", - "metadata": {}, - "outputs": [], - "source": [ - "@egraph.class_\n", - "class Matrix(Expr):\n", - " @classmethod\n", - " def identity(cls, dim: Dim) -> Matrix:\n", - " \"\"\"\n", - " Create an identity matrix of the given dimension.\n", - " \"\"\"\n", - " ...\n", - "\n", - " @classmethod\n", - " def named(cls, name: StringLike) -> Matrix:\n", - " \"\"\"\n", - " Create a named matrix.\n", - " \"\"\"\n", - " ...\n", - "\n", - " def __matmul__(self, other: Matrix) -> Matrix:\n", - " \"\"\"\n", - " Matrix multiplication.\n", - " \"\"\"\n", - " ...\n", - "\n", - " def nrows(self) -> Dim:\n", - " \"\"\"\n", - " Number of rows in the matrix.\n", - " \"\"\"\n", - " ...\n", - "\n", - " def ncols(self) -> Dim:\n", - " \"\"\"\n", - " Number of columns in the matrix.\n", - " \"\"\"\n", - " ...\n", - "\n", - "\n", - "@egraph.function\n", - "def kron(a: Matrix, b: Matrix) -> Matrix:\n", - " \"\"\"\n", - " Kronecker product of two matrices.\n", - "\n", - " https://en.wikipedia.org/wiki/Kronecker_product#Definition\n", - " \"\"\"\n", - " ..." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "be8e6526", - "metadata": {}, - "source": [ - "### Rows/cols Replacements\n", - "\n", - "We can also define some replacements to understand the number of rows and columns of a matrix:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "cb2b4fb8", - "metadata": {}, - "outputs": [], - "source": [ - "A, B, C, D = vars_(\"A B C D\", Matrix)\n", - "egraph.register(\n", - " # The dimensions of a kronecker product are the product of the dimensions\n", - " rewrite(kron(A, B).nrows()).to(A.nrows() * B.nrows()),\n", - " rewrite(kron(A, B).ncols()).to(A.ncols() * B.ncols()),\n", - " # The dimensions of a matrix multiplication are the number of rows of the first\n", - " # matrix and the number of columns of the second matrix.\n", - " rewrite((A @ B).nrows()).to(A.nrows()),\n", - " rewrite((A @ B).ncols()).to(B.ncols()),\n", - " # The dimensions of an identity matrix are the input dimension\n", - " rewrite(Matrix.identity(a).nrows()).to(a),\n", - " rewrite(Matrix.identity(a).ncols()).to(a),\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "13b969e8", - "metadata": {}, - "source": [ - "We can try these out in our notebook (after restarting and re-importing) to compute the dimensions after some operations:\n" - ] - }, + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "x = Dim.named(\"x\")\n", + "ten = Dim(10)\n", + "res = x * ten * ten\n", + "res" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "ef5ebb16", + "metadata": {}, + "source": [ + "We see that the output is not evaluated, it's just a representation of the computation as well as the type. This is because we haven't defined any simplification rules yet.\n", + "\n", + "We can also try to create a dimension from an invalid type, or use it in an invalid way, we get a type error before we even run the code:\n", + "\n", + "```python\n", + "x - ten\n", + "```\n", + "\n", + "![Screenshot of VS Code showing a type error](./screenshot-1.png)\n", + "\n", + "## Dimension Replacements\n", + "\n", + "Now we will register some replacements for our dimensions and see how we can interface with egg to get it\n", + "to execute them.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b06b1749", + "metadata": {}, + "outputs": [], + "source": [ + "a, b, c = vars_(\"a b c\", Dim)\n", + "i, j = vars_(\"i j\", i64)\n", + "egraph.register(\n", + " rewrite(a * (b * c)).to((a * b) * c),\n", + " rewrite((a * b) * c).to(a * (b * c)),\n", + " rewrite(Dim(i) * Dim(j)).to(Dim(i * j)),\n", + " rewrite(a * b).to(b * a),\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "167722d1-60b8-452a-ae54-6a8df4db5b00", + "metadata": {}, + "source": [ + "You might notice that unlike a traditional term rewriting system, we don't specify any order for these rewrites. They will be executed until the graph is fully saturated, meaning that no new terms are created.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "a4d2c911", + "metadata": {}, + "source": [ + "We can also see how the type checking can help us. If we try to create a rewrite from a `Dim` to an `i64` we see that we get a type error:\n", + "\n", + "![Screenshot of VS Code showing a type error](./screenshot-2.png)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "76dc1672-dba6-44ab-b9f1-aa01de685fb1", + "metadata": {}, + "source": [ + "### Testing\n", + "\n", + "Going back to the notebook, we can test out the that the rewrites are working.\n", + "We can run some number of iterations and extract out the lowest cost expression which is equivalent to our variable:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "31afa12e-da68-4398-91fa-14523f6c099a", + "metadata": { + "tags": [] + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 13, - "id": "8d18be2d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Dim.named(\"y\")\n", - "Dim.named(\"x\")\n" - ] - } + "data": { + "text/html": [ + "
Dim.named("x") * Dim(100)\n",
+       "
\n" ], - "source": [ - "# If we multiply two identity matrices, we should be able to get the number of columns of the result\n", - "x = Matrix.identity(Dim.named(\"x\"))\n", - "y = Matrix.identity(Dim.named(\"y\"))\n", - "x_mult_y = x @ y\n", - "print(egraph.simplify(x_mult_y.ncols(), 10))\n", - "print(egraph.simplify(x_mult_y.nrows(), 10))" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "2f2c68c3", - "metadata": {}, - "source": [ - "### Operation replacements\n", - "\n", - "We can also define some replacements for matrix operations:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "18a91684", - "metadata": {}, - "outputs": [], - "source": [ - "egraph.register(\n", - " # Multiplication by an identity matrix is the same as the other matrix\n", - " rewrite(A @ Matrix.identity(a)).to(A),\n", - " rewrite(Matrix.identity(a) @ A).to(A),\n", - " # Matrix multiplication is associative\n", - " rewrite((A @ B) @ C).to(A @ (B @ C)),\n", - " rewrite(A @ (B @ C)).to((A @ B) @ C),\n", - " # Kronecker product is associative\n", - " rewrite(kron(A, kron(B, C))).to(kron(kron(A, B), C)),\n", - " rewrite(kron(kron(A, B), C)).to(kron(A, kron(B, C))),\n", - " # Kronecker product distributes over matrix multiplication\n", - " rewrite(kron(A @ C, B @ D)).to(kron(A, B) @ kron(C, D)),\n", - " rewrite(kron(A, B) @ kron(C, D)).to(\n", - " kron(A @ C, B @ D),\n", - " # Only when the dimensions match\n", - " eq(A.ncols()).to(C.nrows()),\n", - " eq(B.ncols()).to(D.nrows()),\n", - " ),\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "1cd649dc", - "metadata": {}, - "source": [ - "In our previous tests, we had to add the `ncols` and `nrows` operations to the e-graph seperately in order to have them be simplified. We can write some \"demand\" rules which automatically add these operations to the e-graph when they are needed:\n" + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{n}{Dim}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{x}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)} \\PY{o}{*} \\PY{n}{Dim}\\PY{p}{(}\\PY{l+m+mi}{100}\\PY{p}{)}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "Dim.named(\"x\") * Dim(100)" ] - }, + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "egraph.simplify(res, 10)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "7e44104c-d87b-441d-a717-92d42aab9d37", + "metadata": {}, + "source": [ + "## Matrix Expressions\n", + "\n", + "Now that we have defined dimensions, we can define matrices as well as some functions on them:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c5b96cfb", + "metadata": {}, + "outputs": [], + "source": [ + "@egraph.class_\n", + "class Matrix(Expr):\n", + " @classmethod\n", + " def identity(cls, dim: Dim) -> Matrix:\n", + " \"\"\"\n", + " Create an identity matrix of the given dimension.\n", + " \"\"\"\n", + " ...\n", + "\n", + " @classmethod\n", + " def named(cls, name: StringLike) -> Matrix:\n", + " \"\"\"\n", + " Create a named matrix.\n", + " \"\"\"\n", + " ...\n", + "\n", + " def __matmul__(self, other: Matrix) -> Matrix:\n", + " \"\"\"\n", + " Matrix multiplication.\n", + " \"\"\"\n", + " ...\n", + "\n", + " def nrows(self) -> Dim:\n", + " \"\"\"\n", + " Number of rows in the matrix.\n", + " \"\"\"\n", + " ...\n", + "\n", + " def ncols(self) -> Dim:\n", + " \"\"\"\n", + " Number of columns in the matrix.\n", + " \"\"\"\n", + " ...\n", + "\n", + "\n", + "@egraph.function\n", + "def kron(a: Matrix, b: Matrix) -> Matrix:\n", + " \"\"\"\n", + " Kronecker product of two matrices.\n", + "\n", + " https://en.wikipedia.org/wiki/Kronecker_product#Definition\n", + " \"\"\"\n", + " ..." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "be8e6526", + "metadata": {}, + "source": [ + "### Rows/cols Replacements\n", + "\n", + "We can also define some replacements to understand the number of rows and columns of a matrix:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "cb2b4fb8", + "metadata": {}, + "outputs": [], + "source": [ + "A, B, C, D = vars_(\"A B C D\", Matrix)\n", + "egraph.register(\n", + " # The dimensions of a kronecker product are the product of the dimensions\n", + " rewrite(kron(A, B).nrows()).to(A.nrows() * B.nrows()),\n", + " rewrite(kron(A, B).ncols()).to(A.ncols() * B.ncols()),\n", + " # The dimensions of a matrix multiplication are the number of rows of the first\n", + " # matrix and the number of columns of the second matrix.\n", + " rewrite((A @ B).nrows()).to(A.nrows()),\n", + " rewrite((A @ B).ncols()).to(B.ncols()),\n", + " # The dimensions of an identity matrix are the input dimension\n", + " rewrite(Matrix.identity(a).nrows()).to(a),\n", + " rewrite(Matrix.identity(a).ncols()).to(a),\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "13b969e8", + "metadata": {}, + "source": [ + "We can try these out in our notebook (after restarting and re-importing) to compute the dimensions after some operations:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8d18be2d", + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 15, - "id": "303ce7f3", - "metadata": {}, - "outputs": [], - "source": [ - "egraph.register(\n", - " # demand rows and columns when we multiply matrices\n", - " rule(eq(C).to(A @ B)).then(\n", - " let(\"1\", A.ncols()),\n", - " let(\"2\", A.nrows()),\n", - " let(\"3\", B.nrows()),\n", - " let(\"4\", B.ncols()),\n", - " ),\n", - " # demand rows and columns when we take the kronecker product\n", - " rule(eq(C).to(kron(A, B))).then(\n", - " let(\"1\", A.ncols()),\n", - " let(\"2\", A.nrows()),\n", - " let(\"3\", B.nrows()),\n", - " let(\"4\", B.ncols()),\n", - " ),\n", - ")" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Dim.named(\"y\")\n", + "Dim.named(\"x\")\n" + ] + } + ], + "source": [ + "# If we multiply two identity matrices, we should be able to get the number of columns of the result\n", + "x = Matrix.identity(Dim.named(\"x\"))\n", + "y = Matrix.identity(Dim.named(\"y\"))\n", + "x_mult_y = x @ y\n", + "print(egraph.simplify(x_mult_y.ncols(), 10))\n", + "print(egraph.simplify(x_mult_y.nrows(), 10))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "2f2c68c3", + "metadata": {}, + "source": [ + "### Operation replacements\n", + "\n", + "We can also define some replacements for matrix operations:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "18a91684", + "metadata": {}, + "outputs": [], + "source": [ + "egraph.register(\n", + " # Multiplication by an identity matrix is the same as the other matrix\n", + " rewrite(A @ Matrix.identity(a)).to(A),\n", + " rewrite(Matrix.identity(a) @ A).to(A),\n", + " # Matrix multiplication is associative\n", + " rewrite((A @ B) @ C).to(A @ (B @ C)),\n", + " rewrite(A @ (B @ C)).to((A @ B) @ C),\n", + " # Kronecker product is associative\n", + " rewrite(kron(A, kron(B, C))).to(kron(kron(A, B), C)),\n", + " rewrite(kron(kron(A, B), C)).to(kron(A, kron(B, C))),\n", + " # Kronecker product distributes over matrix multiplication\n", + " rewrite(kron(A @ C, B @ D)).to(kron(A, B) @ kron(C, D)),\n", + " rewrite(kron(A, B) @ kron(C, D)).to(\n", + " kron(A @ C, B @ D),\n", + " # Only when the dimensions match\n", + " eq(A.ncols()).to(C.nrows()),\n", + " eq(B.ncols()).to(D.nrows()),\n", + " ),\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "1cd649dc", + "metadata": {}, + "source": [ + "In our previous tests, we had to add the `ncols` and `nrows` operations to the e-graph seperately in order to have them be simplified. We can write some \"demand\" rules which automatically add these operations to the e-graph when they are needed:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "303ce7f3", + "metadata": {}, + "outputs": [], + "source": [ + "egraph.register(\n", + " # demand rows and columns when we multiply matrices\n", + " rule(A @ B).then(\n", + " A.ncols(),\n", + " A.nrows(),\n", + " B.nrows(),\n", + " B.ncols(),\n", + " ),\n", + " # demand rows and columns when we take the kronecker product\n", + " rule(kron(A, B)).then(\n", + " A.ncols(),\n", + " A.nrows(),\n", + " B.nrows(),\n", + " B.ncols(),\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "334a2cc4-0004-415a-a8fb-4c5ef2e26aec", + "metadata": {}, + "source": [ + "For example, if we have `X @ Y` in the egraph, it will add expression for the columns of each as well:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "c79a9105-e8fe-4545-b7c6-262648f82aad", + "metadata": {}, + "outputs": [ { - "attachments": {}, - "cell_type": "markdown", - "id": "bd9e94de", - "metadata": {}, - "source": [ - "We can try this out in our notebook, by multiplying some matrices and checking their dimensions:\n" + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "outer_cluster_1\n", + "\n", + "\n", + "cluster_1\n", + "\n", + "\n", + "\n", + "outer_cluster_0\n", + "\n", + "\n", + "cluster_0\n", + "\n", + "\n", + "\n", + "outer_cluster_2\n", + "\n", + "\n", + "cluster_2\n", + "\n", + "\n", + "\n", + "outer_cluster_String-1316606400713378063\n", + "\n", + "\n", + "cluster_String-1316606400713378063\n", + "\n", + "\n", + "\n", + "outer_cluster_String-4801791173778264996\n", + "\n", + "\n", + "cluster_String-4801791173778264996\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_named-1316606400713378063:s->String-1316606400713378063\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_named-4801791173778264996:s->String-4801791173778264996\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix___matmul__-5871781006564002453:s->Matrix_named-1316606400713378063\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix___matmul__-5871781006564002453:s->Matrix_named-4801791173778264996\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_named-1316606400713378063\n", + "\n", + "\n", + "Matrix_named\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "String-1316606400713378063\n", + "\n", + "\n", + ""X"\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_named-4801791173778264996\n", + "\n", + "\n", + "Matrix_named\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "String-4801791173778264996\n", + "\n", + "\n", + ""Y"\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix___matmul__-5871781006564002453\n", + "\n", + "\n", + "Matrix___matmul__\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": 16, - "id": "bb50ade6", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "kron(Matrix.named(\"A\"), Matrix.named(\"B\"))" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "outer_cluster_String-4801791173778264996\n", + "\n", + "\n", + "cluster_String-4801791173778264996\n", + "\n", + "\n", + "\n", + "outer_cluster_String-1316606400713378063\n", + "\n", + "\n", + "cluster_String-1316606400713378063\n", + "\n", + "\n", + "\n", + "outer_cluster_2\n", + "\n", + "\n", + "cluster_2\n", + "\n", + "\n", + "\n", + "outer_cluster_0\n", + "\n", + "\n", + "cluster_0\n", + "\n", + "\n", + "\n", + "outer_cluster_1\n", + "\n", + "\n", + "cluster_1\n", + "\n", + "\n", + "\n", + "outer_cluster_4\n", + "\n", + "\n", + "cluster_4\n", + "\n", + "\n", + "\n", + "outer_cluster_3\n", + "\n", + "\n", + "cluster_3\n", + "\n", + "\n", + "\n", + "outer_cluster_5\n", + "\n", + "\n", + "cluster_5\n", + "\n", + "\n", + "\n", + "outer_cluster_6\n", + "\n", + "\n", + "cluster_6\n", + "\n", + "\n", + "\n", + "\n", + "Matrix___matmul__-5871781006564002453:s->Matrix_named-1316606400713378063\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix___matmul__-5871781006564002453:s->Matrix_named-4801791173778264996\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_named-1316606400713378063:s->String-1316606400713378063\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_named-4801791173778264996:s->String-4801791173778264996\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_nrows-0:s->Matrix_named-1316606400713378063\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_ncols-0:s->Matrix_named-1316606400713378063\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_nrows-5871781006564002453:s->Matrix_named-4801791173778264996\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_ncols-5871781006564002453:s->Matrix_named-4801791173778264996\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix___matmul__-5871781006564002453\n", + "\n", + "\n", + "Matrix___matmul__\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_named-1316606400713378063\n", + "\n", + "\n", + "Matrix_named\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_named-4801791173778264996\n", + "\n", + "\n", + "Matrix_named\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "String-1316606400713378063\n", + "\n", + "\n", + ""X"\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "String-4801791173778264996\n", + "\n", + "\n", + ""Y"\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_nrows-0\n", + "\n", + "\n", + "Matrix_nrows\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_ncols-0\n", + "\n", + "\n", + "Matrix_ncols\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_nrows-5871781006564002453\n", + "\n", + "\n", + "Matrix_nrows\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_ncols-5871781006564002453\n", + "\n", + "\n", + "Matrix_ncols\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "" ], - "source": [ - "# Define a number of dimensions\n", - "n, m, p = Dim.named(\"n\"), Dim.named(\"m\"), Dim.named(\"p\")\n", - "# Define a number of matrices\n", - "A, B, C = Matrix.named(\"A\"), Matrix.named(\"B\"), Matrix.named(\"C\")\n", - "# Set each to be a square matrix of the given dimension\n", - "egraph.register(\n", - " union(A.nrows()).with_(n),\n", - " union(A.ncols()).with_(n),\n", - " union(B.nrows()).with_(m),\n", - " union(B.ncols()).with_(m),\n", - " union(C.nrows()).with_(p),\n", - " union(C.ncols()).with_(p),\n", - ")\n", - "# Create an example which should equal the kronecker product of A and B\n", - "ex1 = kron(Matrix.identity(n), B) @ kron(A, Matrix.identity(m))\n", - "egraph.simplify(ex1, 20)" + "text/plain": [ + "" ] - }, + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "with egraph:\n", + " egraph.register(Matrix.named(\"X\") @ Matrix.named(\"Y\"))\n", + " egraph.display()\n", + " egraph.run(1)\n", + " egraph.display()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "bd9e94de", + "metadata": {}, + "source": [ + "We can try this out in our notebook, by multiplying some matrices and checking their dimensions:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "bb50ade6", + "metadata": {}, + "outputs": [ { - "attachments": {}, - "cell_type": "markdown", - "id": "554321e2", - "metadata": {}, - "source": [ - "We can make sure that if the rows/columns do not line up, then the transformation will not be applied:\n" + "data": { + "text/html": [ + "
kron(Matrix.named("A"), Matrix.named("B"))\n",
+       "
\n" + ], + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{n}{kron}\\PY{p}{(}\\PY{n}{Matrix}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{A}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Matrix}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{B}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{)}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "kron(Matrix.named(\"A\"), Matrix.named(\"B\"))" ] - }, + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Define a number of dimensions\n", + "n, m, p = Dim.named(\"n\"), Dim.named(\"m\"), Dim.named(\"p\")\n", + "# Define a number of matrices\n", + "A, B, C = Matrix.named(\"A\"), Matrix.named(\"B\"), Matrix.named(\"C\")\n", + "# Set each to be a square matrix of the given dimension\n", + "egraph.register(\n", + " union(A.nrows()).with_(n),\n", + " union(A.ncols()).with_(n),\n", + " union(B.nrows()).with_(m),\n", + " union(B.ncols()).with_(m),\n", + " union(C.nrows()).with_(p),\n", + " union(C.ncols()).with_(p),\n", + ")\n", + "# Create an example which should equal the kronecker product of A and B\n", + "ex1 = kron(Matrix.identity(n), B) @ kron(A, Matrix.identity(m))\n", + "egraph.simplify(ex1, 20)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "554321e2", + "metadata": {}, + "source": [ + "We can make sure that if the rows/columns do not line up, then the transformation will not be applied:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "d8dea199", + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 17, - "id": "d8dea199", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "kron(Matrix.identity(Dim.named(\"p\")), Matrix.named(\"C\")) @ kron(Matrix.named(\"A\"), Matrix.identity(Dim.named(\"m\")))" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } + "data": { + "text/html": [ + "
kron(Matrix.identity(Dim.named("p")), Matrix.named("C")) @ kron(Matrix.named("A"), Matrix.identity(Dim.named("m")))\n",
+       "
\n" ], - "source": [ - "ex2 = kron(Matrix.identity(p), C) @ kron(A, Matrix.identity(m))\n", - "egraph.simplify(ex2, 20)" + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{n}{kron}\\PY{p}{(}\\PY{n}{Matrix}\\PY{o}{.}\\PY{n}{identity}\\PY{p}{(}\\PY{n}{Dim}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{p}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Matrix}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{C}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{)} \\PY{o}{@} \\PY{n}{kron}\\PY{p}{(}\\PY{n}{Matrix}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{A}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Matrix}\\PY{o}{.}\\PY{n}{identity}\\PY{p}{(}\\PY{n}{Dim}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{m}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "kron(Matrix.identity(Dim.named(\"p\")), Matrix.named(\"C\")) @ kron(Matrix.named(\"A\"), Matrix.identity(Dim.named(\"m\")))" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b0b13665", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "file_format": "mystnb", - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" + }, + "metadata": {}, + "output_type": "display_data" } + ], + "source": [ + "ex2 = kron(Matrix.identity(p), C) @ kron(A, Matrix.identity(m))\n", + "egraph.simplify(ex2, 20)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0b13665", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "file_format": "mystnb", + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" }, - "nbformat": 4, - "nbformat_minor": 5 + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 } From 6d001b8c4e09cf6079d4d2695e4fd7d5e806987d Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 4 Oct 2023 07:06:25 -0400 Subject: [PATCH 15/23] fix displaying graphs in notebook --- python/egglog/egraph.py | 5 +++-- python/egglog/ipython_magic.py | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 803a5ab3..ff352a05 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -33,6 +33,7 @@ from . import bindings from .declarations import * +from .ipython_magic import IN_IPYTHON from .monkeypatch import monkeypatch_forward_ref from .runtime import * from .runtime import _resolve_callable, class_to_ref @@ -719,12 +720,12 @@ def display(self, **kwargs: Unpack[GraphvizKwargs]): Displays the e-graph in the notebook. """ graphviz = self.graphviz(**kwargs) - if hasattr(__builtins__, "__IPYTHON__"): + if IN_IPYTHON: from IPython.display import SVG, display display(SVG(graphviz.pipe(format="svg", quiet=True, encoding="utf-8"))) else: - graphviz.view() + graphviz.render(view=True, format="svg", quiet=True) @overload def simplify(self, expr: EXPR, limit: int, /, *until: Fact, ruleset: Optional[Ruleset] = None) -> EXPR: diff --git a/python/egglog/ipython_magic.py b/python/egglog/ipython_magic.py index a080e5e8..cb7472a6 100644 --- a/python/egglog/ipython_magic.py +++ b/python/egglog/ipython_magic.py @@ -4,11 +4,11 @@ try: get_ipython() # type: ignore[name-defined] - in_ipython = True + IN_IPYTHON = True except NameError: - in_ipython = False + IN_IPYTHON = False -if in_ipython: +if IN_IPYTHON: import graphviz from IPython.core.magic import needs_local_scope, register_cell_magic From 0720ada02082f74dee6eeb409971e9d29a4e5d1b Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 4 Oct 2023 08:44:52 -0400 Subject: [PATCH 16/23] Try not to extract parents --- python/egglog/exp/program_gen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/egglog/exp/program_gen.py b/python/egglog/exp/program_gen.py index b2410b5c..e9f46038 100644 --- a/python/egglog/exp/program_gen.py +++ b/python/egglog/exp/program_gen.py @@ -84,7 +84,7 @@ def compile(self, next_sym: i64 = i64(0)) -> Unit: Triggers compilation of the program. """ - @program_gen_module.method(merge=lambda old, new: old) # type: ignore[misc] + @program_gen_module.method(merge=lambda old, new: old, cost=1000) # type: ignore[misc] @property def parent(self) -> Program: """ From 0d12b8fc5533473ea5f9366a8ea3b3349c7d4abd Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 4 Oct 2023 08:45:04 -0400 Subject: [PATCH 17/23] Fix compilation of program args --- python/egglog/exp/program_gen.py | 28 +++++++++++++++++++--------- 1 file changed, 19 insertions(+), 9 deletions(-) diff --git a/python/egglog/exp/program_gen.py b/python/egglog/exp/program_gen.py index e9f46038..07dca96a 100644 --- a/python/egglog/exp/program_gen.py +++ b/python/egglog/exp/program_gen.py @@ -100,6 +100,8 @@ def eval_py_object(self, globals: PyObject) -> Unit: Evaluates the program and saves as the py_object """ + # Only allow it to be set once, b/c hash of functions not stable + @program_gen_module.method(merge=lambda old, new: old) # type: ignore[misc] @property def py_object(self) -> PyObject: """ @@ -210,11 +212,13 @@ def _compile( # Set parent of p2, once p1 compiled yield rule(eq(p).to(program_add), p1.next_sym).then(set_(p2.parent).to(p)) - # Compile p2, if p1 parent not equal - yield rule(eq(p).to(program_add), p.compile(i), p1.parent != p).then(p2.compile(i)) + # Compile p2, if p1 parent not equal, but p2 parent equal + yield rule(eq(p).to(program_add), p.compile(i), p1.parent != p, eq(p2.parent).to(p)).then(p2.compile(i)) # Compile p2, if p1 parent eqal - yield rule(eq(p).to(program_add), eq(p1.parent).to(program_add), eq(i).to(p1.next_sym)).then(p2.compile(i)) + yield rule(eq(p).to(program_add), eq(p1.parent).to(program_add), eq(i).to(p1.next_sym), eq(p2.parent).to(p)).then( + p2.compile(i) + ) # Set p expr to join of p1 and p2 yield rule( @@ -311,12 +315,18 @@ def _compile( # When compiling a function, the two args, p2 and p3, should get compiled when we compile p1, and should just be vars. fn_two = p1.function_two(p2, p3, s1) - # 1. Set parent of p1 - yield rule(eq(p).to(fn_two), fn_two.compile(i)).then(set_(p1.parent).to(p)) - # TODO: Compile vars? - # 2. Compile p1 if parent set - yield rule(eq(p).to(fn_two), p.compile(i), eq(p1.parent).to(fn_two)).then(p1.compile(i)) - # 3. Set statements to function body and the next sym to i + # 1. Set parents of both args to p and compile them + # Assumes that this if the first thing to compile, so no need to check, and assumes that compiling args doesn't result in any + # change in the next sym + yield rule(eq(p).to(fn_two), p.compile(i)).then( + set_(p2.parent).to(p), + set_(p3.parent).to(p), + set_(p1.parent).to(p), + p2.compile(i), + p3.compile(i), + p1.compile(i), + ) + # 2. Set statements to function body and the next sym to i yield rule( eq(p).to(fn_two), p.compile(i), From 3e8bda5638161575efe4403454593fbc9b10b65b Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 4 Oct 2023 08:45:45 -0400 Subject: [PATCH 18/23] Fix array api test --- .../test_array_api/test_to_source.py | 12 ++++++++++++ python/tests/test_array_api.py | 19 ++++++++++++------- 2 files changed, 24 insertions(+), 7 deletions(-) create mode 100644 python/tests/__snapshots__/test_array_api/test_to_source.py diff --git a/python/tests/__snapshots__/test_array_api/test_to_source.py b/python/tests/__snapshots__/test_array_api/test_to_source.py new file mode 100644 index 00000000..68d6db71 --- /dev/null +++ b/python/tests/__snapshots__/test_array_api/test_to_source.py @@ -0,0 +1,12 @@ +def __fn(X, y): + assert y.dtype == np.int64 + assert y.shape == (150,) + assert set(y.flatten()) == set((0,) + (1,) + (2,)) + _0 = y.reshape((-1,)) + _1 = np.zeros((3,) + (4,), dtype=np.float64) + _2 = _0 + _1 + _3 = np.unique(_0, return_counts=True) + _4 = _3[1].astype(np.float64) + _5 = _4 / np.array(150.0) + _6 = _2 + _5 + return _6 diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index 90404054..057b3e73 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -48,14 +48,19 @@ def test_to_source(snapshot_py): OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device), ) - res = _NDArray_4 + _NDArray_1 + _NDArray_5 - fn = ndarray_program(res).function_two(ndarray_program(X_orig), ndarray_program(Y_orig)) + res = _NDArray_3 + _NDArray_5 + _NDArray_4 egraph = EGraph([array_api_module_string]) - egraph.register(fn.eval_py_object(egraph.save_object({"np": numpy}))) - - egraph.run(100) - egraph.display(n_inline_leaves=1) - fn_source = egraph.load_object(egraph.extract(PyObject.from_string(fn.statements))) + with egraph: + egraph.register(res) + egraph.run(10000) + res = egraph.extract(res) + fn = ndarray_program(res).function_two(ndarray_program(X_orig), ndarray_program(Y_orig)) + with egraph: + egraph.register(fn.eval_py_object(egraph.save_object({"np": numpy}))) + egraph.run(10000) + fn = egraph.extract(fn) + # egraph.display(n_inline_leaves=0, split_primitive_outputs=True) + fn_source = egraph.load_object(egraph.extract(PyObject.from_string(fn.statements))) assert fn_source == snapshot_py From f85a074c6391051e0deaa1959c23e7cbdf2e1164 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 4 Oct 2023 08:46:28 -0400 Subject: [PATCH 19/23] Add split primitive outputs option --- Cargo.lock | 180 ++++++++++++++++++------------------- Cargo.toml | 4 +- python/egglog/bindings.pyi | 1 + python/egglog/egraph.py | 1 + src/egraph.rs | 6 +- 5 files changed, 94 insertions(+), 98 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c0b04622..9d31d4af 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -41,30 +41,29 @@ checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" [[package]] name = "anstream" -version = "0.3.2" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ca84f3628370c59db74ee214b3263d58f9aadd9b4fe7e711fd87dc452b7f163" +checksum = "2ab91ebe16eb252986481c5b62f6098f3b698a45e34b5b98200cf20dd2484a44" dependencies = [ "anstyle", "anstyle-parse", "anstyle-query", "anstyle-wincon", "colorchoice", - "is-terminal", "utf8parse", ] [[package]] name = "anstyle" -version = "1.0.1" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a30da5c5f2d5e72842e00bcb57657162cdabef0931f40e2deb9b4140440cecd" +checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" [[package]] name = "anstyle-parse" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "938874ff5980b03a87c5524b3ae5b59cf99b1d6bc836848df7bc5ada9643c333" +checksum = "317b9a89c1868f5ea6ff1d9539a69f45dffc21ce321ac1fd1160dfa48c8e2140" dependencies = [ "utf8parse", ] @@ -80,9 +79,9 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "1.0.1" +version = "3.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "180abfa45703aebe0093f79badacc01b8fd4ea2e35118747e5811127f926e188" +checksum = "f0699d10d2f4d628a98ee7b57b289abbc98ff3bad977cb3152709d4bf2330628" dependencies = [ "anstyle", "windows-sys", @@ -132,9 +131,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.3.3" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42" +checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" [[package]] name = "block-buffer" @@ -147,9 +146,12 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.79" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +dependencies = [ + "libc", +] [[package]] name = "cfg-if" @@ -159,20 +161,19 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.3.15" +version = "4.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f644d0dac522c8b05ddc39aaaccc5b136d5dc4ff216610c5641e3be5becf56c" +checksum = "d04704f56c2cde07f43e8e2c154b43f216dc5c92fc98ada720177362f953b956" dependencies = [ "clap_builder", "clap_derive", - "once_cell", ] [[package]] name = "clap_builder" -version = "4.3.15" +version = "4.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af410122b9778e024f9e0fb35682cc09cc3f85cad5e8d3ba8f47a9702df6e73d" +checksum = "0e231faeaca65ebd1ea3c737966bf858971cd38c3849107aa3ea7de90a804e45" dependencies = [ "anstream", "anstyle", @@ -182,21 +183,21 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.3.12" +version = "4.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54a9bb5758fc5dfe728d1019941681eccaf0cf8a4189b692a0ee2f2ecf90a050" +checksum = "0862016ff20d69b84ef8247369fabf5c008a7417002411897d40ee1f4532b873" dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.27", + "syn 2.0.32", ] [[package]] name = "clap_lex" -version = "0.5.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" +checksum = "cd7cc57abe963c6d3b9d8be5b06ba7c8957a930305ca90304f24ef040aa6f961" [[package]] name = "colorchoice" @@ -284,13 +285,13 @@ checksum = "675e35c02a51bb4d4618cb4885b3839ce6d1787c97b664474d9208d074742e20" [[package]] name = "egglog" version = "0.1.0" -source = "git+https://github.com/egraphs-good/egglog?rev=4d67f262a6f27aa5cfb62a2cfc7df968959105df#4d67f262a6f27aa5cfb62a2cfc7df968959105df" +source = "git+https://github.com/egraphs-good/egglog?rev=45d05e727cceaab13413b4e51a60ee3be9fbf403#45d05e727cceaab13413b4e51a60ee3be9fbf403" dependencies = [ "clap", "egraph-serialize", "env_logger", - "hashbrown 0.14.0", - "indexmap 2.0.0", + "hashbrown 0.14.1", + "indexmap", "instant", "lalrpop", "lalrpop-util 0.20.0", @@ -327,7 +328,7 @@ version = "0.1.0" source = "git+https://github.com/saulshanabrook/egraph-serialize?rev=a3f6fef9b958a335367d80d51e028c6db886fb6e#a3f6fef9b958a335367d80d51e028c6db886fb6e" dependencies = [ "graphviz-rust", - "indexmap 2.0.0", + "indexmap", "once_cell", "ordered-float", "serde", @@ -336,9 +337,9 @@ dependencies = [ [[package]] name = "either" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" [[package]] name = "ena" @@ -370,9 +371,9 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.1" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" +checksum = "add4f07d43996f76ef320709726a556a9d4f965d9410d8d0271132d2f8293480" dependencies = [ "errno-dragonfly", "libc", @@ -391,9 +392,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.0.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6999dc1837253364c2ebb0704ba97994bd874e8f195d665c50b7548f6ea92764" +checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" [[package]] name = "fixedbitset" @@ -449,9 +450,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.0" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +checksum = "7dfda62a12f55daeae5015f81b0baea145391cb4520f86c248fc615d72640d12" dependencies = [ "ahash 0.8.3", "allocator-api2", @@ -465,9 +466,9 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" [[package]] name = "humantime" @@ -477,22 +478,12 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "indexmap" -version = "1.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" -dependencies = [ - "autocfg", - "hashbrown 0.12.3", -] - -[[package]] -name = "indexmap" -version = "2.0.0" +version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" +checksum = "8adf3ddd720272c6ea8bf59463c04e0f93d0bbf7c5439b691bca2987e0270897" dependencies = [ "equivalent", - "hashbrown 0.14.0", + "hashbrown 0.14.1", "serde", ] @@ -614,9 +605,9 @@ checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" [[package]] name = "linux-raw-sys" -version = "0.4.3" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09fc20d2ca12cb9f044c93e3bd6d32d523e6e2ec3db4f7b2939cd99026ecd3f0" +checksum = "3852614a3bd9ca9804678ba6be5e3b8ce76dfc902cae004e3e0c44051b6e88db" [[package]] name = "lock_api" @@ -739,19 +730,20 @@ dependencies = [ [[package]] name = "pest" -version = "2.7.2" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1acb4a4365a13f749a93f1a094a7805e5cfa0955373a9de860d962eaa3a5fe5a" +checksum = "c022f1e7b65d6a24c0dbbd5fb344c66881bc01f3e5ae74a1c8100f2f985d98a4" dependencies = [ + "memchr", "thiserror", "ucd-trie", ] [[package]] name = "pest_derive" -version = "2.7.2" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "666d00490d4ac815001da55838c500eafb0320019bbaa44444137c48b443a853" +checksum = "35513f630d46400a977c4cb58f78e1bfbe01434316e60c37d27b9ad6139c66d8" dependencies = [ "pest", "pest_generator", @@ -759,22 +751,22 @@ dependencies = [ [[package]] name = "pest_generator" -version = "2.7.2" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68ca01446f50dbda87c1786af8770d535423fa8a53aec03b8f4e3d7eb10e0929" +checksum = "bc9fc1b9e7057baba189b5c626e2d6f40681ae5b6eb064dc7c7834101ec8123a" dependencies = [ "pest", "pest_meta", "proc-macro2", "quote", - "syn 2.0.27", + "syn 2.0.32", ] [[package]] name = "pest_meta" -version = "2.7.2" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56af0a30af74d0445c0bf6d9d051c979b516a1a5af790d251daee76005420a48" +checksum = "1df74e9e7ec4053ceb980e7c0c8bd3594e977fde1af91daba9c928e8e8c6708d" dependencies = [ "once_cell", "pest", @@ -783,12 +775,12 @@ dependencies = [ [[package]] name = "petgraph" -version = "0.6.3" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4" +checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" dependencies = [ "fixedbitset", - "indexmap 1.9.3", + "indexmap", ] [[package]] @@ -1005,11 +997,11 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustix" -version = "0.38.4" +version = "0.38.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a962918ea88d644592894bc6dc55acc6c0956488adcebbfb6e273506b7fd6e5" +checksum = "d7db8590df6dfcd144d22afd1b83b36c21a18d7cbc1dc4bb5295a8712e9eb662" dependencies = [ - "bitflags 2.3.3", + "bitflags 2.4.0", "errno", "libc", "linux-raw-sys", @@ -1036,31 +1028,31 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.179" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a5bf42b8d227d4abf38a1ddb08602e229108a517cd4e5bb28f9c7eaafdce5c0" +checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.179" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "741e124f5485c7e60c03b043f79f320bff3527f4bbf12cf3831750dc46a0ec2c" +checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.27", + "syn 2.0.32", ] [[package]] name = "serde_json" -version = "1.0.105" +version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360" +checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" dependencies = [ - "indexmap 2.0.0", + "indexmap", "itoa", "ryu", "serde", @@ -1068,9 +1060,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.10.7" +version = "0.10.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "479fb9d862239e610720565ca91403019f2f00410f1864c5aa7479b950a76ed8" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" dependencies = [ "cfg-if", "cpufeatures", @@ -1079,9 +1071,9 @@ dependencies = [ [[package]] name = "siphasher" -version = "0.3.10" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" [[package]] name = "smallvec" @@ -1135,9 +1127,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.27" +version = "2.0.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b60f673f44a8255b9c8c657daf66a596d435f2da81a555b06dc644d080ba45e0" +checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2" dependencies = [ "proc-macro2", "quote", @@ -1152,9 +1144,9 @@ checksum = "df8e77cb757a61f51b947ec4a7e3646efd825b73561db1c232a8ccb639e611a0" [[package]] name = "tempfile" -version = "3.7.1" +version = "3.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc02fddf48964c42031a0b3fe0428320ecf3a73c401040fc0096f97794310651" +checksum = "cb94d2f3cc536af71caac6b6fcebf65860b347e7ce0cc9ebe8f70d3e521054ef" dependencies = [ "cfg-if", "fastrand", @@ -1176,31 +1168,31 @@ dependencies = [ [[package]] name = "termcolor" -version = "1.2.0" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" +checksum = "6093bad37da69aab9d123a8091e4be0aa4a03e4d601ec641c327398315f62b64" dependencies = [ "winapi-util", ] [[package]] name = "thiserror" -version = "1.0.43" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a35fc5b8971143ca348fa6df4f024d4d55264f3468c71ad1c2f365b0a4d58c42" +checksum = "1177e8c6d7ede7afde3585fd2513e611227efd6481bd78d2e82ba1ce16557ed4" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.43" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "463fe12d7993d3b327787537ce8dd4dfa058de32fc2b195ef3cde03dc4771e8f" +checksum = "10712f02019e9288794769fba95cd6847df9874d49d871d062172f9dd41bc4cc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.27", + "syn 2.0.32", ] [[package]] @@ -1214,9 +1206,9 @@ dependencies = [ [[package]] name = "typenum" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "ucd-trie" @@ -1278,9 +1270,9 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" dependencies = [ "winapi", ] diff --git a/Cargo.toml b/Cargo.toml index 8d928a56..6dee8d3d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,9 +10,9 @@ crate-type = ["cdylib"] [dependencies] pyo3 = { version = "0.18.1", features = ["extension-module"] } -egglog = { git = "https://github.com/egraphs-good/egglog", rev = "4d67f262a6f27aa5cfb62a2cfc7df968959105df" } +egglog = { git = "https://github.com/egraphs-good/egglog", rev = "45d05e727cceaab13413b4e51a60ee3be9fbf403" } # egglog = { git = "https://github.com/oflatt/egg-smol", rev = "f6df3ff831b65405665e1751b0ef71c61b025432" } -# egglog = { git = "https://github.com/saulshanabrook/egg-smol", rev = "c01695618ed4de2fbfa8116476e208bc1ca86612" } +# egglog = { git = "https://github.com/saulshanabrook/egg-smol", rev = "38b3014b34399cc78887ede09c845b2a5d6c7d19" } pyo3-log = "0.8.1" log = "0.4.17" diff --git a/python/egglog/bindings.pyi b/python/egglog/bindings.pyi index b62ff026..c3236fc9 100644 --- a/python/egglog/bindings.pyi +++ b/python/egglog/bindings.pyi @@ -19,6 +19,7 @@ class EGraph: max_functions: Optional[int] = None, max_calls_per_function: Optional[int] = None, n_inline_leaves: int = 0, + split_primitive_outputs: bool = False, ) -> str: ... def save_object(self, __o: object, /) -> _Expr: ... def load_object(self, __e: _Expr, /) -> object: ... diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index ff352a05..00a9e806 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -670,6 +670,7 @@ class GraphvizKwargs(TypedDict, total=False): max_functions: Optional[int] max_calls_per_function: Optional[int] n_inline_leaves: int + split_primitive_outputs: bool @dataclass diff --git a/src/egraph.rs b/src/egraph.rs index aae4c1bc..4f805c23 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -83,14 +83,15 @@ impl EGraph { /// Returns the EGraph as graphviz string. #[pyo3( - signature = (*, max_functions=None, max_calls_per_function=None, n_inline_leaves=0), - text_signature = "(self, *, max_functions=None, max_calls_per_function=None, n_inline_leaves=0)" + signature = (*, max_functions=None, max_calls_per_function=None, n_inline_leaves=0, split_primitive_outputs=false), + text_signature = "(self, *, max_functions=None, max_calls_per_function=None, n_inline_leaves=0, split_primitive_outputs=False)" )] fn to_graphviz_string( &self, max_functions: Option, max_calls_per_function: Option, n_inline_leaves: usize, + split_primitive_outputs: bool, ) -> String { info!("Getting graphviz"); // TODO: Expose full serialized e-graph in the future @@ -98,6 +99,7 @@ impl EGraph { max_functions, max_calls_per_function, include_temporary_functions: false, + split_primitive_outputs, }); for _ in 0..n_inline_leaves { serialized.inline_leaves(); From 80debbc87697616b8b856587ef50a0fa4b54ae1a Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 4 Oct 2023 08:47:59 -0400 Subject: [PATCH 20/23] Highlight paths in graphviz output --- python/egglog/egraph.py | 38 +++++++++++++++++++++++++++++++++++--- 1 file changed, 35 insertions(+), 3 deletions(-) diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 00a9e806..60b4327d 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -1,6 +1,8 @@ from __future__ import annotations import inspect +import pathlib +import tempfile from abc import ABC, abstractmethod from contextvars import ContextVar, Token from copy import deepcopy @@ -705,7 +707,37 @@ def _repr_mimebundle_(self, *args, **kwargs): return {"image/svg+xml": self.graphviz().pipe(format="svg", quiet=True, encoding="utf-8")} def graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source: - return graphviz.Source(self._egraph.to_graphviz_string(**kwargs)) + original = self._egraph.to_graphviz_string(**kwargs) + # Add link to stylesheet to the graph, so that edges light up on hover + # https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682 + styles = """/* the lines within the edges */ + .edge:active path, + .edge:hover path { + stroke: fuchsia; + stroke-width: 3; + stroke-opacity: 1; + } + /* arrows are typically drawn with a polygon */ + .edge:active polygon, + .edge:hover polygon { + stroke: fuchsia; + stroke-width: 3; + fill: fuchsia; + stroke-opacity: 1; + fill-opacity: 1; + } + /* If you happen to have text and want to color that as well... */ + .edge:active text, + .edge:hover text { + fill: fuchsia; + }""" + p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css" + p.write_text(styles) + with_stylesheet = original.replace("{", f'{{stylesheet="{str(p)}"', 1) + return graphviz.Source(with_stylesheet) + + def graphviz_svg(self, **kwargs: Unpack[GraphvizKwargs]) -> str: + return graphviz.pipe(format="svg", quiet=True, encoding="utf-8") def _repr_html_(self) -> str: """ @@ -714,7 +746,7 @@ def _repr_html_(self) -> str: until this PR is merged and released https://github.com/sphinx-gallery/sphinx-gallery/pull/1138 """ - return self.graphviz().pipe(format="svg", quiet=True).decode() + return self.graphviz_svg() def display(self, **kwargs: Unpack[GraphvizKwargs]): """ @@ -724,7 +756,7 @@ def display(self, **kwargs: Unpack[GraphvizKwargs]): if IN_IPYTHON: from IPython.display import SVG, display - display(SVG(graphviz.pipe(format="svg", quiet=True, encoding="utf-8"))) + display(SVG(self.graphviz_svg(**kwargs))) else: graphviz.render(view=True, format="svg", quiet=True) From 1c65e75b99b719e05387afc7851133751bf622fa Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 4 Oct 2023 08:48:09 -0400 Subject: [PATCH 21/23] Increase cost of emitting var --- python/egglog/exp/array_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index 02b5716a..8631c8dd 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -699,7 +699,7 @@ class NDArray(Expr): def __init__(self, py_array: PyObject) -> None: ... - @array_api_module.method(cost=100) + @array_api_module.method(cost=200) @classmethod def var(cls, name: StringLike) -> NDArray: ... From b5bfbd23ad827bf537f39ee62ebf2a59f2d008b3 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 4 Oct 2023 08:54:51 -0400 Subject: [PATCH 22/23] Fix docs --- docs/reference/python-integration.md | 9 +++++++++ python/egglog/egraph.py | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/docs/reference/python-integration.md b/docs/reference/python-integration.md index da61d70d..8850dcf4 100644 --- a/docs/reference/python-integration.md +++ b/docs/reference/python-integration.md @@ -118,6 +118,15 @@ assert egraph.load_object(egraph.extract(evalled)) == 3 Similar to how an `int` can be automatically upcasted to an `i64`, we also support registering conversion to your custom types. For example: ```{code-cell} python +@egraph.class_ +class Math(Expr): + def __init__(self, x: i64Like) -> None: ... + + @classmethod + def var(cls, name: StringLike) -> Math: ... + + def __add__(self, other: Math) -> Math: ... + converter(i64, Math, Math) converter(String, Math, Math.var) diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 60b4327d..1d8b87ba 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -737,7 +737,7 @@ def graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source: return graphviz.Source(with_stylesheet) def graphviz_svg(self, **kwargs: Unpack[GraphvizKwargs]) -> str: - return graphviz.pipe(format="svg", quiet=True, encoding="utf-8") + return self.graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8") def _repr_html_(self) -> str: """ From e305089c482241b13b61faa4ba952686d9250256 Mon Sep 17 00:00:00 2001 From: Saul Shanabrook Date: Wed, 4 Oct 2023 09:00:29 -0400 Subject: [PATCH 23/23] Adds changelog entries --- docs/changelog.md | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/docs/changelog.md b/docs/changelog.md index 5ee3364e..53e2fec5 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,7 +4,7 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea ## Unreleased -- Bump [egglog dep](https://github.com/egraphs-good/egglog/compare/c83fc750878755eb610a314da90f9273b3bfe25d...4d67f262a6f27aa5cfb62a2cfc7df968959105df) +- Bump [egglog dep](https://github.com/egraphs-good/egglog/compare/c83fc750878755eb610a314da90f9273b3bfe25d...45d05e727cceaab13413b4e51a60ee3be9fbf403) ### Breaking Changes @@ -16,11 +16,16 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea - Add ability to inline leaves $n$ times instead of just once for visualization [#48](https://github.com/metadsl/egglog-python/pull/48) - Add `Relation` and `PrintOverallStatistics` low level commands [#46](https://github.com/metadsl/egglog-python/pull/46) - Adds `count-matches` and `replace` string commands [#46](https://github.com/metadsl/egglog-python/pull/46) -- Adds ability for custom user defined types in a union for proper static typing with conversions -- Adds `py_eval` function to `EGraph` as a helper to eval Python code. +- Adds ability for custom user defined types in a union for proper static typing with conversions [#49](https://github.com/metadsl/egglog-python/pull/49) +- Adds `py_eval` function to `EGraph` as a helper to eval Python code. [#49](https://github.com/metadsl/egglog-python/pull/49) +- Adds on hover behavior for edges in graphviz SVG output to make them easier to trace [#49](https://github.com/metadsl/egglog-python/pull/49) +- Adds `egglog.exp.program_gen` module that will compile expressions into Python statements/functions [#49](https://github.com/metadsl/egglog-python/pull/49) +- Adds `py_exec` primitive function for executing Python code [#49](https://github.com/metadsl/egglog-python/pull/49) ### Bug fixes +- Clean up example in tutorial with demand based expression generation [#49](https://github.com/metadsl/egglog-python/pull/49) + ### Uncategorized - Added initial supported for Python objects [#31](https://github.com/metadsl/egglog-python/pull/31)