Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make visualizations interactive #212

Merged
merged 3 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,5 @@ Source.*
3
4
inlined
visualizer.tgz
package
24 changes: 24 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@

all: python/egglog/visualizer.js



# download visualizer release from github
#
visualizer.tgz:
curl -s https://api.github.com/repos/egraphs-good/egraph-visualizer/releases/latest \
| grep "browser_download_url.*tgz" \
| cut -d : -f 2,3 \
| tr -d \" \
| wget -qi - -O visualizer.tgz

# extract visualizer release
python/egglog/visualizer.js python/egglog/visualizer.css: visualizer.tgz
tar -xzf visualizer.tgz
rm visualizer.tgz
mv package/dist/index.js python/egglog/visualizer.js
mv package/dist/style.css python/egglog/visualizer.css
rm -rf package

clean:
rm -rf package python/egglog/visualizer.css python/egglog/visualizer.js visualizer.tgz
1 change: 1 addition & 0 deletions docs/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ _This project uses semantic versioning_
- Adds source annotations to expressions for tracebacks
- Adds ability to inline other functions besides primitives in serialized output
- Adds `remove` and `set` methods to `Vec`
- Upgrades to use the new egraph-visualizer so we can have interactive visualizations

## 7.2.0 (2024-05-23)

Expand Down
13 changes: 2 additions & 11 deletions docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
file_format: mystnb
---

# [`egglog`](https://github.com/egraphs-good/egglog/) Python
# `egglog` Python

`egglog` is a Python package that provides bindings to the Rust library of the same name,
`egglog` is a Python package that provides bindings to [the Rust library of the same name](https://github.com/egraphs-good/egglog/),
allowing you to use e-graphs in Python for optimization, symbolic computation, and analysis.

It wraps the Rust library [`egglog`](https://github.com/egraphs-good/egglog) which
Expand All @@ -13,14 +13,10 @@ See the ["Better Together: Unifying Datalog and Equality Saturation"](https://ar

> We present egglog, a fixpoint reasoning system that unifies Datalog and equality saturation (EqSat). Like Datalog, it supports efficient incremental execution, cooperating analyses, and lattice-based reasoning. Like EqSat, it supports term rewriting, efficient congruence closure, and extraction of optimized terms.

## [Installation](./reference/usage)

```shell
pip install egglog
```

## Example

```{code-cell} python
from __future__ import annotations
from egglog import *
Expand Down Expand Up @@ -49,15 +45,10 @@ def _num_rule(a: Num, b: Num, c: Num, i: i64, j: i64):
yield rewrite(Num(i) * Num(j)).to(Num(i * j))

egraph.saturate()
```

```{code-cell} python
egraph.check(eq(expr1).to(expr2))
egraph.extract(expr1)
```

## Contents

```{toctree}
:maxdepth: 2
tutorials
Expand Down
5 changes: 4 additions & 1 deletion docs/reference/contributing.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,14 @@ conda activate egglog-python
Then install the package in editable mode with the development dependencies:

```bash
maturin develop -E .[dev]
maturin develop -E dev,docs,test,array
```

Anytime you change the rust code, you can run `maturin develop -E` to recompile the rust code.

If you would like to download a new version of the visualizer source, run `make clean; make`. This will download
the most recent released version from the github actions artifact in the [egraph-visualizer](https://github.com/egraphs-good/egraph-visualizer) repo. It is checked in because it's a pain to get cargo to include only one git ignored file while ignoring the rest of the files that were ignored.

### Running Tests

To run the tests, you can use the `pytest` command:
Expand Down
27 changes: 27 additions & 0 deletions docs/reference/python-integration.md
Original file line number Diff line number Diff line change
Expand Up @@ -537,3 +537,30 @@ egraph.run(other_math_ruleset * 2)
egraph.check(eq(x).to(WrappedMath(math_float(3.14)) + WrappedMath(math_float(3.14))))
egraph
```

## Visualization

The default renderer for the e-graph in a Jupyter Notebook [an interactive Javascript visualizer](https://github.com/egraphs-good/egraph-visualizer):

```{code-cell} python
egraph
```

You can also customize the visualization through using the <inv:egglog.EGraph.display> method:

```{code-cell} python
egraph.display()
```

If you would like to visualize the progression of the e-graph over time, you can use the <inv:egglog.EGraph.saturate> method to
run a number of iterations and then visualize the e-graph at each step:

```{code-cell} python
egraph = EGraph()
egraph.register(Math(2) + Math(100))
i, j = vars_("i j", i64)
r = ruleset(
rewrite(Math(i) + Math(j)).to(Math(i + j)),
)
egraph.saturate(r)
```
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ classifiers = [
"Topic :: Software Development :: Interpreters",
"Typing :: Typed",
]
dependencies = ["typing-extensions", "black", "graphviz"]
dependencies = ["typing-extensions", "black", "graphviz", "anywidget"]

[project.optional-dependencies]

Expand Down
2 changes: 1 addition & 1 deletion python/egglog/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ def resolve_literal(
# Try all parent types as well, if we are converting from a Python type
for arg_type_instance in arg_type.__mro__ if isinstance(arg_type, type) else [arg_type]:
try:
fn = CONVERSIONS[(cast(TypeName | type, arg_type_instance), tp_name)][1]
fn = CONVERSIONS[(arg_type_instance, tp_name)][1]
except KeyError:
continue
break
Expand Down
191 changes: 88 additions & 103 deletions python/egglog/egraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
from .thunk import *

if TYPE_CHECKING:
import ipywidgets

from .builtins import Bool, PyObject, String, f64, i64


Expand Down Expand Up @@ -973,6 +971,7 @@ class GraphvizKwargs(TypedDict, total=False):
n_inline_leaves: int
split_primitive_outputs: bool
split_functions: list[object]
include_temporary_functions: bool


@dataclass
Expand Down Expand Up @@ -1015,82 +1014,8 @@ def as_egglog_string(self) -> str:
raise ValueError(msg)
return cmds

def _repr_mimebundle_(self, *args, **kwargs):
"""
Returns the graphviz representation of the e-graph.
"""
return {"image/svg+xml": self.graphviz().pipe(format="svg", quiet=True, encoding="utf-8")}

def graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source:
# By default we want to split primitive outputs
split_primitive_outputs = kwargs.pop("split_primitive_outputs", True)
split_additional_functions = kwargs.pop("split_functions", [])
n_inline = kwargs.pop("n_inline_leaves", 0)
serialized = self._egraph.serialize(
[],
max_functions=kwargs.pop("max_functions", None),
max_calls_per_function=kwargs.pop("max_calls_per_function", None),
include_temporary_functions=False,
)
if split_primitive_outputs or split_additional_functions:
additional_ops = set(map(self._callable_to_egg, split_additional_functions))
serialized.split_e_classes(self._egraph, additional_ops)
serialized.map_ops(self._state.op_mapping())

for _ in range(n_inline):
serialized.inline_leaves()
original = serialized.to_dot()
# Add link to stylesheet to the graph, so that edges light up on hover
# https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682
styles = """/* the lines within the edges */
.edge:active path,
.edge:hover path {
stroke: fuchsia;
stroke-width: 3;
stroke-opacity: 1;
}
/* arrows are typically drawn with a polygon */
.edge:active polygon,
.edge:hover polygon {
stroke: fuchsia;
stroke-width: 3;
fill: fuchsia;
stroke-opacity: 1;
fill-opacity: 1;
}
/* If you happen to have text and want to color that as well... */
.edge:active text,
.edge:hover text {
fill: fuchsia;
}"""
p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css"
p.write_text(styles)
with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1)
return graphviz.Source(with_stylesheet)

def graphviz_svg(self, **kwargs: Unpack[GraphvizKwargs]) -> str:
return self.graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8")

def _repr_html_(self) -> str:
"""
Add a _repr_html_ to be an SVG to work with sphinx gallery.

ala https://github.com/xflr6/graphviz/pull/121
until this PR is merged and released
https://github.com/sphinx-gallery/sphinx-gallery/pull/1138
"""
return self.graphviz_svg()

def display(self, **kwargs: Unpack[GraphvizKwargs]) -> None:
"""
Displays the e-graph in the notebook.
"""
if IN_IPYTHON:
from IPython.display import SVG, display

display(SVG(self.graphviz_svg(**kwargs)))
else:
self.graphviz(**kwargs).render(view=True, format="svg", quiet=True)
def _ipython_display_(self) -> None:
self.display()

def input(self, fn: Callable[..., String], path: str) -> None:
"""
Expand Down Expand Up @@ -1319,40 +1244,100 @@ def eval(self, expr: Expr) -> object:
return self._egraph.eval_py_object(egg_expr)
raise TypeError(f"Eval not implemented for {typed_expr.tp}")

def saturate(
def _serialize(
self,
schedule: Schedule | None = None,
*,
max: int = 1000,
performance: bool = False,
**kwargs: Unpack[GraphvizKwargs],
) -> ipywidgets.Widget:
from .graphviz_widget import graphviz_widget_with_slider
) -> bindings.SerializedEGraph:
max_functions = kwargs.pop("max_functions", None)
max_calls_per_function = kwargs.pop("max_calls_per_function", None)
split_primitive_outputs = kwargs.pop("split_primitive_outputs", True)
split_functions = kwargs.pop("split_functions", [])
include_temporary_functions = kwargs.pop("include_temporary_functions", False)
n_inline_leaves = kwargs.pop("n_inline_leaves", 1)
serialized = self._egraph.serialize(
[],
max_functions=max_functions,
max_calls_per_function=max_calls_per_function,
include_temporary_functions=include_temporary_functions,
)
if split_primitive_outputs or split_functions:
additional_ops = set(map(self._callable_to_egg, split_functions))
serialized.split_e_classes(self._egraph, additional_ops)
serialized.map_ops(self._state.op_mapping())

dots = [str(self.graphviz(**kwargs))]
i = 0
while self.run(schedule or 1).updated and i < max:
i += 1
dots.append(str(self.graphviz(**kwargs)))
return graphviz_widget_with_slider(dots, performance=performance)
for _ in range(n_inline_leaves):
serialized.inline_leaves()

def saturate_to_html(
self, file: str = "tmp.html", performance: bool = False, **kwargs: Unpack[GraphvizKwargs]
) -> None:
# raise NotImplementedError("Upstream bugs prevent rendering to HTML")
return serialized

# import panel
def _graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source:
serialized = self._serialize(**kwargs)

# panel.extension("ipywidgets")
original = serialized.to_dot()
# Add link to stylesheet to the graph, so that edges light up on hover
# https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682
styles = """/* the lines within the edges */
.edge:active path,
.edge:hover path {
stroke: fuchsia;
stroke-width: 3;
stroke-opacity: 1;
}
/* arrows are typically drawn with a polygon */
.edge:active polygon,
.edge:hover polygon {
stroke: fuchsia;
stroke-width: 3;
fill: fuchsia;
stroke-opacity: 1;
fill-opacity: 1;
}
/* If you happen to have text and want to color that as well... */
.edge:active text,
.edge:hover text {
fill: fuchsia;
}"""
p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css"
p.write_text(styles)
with_stylesheet = original.replace("{", f'{{stylesheet="{p!s}"', 1)
return graphviz.Source(with_stylesheet)

widget = self.saturate(performance=performance, **kwargs)
# panel.panel(widget).save(file)
def display(self, graphviz: bool = False, **kwargs: Unpack[GraphvizKwargs]) -> None:
"""
Displays the e-graph.

If in IPython it will display it inline, otherwise it will write it to a file and open it.
"""
from IPython.display import SVG, display

from .visualizer_widget import VisualizerWidget

if graphviz:
if IN_IPYTHON:
svg = self._graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8")
display(SVG(svg))
else:
self._graphviz(**kwargs).render(view=True, format="svg", quiet=True)
else:
serialized = self._serialize(**kwargs)
VisualizerWidget(egraphs=[serialized.to_json()]).display_or_open()

from ipywidgets.embed import embed_minimal_html
def saturate(self, schedule: Schedule | None = None, *, max: int = 1000, **kwargs: Unpack[GraphvizKwargs]) -> None:
"""
Saturate the egraph, running the given schedule until the egraph is saturated.
It serializes the egraph at each step and returns a widget to visualize the egraph.
"""
from .visualizer_widget import VisualizerWidget

def to_json() -> str:
return self._serialize(**kwargs).to_json()

embed_minimal_html("tmp.html", views=[widget], drop_defaults=False)
# Use panel while this issue persists
# https://github.com/jupyter-widgets/ipywidgets/issues/3761#issuecomment-1755563436
egraphs = [to_json()]
i = 0
while self.run(schedule or 1).updated and i < max:
i += 1
egraphs.append(to_json())
VisualizerWidget(egraphs=egraphs).display_or_open()

@classmethod
def current(cls) -> EGraph:
Expand Down
5 changes: 1 addition & 4 deletions python/egglog/exp/array_api_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@ def jit(fn: X) -> X:
"""
Jit compiles a function
"""
from IPython.display import SVG

# 1. Create variables for each of the two args in the functions
sig = inspect.signature(fn)
arg1, arg2 = sig.parameters.keys()
Expand All @@ -25,13 +23,12 @@ def jit(fn: X) -> X:
egraph.register(res)
egraph.run(array_api_numba_schedule)
res_optimized = egraph.extract(res)
svg = SVG(egraph.graphviz_svg(split_primitive_outputs=True, n_inline_leaves=3))
egraph.display(split_primitive_outputs=True, n_inline_leaves=3)

egraph = EGraph()
fn_program = ndarray_function_two(res_optimized, NDArray.var(arg1), NDArray.var(arg2))
egraph.register(fn_program)
egraph.run(array_api_program_gen_schedule)
fn = cast(X, egraph.eval(fn_program.py_object))
fn.egraph = svg # type: ignore[attr-defined]
fn.expr = res_optimized # type: ignore[attr-defined]
return fn
Loading
Loading