Skip to content

Commit

Permalink
Merge pull request #212 from egraphs-good/interactive-visualizer
Browse files Browse the repository at this point in the history
Make visualizations interactive
  • Loading branch information
saulshanabrook authored Oct 16, 2024
2 parents b7c6762 + ae215ba commit 17ffd1a
Show file tree
Hide file tree
Showing 16 changed files with 35,944 additions and 211 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,5 @@ Source.*
3
4
inlined
visualizer.tgz
package
24 changes: 24 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

all: python/egglog/visualizer.js



# download visualizer release from github
#
visualizer.tgz:
curl -s https://api.github.com/repos/egraphs-good/egraph-visualizer/releases/latest \
| grep "browser_download_url.*tgz" \
| cut -d : -f 2,3 \
| tr -d \" \
| wget -qi - -O visualizer.tgz

# extract visualizer release
python/egglog/visualizer.js python/egglog/visualizer.css: visualizer.tgz
tar -xzf visualizer.tgz
rm visualizer.tgz
mv package/dist/index.js python/egglog/visualizer.js
mv package/dist/style.css python/egglog/visualizer.css
rm -rf package

clean:
rm -rf package python/egglog/visualizer.css python/egglog/visualizer.js visualizer.tgz
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ _This project uses semantic versioning_
- Adds source annotations to expressions for tracebacks
- Adds ability to inline other functions besides primitives in serialized output
- Adds `remove` and `set` methods to `Vec`
- Upgrades to use the new egraph-visualizer so we can have interactive visualizations

## 7.2.0 (2024-05-23)

Expand Down
13 changes: 2 additions & 11 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
file_format: mystnb
---

# [`egglog`](https://github.com/egraphs-good/egglog/) Python
# `egglog` Python

`egglog` is a Python package that provides bindings to the Rust library of the same name,
`egglog` is a Python package that provides bindings to [the Rust library of the same name](https://github.com/egraphs-good/egglog/),
allowing you to use e-graphs in Python for optimization, symbolic computation, and analysis.

It wraps the Rust library [`egglog`](https://github.com/egraphs-good/egglog) which
Expand All @@ -13,14 +13,10 @@ See the ["Better Together: Unifying Datalog and Equality Saturation"](https://ar

> We present egglog, a fixpoint reasoning system that unifies Datalog and equality saturation (EqSat). Like Datalog, it supports efficient incremental execution, cooperating analyses, and lattice-based reasoning. Like EqSat, it supports term rewriting, efficient congruence closure, and extraction of optimized terms.
## [Installation](./reference/usage)

```shell
pip install egglog
```

## Example

```{code-cell} python
from __future__ import annotations
from egglog import *
Expand Down Expand Up @@ -49,15 +45,10 @@ def _num_rule(a: Num, b: Num, c: Num, i: i64, j: i64):
yield rewrite(Num(i) * Num(j)).to(Num(i * j))
egraph.saturate()
```

```{code-cell} python
egraph.check(eq(expr1).to(expr2))
egraph.extract(expr1)
```

## Contents

```{toctree}
:maxdepth: 2
tutorials
Expand Down
5 changes: 4 additions & 1 deletion docs/reference/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ conda activate egglog-python
Then install the package in editable mode with the development dependencies:

```bash
maturin develop -E .[dev]
maturin develop -E dev,docs,test,array
```

Anytime you change the rust code, you can run `maturin develop -E` to recompile the rust code.

If you would like to download a new version of the visualizer source, run `make clean; make`. This will download
the most recent released version from the github actions artifact in the [egraph-visualizer](https://github.com/egraphs-good/egraph-visualizer) repo. It is checked in because it's a pain to get cargo to include only one git ignored file while ignoring the rest of the files that were ignored.

### Running Tests

To run the tests, you can use the `pytest` command:
Expand Down
27 changes: 27 additions & 0 deletions docs/reference/python-integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -537,3 +537,30 @@ egraph.run(other_math_ruleset * 2)
egraph.check(eq(x).to(WrappedMath(math_float(3.14)) + WrappedMath(math_float(3.14))))
egraph
```

## Visualization

The default renderer for the e-graph in a Jupyter Notebook [an interactive Javascript visualizer](https://github.com/egraphs-good/egraph-visualizer):

```{code-cell} python
egraph
```

You can also customize the visualization through using the <inv:egglog.EGraph.display> method:

```{code-cell} python
egraph.display()
```

If you would like to visualize the progression of the e-graph over time, you can use the <inv:egglog.EGraph.saturate> method to
run a number of iterations and then visualize the e-graph at each step:

```{code-cell} python
egraph = EGraph()
egraph.register(Math(2) + Math(100))
i, j = vars_("i j", i64)
r = ruleset(
rewrite(Math(i) + Math(j)).to(Math(i + j)),
)
egraph.saturate(r)
```
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ classifiers = [
"Topic :: Software Development :: Interpreters",
"Typing :: Typed",
]
dependencies = ["typing-extensions", "black", "graphviz"]
dependencies = ["typing-extensions", "black", "graphviz", "anywidget"]

[project.optional-dependencies]

Expand Down
2 changes: 1 addition & 1 deletion python/egglog/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def resolve_literal(
# Try all parent types as well, if we are converting from a Python type
for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]:
try:
fn = CONVERSIONS[(cast(TypeName | type, arg_type_instance), tp_name)][1]
fn = CONVERSIONS[(arg_type_instance, tp_name)][1]
except KeyError:
continue
break
Expand Down
191 changes: 88 additions & 103 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
from .thunk import *

if TYPE_CHECKING:
import ipywidgets

from .builtins import Bool, PyObject, String, f64, i64


Expand Down Expand Up @@ -973,6 +971,7 @@ class GraphvizKwargs(TypedDict, total=False):
n_inline_leaves: int
split_primitive_outputs: bool
split_functions: list[object]
include_temporary_functions: bool


@dataclass
Expand Down Expand Up @@ -1015,82 +1014,8 @@ def as_egglog_string(self) -> str:
raise ValueError(msg)
return cmds

def _repr_mimebundle_(self, *args, **kwargs):
"""
Returns the graphviz representation of the e-graph.
"""
return {"image/svg+xml": self.graphviz().pipe(format="svg", quiet=True, encoding="utf-8")}

def graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source:
# By default we want to split primitive outputs
split_primitive_outputs = kwargs.pop("split_primitive_outputs", True)
split_additional_functions = kwargs.pop("split_functions", [])
n_inline = kwargs.pop("n_inline_leaves", 0)
serialized = self._egraph.serialize(
[],
max_functions=kwargs.pop("max_functions", None),
max_calls_per_function=kwargs.pop("max_calls_per_function", None),
include_temporary_functions=False,
)
if split_primitive_outputs or split_additional_functions:
additional_ops = set(map(self._callable_to_egg, split_additional_functions))
serialized.split_e_classes(self._egraph, additional_ops)
serialized.map_ops(self._state.op_mapping())

for _ in range(n_inline):
serialized.inline_leaves()
original = serialized.to_dot()
# Add link to stylesheet to the graph, so that edges light up on hover
# https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682
styles = """/* the lines within the edges */
.edge:active path,
.edge:hover path {
stroke: fuchsia;
stroke-width: 3;
stroke-opacity: 1;
}
/* arrows are typically drawn with a polygon */
.edge:active polygon,
.edge:hover polygon {
stroke: fuchsia;
stroke-width: 3;
fill: fuchsia;
stroke-opacity: 1;
fill-opacity: 1;
}
/* If you happen to have text and want to color that as well... */
.edge:active text,
.edge:hover text {
fill: fuchsia;
}"""
p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css"
p.write_text(styles)
with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1)
return graphviz.Source(with_stylesheet)

def graphviz_svg(self, **kwargs: Unpack[GraphvizKwargs]) -> str:
return self.graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8")

def _repr_html_(self) -> str:
"""
Add a _repr_html_ to be an SVG to work with sphinx gallery.
ala https://github.com/xflr6/graphviz/pull/121
until this PR is merged and released
https://github.com/sphinx-gallery/sphinx-gallery/pull/1138
"""
return self.graphviz_svg()

def display(self, **kwargs: Unpack[GraphvizKwargs]) -> None:
"""
Displays the e-graph in the notebook.
"""
if IN_IPYTHON:
from IPython.display import SVG, display

display(SVG(self.graphviz_svg(**kwargs)))
else:
self.graphviz(**kwargs).render(view=True, format="svg", quiet=True)
def _ipython_display_(self) -> None:
self.display()

def input(self, fn: Callable[..., String], path: str) -> None:
"""
Expand Down Expand Up @@ -1319,40 +1244,100 @@ def eval(self, expr: Expr) -> object:
return self._egraph.eval_py_object(egg_expr)
raise TypeError(f"Eval not implemented for {typed_expr.tp}")

def saturate(
def _serialize(
self,
schedule: Schedule | None = None,
*,
max: int = 1000,
performance: bool = False,
**kwargs: Unpack[GraphvizKwargs],
) -> ipywidgets.Widget:
from .graphviz_widget import graphviz_widget_with_slider
) -> bindings.SerializedEGraph:
max_functions = kwargs.pop("max_functions", None)
max_calls_per_function = kwargs.pop("max_calls_per_function", None)
split_primitive_outputs = kwargs.pop("split_primitive_outputs", True)
split_functions = kwargs.pop("split_functions", [])
include_temporary_functions = kwargs.pop("include_temporary_functions", False)
n_inline_leaves = kwargs.pop("n_inline_leaves", 1)
serialized = self._egraph.serialize(
[],
max_functions=max_functions,
max_calls_per_function=max_calls_per_function,
include_temporary_functions=include_temporary_functions,
)
if split_primitive_outputs or split_functions:
additional_ops = set(map(self._callable_to_egg, split_functions))
serialized.split_e_classes(self._egraph, additional_ops)
serialized.map_ops(self._state.op_mapping())

dots = [str(self.graphviz(**kwargs))]
i = 0
while self.run(schedule or 1).updated and i < max:
i += 1
dots.append(str(self.graphviz(**kwargs)))
return graphviz_widget_with_slider(dots, performance=performance)
for _ in range(n_inline_leaves):
serialized.inline_leaves()

def saturate_to_html(
self, file: str = "tmp.html", performance: bool = False, **kwargs: Unpack[GraphvizKwargs]
) -> None:
# raise NotImplementedError("Upstream bugs prevent rendering to HTML")
return serialized

# import panel
def _graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source:
serialized = self._serialize(**kwargs)

# panel.extension("ipywidgets")
original = serialized.to_dot()
# Add link to stylesheet to the graph, so that edges light up on hover
# https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682
styles = """/* the lines within the edges */
.edge:active path,
.edge:hover path {
stroke: fuchsia;
stroke-width: 3;
stroke-opacity: 1;
}
/* arrows are typically drawn with a polygon */
.edge:active polygon,
.edge:hover polygon {
stroke: fuchsia;
stroke-width: 3;
fill: fuchsia;
stroke-opacity: 1;
fill-opacity: 1;
}
/* If you happen to have text and want to color that as well... */
.edge:active text,
.edge:hover text {
fill: fuchsia;
}"""
p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css"
p.write_text(styles)
with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1)
return graphviz.Source(with_stylesheet)

widget = self.saturate(performance=performance, **kwargs)
# panel.panel(widget).save(file)
def display(self, graphviz: bool = False, **kwargs: Unpack[GraphvizKwargs]) -> None:
"""
Displays the e-graph.
If in IPython it will display it inline, otherwise it will write it to a file and open it.
"""
from IPython.display import SVG, display

from .visualizer_widget import VisualizerWidget

if graphviz:
if IN_IPYTHON:
svg = self._graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8")
display(SVG(svg))
else:
self._graphviz(**kwargs).render(view=True, format="svg", quiet=True)
else:
serialized = self._serialize(**kwargs)
VisualizerWidget(egraphs=[serialized.to_json()]).display_or_open()

from ipywidgets.embed import embed_minimal_html
def saturate(self, schedule: Schedule | None = None, *, max: int = 1000, **kwargs: Unpack[GraphvizKwargs]) -> None:
"""
Saturate the egraph, running the given schedule until the egraph is saturated.
It serializes the egraph at each step and returns a widget to visualize the egraph.
"""
from .visualizer_widget import VisualizerWidget

def to_json() -> str:
return self._serialize(**kwargs).to_json()

embed_minimal_html("tmp.html", views=[widget], drop_defaults=False)
# Use panel while this issue persists
# https://github.com/jupyter-widgets/ipywidgets/issues/3761#issuecomment-1755563436
egraphs = [to_json()]
i = 0
while self.run(schedule or 1).updated and i < max:
i += 1
egraphs.append(to_json())
VisualizerWidget(egraphs=egraphs).display_or_open()

@classmethod
def current(cls) -> EGraph:
Expand Down
5 changes: 1 addition & 4 deletions python/egglog/exp/array_api_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ def jit(fn: X) -> X:
"""
Jit compiles a function
"""
from IPython.display import SVG

# 1. Create variables for each of the two args in the functions
sig = inspect.signature(fn)
arg1, arg2 = sig.parameters.keys()
Expand All @@ -25,13 +23,12 @@ def jit(fn: X) -> X:
egraph.register(res)
egraph.run(array_api_numba_schedule)
res_optimized = egraph.extract(res)
svg = SVG(egraph.graphviz_svg(split_primitive_outputs=True, n_inline_leaves=3))
egraph.display(split_primitive_outputs=True, n_inline_leaves=3)

egraph = EGraph()
fn_program = ndarray_function_two(res_optimized, NDArray.var(arg1), NDArray.var(arg2))
egraph.register(fn_program)
egraph.run(array_api_program_gen_schedule)
fn = cast(X, egraph.eval(fn_program.py_object))
fn.egraph = svg # type: ignore[attr-defined]
fn.expr = res_optimized # type: ignore[attr-defined]
return fn
Loading

0 comments on commit 17ffd1a

Please sign in to comment.