diff --git a/Cargo.lock b/Cargo.lock
index 28f0cf09..af973ee1 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -41,30 +41,29 @@ checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5"
[[package]]
name = "anstream"
-version = "0.3.2"
+version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0ca84f3628370c59db74ee214b3263d58f9aadd9b4fe7e711fd87dc452b7f163"
+checksum = "2ab91ebe16eb252986481c5b62f6098f3b698a45e34b5b98200cf20dd2484a44"
dependencies = [
"anstyle",
"anstyle-parse",
"anstyle-query",
"anstyle-wincon",
"colorchoice",
- "is-terminal",
"utf8parse",
]
[[package]]
name = "anstyle"
-version = "1.0.1"
+version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3a30da5c5f2d5e72842e00bcb57657162cdabef0931f40e2deb9b4140440cecd"
+checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87"
[[package]]
name = "anstyle-parse"
-version = "0.2.1"
+version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "938874ff5980b03a87c5524b3ae5b59cf99b1d6bc836848df7bc5ada9643c333"
+checksum = "317b9a89c1868f5ea6ff1d9539a69f45dffc21ce321ac1fd1160dfa48c8e2140"
dependencies = [
"utf8parse",
]
@@ -80,9 +79,9 @@ dependencies = [
[[package]]
name = "anstyle-wincon"
-version = "1.0.1"
+version = "3.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "180abfa45703aebe0093f79badacc01b8fd4ea2e35118747e5811127f926e188"
+checksum = "f0699d10d2f4d628a98ee7b57b289abbc98ff3bad977cb3152709d4bf2330628"
dependencies = [
"anstyle",
"windows-sys",
@@ -132,9 +131,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "bitflags"
-version = "2.3.3"
+version = "2.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42"
+checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635"
[[package]]
name = "block-buffer"
@@ -147,9 +146,12 @@ dependencies = [
[[package]]
name = "cc"
-version = "1.0.79"
+version = "1.0.83"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f"
+checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0"
+dependencies = [
+ "libc",
+]
[[package]]
name = "cfg-if"
@@ -159,20 +161,19 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "clap"
-version = "4.3.15"
+version = "4.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "8f644d0dac522c8b05ddc39aaaccc5b136d5dc4ff216610c5641e3be5becf56c"
+checksum = "d04704f56c2cde07f43e8e2c154b43f216dc5c92fc98ada720177362f953b956"
dependencies = [
"clap_builder",
"clap_derive",
- "once_cell",
]
[[package]]
name = "clap_builder"
-version = "4.3.15"
+version = "4.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "af410122b9778e024f9e0fb35682cc09cc3f85cad5e8d3ba8f47a9702df6e73d"
+checksum = "0e231faeaca65ebd1ea3c737966bf858971cd38c3849107aa3ea7de90a804e45"
dependencies = [
"anstream",
"anstyle",
@@ -182,21 +183,21 @@ dependencies = [
[[package]]
name = "clap_derive"
-version = "4.3.12"
+version = "4.4.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "54a9bb5758fc5dfe728d1019941681eccaf0cf8a4189b692a0ee2f2ecf90a050"
+checksum = "0862016ff20d69b84ef8247369fabf5c008a7417002411897d40ee1f4532b873"
dependencies = [
"heck",
"proc-macro2",
"quote",
- "syn 2.0.27",
+ "syn 2.0.32",
]
[[package]]
name = "clap_lex"
-version = "0.5.0"
+version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b"
+checksum = "cd7cc57abe963c6d3b9d8be5b06ba7c8957a930305ca90304f24ef040aa6f961"
[[package]]
name = "colorchoice"
@@ -284,13 +285,13 @@ checksum = "675e35c02a51bb4d4618cb4885b3839ce6d1787c97b664474d9208d074742e20"
[[package]]
name = "egglog"
version = "0.1.0"
-source = "git+https://github.com/egraphs-good/egglog?rev=4d67f262a6f27aa5cfb62a2cfc7df968959105df#4d67f262a6f27aa5cfb62a2cfc7df968959105df"
+source = "git+https://github.com/egraphs-good/egglog?rev=45d05e727cceaab13413b4e51a60ee3be9fbf403#45d05e727cceaab13413b4e51a60ee3be9fbf403"
dependencies = [
"clap",
"egraph-serialize",
"env_logger",
- "hashbrown 0.14.0",
- "indexmap 2.0.0",
+ "hashbrown 0.14.1",
+ "indexmap",
"instant",
"lalrpop",
"lalrpop-util 0.20.0",
@@ -327,7 +328,7 @@ version = "0.1.0"
source = "git+https://github.com/saulshanabrook/egraph-serialize?rev=a3f6fef9b958a335367d80d51e028c6db886fb6e#a3f6fef9b958a335367d80d51e028c6db886fb6e"
dependencies = [
"graphviz-rust",
- "indexmap 2.0.0",
+ "indexmap",
"once_cell",
"ordered-float",
"serde",
@@ -336,9 +337,9 @@ dependencies = [
[[package]]
name = "either"
-version = "1.8.1"
+version = "1.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91"
+checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07"
[[package]]
name = "ena"
@@ -370,9 +371,9 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5"
[[package]]
name = "errno"
-version = "0.3.1"
+version = "0.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a"
+checksum = "add4f07d43996f76ef320709726a556a9d4f965d9410d8d0271132d2f8293480"
dependencies = [
"errno-dragonfly",
"libc",
@@ -391,9 +392,9 @@ dependencies = [
[[package]]
name = "fastrand"
-version = "2.0.0"
+version = "2.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "6999dc1837253364c2ebb0704ba97994bd874e8f195d665c50b7548f6ea92764"
+checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5"
[[package]]
name = "fixedbitset"
@@ -449,9 +450,9 @@ dependencies = [
[[package]]
name = "hashbrown"
-version = "0.14.0"
+version = "0.14.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a"
+checksum = "7dfda62a12f55daeae5015f81b0baea145391cb4520f86c248fc615d72640d12"
dependencies = [
"ahash 0.8.3",
"allocator-api2",
@@ -465,9 +466,9 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8"
[[package]]
name = "hermit-abi"
-version = "0.3.2"
+version = "0.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b"
+checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7"
[[package]]
name = "humantime"
@@ -477,22 +478,12 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"
[[package]]
name = "indexmap"
-version = "1.9.3"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99"
-dependencies = [
- "autocfg",
- "hashbrown 0.12.3",
-]
-
-[[package]]
-name = "indexmap"
-version = "2.0.0"
+version = "2.0.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d"
+checksum = "8adf3ddd720272c6ea8bf59463c04e0f93d0bbf7c5439b691bca2987e0270897"
dependencies = [
"equivalent",
- "hashbrown 0.14.0",
+ "hashbrown 0.14.1",
"serde",
]
@@ -614,9 +605,9 @@ checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3"
[[package]]
name = "linux-raw-sys"
-version = "0.4.3"
+version = "0.4.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "09fc20d2ca12cb9f044c93e3bd6d32d523e6e2ec3db4f7b2939cd99026ecd3f0"
+checksum = "3852614a3bd9ca9804678ba6be5e3b8ce76dfc902cae004e3e0c44051b6e88db"
[[package]]
name = "lock_api"
@@ -739,19 +730,20 @@ dependencies = [
[[package]]
name = "pest"
-version = "2.7.2"
+version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1acb4a4365a13f749a93f1a094a7805e5cfa0955373a9de860d962eaa3a5fe5a"
+checksum = "c022f1e7b65d6a24c0dbbd5fb344c66881bc01f3e5ae74a1c8100f2f985d98a4"
dependencies = [
+ "memchr",
"thiserror",
"ucd-trie",
]
[[package]]
name = "pest_derive"
-version = "2.7.2"
+version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "666d00490d4ac815001da55838c500eafb0320019bbaa44444137c48b443a853"
+checksum = "35513f630d46400a977c4cb58f78e1bfbe01434316e60c37d27b9ad6139c66d8"
dependencies = [
"pest",
"pest_generator",
@@ -759,22 +751,22 @@ dependencies = [
[[package]]
name = "pest_generator"
-version = "2.7.2"
+version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "68ca01446f50dbda87c1786af8770d535423fa8a53aec03b8f4e3d7eb10e0929"
+checksum = "bc9fc1b9e7057baba189b5c626e2d6f40681ae5b6eb064dc7c7834101ec8123a"
dependencies = [
"pest",
"pest_meta",
"proc-macro2",
"quote",
- "syn 2.0.27",
+ "syn 2.0.32",
]
[[package]]
name = "pest_meta"
-version = "2.7.2"
+version = "2.7.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "56af0a30af74d0445c0bf6d9d051c979b516a1a5af790d251daee76005420a48"
+checksum = "1df74e9e7ec4053ceb980e7c0c8bd3594e977fde1af91daba9c928e8e8c6708d"
dependencies = [
"once_cell",
"pest",
@@ -783,12 +775,12 @@ dependencies = [
[[package]]
name = "petgraph"
-version = "0.6.3"
+version = "0.6.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4"
+checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9"
dependencies = [
"fixedbitset",
- "indexmap 1.9.3",
+ "indexmap",
]
[[package]]
@@ -1005,11 +997,11 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
[[package]]
name = "rustix"
-version = "0.38.4"
+version = "0.38.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0a962918ea88d644592894bc6dc55acc6c0956488adcebbfb6e273506b7fd6e5"
+checksum = "d7db8590df6dfcd144d22afd1b83b36c21a18d7cbc1dc4bb5295a8712e9eb662"
dependencies = [
- "bitflags 2.3.3",
+ "bitflags 2.4.0",
"errno",
"libc",
"linux-raw-sys",
@@ -1036,31 +1028,31 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "serde"
-version = "1.0.179"
+version = "1.0.188"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "0a5bf42b8d227d4abf38a1ddb08602e229108a517cd4e5bb28f9c7eaafdce5c0"
+checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
-version = "1.0.179"
+version = "1.0.188"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "741e124f5485c7e60c03b043f79f320bff3527f4bbf12cf3831750dc46a0ec2c"
+checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.27",
+ "syn 2.0.32",
]
[[package]]
name = "serde_json"
-version = "1.0.105"
+version = "1.0.107"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360"
+checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65"
dependencies = [
- "indexmap 2.0.0",
+ "indexmap",
"itoa",
"ryu",
"serde",
@@ -1068,9 +1060,9 @@ dependencies = [
[[package]]
name = "sha2"
-version = "0.10.7"
+version = "0.10.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "479fb9d862239e610720565ca91403019f2f00410f1864c5aa7479b950a76ed8"
+checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8"
dependencies = [
"cfg-if",
"cpufeatures",
@@ -1079,9 +1071,9 @@ dependencies = [
[[package]]
name = "siphasher"
-version = "0.3.10"
+version = "0.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de"
+checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d"
[[package]]
name = "smallvec"
@@ -1135,9 +1127,9 @@ dependencies = [
[[package]]
name = "syn"
-version = "2.0.27"
+version = "2.0.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "b60f673f44a8255b9c8c657daf66a596d435f2da81a555b06dc644d080ba45e0"
+checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2"
dependencies = [
"proc-macro2",
"quote",
@@ -1152,9 +1144,9 @@ checksum = "df8e77cb757a61f51b947ec4a7e3646efd825b73561db1c232a8ccb639e611a0"
[[package]]
name = "tempfile"
-version = "3.7.1"
+version = "3.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "dc02fddf48964c42031a0b3fe0428320ecf3a73c401040fc0096f97794310651"
+checksum = "cb94d2f3cc536af71caac6b6fcebf65860b347e7ce0cc9ebe8f70d3e521054ef"
dependencies = [
"cfg-if",
"fastrand",
@@ -1176,31 +1168,31 @@ dependencies = [
[[package]]
name = "termcolor"
-version = "1.2.0"
+version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6"
+checksum = "6093bad37da69aab9d123a8091e4be0aa4a03e4d601ec641c327398315f62b64"
dependencies = [
"winapi-util",
]
[[package]]
name = "thiserror"
-version = "1.0.43"
+version = "1.0.49"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "a35fc5b8971143ca348fa6df4f024d4d55264f3468c71ad1c2f365b0a4d58c42"
+checksum = "1177e8c6d7ede7afde3585fd2513e611227efd6481bd78d2e82ba1ce16557ed4"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
-version = "1.0.43"
+version = "1.0.49"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "463fe12d7993d3b327787537ce8dd4dfa058de32fc2b195ef3cde03dc4771e8f"
+checksum = "10712f02019e9288794769fba95cd6847df9874d49d871d062172f9dd41bc4cc"
dependencies = [
"proc-macro2",
"quote",
- "syn 2.0.27",
+ "syn 2.0.32",
]
[[package]]
@@ -1214,9 +1206,9 @@ dependencies = [
[[package]]
name = "typenum"
-version = "1.16.0"
+version = "1.17.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba"
+checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825"
[[package]]
name = "ucd-trie"
@@ -1278,9 +1270,9 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6"
[[package]]
name = "winapi-util"
-version = "0.1.5"
+version = "0.1.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178"
+checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596"
dependencies = [
"winapi",
]
diff --git a/Cargo.toml b/Cargo.toml
index a10585e3..99d87ce4 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -10,9 +10,9 @@ crate-type = ["cdylib"]
[dependencies]
pyo3 = { version = "0.18.1", features = ["extension-module"] }
-egglog = { git = "https://github.com/egraphs-good/egglog", rev = "4d67f262a6f27aa5cfb62a2cfc7df968959105df" }
+egglog = { git = "https://github.com/egraphs-good/egglog", rev = "45d05e727cceaab13413b4e51a60ee3be9fbf403" }
# egglog = { git = "https://github.com/oflatt/egg-smol", rev = "f6df3ff831b65405665e1751b0ef71c61b025432" }
-# egglog = { git = "https://github.com/saulshanabrook/egg-smol", rev = "c01695618ed4de2fbfa8116476e208bc1ca86612" }
+# egglog = { git = "https://github.com/saulshanabrook/egg-smol", rev = "38b3014b34399cc78887ede09c845b2a5d6c7d19" }
pyo3-log = "0.8.1"
log = "0.4.17"
diff --git a/docs/changelog.md b/docs/changelog.md
index ca99dc78..d8f6de38 100644
--- a/docs/changelog.md
+++ b/docs/changelog.md
@@ -4,6 +4,20 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea
## Unreleased
+- Bump [egglog dep](https://github.com/egraphs-good/egglog/compare/4d67f262a6f27aa5cfb62a2cfc7df968959105df...45d05e727cceaab13413b4e51a60ee3be9fbf403)
+
+### New Features
+
+- Adds ability for custom user defined types in a union for proper static typing with conversions [#49](https://github.com/metadsl/egglog-python/pull/49)
+- Adds `py_eval` function to `EGraph` as a helper to eval Python code. [#49](https://github.com/metadsl/egglog-python/pull/49)
+- Adds on hover behavior for edges in graphviz SVG output to make them easier to trace [#49](https://github.com/metadsl/egglog-python/pull/49)
+- Adds `egglog.exp.program_gen` module that will compile expressions into Python statements/functions [#49](https://github.com/metadsl/egglog-python/pull/49)
+- Adds `py_exec` primitive function for executing Python code [#49](https://github.com/metadsl/egglog-python/pull/49)
+
+### Bug fixes
+
+- Clean up example in tutorial with demand based expression generation [#49](https://github.com/metadsl/egglog-python/pull/49)
+
## 0.6.0 (2023-09-20)
- Bump [egglog dep](https://github.com/egraphs-good/egglog/compare/c83fc750878755eb610a314da90f9273b3bfe25d...4d67f262a6f27aa5cfb62a2cfc7df968959105df)
@@ -19,8 +33,6 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea
- Add `Relation` and `PrintOverallStatistics` low level commands [#46](https://github.com/metadsl/egglog-python/pull/46)
- Adds `count-matches` and `replace` string commands [#46](https://github.com/metadsl/egglog-python/pull/46)
-### Bug fixes
-
### Uncategorized
- Added initial supported for Python objects [#31](https://github.com/metadsl/egglog-python/pull/31)
diff --git a/docs/reference/egglog-translation.md b/docs/reference/egglog-translation.md
index 25877037..6a2c73a3 100644
--- a/docs/reference/egglog-translation.md
+++ b/docs/reference/egglog-translation.md
@@ -245,21 +245,6 @@ As shown above, we can also use the `@classmethod` and `@property` decorators to
Note that reflected methods (i.e. `__radd__`) are handled as a special case. If defined, they won't create their own egglog functions.
Instead, whenever a reflected method is called, we will try to find the corresponding non-reflected method and call that instead.
-#### Custom Type Promotion
-
-Similar to how an `int` can be automatically upcasted to an `i64`, we also support registering conversion to your custom types. For example:
-
-```{code-cell} python
-converter(int, Math, Math)
-converter(str, Math, Math.var)
-
-Math(2) + 30 + "x"
-# equal to
-Math(2) + Math(i64(30)) + Math.var(String("x"))
-```
-
-Regstering a conversion from A to B will also register all transitively reachable conversions from A to B.
-
### Declarations
In egglog, the `(declare ...)` command is syntactic sugar for a nullary function. In Python, these can be declare either as class variables or with the toplevel `egraph.constant` function:
diff --git a/docs/reference/python-integration.md b/docs/reference/python-integration.md
index 3fa8cd59..8850dcf4 100644
--- a/docs/reference/python-integration.md
+++ b/docs/reference/python-integration.md
@@ -72,8 +72,13 @@ egraph.load_object(egraph.extract(PyObject.from_int(1)))
We also support evaling arbitrary Python bode, given some locals and globals. This technically allows us to implement any Python method:
```{code-cell} python
-empty_dict = egraph.save_object({})
-egraph.load_object(egraph.extract(py_eval("1 + 2", empty_dict, empty_dict)))
+egraph.load_object(egraph.extract(py_eval("1 + 2")))
+```
+
+Execing Python code is also supported. In this case, the return value will be the updated globals dict, which will be copied first before using.
+
+```{code-cell} python
+egraph.load_object(egraph.extract(py_exec("x = 1 + 2")))
```
Alongside this, we support a function `dict_update` method, which can allow you to combine some local local egglog expressions alongside, say, the locals and globals of the Python code you are evaling.
@@ -87,11 +92,68 @@ locals_expr = egraph.save_object(locals())
globals_expr = egraph.save_object(globals())
# Need `one` to map to the expression for `1` not the Python object of the expression
amended_globals = globals_expr.dict_update(PyObject.from_string("one"), one)
-evalled = py_eval("one + 2", locals_expr, amended_globals)
+evalled = py_eval("my_add(one, 2)", locals_expr, amended_globals)
assert egraph.load_object(egraph.extract(evalled)) == 3
```
-This is a bit subtle at the moment, and we plan on adding an easier wrapper to eval arbitrary Python code in the future.
+### Simpler Eval
+
+Instead of using the above low level primitive for evaling, there is a higher level wrapper function, `egraph.eval_fn`.
+
+It takes in a Python function and converts it to a function of PyObjects, by using `py_eval`
+under the hood.
+
+The above code code be re-written like this:
+
+```{code-cell} python
+def my_add(a, b):
+ return a + b
+
+evalled = egraph.eval_fn(lambda a: my_add(a, 2))(one)
+assert egraph.load_object(egraph.extract(evalled)) == 3
+```
+
+#### Custom Type Promotion
+
+Similar to how an `int` can be automatically upcasted to an `i64`, we also support registering conversion to your custom types. For example:
+
+```{code-cell} python
+@egraph.class_
+class Math(Expr):
+ def __init__(self, x: i64Like) -> None: ...
+
+ @classmethod
+ def var(cls, name: StringLike) -> Math: ...
+
+ def __add__(self, other: Math) -> Math: ...
+
+converter(i64, Math, Math)
+converter(String, Math, Math.var)
+
+Math(2) + i64(30) + String("x")
+# equal to
+Math(2) + Math(i64(30)) + Math.var(String("x"))
+```
+
+Regstering a conversion from A to B will also register all transitively reachable conversions from A to B, so you can also use:
+
+```{code-cell} python
+Math(2) + 30 + "x"
+```
+
+If you want to have this work with the static type checker, you can define your own `Union` type, which MUST include
+have the Egglog class as the first item in the union. For example, in this case you could then define:
+
+```{code-cell} python
+from typing import Union
+MathLike = Union[Math, i64Like, StringLike]
+
+@egraph.function
+def some_math_fn(x: MathLike) -> MathLike:
+ ...
+
+some_math_fn(10)
+```
## "Preserved" methods
diff --git a/docs/tutorials/getting-started.ipynb b/docs/tutorials/getting-started.ipynb
index 9931fe86..95396b89 100644
--- a/docs/tutorials/getting-started.ipynb
+++ b/docs/tutorials/getting-started.ipynb
@@ -1,562 +1,1255 @@
{
- "cells": [
- {
- "attachments": {},
- "cell_type": "markdown",
- "id": "ffabb623",
- "metadata": {
- "tags": []
- },
- "source": [
- "# Getting Started - Matrix Multiplication\n",
- "\n",
- "In this tutorial, you will learn how to:\n",
- "\n",
- "1. Install `egglog` Python\n",
- "2. Create a representation for matrices and some simplification rules for them. This will be based off of the [matrix multiplication example](https://github.com/egraphs-good/egglog/blob/08a6e8f/tests/matrix.egg) in the egglog repository. By using our high level wrapper, we can rely on Python's built in static type checker to check the correctness of your representation.\n",
- "3. Try out using our library in an interactive notebook.\n",
- "\n",
- "## Install egglog Python\n",
- "\n",
- "First, you will need to have a working Python interpreter. In this tutorial, we will [use `miniconda`](https://docs.conda.io/en/latest/miniconda.html) to create a new Python environment and activate it:\n",
- "\n",
- "```bash\n",
- "$ brew install miniconda\n",
- "$ conda create -n egglog-python python=3.11\n",
- "$ conda activate egglog-python\n",
- "```\n",
- "\n",
- "Then we want to install `egglog` Python. `egglog` Python can run on any recent Python version, and is tested on 3.8 - 3.11. To install it, run:\n",
- "\n",
- "```bash\n",
- "$ pip install egglog\n",
- "```\n",
- "\n",
- "To test you have installed it correctly, run:\n",
- "\n",
- "```bash\n",
- "$ python -m 'import egglog'\n",
- "```\n",
- "\n",
- "We also want to install `mypy` for static type checking. This is not required, but it will help us write correct representations. To install it, run:\n",
- "\n",
- "```bash\n",
- "$ pip install mypy\n",
- "```\n",
- "\n",
- "## Creating an E-Graph\n",
- "\n",
- "In this tutorial, we will use [VS Code](https://code.visualstudio.com/) to create file, `matrix.py`, to include our egraph\n",
- "and the simplification rules:\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 1,
- "id": "7369b71b",
- "metadata": {},
- "outputs": [],
- "source": [
- "from __future__ import annotations\n",
- "\n",
- "from egglog import *\n",
- "\n",
- "egraph = EGraph()"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "id": "814a51c5",
- "metadata": {},
- "source": [
- "## Defining Dimensions\n",
- "\n",
- "We will start by defining a representation for integers, which we will use to represent\n",
- "the dimensions of the matrix:\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 2,
- "id": "04fa991a",
- "metadata": {},
- "outputs": [],
- "source": [
- "@egraph.class_\n",
- "class Dim(Expr):\n",
- " \"\"\"\n",
- " A dimension of a matix.\n",
- "\n",
- " >>> Dim(3) * Dim.named(\"n\")\n",
- " Dim(3) * Dim.named(\"n\")\n",
- " \"\"\"\n",
- "\n",
- " def __init__(self, value: i64Like) -> None:\n",
- " ...\n",
- "\n",
- " @classmethod\n",
- " def named(cls, name: StringLike) -> Dim:\n",
- " ...\n",
- "\n",
- " def __mul__(self, other: Dim) -> Dim:\n",
- " ..."
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "id": "f5098a2b",
- "metadata": {
- "tags": []
- },
- "source": [
- "As you can see, you must wrap any class with the `egraph.class_` to register\n",
- "it with the egraph and be able to use it like a Python class.\n",
- "\n",
- "### Testing in a notebook\n",
- "\n",
- "We can try out this by [creating a new notebook](https://code.visualstudio.com/docs/datascience/jupyter-notebooks#_create-or-open-a-jupyter-notebook) which imports this file:\n",
- "\n",
- "```python\n",
- "from matrix import *\n",
- "```\n"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "id": "fd43c7ef",
- "metadata": {},
- "source": [
- "We can then create a new `Dim` object:\n"
- ]
- },
+ "cells": [
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "ffabb623",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "# Getting Started - Matrix Multiplication\n",
+ "\n",
+ "In this tutorial, you will learn how to:\n",
+ "\n",
+ "1. Install `egglog` Python\n",
+ "2. Create a representation for matrices and some simplification rules for them. This will be based off of the [matrix multiplication example](https://github.com/egraphs-good/egglog/blob/08a6e8f/tests/matrix.egg) in the egglog repository. By using our high level wrapper, we can rely on Python's built in static type checker to check the correctness of your representation.\n",
+ "3. Try out using our library in an interactive notebook.\n",
+ "\n",
+ "## Install egglog Python\n",
+ "\n",
+ "First, you will need to have a working Python interpreter. In this tutorial, we will [use `miniconda`](https://docs.conda.io/en/latest/miniconda.html) to create a new Python environment and activate it:\n",
+ "\n",
+ "```bash\n",
+ "$ brew install miniconda\n",
+ "$ conda create -n egglog-python python=3.11\n",
+ "$ conda activate egglog-python\n",
+ "```\n",
+ "\n",
+ "Then we want to install `egglog` Python. `egglog` Python can run on any recent Python version, and is tested on 3.8 - 3.11. To install it, run:\n",
+ "\n",
+ "```bash\n",
+ "$ pip install egglog\n",
+ "```\n",
+ "\n",
+ "To test you have installed it correctly, run:\n",
+ "\n",
+ "```bash\n",
+ "$ python -m 'import egglog'\n",
+ "```\n",
+ "\n",
+ "We also want to install `mypy` for static type checking. This is not required, but it will help us write correct representations. To install it, run:\n",
+ "\n",
+ "```bash\n",
+ "$ pip install mypy\n",
+ "```\n",
+ "\n",
+ "## Creating an E-Graph\n",
+ "\n",
+ "In this tutorial, we will use [VS Code](https://code.visualstudio.com/) to create file, `matrix.py`, to include our egraph\n",
+ "and the simplification rules:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "7369b71b",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "from __future__ import annotations\n",
+ "\n",
+ "from egglog import *\n",
+ "\n",
+ "egraph = EGraph()"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "814a51c5",
+ "metadata": {},
+ "source": [
+ "## Defining Dimensions\n",
+ "\n",
+ "We will start by defining a representation for integers, which we will use to represent\n",
+ "the dimensions of the matrix:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "04fa991a",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@egraph.class_\n",
+ "class Dim(Expr):\n",
+ " \"\"\"\n",
+ " A dimension of a matix.\n",
+ "\n",
+ " >>> Dim(3) * Dim.named(\"n\")\n",
+ " Dim(3) * Dim.named(\"n\")\n",
+ " \"\"\"\n",
+ "\n",
+ " def __init__(self, value: i64Like) -> None:\n",
+ " ...\n",
+ "\n",
+ " @classmethod\n",
+ " def named(cls, name: StringLike) -> Dim:\n",
+ " ...\n",
+ "\n",
+ " def __mul__(self, other: Dim) -> Dim:\n",
+ " ..."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "f5098a2b",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "As you can see, you must wrap any class with the `egraph.class_` to register\n",
+ "it with the egraph and be able to use it like a Python class.\n",
+ "\n",
+ "### Testing in a notebook\n",
+ "\n",
+ "We can try out this by [creating a new notebook](https://code.visualstudio.com/docs/datascience/jupyter-notebooks#_create-or-open-a-jupyter-notebook) which imports this file:\n",
+ "\n",
+ "```python\n",
+ "from matrix import *\n",
+ "```\n"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "fd43c7ef",
+ "metadata": {},
+ "source": [
+ "We can then create a new `Dim` object:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "b6424530",
+ "metadata": {},
+ "outputs": [
{
- "cell_type": "code",
- "execution_count": 8,
- "id": "b6424530",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "(Dim.named(\"x\") * Dim(10)) * Dim(10)"
- ]
- },
- "execution_count": 8,
- "metadata": {},
- "output_type": "execute_result"
- }
+ "data": {
+ "text/html": [
+ "
(Dim.named("x")*Dim(10))*Dim(10)\n",
+ "
\n"
],
- "source": [
- "x = Dim.named(\"x\")\n",
- "ten = Dim(10)\n",
- "res = x * ten * ten\n",
- "res"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "id": "ef5ebb16",
- "metadata": {},
- "source": [
- "We see that the output is not evaluated, it's just a representation of the computation as well as the type. This is because we haven't defined any simplification rules yet.\n",
- "\n",
- "We can also try to create a dimension from an invalid type, or use it in an invalid way, we get a type error before we even run the code:\n",
- "\n",
- "```python\n",
- "x - ten\n",
- "```\n",
- "\n",
- "![Screenshot of VS Code showing a type error](./screenshot-1.png)\n",
- "\n",
- "## Dimension Replacements\n",
- "\n",
- "Now we will register some replacements for our dimensions and see how we can interface with egg to get it\n",
- "to execute them.\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 4,
- "id": "b06b1749",
- "metadata": {},
- "outputs": [],
- "source": [
- "a, b, c = vars_(\"a b c\", Dim)\n",
- "i, j = vars_(\"i j\", i64)\n",
- "egraph.register(\n",
- " rewrite(a * (b * c)).to((a * b) * c),\n",
- " rewrite((a * b) * c).to(a * (b * c)),\n",
- " rewrite(Dim(i) * Dim(j)).to(Dim(i * j)),\n",
- " rewrite(a * b).to(b * a),\n",
- ")"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "id": "167722d1-60b8-452a-ae54-6a8df4db5b00",
- "metadata": {},
- "source": [
- "You might notice that unlike a traditional term rewriting system, we don't specify any order for these rewrites. They will be executed until the graph is fully saturated, meaning that no new terms are created.\n"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "id": "a4d2c911",
- "metadata": {},
- "source": [
- "We can also see how the type checking can help us. If we try to create a rewrite from a `Dim` to an `i64` we see that we get a type error:\n",
- "\n",
- "![Screenshot of VS Code showing a type error](./screenshot-2.png)\n"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "id": "76dc1672-dba6-44ab-b9f1-aa01de685fb1",
- "metadata": {},
- "source": [
- "### Testing\n",
- "\n",
- "Going back to the notebook, we can test out the that the rewrites are working.\n",
- "We can run some number of iterations and extract out the lowest cost expression which is equivalent to our variable:\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 5,
- "id": "31afa12e-da68-4398-91fa-14523f6c099a",
- "metadata": {
- "tags": []
- },
- "outputs": [
- {
- "data": {
- "text/plain": [
- "Dim.named(\"x\") * Dim(100)"
- ]
- },
- "execution_count": 5,
- "metadata": {},
- "output_type": "execute_result"
- }
+ "text/latex": [
+ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n",
+ "\\PY{p}{(}\\PY{n}{Dim}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{x}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)} \\PY{o}{*} \\PY{n}{Dim}\\PY{p}{(}\\PY{l+m+mi}{10}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{Dim}\\PY{p}{(}\\PY{l+m+mi}{10}\\PY{p}{)}\n",
+ "\\end{Verbatim}\n"
],
- "source": [
- "egraph.simplify(res, 10)"
+ "text/plain": [
+ "(Dim.named(\"x\") * Dim(10)) * Dim(10)"
]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "id": "7e44104c-d87b-441d-a717-92d42aab9d37",
- "metadata": {},
- "source": [
- "## Matrix Expressions\n",
- "\n",
- "Now that we have defined dimensions, we can define matrices as well as some functions on them:\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 9,
- "id": "c5b96cfb",
- "metadata": {},
- "outputs": [],
- "source": [
- "@egraph.class_\n",
- "class Matrix(Expr):\n",
- " @classmethod\n",
- " def identity(cls, dim: Dim) -> Matrix:\n",
- " \"\"\"\n",
- " Create an identity matrix of the given dimension.\n",
- " \"\"\"\n",
- " ...\n",
- "\n",
- " @classmethod\n",
- " def named(cls, name: StringLike) -> Matrix:\n",
- " \"\"\"\n",
- " Create a named matrix.\n",
- " \"\"\"\n",
- " ...\n",
- "\n",
- " def __matmul__(self, other: Matrix) -> Matrix:\n",
- " \"\"\"\n",
- " Matrix multiplication.\n",
- " \"\"\"\n",
- " ...\n",
- "\n",
- " def nrows(self) -> Dim:\n",
- " \"\"\"\n",
- " Number of rows in the matrix.\n",
- " \"\"\"\n",
- " ...\n",
- "\n",
- " def ncols(self) -> Dim:\n",
- " \"\"\"\n",
- " Number of columns in the matrix.\n",
- " \"\"\"\n",
- " ...\n",
- "\n",
- "\n",
- "@egraph.function\n",
- "def kron(a: Matrix, b: Matrix) -> Matrix:\n",
- " \"\"\"\n",
- " Kronecker product of two matrices.\n",
- "\n",
- " https://en.wikipedia.org/wiki/Kronecker_product#Definition\n",
- " \"\"\"\n",
- " ..."
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "id": "be8e6526",
- "metadata": {},
- "source": [
- "### Rows/cols Replacements\n",
- "\n",
- "We can also define some replacements to understand the number of rows and columns of a matrix:\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 10,
- "id": "cb2b4fb8",
- "metadata": {},
- "outputs": [],
- "source": [
- "A, B, C, D = vars_(\"A B C D\", Matrix)\n",
- "egraph.register(\n",
- " # The dimensions of a kronecker product are the product of the dimensions\n",
- " rewrite(kron(A, B).nrows()).to(A.nrows() * B.nrows()),\n",
- " rewrite(kron(A, B).ncols()).to(A.ncols() * B.ncols()),\n",
- " # The dimensions of a matrix multiplication are the number of rows of the first\n",
- " # matrix and the number of columns of the second matrix.\n",
- " rewrite((A @ B).nrows()).to(A.nrows()),\n",
- " rewrite((A @ B).ncols()).to(B.ncols()),\n",
- " # The dimensions of an identity matrix are the input dimension\n",
- " rewrite(Matrix.identity(a).nrows()).to(a),\n",
- " rewrite(Matrix.identity(a).ncols()).to(a),\n",
- ")"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "id": "13b969e8",
- "metadata": {},
- "source": [
- "We can try these out in our notebook (after restarting and re-importing) to compute the dimensions after some operations:\n"
- ]
- },
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "x = Dim.named(\"x\")\n",
+ "ten = Dim(10)\n",
+ "res = x * ten * ten\n",
+ "res"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "ef5ebb16",
+ "metadata": {},
+ "source": [
+ "We see that the output is not evaluated, it's just a representation of the computation as well as the type. This is because we haven't defined any simplification rules yet.\n",
+ "\n",
+ "We can also try to create a dimension from an invalid type, or use it in an invalid way, we get a type error before we even run the code:\n",
+ "\n",
+ "```python\n",
+ "x - ten\n",
+ "```\n",
+ "\n",
+ "![Screenshot of VS Code showing a type error](./screenshot-1.png)\n",
+ "\n",
+ "## Dimension Replacements\n",
+ "\n",
+ "Now we will register some replacements for our dimensions and see how we can interface with egg to get it\n",
+ "to execute them.\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "b06b1749",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "a, b, c = vars_(\"a b c\", Dim)\n",
+ "i, j = vars_(\"i j\", i64)\n",
+ "egraph.register(\n",
+ " rewrite(a * (b * c)).to((a * b) * c),\n",
+ " rewrite((a * b) * c).to(a * (b * c)),\n",
+ " rewrite(Dim(i) * Dim(j)).to(Dim(i * j)),\n",
+ " rewrite(a * b).to(b * a),\n",
+ ")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "167722d1-60b8-452a-ae54-6a8df4db5b00",
+ "metadata": {},
+ "source": [
+ "You might notice that unlike a traditional term rewriting system, we don't specify any order for these rewrites. They will be executed until the graph is fully saturated, meaning that no new terms are created.\n"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "a4d2c911",
+ "metadata": {},
+ "source": [
+ "We can also see how the type checking can help us. If we try to create a rewrite from a `Dim` to an `i64` we see that we get a type error:\n",
+ "\n",
+ "![Screenshot of VS Code showing a type error](./screenshot-2.png)\n"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "76dc1672-dba6-44ab-b9f1-aa01de685fb1",
+ "metadata": {},
+ "source": [
+ "### Testing\n",
+ "\n",
+ "Going back to the notebook, we can test out the that the rewrites are working.\n",
+ "We can run some number of iterations and extract out the lowest cost expression which is equivalent to our variable:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 5,
+ "id": "31afa12e-da68-4398-91fa-14523f6c099a",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
{
- "cell_type": "code",
- "execution_count": 13,
- "id": "8d18be2d",
- "metadata": {},
- "outputs": [
- {
- "name": "stdout",
- "output_type": "stream",
- "text": [
- "Dim.named(\"y\")\n",
- "Dim.named(\"x\")\n"
- ]
- }
+ "data": {
+ "text/html": [
+ "
Dim.named("x")*Dim(100)\n",
+ "
\n"
],
- "source": [
- "# If we multiply two identity matrices, we should be able to get the number of columns of the result\n",
- "x = Matrix.identity(Dim.named(\"x\"))\n",
- "y = Matrix.identity(Dim.named(\"y\"))\n",
- "x_mult_y = x @ y\n",
- "print(egraph.simplify(x_mult_y.ncols(), 10))\n",
- "print(egraph.simplify(x_mult_y.nrows(), 10))"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "id": "2f2c68c3",
- "metadata": {},
- "source": [
- "### Operation replacements\n",
- "\n",
- "We can also define some replacements for matrix operations:\n"
- ]
- },
- {
- "cell_type": "code",
- "execution_count": 14,
- "id": "18a91684",
- "metadata": {},
- "outputs": [],
- "source": [
- "egraph.register(\n",
- " # Multiplication by an identity matrix is the same as the other matrix\n",
- " rewrite(A @ Matrix.identity(a)).to(A),\n",
- " rewrite(Matrix.identity(a) @ A).to(A),\n",
- " # Matrix multiplication is associative\n",
- " rewrite((A @ B) @ C).to(A @ (B @ C)),\n",
- " rewrite(A @ (B @ C)).to((A @ B) @ C),\n",
- " # Kronecker product is associative\n",
- " rewrite(kron(A, kron(B, C))).to(kron(kron(A, B), C)),\n",
- " rewrite(kron(kron(A, B), C)).to(kron(A, kron(B, C))),\n",
- " # Kronecker product distributes over matrix multiplication\n",
- " rewrite(kron(A @ C, B @ D)).to(kron(A, B) @ kron(C, D)),\n",
- " rewrite(kron(A, B) @ kron(C, D)).to(\n",
- " kron(A @ C, B @ D),\n",
- " # Only when the dimensions match\n",
- " eq(A.ncols()).to(C.nrows()),\n",
- " eq(B.ncols()).to(D.nrows()),\n",
- " ),\n",
- ")"
- ]
- },
- {
- "attachments": {},
- "cell_type": "markdown",
- "id": "1cd649dc",
- "metadata": {},
- "source": [
- "In our previous tests, we had to add the `ncols` and `nrows` operations to the e-graph seperately in order to have them be simplified. We can write some \"demand\" rules which automatically add these operations to the e-graph when they are needed:\n"
+ "text/latex": [
+ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n",
+ "\\PY{n}{Dim}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{x}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)} \\PY{o}{*} \\PY{n}{Dim}\\PY{p}{(}\\PY{l+m+mi}{100}\\PY{p}{)}\n",
+ "\\end{Verbatim}\n"
+ ],
+ "text/plain": [
+ "Dim.named(\"x\") * Dim(100)"
]
- },
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "egraph.simplify(res, 10)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "7e44104c-d87b-441d-a717-92d42aab9d37",
+ "metadata": {},
+ "source": [
+ "## Matrix Expressions\n",
+ "\n",
+ "Now that we have defined dimensions, we can define matrices as well as some functions on them:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 6,
+ "id": "c5b96cfb",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "@egraph.class_\n",
+ "class Matrix(Expr):\n",
+ " @classmethod\n",
+ " def identity(cls, dim: Dim) -> Matrix:\n",
+ " \"\"\"\n",
+ " Create an identity matrix of the given dimension.\n",
+ " \"\"\"\n",
+ " ...\n",
+ "\n",
+ " @classmethod\n",
+ " def named(cls, name: StringLike) -> Matrix:\n",
+ " \"\"\"\n",
+ " Create a named matrix.\n",
+ " \"\"\"\n",
+ " ...\n",
+ "\n",
+ " def __matmul__(self, other: Matrix) -> Matrix:\n",
+ " \"\"\"\n",
+ " Matrix multiplication.\n",
+ " \"\"\"\n",
+ " ...\n",
+ "\n",
+ " def nrows(self) -> Dim:\n",
+ " \"\"\"\n",
+ " Number of rows in the matrix.\n",
+ " \"\"\"\n",
+ " ...\n",
+ "\n",
+ " def ncols(self) -> Dim:\n",
+ " \"\"\"\n",
+ " Number of columns in the matrix.\n",
+ " \"\"\"\n",
+ " ...\n",
+ "\n",
+ "\n",
+ "@egraph.function\n",
+ "def kron(a: Matrix, b: Matrix) -> Matrix:\n",
+ " \"\"\"\n",
+ " Kronecker product of two matrices.\n",
+ "\n",
+ " https://en.wikipedia.org/wiki/Kronecker_product#Definition\n",
+ " \"\"\"\n",
+ " ..."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "be8e6526",
+ "metadata": {},
+ "source": [
+ "### Rows/cols Replacements\n",
+ "\n",
+ "We can also define some replacements to understand the number of rows and columns of a matrix:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 7,
+ "id": "cb2b4fb8",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "A, B, C, D = vars_(\"A B C D\", Matrix)\n",
+ "egraph.register(\n",
+ " # The dimensions of a kronecker product are the product of the dimensions\n",
+ " rewrite(kron(A, B).nrows()).to(A.nrows() * B.nrows()),\n",
+ " rewrite(kron(A, B).ncols()).to(A.ncols() * B.ncols()),\n",
+ " # The dimensions of a matrix multiplication are the number of rows of the first\n",
+ " # matrix and the number of columns of the second matrix.\n",
+ " rewrite((A @ B).nrows()).to(A.nrows()),\n",
+ " rewrite((A @ B).ncols()).to(B.ncols()),\n",
+ " # The dimensions of an identity matrix are the input dimension\n",
+ " rewrite(Matrix.identity(a).nrows()).to(a),\n",
+ " rewrite(Matrix.identity(a).ncols()).to(a),\n",
+ ")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "13b969e8",
+ "metadata": {},
+ "source": [
+ "We can try these out in our notebook (after restarting and re-importing) to compute the dimensions after some operations:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 8,
+ "id": "8d18be2d",
+ "metadata": {},
+ "outputs": [
{
- "cell_type": "code",
- "execution_count": 15,
- "id": "303ce7f3",
- "metadata": {},
- "outputs": [],
- "source": [
- "egraph.register(\n",
- " # demand rows and columns when we multiply matrices\n",
- " rule(eq(C).to(A @ B)).then(\n",
- " let(\"1\", A.ncols()),\n",
- " let(\"2\", A.nrows()),\n",
- " let(\"3\", B.nrows()),\n",
- " let(\"4\", B.ncols()),\n",
- " ),\n",
- " # demand rows and columns when we take the kronecker product\n",
- " rule(eq(C).to(kron(A, B))).then(\n",
- " let(\"1\", A.ncols()),\n",
- " let(\"2\", A.nrows()),\n",
- " let(\"3\", B.nrows()),\n",
- " let(\"4\", B.ncols()),\n",
- " ),\n",
- ")"
- ]
- },
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Dim.named(\"y\")\n",
+ "Dim.named(\"x\")\n"
+ ]
+ }
+ ],
+ "source": [
+ "# If we multiply two identity matrices, we should be able to get the number of columns of the result\n",
+ "x = Matrix.identity(Dim.named(\"x\"))\n",
+ "y = Matrix.identity(Dim.named(\"y\"))\n",
+ "x_mult_y = x @ y\n",
+ "print(egraph.simplify(x_mult_y.ncols(), 10))\n",
+ "print(egraph.simplify(x_mult_y.nrows(), 10))"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "2f2c68c3",
+ "metadata": {},
+ "source": [
+ "### Operation replacements\n",
+ "\n",
+ "We can also define some replacements for matrix operations:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 9,
+ "id": "18a91684",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "egraph.register(\n",
+ " # Multiplication by an identity matrix is the same as the other matrix\n",
+ " rewrite(A @ Matrix.identity(a)).to(A),\n",
+ " rewrite(Matrix.identity(a) @ A).to(A),\n",
+ " # Matrix multiplication is associative\n",
+ " rewrite((A @ B) @ C).to(A @ (B @ C)),\n",
+ " rewrite(A @ (B @ C)).to((A @ B) @ C),\n",
+ " # Kronecker product is associative\n",
+ " rewrite(kron(A, kron(B, C))).to(kron(kron(A, B), C)),\n",
+ " rewrite(kron(kron(A, B), C)).to(kron(A, kron(B, C))),\n",
+ " # Kronecker product distributes over matrix multiplication\n",
+ " rewrite(kron(A @ C, B @ D)).to(kron(A, B) @ kron(C, D)),\n",
+ " rewrite(kron(A, B) @ kron(C, D)).to(\n",
+ " kron(A @ C, B @ D),\n",
+ " # Only when the dimensions match\n",
+ " eq(A.ncols()).to(C.nrows()),\n",
+ " eq(B.ncols()).to(D.nrows()),\n",
+ " ),\n",
+ ")"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "1cd649dc",
+ "metadata": {},
+ "source": [
+ "In our previous tests, we had to add the `ncols` and `nrows` operations to the e-graph seperately in order to have them be simplified. We can write some \"demand\" rules which automatically add these operations to the e-graph when they are needed:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 10,
+ "id": "303ce7f3",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "egraph.register(\n",
+ " # demand rows and columns when we multiply matrices\n",
+ " rule(A @ B).then(\n",
+ " A.ncols(),\n",
+ " A.nrows(),\n",
+ " B.nrows(),\n",
+ " B.ncols(),\n",
+ " ),\n",
+ " # demand rows and columns when we take the kronecker product\n",
+ " rule(kron(A, B)).then(\n",
+ " A.ncols(),\n",
+ " A.nrows(),\n",
+ " B.nrows(),\n",
+ " B.ncols(),\n",
+ " ),\n",
+ ")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "334a2cc4-0004-415a-a8fb-4c5ef2e26aec",
+ "metadata": {},
+ "source": [
+ "For example, if we have `X @ Y` in the egraph, it will add expression for the columns of each as well:"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "c79a9105-e8fe-4545-b7c6-262648f82aad",
+ "metadata": {},
+ "outputs": [
{
- "attachments": {},
- "cell_type": "markdown",
- "id": "bd9e94de",
- "metadata": {},
- "source": [
- "We can try this out in our notebook, by multiplying some matrices and checking their dimensions:\n"
+ "data": {
+ "image/svg+xml": [
+ ""
+ ],
+ "text/plain": [
+ ""
]
+ },
+ "metadata": {},
+ "output_type": "display_data"
},
{
- "cell_type": "code",
- "execution_count": 16,
- "id": "bb50ade6",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "kron(Matrix.named(\"A\"), Matrix.named(\"B\"))"
- ]
- },
- "execution_count": 16,
- "metadata": {},
- "output_type": "execute_result"
- }
+ "data": {
+ "image/svg+xml": [
+ ""
],
- "source": [
- "# Define a number of dimensions\n",
- "n, m, p = Dim.named(\"n\"), Dim.named(\"m\"), Dim.named(\"p\")\n",
- "# Define a number of matrices\n",
- "A, B, C = Matrix.named(\"A\"), Matrix.named(\"B\"), Matrix.named(\"C\")\n",
- "# Set each to be a square matrix of the given dimension\n",
- "egraph.register(\n",
- " union(A.nrows()).with_(n),\n",
- " union(A.ncols()).with_(n),\n",
- " union(B.nrows()).with_(m),\n",
- " union(B.ncols()).with_(m),\n",
- " union(C.nrows()).with_(p),\n",
- " union(C.ncols()).with_(p),\n",
- ")\n",
- "# Create an example which should equal the kronecker product of A and B\n",
- "ex1 = kron(Matrix.identity(n), B) @ kron(A, Matrix.identity(m))\n",
- "egraph.simplify(ex1, 20)"
+ "text/plain": [
+ ""
]
- },
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "with egraph:\n",
+ " egraph.register(Matrix.named(\"X\") @ Matrix.named(\"Y\"))\n",
+ " egraph.display()\n",
+ " egraph.run(1)\n",
+ " egraph.display()"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "bd9e94de",
+ "metadata": {},
+ "source": [
+ "We can try this out in our notebook, by multiplying some matrices and checking their dimensions:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "bb50ade6",
+ "metadata": {},
+ "outputs": [
{
- "attachments": {},
- "cell_type": "markdown",
- "id": "554321e2",
- "metadata": {},
- "source": [
- "We can make sure that if the rows/columns do not line up, then the transformation will not be applied:\n"
+ "data": {
+ "text/html": [
+ "
kron(Matrix.named("A"),Matrix.named("B"))\n",
+ "
\n"
+ ],
+ "text/latex": [
+ "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n",
+ "\\PY{n}{kron}\\PY{p}{(}\\PY{n}{Matrix}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{A}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Matrix}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{B}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{)}\n",
+ "\\end{Verbatim}\n"
+ ],
+ "text/plain": [
+ "kron(Matrix.named(\"A\"), Matrix.named(\"B\"))"
]
- },
+ },
+ "metadata": {},
+ "output_type": "display_data"
+ }
+ ],
+ "source": [
+ "# Define a number of dimensions\n",
+ "n, m, p = Dim.named(\"n\"), Dim.named(\"m\"), Dim.named(\"p\")\n",
+ "# Define a number of matrices\n",
+ "A, B, C = Matrix.named(\"A\"), Matrix.named(\"B\"), Matrix.named(\"C\")\n",
+ "# Set each to be a square matrix of the given dimension\n",
+ "egraph.register(\n",
+ " union(A.nrows()).with_(n),\n",
+ " union(A.ncols()).with_(n),\n",
+ " union(B.nrows()).with_(m),\n",
+ " union(B.ncols()).with_(m),\n",
+ " union(C.nrows()).with_(p),\n",
+ " union(C.ncols()).with_(p),\n",
+ ")\n",
+ "# Create an example which should equal the kronecker product of A and B\n",
+ "ex1 = kron(Matrix.identity(n), B) @ kron(A, Matrix.identity(m))\n",
+ "egraph.simplify(ex1, 20)"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "554321e2",
+ "metadata": {},
+ "source": [
+ "We can make sure that if the rows/columns do not line up, then the transformation will not be applied:\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "d8dea199",
+ "metadata": {},
+ "outputs": [
{
- "cell_type": "code",
- "execution_count": 17,
- "id": "d8dea199",
- "metadata": {},
- "outputs": [
- {
- "data": {
- "text/plain": [
- "kron(Matrix.identity(Dim.named(\"p\")), Matrix.named(\"C\")) @ kron(Matrix.named(\"A\"), Matrix.identity(Dim.named(\"m\")))"
- ]
- },
- "execution_count": 17,
- "metadata": {},
- "output_type": "execute_result"
- }
+ "data": {
+ "text/html": [
+ "