diff --git a/Cargo.lock b/Cargo.lock index 28f0cf09..af973ee1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -41,30 +41,29 @@ checksum = "0942ffc6dcaadf03badf6e6a2d0228460359d5e34b57ccdc720b7382dfbd5ec5" [[package]] name = "anstream" -version = "0.3.2" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ca84f3628370c59db74ee214b3263d58f9aadd9b4fe7e711fd87dc452b7f163" +checksum = "2ab91ebe16eb252986481c5b62f6098f3b698a45e34b5b98200cf20dd2484a44" dependencies = [ "anstyle", "anstyle-parse", "anstyle-query", "anstyle-wincon", "colorchoice", - "is-terminal", "utf8parse", ] [[package]] name = "anstyle" -version = "1.0.1" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a30da5c5f2d5e72842e00bcb57657162cdabef0931f40e2deb9b4140440cecd" +checksum = "7079075b41f533b8c61d2a4d073c4676e1f8b249ff94a393b0595db304e0dd87" [[package]] name = "anstyle-parse" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "938874ff5980b03a87c5524b3ae5b59cf99b1d6bc836848df7bc5ada9643c333" +checksum = "317b9a89c1868f5ea6ff1d9539a69f45dffc21ce321ac1fd1160dfa48c8e2140" dependencies = [ "utf8parse", ] @@ -80,9 +79,9 @@ dependencies = [ [[package]] name = "anstyle-wincon" -version = "1.0.1" +version = "3.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "180abfa45703aebe0093f79badacc01b8fd4ea2e35118747e5811127f926e188" +checksum = "f0699d10d2f4d628a98ee7b57b289abbc98ff3bad977cb3152709d4bf2330628" dependencies = [ "anstyle", "windows-sys", @@ -132,9 +131,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.3.3" +version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "630be753d4e58660abd17930c71b647fe46c27ea6b63cc59e1e3851406972e42" +checksum = "b4682ae6287fcf752ecaabbfcc7b6f9b72aa33933dc23a554d853aea8eea8635" [[package]] name = "block-buffer" @@ -147,9 +146,12 @@ dependencies = [ [[package]] name = "cc" -version = "1.0.79" +version = "1.0.83" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50d30906286121d95be3d479533b458f87493b30a4b5f79a607db8f5d11aa91f" +checksum = "f1174fb0b6ec23863f8b971027804a42614e347eafb0a95bf0b12cdae21fc4d0" +dependencies = [ + "libc", +] [[package]] name = "cfg-if" @@ -159,20 +161,19 @@ checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "clap" -version = "4.3.15" +version = "4.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f644d0dac522c8b05ddc39aaaccc5b136d5dc4ff216610c5641e3be5becf56c" +checksum = "d04704f56c2cde07f43e8e2c154b43f216dc5c92fc98ada720177362f953b956" dependencies = [ "clap_builder", "clap_derive", - "once_cell", ] [[package]] name = "clap_builder" -version = "4.3.15" +version = "4.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af410122b9778e024f9e0fb35682cc09cc3f85cad5e8d3ba8f47a9702df6e73d" +checksum = "0e231faeaca65ebd1ea3c737966bf858971cd38c3849107aa3ea7de90a804e45" dependencies = [ "anstream", "anstyle", @@ -182,21 +183,21 @@ dependencies = [ [[package]] name = "clap_derive" -version = "4.3.12" +version = "4.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "54a9bb5758fc5dfe728d1019941681eccaf0cf8a4189b692a0ee2f2ecf90a050" +checksum = "0862016ff20d69b84ef8247369fabf5c008a7417002411897d40ee1f4532b873" dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.27", + "syn 2.0.32", ] [[package]] name = "clap_lex" -version = "0.5.0" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2da6da31387c7e4ef160ffab6d5e7f00c42626fe39aea70a7b0f1773f7dd6c1b" +checksum = "cd7cc57abe963c6d3b9d8be5b06ba7c8957a930305ca90304f24ef040aa6f961" [[package]] name = "colorchoice" @@ -284,13 +285,13 @@ checksum = "675e35c02a51bb4d4618cb4885b3839ce6d1787c97b664474d9208d074742e20" [[package]] name = "egglog" version = "0.1.0" -source = "git+https://github.com/egraphs-good/egglog?rev=4d67f262a6f27aa5cfb62a2cfc7df968959105df#4d67f262a6f27aa5cfb62a2cfc7df968959105df" +source = "git+https://github.com/egraphs-good/egglog?rev=45d05e727cceaab13413b4e51a60ee3be9fbf403#45d05e727cceaab13413b4e51a60ee3be9fbf403" dependencies = [ "clap", "egraph-serialize", "env_logger", - "hashbrown 0.14.0", - "indexmap 2.0.0", + "hashbrown 0.14.1", + "indexmap", "instant", "lalrpop", "lalrpop-util 0.20.0", @@ -327,7 +328,7 @@ version = "0.1.0" source = "git+https://github.com/saulshanabrook/egraph-serialize?rev=a3f6fef9b958a335367d80d51e028c6db886fb6e#a3f6fef9b958a335367d80d51e028c6db886fb6e" dependencies = [ "graphviz-rust", - "indexmap 2.0.0", + "indexmap", "once_cell", "ordered-float", "serde", @@ -336,9 +337,9 @@ dependencies = [ [[package]] name = "either" -version = "1.8.1" +version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7fcaabb2fef8c910e7f4c7ce9f67a1283a1715879a7c230ca9d6d1ae31f16d91" +checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" [[package]] name = "ena" @@ -370,9 +371,9 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" -version = "0.3.1" +version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4bcfec3a70f97c962c307b2d2c56e358cf1d00b558d74262b5f929ee8cc7e73a" +checksum = "add4f07d43996f76ef320709726a556a9d4f965d9410d8d0271132d2f8293480" dependencies = [ "errno-dragonfly", "libc", @@ -391,9 +392,9 @@ dependencies = [ [[package]] name = "fastrand" -version = "2.0.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6999dc1837253364c2ebb0704ba97994bd874e8f195d665c50b7548f6ea92764" +checksum = "25cbce373ec4653f1a01a31e8a5e5ec0c622dc27ff9c4e6606eefef5cbbed4a5" [[package]] name = "fixedbitset" @@ -449,9 +450,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.14.0" +version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c6201b9ff9fd90a5a3bac2e56a830d0caa509576f0e503818ee82c181b3437a" +checksum = "7dfda62a12f55daeae5015f81b0baea145391cb4520f86c248fc615d72640d12" dependencies = [ "ahash 0.8.3", "allocator-api2", @@ -465,9 +466,9 @@ checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "443144c8cdadd93ebf52ddb4056d257f5b52c04d3c804e657d19eb73fc33668b" +checksum = "d77f7ec81a6d05a3abb01ab6eb7590f6083d08449fe5a1c8b1e620283546ccb7" [[package]] name = "humantime" @@ -477,22 +478,12 @@ checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4" [[package]] name = "indexmap" -version = "1.9.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bd070e393353796e801d209ad339e89596eb4c8d430d18ede6a1cced8fafbd99" -dependencies = [ - "autocfg", - "hashbrown 0.12.3", -] - -[[package]] -name = "indexmap" -version = "2.0.0" +version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d5477fe2230a79769d8dc68e0eabf5437907c0457a5614a9e8dddb67f65eb65d" +checksum = "8adf3ddd720272c6ea8bf59463c04e0f93d0bbf7c5439b691bca2987e0270897" dependencies = [ "equivalent", - "hashbrown 0.14.0", + "hashbrown 0.14.1", "serde", ] @@ -614,9 +605,9 @@ checksum = "b4668fb0ea861c1df094127ac5f1da3409a82116a4ba74fca2e58ef927159bb3" [[package]] name = "linux-raw-sys" -version = "0.4.3" +version = "0.4.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09fc20d2ca12cb9f044c93e3bd6d32d523e6e2ec3db4f7b2939cd99026ecd3f0" +checksum = "3852614a3bd9ca9804678ba6be5e3b8ce76dfc902cae004e3e0c44051b6e88db" [[package]] name = "lock_api" @@ -739,19 +730,20 @@ dependencies = [ [[package]] name = "pest" -version = "2.7.2" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1acb4a4365a13f749a93f1a094a7805e5cfa0955373a9de860d962eaa3a5fe5a" +checksum = "c022f1e7b65d6a24c0dbbd5fb344c66881bc01f3e5ae74a1c8100f2f985d98a4" dependencies = [ + "memchr", "thiserror", "ucd-trie", ] [[package]] name = "pest_derive" -version = "2.7.2" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "666d00490d4ac815001da55838c500eafb0320019bbaa44444137c48b443a853" +checksum = "35513f630d46400a977c4cb58f78e1bfbe01434316e60c37d27b9ad6139c66d8" dependencies = [ "pest", "pest_generator", @@ -759,22 +751,22 @@ dependencies = [ [[package]] name = "pest_generator" -version = "2.7.2" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68ca01446f50dbda87c1786af8770d535423fa8a53aec03b8f4e3d7eb10e0929" +checksum = "bc9fc1b9e7057baba189b5c626e2d6f40681ae5b6eb064dc7c7834101ec8123a" dependencies = [ "pest", "pest_meta", "proc-macro2", "quote", - "syn 2.0.27", + "syn 2.0.32", ] [[package]] name = "pest_meta" -version = "2.7.2" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "56af0a30af74d0445c0bf6d9d051c979b516a1a5af790d251daee76005420a48" +checksum = "1df74e9e7ec4053ceb980e7c0c8bd3594e977fde1af91daba9c928e8e8c6708d" dependencies = [ "once_cell", "pest", @@ -783,12 +775,12 @@ dependencies = [ [[package]] name = "petgraph" -version = "0.6.3" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4dd7d28ee937e54fe3080c91faa1c3a46c06de6252988a7f4592ba2310ef22a4" +checksum = "e1d3afd2628e69da2be385eb6f2fd57c8ac7977ceeff6dc166ff1657b0e386a9" dependencies = [ "fixedbitset", - "indexmap 1.9.3", + "indexmap", ] [[package]] @@ -1005,11 +997,11 @@ checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" [[package]] name = "rustix" -version = "0.38.4" +version = "0.38.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a962918ea88d644592894bc6dc55acc6c0956488adcebbfb6e273506b7fd6e5" +checksum = "d7db8590df6dfcd144d22afd1b83b36c21a18d7cbc1dc4bb5295a8712e9eb662" dependencies = [ - "bitflags 2.3.3", + "bitflags 2.4.0", "errno", "libc", "linux-raw-sys", @@ -1036,31 +1028,31 @@ checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "serde" -version = "1.0.179" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0a5bf42b8d227d4abf38a1ddb08602e229108a517cd4e5bb28f9c7eaafdce5c0" +checksum = "cf9e0fcba69a370eed61bcf2b728575f726b50b55cba78064753d708ddc7549e" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.179" +version = "1.0.188" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "741e124f5485c7e60c03b043f79f320bff3527f4bbf12cf3831750dc46a0ec2c" +checksum = "4eca7ac642d82aa35b60049a6eccb4be6be75e599bd2e9adb5f875a737654af2" dependencies = [ "proc-macro2", "quote", - "syn 2.0.27", + "syn 2.0.32", ] [[package]] name = "serde_json" -version = "1.0.105" +version = "1.0.107" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "693151e1ac27563d6dbcec9dee9fbd5da8539b20fa14ad3752b2e6d363ace360" +checksum = "6b420ce6e3d8bd882e9b243c6eed35dbc9a6110c9769e74b584e0d68d1f20c65" dependencies = [ - "indexmap 2.0.0", + "indexmap", "itoa", "ryu", "serde", @@ -1068,9 +1060,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.10.7" +version = "0.10.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "479fb9d862239e610720565ca91403019f2f00410f1864c5aa7479b950a76ed8" +checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" dependencies = [ "cfg-if", "cpufeatures", @@ -1079,9 +1071,9 @@ dependencies = [ [[package]] name = "siphasher" -version = "0.3.10" +version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7bd3e3206899af3f8b12af284fafc038cc1dc2b41d1b89dd17297221c5d225de" +checksum = "38b58827f4464d87d377d175e90bf58eb00fd8716ff0a62f80356b5e61555d0d" [[package]] name = "smallvec" @@ -1135,9 +1127,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.27" +version = "2.0.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b60f673f44a8255b9c8c657daf66a596d435f2da81a555b06dc644d080ba45e0" +checksum = "239814284fd6f1a4ffe4ca893952cdd93c224b6a1571c9a9eadd670295c0c9e2" dependencies = [ "proc-macro2", "quote", @@ -1152,9 +1144,9 @@ checksum = "df8e77cb757a61f51b947ec4a7e3646efd825b73561db1c232a8ccb639e611a0" [[package]] name = "tempfile" -version = "3.7.1" +version = "3.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc02fddf48964c42031a0b3fe0428320ecf3a73c401040fc0096f97794310651" +checksum = "cb94d2f3cc536af71caac6b6fcebf65860b347e7ce0cc9ebe8f70d3e521054ef" dependencies = [ "cfg-if", "fastrand", @@ -1176,31 +1168,31 @@ dependencies = [ [[package]] name = "termcolor" -version = "1.2.0" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "be55cf8942feac5c765c2c993422806843c9a9a45d4d5c407ad6dd2ea95eb9b6" +checksum = "6093bad37da69aab9d123a8091e4be0aa4a03e4d601ec641c327398315f62b64" dependencies = [ "winapi-util", ] [[package]] name = "thiserror" -version = "1.0.43" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a35fc5b8971143ca348fa6df4f024d4d55264f3468c71ad1c2f365b0a4d58c42" +checksum = "1177e8c6d7ede7afde3585fd2513e611227efd6481bd78d2e82ba1ce16557ed4" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.43" +version = "1.0.49" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "463fe12d7993d3b327787537ce8dd4dfa058de32fc2b195ef3cde03dc4771e8f" +checksum = "10712f02019e9288794769fba95cd6847df9874d49d871d062172f9dd41bc4cc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.27", + "syn 2.0.32", ] [[package]] @@ -1214,9 +1206,9 @@ dependencies = [ [[package]] name = "typenum" -version = "1.16.0" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "497961ef93d974e23eb6f433eb5fe1b7930b659f06d12dec6fc44a8f554c0bba" +checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "ucd-trie" @@ -1278,9 +1270,9 @@ checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" -version = "0.1.5" +version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70ec6ce85bb158151cae5e5c87f95a8e97d2c0c4b001223f33a334e3ce5de178" +checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" dependencies = [ "winapi", ] diff --git a/Cargo.toml b/Cargo.toml index a10585e3..99d87ce4 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,9 +10,9 @@ crate-type = ["cdylib"] [dependencies] pyo3 = { version = "0.18.1", features = ["extension-module"] } -egglog = { git = "https://github.com/egraphs-good/egglog", rev = "4d67f262a6f27aa5cfb62a2cfc7df968959105df" } +egglog = { git = "https://github.com/egraphs-good/egglog", rev = "45d05e727cceaab13413b4e51a60ee3be9fbf403" } # egglog = { git = "https://github.com/oflatt/egg-smol", rev = "f6df3ff831b65405665e1751b0ef71c61b025432" } -# egglog = { git = "https://github.com/saulshanabrook/egg-smol", rev = "c01695618ed4de2fbfa8116476e208bc1ca86612" } +# egglog = { git = "https://github.com/saulshanabrook/egg-smol", rev = "38b3014b34399cc78887ede09c845b2a5d6c7d19" } pyo3-log = "0.8.1" log = "0.4.17" diff --git a/docs/changelog.md b/docs/changelog.md index ca99dc78..d8f6de38 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -4,6 +4,20 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea ## Unreleased +- Bump [egglog dep](https://github.com/egraphs-good/egglog/compare/4d67f262a6f27aa5cfb62a2cfc7df968959105df...45d05e727cceaab13413b4e51a60ee3be9fbf403) + +### New Features + +- Adds ability for custom user defined types in a union for proper static typing with conversions [#49](https://github.com/metadsl/egglog-python/pull/49) +- Adds `py_eval` function to `EGraph` as a helper to eval Python code. [#49](https://github.com/metadsl/egglog-python/pull/49) +- Adds on hover behavior for edges in graphviz SVG output to make them easier to trace [#49](https://github.com/metadsl/egglog-python/pull/49) +- Adds `egglog.exp.program_gen` module that will compile expressions into Python statements/functions [#49](https://github.com/metadsl/egglog-python/pull/49) +- Adds `py_exec` primitive function for executing Python code [#49](https://github.com/metadsl/egglog-python/pull/49) + +### Bug fixes + +- Clean up example in tutorial with demand based expression generation [#49](https://github.com/metadsl/egglog-python/pull/49) + ## 0.6.0 (2023-09-20) - Bump [egglog dep](https://github.com/egraphs-good/egglog/compare/c83fc750878755eb610a314da90f9273b3bfe25d...4d67f262a6f27aa5cfb62a2cfc7df968959105df) @@ -19,8 +33,6 @@ _This project uses semantic versioning. Before 1.0.0, this means that every brea - Add `Relation` and `PrintOverallStatistics` low level commands [#46](https://github.com/metadsl/egglog-python/pull/46) - Adds `count-matches` and `replace` string commands [#46](https://github.com/metadsl/egglog-python/pull/46) -### Bug fixes - ### Uncategorized - Added initial supported for Python objects [#31](https://github.com/metadsl/egglog-python/pull/31) diff --git a/docs/reference/egglog-translation.md b/docs/reference/egglog-translation.md index 25877037..6a2c73a3 100644 --- a/docs/reference/egglog-translation.md +++ b/docs/reference/egglog-translation.md @@ -245,21 +245,6 @@ As shown above, we can also use the `@classmethod` and `@property` decorators to Note that reflected methods (i.e. `__radd__`) are handled as a special case. If defined, they won't create their own egglog functions. Instead, whenever a reflected method is called, we will try to find the corresponding non-reflected method and call that instead. -#### Custom Type Promotion - -Similar to how an `int` can be automatically upcasted to an `i64`, we also support registering conversion to your custom types. For example: - -```{code-cell} python -converter(int, Math, Math) -converter(str, Math, Math.var) - -Math(2) + 30 + "x" -# equal to -Math(2) + Math(i64(30)) + Math.var(String("x")) -``` - -Regstering a conversion from A to B will also register all transitively reachable conversions from A to B. - ### Declarations In egglog, the `(declare ...)` command is syntactic sugar for a nullary function. In Python, these can be declare either as class variables or with the toplevel `egraph.constant` function: diff --git a/docs/reference/python-integration.md b/docs/reference/python-integration.md index 3fa8cd59..8850dcf4 100644 --- a/docs/reference/python-integration.md +++ b/docs/reference/python-integration.md @@ -72,8 +72,13 @@ egraph.load_object(egraph.extract(PyObject.from_int(1))) We also support evaling arbitrary Python bode, given some locals and globals. This technically allows us to implement any Python method: ```{code-cell} python -empty_dict = egraph.save_object({}) -egraph.load_object(egraph.extract(py_eval("1 + 2", empty_dict, empty_dict))) +egraph.load_object(egraph.extract(py_eval("1 + 2"))) +``` + +Execing Python code is also supported. In this case, the return value will be the updated globals dict, which will be copied first before using. + +```{code-cell} python +egraph.load_object(egraph.extract(py_exec("x = 1 + 2"))) ``` Alongside this, we support a function `dict_update` method, which can allow you to combine some local local egglog expressions alongside, say, the locals and globals of the Python code you are evaling. @@ -87,11 +92,68 @@ locals_expr = egraph.save_object(locals()) globals_expr = egraph.save_object(globals()) # Need `one` to map to the expression for `1` not the Python object of the expression amended_globals = globals_expr.dict_update(PyObject.from_string("one"), one) -evalled = py_eval("one + 2", locals_expr, amended_globals) +evalled = py_eval("my_add(one, 2)", locals_expr, amended_globals) assert egraph.load_object(egraph.extract(evalled)) == 3 ``` -This is a bit subtle at the moment, and we plan on adding an easier wrapper to eval arbitrary Python code in the future. +### Simpler Eval + +Instead of using the above low level primitive for evaling, there is a higher level wrapper function, `egraph.eval_fn`. + +It takes in a Python function and converts it to a function of PyObjects, by using `py_eval` +under the hood. + +The above code code be re-written like this: + +```{code-cell} python +def my_add(a, b): + return a + b + +evalled = egraph.eval_fn(lambda a: my_add(a, 2))(one) +assert egraph.load_object(egraph.extract(evalled)) == 3 +``` + +#### Custom Type Promotion + +Similar to how an `int` can be automatically upcasted to an `i64`, we also support registering conversion to your custom types. For example: + +```{code-cell} python +@egraph.class_ +class Math(Expr): + def __init__(self, x: i64Like) -> None: ... + + @classmethod + def var(cls, name: StringLike) -> Math: ... + + def __add__(self, other: Math) -> Math: ... + +converter(i64, Math, Math) +converter(String, Math, Math.var) + +Math(2) + i64(30) + String("x") +# equal to +Math(2) + Math(i64(30)) + Math.var(String("x")) +``` + +Regstering a conversion from A to B will also register all transitively reachable conversions from A to B, so you can also use: + +```{code-cell} python +Math(2) + 30 + "x" +``` + +If you want to have this work with the static type checker, you can define your own `Union` type, which MUST include +have the Egglog class as the first item in the union. For example, in this case you could then define: + +```{code-cell} python +from typing import Union +MathLike = Union[Math, i64Like, StringLike] + +@egraph.function +def some_math_fn(x: MathLike) -> MathLike: + ... + +some_math_fn(10) +``` ## "Preserved" methods diff --git a/docs/tutorials/getting-started.ipynb b/docs/tutorials/getting-started.ipynb index 9931fe86..95396b89 100644 --- a/docs/tutorials/getting-started.ipynb +++ b/docs/tutorials/getting-started.ipynb @@ -1,562 +1,1255 @@ { - "cells": [ - { - "attachments": {}, - "cell_type": "markdown", - "id": "ffabb623", - "metadata": { - "tags": [] - }, - "source": [ - "# Getting Started - Matrix Multiplication\n", - "\n", - "In this tutorial, you will learn how to:\n", - "\n", - "1. Install `egglog` Python\n", - "2. Create a representation for matrices and some simplification rules for them. This will be based off of the [matrix multiplication example](https://github.com/egraphs-good/egglog/blob/08a6e8f/tests/matrix.egg) in the egglog repository. By using our high level wrapper, we can rely on Python's built in static type checker to check the correctness of your representation.\n", - "3. Try out using our library in an interactive notebook.\n", - "\n", - "## Install egglog Python\n", - "\n", - "First, you will need to have a working Python interpreter. In this tutorial, we will [use `miniconda`](https://docs.conda.io/en/latest/miniconda.html) to create a new Python environment and activate it:\n", - "\n", - "```bash\n", - "$ brew install miniconda\n", - "$ conda create -n egglog-python python=3.11\n", - "$ conda activate egglog-python\n", - "```\n", - "\n", - "Then we want to install `egglog` Python. `egglog` Python can run on any recent Python version, and is tested on 3.8 - 3.11. To install it, run:\n", - "\n", - "```bash\n", - "$ pip install egglog\n", - "```\n", - "\n", - "To test you have installed it correctly, run:\n", - "\n", - "```bash\n", - "$ python -m 'import egglog'\n", - "```\n", - "\n", - "We also want to install `mypy` for static type checking. This is not required, but it will help us write correct representations. To install it, run:\n", - "\n", - "```bash\n", - "$ pip install mypy\n", - "```\n", - "\n", - "## Creating an E-Graph\n", - "\n", - "In this tutorial, we will use [VS Code](https://code.visualstudio.com/) to create file, `matrix.py`, to include our egraph\n", - "and the simplification rules:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "7369b71b", - "metadata": {}, - "outputs": [], - "source": [ - "from __future__ import annotations\n", - "\n", - "from egglog import *\n", - "\n", - "egraph = EGraph()" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "814a51c5", - "metadata": {}, - "source": [ - "## Defining Dimensions\n", - "\n", - "We will start by defining a representation for integers, which we will use to represent\n", - "the dimensions of the matrix:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "04fa991a", - "metadata": {}, - "outputs": [], - "source": [ - "@egraph.class_\n", - "class Dim(Expr):\n", - " \"\"\"\n", - " A dimension of a matix.\n", - "\n", - " >>> Dim(3) * Dim.named(\"n\")\n", - " Dim(3) * Dim.named(\"n\")\n", - " \"\"\"\n", - "\n", - " def __init__(self, value: i64Like) -> None:\n", - " ...\n", - "\n", - " @classmethod\n", - " def named(cls, name: StringLike) -> Dim:\n", - " ...\n", - "\n", - " def __mul__(self, other: Dim) -> Dim:\n", - " ..." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "f5098a2b", - "metadata": { - "tags": [] - }, - "source": [ - "As you can see, you must wrap any class with the `egraph.class_` to register\n", - "it with the egraph and be able to use it like a Python class.\n", - "\n", - "### Testing in a notebook\n", - "\n", - "We can try out this by [creating a new notebook](https://code.visualstudio.com/docs/datascience/jupyter-notebooks#_create-or-open-a-jupyter-notebook) which imports this file:\n", - "\n", - "```python\n", - "from matrix import *\n", - "```\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "fd43c7ef", - "metadata": {}, - "source": [ - "We can then create a new `Dim` object:\n" - ] - }, + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "ffabb623", + "metadata": { + "tags": [] + }, + "source": [ + "# Getting Started - Matrix Multiplication\n", + "\n", + "In this tutorial, you will learn how to:\n", + "\n", + "1. Install `egglog` Python\n", + "2. Create a representation for matrices and some simplification rules for them. This will be based off of the [matrix multiplication example](https://github.com/egraphs-good/egglog/blob/08a6e8f/tests/matrix.egg) in the egglog repository. By using our high level wrapper, we can rely on Python's built in static type checker to check the correctness of your representation.\n", + "3. Try out using our library in an interactive notebook.\n", + "\n", + "## Install egglog Python\n", + "\n", + "First, you will need to have a working Python interpreter. In this tutorial, we will [use `miniconda`](https://docs.conda.io/en/latest/miniconda.html) to create a new Python environment and activate it:\n", + "\n", + "```bash\n", + "$ brew install miniconda\n", + "$ conda create -n egglog-python python=3.11\n", + "$ conda activate egglog-python\n", + "```\n", + "\n", + "Then we want to install `egglog` Python. `egglog` Python can run on any recent Python version, and is tested on 3.8 - 3.11. To install it, run:\n", + "\n", + "```bash\n", + "$ pip install egglog\n", + "```\n", + "\n", + "To test you have installed it correctly, run:\n", + "\n", + "```bash\n", + "$ python -m 'import egglog'\n", + "```\n", + "\n", + "We also want to install `mypy` for static type checking. This is not required, but it will help us write correct representations. To install it, run:\n", + "\n", + "```bash\n", + "$ pip install mypy\n", + "```\n", + "\n", + "## Creating an E-Graph\n", + "\n", + "In this tutorial, we will use [VS Code](https://code.visualstudio.com/) to create file, `matrix.py`, to include our egraph\n", + "and the simplification rules:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7369b71b", + "metadata": {}, + "outputs": [], + "source": [ + "from __future__ import annotations\n", + "\n", + "from egglog import *\n", + "\n", + "egraph = EGraph()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "814a51c5", + "metadata": {}, + "source": [ + "## Defining Dimensions\n", + "\n", + "We will start by defining a representation for integers, which we will use to represent\n", + "the dimensions of the matrix:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "04fa991a", + "metadata": {}, + "outputs": [], + "source": [ + "@egraph.class_\n", + "class Dim(Expr):\n", + " \"\"\"\n", + " A dimension of a matix.\n", + "\n", + " >>> Dim(3) * Dim.named(\"n\")\n", + " Dim(3) * Dim.named(\"n\")\n", + " \"\"\"\n", + "\n", + " def __init__(self, value: i64Like) -> None:\n", + " ...\n", + "\n", + " @classmethod\n", + " def named(cls, name: StringLike) -> Dim:\n", + " ...\n", + "\n", + " def __mul__(self, other: Dim) -> Dim:\n", + " ..." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "f5098a2b", + "metadata": { + "tags": [] + }, + "source": [ + "As you can see, you must wrap any class with the `egraph.class_` to register\n", + "it with the egraph and be able to use it like a Python class.\n", + "\n", + "### Testing in a notebook\n", + "\n", + "We can try out this by [creating a new notebook](https://code.visualstudio.com/docs/datascience/jupyter-notebooks#_create-or-open-a-jupyter-notebook) which imports this file:\n", + "\n", + "```python\n", + "from matrix import *\n", + "```\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "fd43c7ef", + "metadata": {}, + "source": [ + "We can then create a new `Dim` object:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "b6424530", + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 8, - "id": "b6424530", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "(Dim.named(\"x\") * Dim(10)) * Dim(10)" - ] - }, - "execution_count": 8, - "metadata": {}, - "output_type": "execute_result" - } + "data": { + "text/html": [ + "
(Dim.named("x") * Dim(10)) * Dim(10)\n",
+       "
\n" ], - "source": [ - "x = Dim.named(\"x\")\n", - "ten = Dim(10)\n", - "res = x * ten * ten\n", - "res" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "ef5ebb16", - "metadata": {}, - "source": [ - "We see that the output is not evaluated, it's just a representation of the computation as well as the type. This is because we haven't defined any simplification rules yet.\n", - "\n", - "We can also try to create a dimension from an invalid type, or use it in an invalid way, we get a type error before we even run the code:\n", - "\n", - "```python\n", - "x - ten\n", - "```\n", - "\n", - "![Screenshot of VS Code showing a type error](./screenshot-1.png)\n", - "\n", - "## Dimension Replacements\n", - "\n", - "Now we will register some replacements for our dimensions and see how we can interface with egg to get it\n", - "to execute them.\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "b06b1749", - "metadata": {}, - "outputs": [], - "source": [ - "a, b, c = vars_(\"a b c\", Dim)\n", - "i, j = vars_(\"i j\", i64)\n", - "egraph.register(\n", - " rewrite(a * (b * c)).to((a * b) * c),\n", - " rewrite((a * b) * c).to(a * (b * c)),\n", - " rewrite(Dim(i) * Dim(j)).to(Dim(i * j)),\n", - " rewrite(a * b).to(b * a),\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "167722d1-60b8-452a-ae54-6a8df4db5b00", - "metadata": {}, - "source": [ - "You might notice that unlike a traditional term rewriting system, we don't specify any order for these rewrites. They will be executed until the graph is fully saturated, meaning that no new terms are created.\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "a4d2c911", - "metadata": {}, - "source": [ - "We can also see how the type checking can help us. If we try to create a rewrite from a `Dim` to an `i64` we see that we get a type error:\n", - "\n", - "![Screenshot of VS Code showing a type error](./screenshot-2.png)\n" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "76dc1672-dba6-44ab-b9f1-aa01de685fb1", - "metadata": {}, - "source": [ - "### Testing\n", - "\n", - "Going back to the notebook, we can test out the that the rewrites are working.\n", - "We can run some number of iterations and extract out the lowest cost expression which is equivalent to our variable:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "31afa12e-da68-4398-91fa-14523f6c099a", - "metadata": { - "tags": [] - }, - "outputs": [ - { - "data": { - "text/plain": [ - "Dim.named(\"x\") * Dim(100)" - ] - }, - "execution_count": 5, - "metadata": {}, - "output_type": "execute_result" - } + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{p}{(}\\PY{n}{Dim}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{x}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)} \\PY{o}{*} \\PY{n}{Dim}\\PY{p}{(}\\PY{l+m+mi}{10}\\PY{p}{)}\\PY{p}{)} \\PY{o}{*} \\PY{n}{Dim}\\PY{p}{(}\\PY{l+m+mi}{10}\\PY{p}{)}\n", + "\\end{Verbatim}\n" ], - "source": [ - "egraph.simplify(res, 10)" + "text/plain": [ + "(Dim.named(\"x\") * Dim(10)) * Dim(10)" ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "7e44104c-d87b-441d-a717-92d42aab9d37", - "metadata": {}, - "source": [ - "## Matrix Expressions\n", - "\n", - "Now that we have defined dimensions, we can define matrices as well as some functions on them:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "c5b96cfb", - "metadata": {}, - "outputs": [], - "source": [ - "@egraph.class_\n", - "class Matrix(Expr):\n", - " @classmethod\n", - " def identity(cls, dim: Dim) -> Matrix:\n", - " \"\"\"\n", - " Create an identity matrix of the given dimension.\n", - " \"\"\"\n", - " ...\n", - "\n", - " @classmethod\n", - " def named(cls, name: StringLike) -> Matrix:\n", - " \"\"\"\n", - " Create a named matrix.\n", - " \"\"\"\n", - " ...\n", - "\n", - " def __matmul__(self, other: Matrix) -> Matrix:\n", - " \"\"\"\n", - " Matrix multiplication.\n", - " \"\"\"\n", - " ...\n", - "\n", - " def nrows(self) -> Dim:\n", - " \"\"\"\n", - " Number of rows in the matrix.\n", - " \"\"\"\n", - " ...\n", - "\n", - " def ncols(self) -> Dim:\n", - " \"\"\"\n", - " Number of columns in the matrix.\n", - " \"\"\"\n", - " ...\n", - "\n", - "\n", - "@egraph.function\n", - "def kron(a: Matrix, b: Matrix) -> Matrix:\n", - " \"\"\"\n", - " Kronecker product of two matrices.\n", - "\n", - " https://en.wikipedia.org/wiki/Kronecker_product#Definition\n", - " \"\"\"\n", - " ..." - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "be8e6526", - "metadata": {}, - "source": [ - "### Rows/cols Replacements\n", - "\n", - "We can also define some replacements to understand the number of rows and columns of a matrix:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "cb2b4fb8", - "metadata": {}, - "outputs": [], - "source": [ - "A, B, C, D = vars_(\"A B C D\", Matrix)\n", - "egraph.register(\n", - " # The dimensions of a kronecker product are the product of the dimensions\n", - " rewrite(kron(A, B).nrows()).to(A.nrows() * B.nrows()),\n", - " rewrite(kron(A, B).ncols()).to(A.ncols() * B.ncols()),\n", - " # The dimensions of a matrix multiplication are the number of rows of the first\n", - " # matrix and the number of columns of the second matrix.\n", - " rewrite((A @ B).nrows()).to(A.nrows()),\n", - " rewrite((A @ B).ncols()).to(B.ncols()),\n", - " # The dimensions of an identity matrix are the input dimension\n", - " rewrite(Matrix.identity(a).nrows()).to(a),\n", - " rewrite(Matrix.identity(a).ncols()).to(a),\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "13b969e8", - "metadata": {}, - "source": [ - "We can try these out in our notebook (after restarting and re-importing) to compute the dimensions after some operations:\n" - ] - }, + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "x = Dim.named(\"x\")\n", + "ten = Dim(10)\n", + "res = x * ten * ten\n", + "res" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "ef5ebb16", + "metadata": {}, + "source": [ + "We see that the output is not evaluated, it's just a representation of the computation as well as the type. This is because we haven't defined any simplification rules yet.\n", + "\n", + "We can also try to create a dimension from an invalid type, or use it in an invalid way, we get a type error before we even run the code:\n", + "\n", + "```python\n", + "x - ten\n", + "```\n", + "\n", + "![Screenshot of VS Code showing a type error](./screenshot-1.png)\n", + "\n", + "## Dimension Replacements\n", + "\n", + "Now we will register some replacements for our dimensions and see how we can interface with egg to get it\n", + "to execute them.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b06b1749", + "metadata": {}, + "outputs": [], + "source": [ + "a, b, c = vars_(\"a b c\", Dim)\n", + "i, j = vars_(\"i j\", i64)\n", + "egraph.register(\n", + " rewrite(a * (b * c)).to((a * b) * c),\n", + " rewrite((a * b) * c).to(a * (b * c)),\n", + " rewrite(Dim(i) * Dim(j)).to(Dim(i * j)),\n", + " rewrite(a * b).to(b * a),\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "167722d1-60b8-452a-ae54-6a8df4db5b00", + "metadata": {}, + "source": [ + "You might notice that unlike a traditional term rewriting system, we don't specify any order for these rewrites. They will be executed until the graph is fully saturated, meaning that no new terms are created.\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "a4d2c911", + "metadata": {}, + "source": [ + "We can also see how the type checking can help us. If we try to create a rewrite from a `Dim` to an `i64` we see that we get a type error:\n", + "\n", + "![Screenshot of VS Code showing a type error](./screenshot-2.png)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "76dc1672-dba6-44ab-b9f1-aa01de685fb1", + "metadata": {}, + "source": [ + "### Testing\n", + "\n", + "Going back to the notebook, we can test out the that the rewrites are working.\n", + "We can run some number of iterations and extract out the lowest cost expression which is equivalent to our variable:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "31afa12e-da68-4398-91fa-14523f6c099a", + "metadata": { + "tags": [] + }, + "outputs": [ { - "cell_type": "code", - "execution_count": 13, - "id": "8d18be2d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Dim.named(\"y\")\n", - "Dim.named(\"x\")\n" - ] - } + "data": { + "text/html": [ + "
Dim.named("x") * Dim(100)\n",
+       "
\n" ], - "source": [ - "# If we multiply two identity matrices, we should be able to get the number of columns of the result\n", - "x = Matrix.identity(Dim.named(\"x\"))\n", - "y = Matrix.identity(Dim.named(\"y\"))\n", - "x_mult_y = x @ y\n", - "print(egraph.simplify(x_mult_y.ncols(), 10))\n", - "print(egraph.simplify(x_mult_y.nrows(), 10))" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "2f2c68c3", - "metadata": {}, - "source": [ - "### Operation replacements\n", - "\n", - "We can also define some replacements for matrix operations:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "18a91684", - "metadata": {}, - "outputs": [], - "source": [ - "egraph.register(\n", - " # Multiplication by an identity matrix is the same as the other matrix\n", - " rewrite(A @ Matrix.identity(a)).to(A),\n", - " rewrite(Matrix.identity(a) @ A).to(A),\n", - " # Matrix multiplication is associative\n", - " rewrite((A @ B) @ C).to(A @ (B @ C)),\n", - " rewrite(A @ (B @ C)).to((A @ B) @ C),\n", - " # Kronecker product is associative\n", - " rewrite(kron(A, kron(B, C))).to(kron(kron(A, B), C)),\n", - " rewrite(kron(kron(A, B), C)).to(kron(A, kron(B, C))),\n", - " # Kronecker product distributes over matrix multiplication\n", - " rewrite(kron(A @ C, B @ D)).to(kron(A, B) @ kron(C, D)),\n", - " rewrite(kron(A, B) @ kron(C, D)).to(\n", - " kron(A @ C, B @ D),\n", - " # Only when the dimensions match\n", - " eq(A.ncols()).to(C.nrows()),\n", - " eq(B.ncols()).to(D.nrows()),\n", - " ),\n", - ")" - ] - }, - { - "attachments": {}, - "cell_type": "markdown", - "id": "1cd649dc", - "metadata": {}, - "source": [ - "In our previous tests, we had to add the `ncols` and `nrows` operations to the e-graph seperately in order to have them be simplified. We can write some \"demand\" rules which automatically add these operations to the e-graph when they are needed:\n" + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{n}{Dim}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{x}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)} \\PY{o}{*} \\PY{n}{Dim}\\PY{p}{(}\\PY{l+m+mi}{100}\\PY{p}{)}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "Dim.named(\"x\") * Dim(100)" ] - }, + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "egraph.simplify(res, 10)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "7e44104c-d87b-441d-a717-92d42aab9d37", + "metadata": {}, + "source": [ + "## Matrix Expressions\n", + "\n", + "Now that we have defined dimensions, we can define matrices as well as some functions on them:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "c5b96cfb", + "metadata": {}, + "outputs": [], + "source": [ + "@egraph.class_\n", + "class Matrix(Expr):\n", + " @classmethod\n", + " def identity(cls, dim: Dim) -> Matrix:\n", + " \"\"\"\n", + " Create an identity matrix of the given dimension.\n", + " \"\"\"\n", + " ...\n", + "\n", + " @classmethod\n", + " def named(cls, name: StringLike) -> Matrix:\n", + " \"\"\"\n", + " Create a named matrix.\n", + " \"\"\"\n", + " ...\n", + "\n", + " def __matmul__(self, other: Matrix) -> Matrix:\n", + " \"\"\"\n", + " Matrix multiplication.\n", + " \"\"\"\n", + " ...\n", + "\n", + " def nrows(self) -> Dim:\n", + " \"\"\"\n", + " Number of rows in the matrix.\n", + " \"\"\"\n", + " ...\n", + "\n", + " def ncols(self) -> Dim:\n", + " \"\"\"\n", + " Number of columns in the matrix.\n", + " \"\"\"\n", + " ...\n", + "\n", + "\n", + "@egraph.function\n", + "def kron(a: Matrix, b: Matrix) -> Matrix:\n", + " \"\"\"\n", + " Kronecker product of two matrices.\n", + "\n", + " https://en.wikipedia.org/wiki/Kronecker_product#Definition\n", + " \"\"\"\n", + " ..." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "be8e6526", + "metadata": {}, + "source": [ + "### Rows/cols Replacements\n", + "\n", + "We can also define some replacements to understand the number of rows and columns of a matrix:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "cb2b4fb8", + "metadata": {}, + "outputs": [], + "source": [ + "A, B, C, D = vars_(\"A B C D\", Matrix)\n", + "egraph.register(\n", + " # The dimensions of a kronecker product are the product of the dimensions\n", + " rewrite(kron(A, B).nrows()).to(A.nrows() * B.nrows()),\n", + " rewrite(kron(A, B).ncols()).to(A.ncols() * B.ncols()),\n", + " # The dimensions of a matrix multiplication are the number of rows of the first\n", + " # matrix and the number of columns of the second matrix.\n", + " rewrite((A @ B).nrows()).to(A.nrows()),\n", + " rewrite((A @ B).ncols()).to(B.ncols()),\n", + " # The dimensions of an identity matrix are the input dimension\n", + " rewrite(Matrix.identity(a).nrows()).to(a),\n", + " rewrite(Matrix.identity(a).ncols()).to(a),\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "13b969e8", + "metadata": {}, + "source": [ + "We can try these out in our notebook (after restarting and re-importing) to compute the dimensions after some operations:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "8d18be2d", + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 15, - "id": "303ce7f3", - "metadata": {}, - "outputs": [], - "source": [ - "egraph.register(\n", - " # demand rows and columns when we multiply matrices\n", - " rule(eq(C).to(A @ B)).then(\n", - " let(\"1\", A.ncols()),\n", - " let(\"2\", A.nrows()),\n", - " let(\"3\", B.nrows()),\n", - " let(\"4\", B.ncols()),\n", - " ),\n", - " # demand rows and columns when we take the kronecker product\n", - " rule(eq(C).to(kron(A, B))).then(\n", - " let(\"1\", A.ncols()),\n", - " let(\"2\", A.nrows()),\n", - " let(\"3\", B.nrows()),\n", - " let(\"4\", B.ncols()),\n", - " ),\n", - ")" - ] - }, + "name": "stdout", + "output_type": "stream", + "text": [ + "Dim.named(\"y\")\n", + "Dim.named(\"x\")\n" + ] + } + ], + "source": [ + "# If we multiply two identity matrices, we should be able to get the number of columns of the result\n", + "x = Matrix.identity(Dim.named(\"x\"))\n", + "y = Matrix.identity(Dim.named(\"y\"))\n", + "x_mult_y = x @ y\n", + "print(egraph.simplify(x_mult_y.ncols(), 10))\n", + "print(egraph.simplify(x_mult_y.nrows(), 10))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "2f2c68c3", + "metadata": {}, + "source": [ + "### Operation replacements\n", + "\n", + "We can also define some replacements for matrix operations:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "18a91684", + "metadata": {}, + "outputs": [], + "source": [ + "egraph.register(\n", + " # Multiplication by an identity matrix is the same as the other matrix\n", + " rewrite(A @ Matrix.identity(a)).to(A),\n", + " rewrite(Matrix.identity(a) @ A).to(A),\n", + " # Matrix multiplication is associative\n", + " rewrite((A @ B) @ C).to(A @ (B @ C)),\n", + " rewrite(A @ (B @ C)).to((A @ B) @ C),\n", + " # Kronecker product is associative\n", + " rewrite(kron(A, kron(B, C))).to(kron(kron(A, B), C)),\n", + " rewrite(kron(kron(A, B), C)).to(kron(A, kron(B, C))),\n", + " # Kronecker product distributes over matrix multiplication\n", + " rewrite(kron(A @ C, B @ D)).to(kron(A, B) @ kron(C, D)),\n", + " rewrite(kron(A, B) @ kron(C, D)).to(\n", + " kron(A @ C, B @ D),\n", + " # Only when the dimensions match\n", + " eq(A.ncols()).to(C.nrows()),\n", + " eq(B.ncols()).to(D.nrows()),\n", + " ),\n", + ")" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "1cd649dc", + "metadata": {}, + "source": [ + "In our previous tests, we had to add the `ncols` and `nrows` operations to the e-graph seperately in order to have them be simplified. We can write some \"demand\" rules which automatically add these operations to the e-graph when they are needed:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "303ce7f3", + "metadata": {}, + "outputs": [], + "source": [ + "egraph.register(\n", + " # demand rows and columns when we multiply matrices\n", + " rule(A @ B).then(\n", + " A.ncols(),\n", + " A.nrows(),\n", + " B.nrows(),\n", + " B.ncols(),\n", + " ),\n", + " # demand rows and columns when we take the kronecker product\n", + " rule(kron(A, B)).then(\n", + " A.ncols(),\n", + " A.nrows(),\n", + " B.nrows(),\n", + " B.ncols(),\n", + " ),\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "334a2cc4-0004-415a-a8fb-4c5ef2e26aec", + "metadata": {}, + "source": [ + "For example, if we have `X @ Y` in the egraph, it will add expression for the columns of each as well:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "c79a9105-e8fe-4545-b7c6-262648f82aad", + "metadata": {}, + "outputs": [ { - "attachments": {}, - "cell_type": "markdown", - "id": "bd9e94de", - "metadata": {}, - "source": [ - "We can try this out in our notebook, by multiplying some matrices and checking their dimensions:\n" + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "outer_cluster_1\n", + "\n", + "\n", + "cluster_1\n", + "\n", + "\n", + "\n", + "outer_cluster_0\n", + "\n", + "\n", + "cluster_0\n", + "\n", + "\n", + "\n", + "outer_cluster_2\n", + "\n", + "\n", + "cluster_2\n", + "\n", + "\n", + "\n", + "outer_cluster_String-1316606400713378063\n", + "\n", + "\n", + "cluster_String-1316606400713378063\n", + "\n", + "\n", + "\n", + "outer_cluster_String-4801791173778264996\n", + "\n", + "\n", + "cluster_String-4801791173778264996\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_named-1316606400713378063:s->String-1316606400713378063\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_named-4801791173778264996:s->String-4801791173778264996\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix___matmul__-5871781006564002453:s->Matrix_named-1316606400713378063\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix___matmul__-5871781006564002453:s->Matrix_named-4801791173778264996\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_named-1316606400713378063\n", + "\n", + "\n", + "Matrix_named\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "String-1316606400713378063\n", + "\n", + "\n", + ""X"\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_named-4801791173778264996\n", + "\n", + "\n", + "Matrix_named\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "String-4801791173778264996\n", + "\n", + "\n", + ""Y"\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix___matmul__-5871781006564002453\n", + "\n", + "\n", + "Matrix___matmul__\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "" + ], + "text/plain": [ + "" ] + }, + "metadata": {}, + "output_type": "display_data" }, { - "cell_type": "code", - "execution_count": 16, - "id": "bb50ade6", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "kron(Matrix.named(\"A\"), Matrix.named(\"B\"))" - ] - }, - "execution_count": 16, - "metadata": {}, - "output_type": "execute_result" - } + "data": { + "image/svg+xml": [ + "\n", + "\n", + "\n", + "\n", + "outer_cluster_String-4801791173778264996\n", + "\n", + "\n", + "cluster_String-4801791173778264996\n", + "\n", + "\n", + "\n", + "outer_cluster_String-1316606400713378063\n", + "\n", + "\n", + "cluster_String-1316606400713378063\n", + "\n", + "\n", + "\n", + "outer_cluster_2\n", + "\n", + "\n", + "cluster_2\n", + "\n", + "\n", + "\n", + "outer_cluster_0\n", + "\n", + "\n", + "cluster_0\n", + "\n", + "\n", + "\n", + "outer_cluster_1\n", + "\n", + "\n", + "cluster_1\n", + "\n", + "\n", + "\n", + "outer_cluster_4\n", + "\n", + "\n", + "cluster_4\n", + "\n", + "\n", + "\n", + "outer_cluster_3\n", + "\n", + "\n", + "cluster_3\n", + "\n", + "\n", + "\n", + "outer_cluster_5\n", + "\n", + "\n", + "cluster_5\n", + "\n", + "\n", + "\n", + "outer_cluster_6\n", + "\n", + "\n", + "cluster_6\n", + "\n", + "\n", + "\n", + "\n", + "Matrix___matmul__-5871781006564002453:s->Matrix_named-1316606400713378063\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix___matmul__-5871781006564002453:s->Matrix_named-4801791173778264996\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_named-1316606400713378063:s->String-1316606400713378063\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_named-4801791173778264996:s->String-4801791173778264996\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_nrows-0:s->Matrix_named-1316606400713378063\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_ncols-0:s->Matrix_named-1316606400713378063\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_nrows-5871781006564002453:s->Matrix_named-4801791173778264996\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_ncols-5871781006564002453:s->Matrix_named-4801791173778264996\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix___matmul__-5871781006564002453\n", + "\n", + "\n", + "Matrix___matmul__\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_named-1316606400713378063\n", + "\n", + "\n", + "Matrix_named\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_named-4801791173778264996\n", + "\n", + "\n", + "Matrix_named\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "String-1316606400713378063\n", + "\n", + "\n", + ""X"\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "String-4801791173778264996\n", + "\n", + "\n", + ""Y"\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_nrows-0\n", + "\n", + "\n", + "Matrix_nrows\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_ncols-0\n", + "\n", + "\n", + "Matrix_ncols\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_nrows-5871781006564002453\n", + "\n", + "\n", + "Matrix_nrows\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "Matrix_ncols-5871781006564002453\n", + "\n", + "\n", + "Matrix_ncols\n", + "\n", + "\n", + "\n", + "\n", + "\n", + "" ], - "source": [ - "# Define a number of dimensions\n", - "n, m, p = Dim.named(\"n\"), Dim.named(\"m\"), Dim.named(\"p\")\n", - "# Define a number of matrices\n", - "A, B, C = Matrix.named(\"A\"), Matrix.named(\"B\"), Matrix.named(\"C\")\n", - "# Set each to be a square matrix of the given dimension\n", - "egraph.register(\n", - " union(A.nrows()).with_(n),\n", - " union(A.ncols()).with_(n),\n", - " union(B.nrows()).with_(m),\n", - " union(B.ncols()).with_(m),\n", - " union(C.nrows()).with_(p),\n", - " union(C.ncols()).with_(p),\n", - ")\n", - "# Create an example which should equal the kronecker product of A and B\n", - "ex1 = kron(Matrix.identity(n), B) @ kron(A, Matrix.identity(m))\n", - "egraph.simplify(ex1, 20)" + "text/plain": [ + "" ] - }, + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "with egraph:\n", + " egraph.register(Matrix.named(\"X\") @ Matrix.named(\"Y\"))\n", + " egraph.display()\n", + " egraph.run(1)\n", + " egraph.display()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "bd9e94de", + "metadata": {}, + "source": [ + "We can try this out in our notebook, by multiplying some matrices and checking their dimensions:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "bb50ade6", + "metadata": {}, + "outputs": [ { - "attachments": {}, - "cell_type": "markdown", - "id": "554321e2", - "metadata": {}, - "source": [ - "We can make sure that if the rows/columns do not line up, then the transformation will not be applied:\n" + "data": { + "text/html": [ + "
kron(Matrix.named("A"), Matrix.named("B"))\n",
+       "
\n" + ], + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{n}{kron}\\PY{p}{(}\\PY{n}{Matrix}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{A}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Matrix}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{B}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{)}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "kron(Matrix.named(\"A\"), Matrix.named(\"B\"))" ] - }, + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "# Define a number of dimensions\n", + "n, m, p = Dim.named(\"n\"), Dim.named(\"m\"), Dim.named(\"p\")\n", + "# Define a number of matrices\n", + "A, B, C = Matrix.named(\"A\"), Matrix.named(\"B\"), Matrix.named(\"C\")\n", + "# Set each to be a square matrix of the given dimension\n", + "egraph.register(\n", + " union(A.nrows()).with_(n),\n", + " union(A.ncols()).with_(n),\n", + " union(B.nrows()).with_(m),\n", + " union(B.ncols()).with_(m),\n", + " union(C.nrows()).with_(p),\n", + " union(C.ncols()).with_(p),\n", + ")\n", + "# Create an example which should equal the kronecker product of A and B\n", + "ex1 = kron(Matrix.identity(n), B) @ kron(A, Matrix.identity(m))\n", + "egraph.simplify(ex1, 20)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "554321e2", + "metadata": {}, + "source": [ + "We can make sure that if the rows/columns do not line up, then the transformation will not be applied:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "d8dea199", + "metadata": {}, + "outputs": [ { - "cell_type": "code", - "execution_count": 17, - "id": "d8dea199", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "kron(Matrix.identity(Dim.named(\"p\")), Matrix.named(\"C\")) @ kron(Matrix.named(\"A\"), Matrix.identity(Dim.named(\"m\")))" - ] - }, - "execution_count": 17, - "metadata": {}, - "output_type": "execute_result" - } + "data": { + "text/html": [ + "
kron(Matrix.identity(Dim.named("p")), Matrix.named("C")) @ kron(Matrix.named("A"), Matrix.identity(Dim.named("m")))\n",
+       "
\n" ], - "source": [ - "ex2 = kron(Matrix.identity(p), C) @ kron(A, Matrix.identity(m))\n", - "egraph.simplify(ex2, 20)" + "text/latex": [ + "\\begin{Verbatim}[commandchars=\\\\\\{\\}]\n", + "\\PY{n}{kron}\\PY{p}{(}\\PY{n}{Matrix}\\PY{o}{.}\\PY{n}{identity}\\PY{p}{(}\\PY{n}{Dim}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{p}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Matrix}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{C}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{)} \\PY{o}{@} \\PY{n}{kron}\\PY{p}{(}\\PY{n}{Matrix}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{A}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{,} \\PY{n}{Matrix}\\PY{o}{.}\\PY{n}{identity}\\PY{p}{(}\\PY{n}{Dim}\\PY{o}{.}\\PY{n}{named}\\PY{p}{(}\\PY{l+s+s2}{\\PYZdq{}}\\PY{l+s+s2}{m}\\PY{l+s+s2}{\\PYZdq{}}\\PY{p}{)}\\PY{p}{)}\\PY{p}{)}\n", + "\\end{Verbatim}\n" + ], + "text/plain": [ + "kron(Matrix.identity(Dim.named(\"p\")), Matrix.named(\"C\")) @ kron(Matrix.named(\"A\"), Matrix.identity(Dim.named(\"m\")))" ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "b0b13665", - "metadata": {}, - "outputs": [], - "source": [] - } - ], - "metadata": { - "file_format": "mystnb", - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.10.12" + }, + "metadata": {}, + "output_type": "display_data" } + ], + "source": [ + "ex2 = kron(Matrix.identity(p), C) @ kron(A, Matrix.identity(m))\n", + "egraph.simplify(ex2, 20)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b0b13665", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "file_format": "mystnb", + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" }, - "nbformat": 4, - "nbformat_minor": 5 + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/pyproject.toml b/pyproject.toml index badc204b..28fbd61d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,7 @@ check_untyped_defs = true strict_equality = true warn_unused_configs = true allow_redefinition = true -enable_incomplete_feature = ["Unpack", "TypeVarTuple"] +enable_incomplete_feature = ["Unpack"] exclude = ["__snapshots__", "_build", "^conftest.py$"] [tool.maturin] diff --git a/python/egglog/bindings.pyi b/python/egglog/bindings.pyi index b62ff026..c3236fc9 100644 --- a/python/egglog/bindings.pyi +++ b/python/egglog/bindings.pyi @@ -19,6 +19,7 @@ class EGraph: max_functions: Optional[int] = None, max_calls_per_function: Optional[int] = None, n_inline_leaves: int = 0, + split_primitive_outputs: bool = False, ) -> str: ... def save_object(self, __o: object, /) -> _Expr: ... def load_object(self, __e: _Expr, /) -> object: ... diff --git a/python/egglog/builtins.py b/python/egglog/builtins.py index 39d3765f..51b20db5 100644 --- a/python/egglog/builtins.py +++ b/python/egglog/builtins.py @@ -1,3 +1,4 @@ +# mypy: disable-error-code="empty-body" """ Builtin sorts and function to egg. """ @@ -24,10 +25,11 @@ "join", "PyObject", "py_eval", + "py_exec", ] -StringLike = Union[str, "String"] +StringLike = Union["String", str] @BUILTINS.class_ @@ -48,7 +50,7 @@ def join(*strings: StringLike) -> String: # type: ignore[empty-body] converter(str, String, String) # The types which can be convertered into an i64 -i64Like = Union[int, "i64"] +i64Like = Union["i64", int] @BUILTINS.class_(egg_sort="i64") @@ -159,7 +161,7 @@ def count_matches(s: StringLike, pattern: StringLike) -> i64: # type: ignore[em ... -f64Like = Union[float, "f64"] +f64Like = Union["f64", float] @BUILTINS.class_(egg_sort="f64") @@ -454,8 +456,20 @@ def dict_update(dict, *keys_and_values: PyObject) -> PyObject: # type: ignore[e def from_int(cls, i: i64Like) -> PyObject: # type: ignore[empty-body] ... + @BUILTINS.method(egg_fn="py-dict") + @classmethod + def dict(cls, *keys_and_values: PyObject) -> PyObject: + ... + -# TODO: Maybe move to static method if we implement those? @BUILTINS.function(egg_fn="py-eval") -def py_eval(code: StringLike, locals: PyObject, globals: PyObject) -> PyObject: # type: ignore[empty-body] +def py_eval(code: StringLike, globals: PyObject = PyObject.dict(), locals: PyObject = PyObject.dict()) -> PyObject: # type: ignore[empty-body] + ... + + +@BUILTINS.function(egg_fn="py-exec") +def py_exec(code: StringLike, globals: PyObject = PyObject.dict(), locals: PyObject = PyObject.dict()) -> PyObject: + """ + Copies the locals, execs the Python code, and returns the locals with any updates. + """ ... diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 5c1d59db..1d8b87ba 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -1,14 +1,15 @@ from __future__ import annotations import inspect +import pathlib +import tempfile from abc import ABC, abstractmethod from contextvars import ContextVar, Token from copy import deepcopy from dataclasses import InitVar, dataclass, field from inspect import Parameter, currentframe, signature from types import FunctionType -from typing import _GenericAlias # type: ignore[attr-defined] -from typing import ( +from typing import ( # type: ignore[attr-defined] TYPE_CHECKING, Any, Callable, @@ -18,8 +19,11 @@ Literal, NoReturn, Optional, + Protocol, + TypedDict, TypeVar, Union, + _GenericAlias, cast, get_type_hints, overload, @@ -27,10 +31,11 @@ import graphviz from egglog.declarations import REFLECTED_BINARY_METHODS, Declarations -from typing_extensions import ParamSpec, get_args, get_origin +from typing_extensions import ParamSpec, Unpack, get_args, get_origin from . import bindings from .declarations import * +from .ipython_magic import IN_IPYTHON from .monkeypatch import monkeypatch_forward_ref from .runtime import * from .runtime import _resolve_callable, class_to_ref @@ -93,6 +98,11 @@ ALWAYS_MUTATES_SELF = {"__setitem__", "__delitem__"} +class PyObjectFunction(Protocol): + def __call__(self, *__args: PyObject) -> PyObject: + ... + + @dataclass class _BaseModule(ABC): """ @@ -490,17 +500,10 @@ def _resolve_type_annotation( ) -> TypeOrVarRef: if isinstance(tp, TypeVar): return ClassTypeVarRef(cls_typevars.index(tp)) - # If there is a union, it should be of a literal and another type to allow type promotion + # If there is a union, then we assume the first item is the type we want, and the others are types that can be converted to that type. if get_origin(tp) == Union: - args = get_args(tp) - if len(args) != 2: - raise TypeError("Union types are only supported for type promotion") - fst, snd = args - if fst in {int, str, float}: - return self._resolve_type_annotation(snd, cls_typevars, cls_type_and_name) - if snd in {int, str, float}: - return self._resolve_type_annotation(fst, cls_typevars, cls_type_and_name) - raise TypeError("Union types are only supported for type promotion") + first, *_rest = get_args(tp) + return self._resolve_type_annotation(first, cls_typevars, cls_type_and_name) # If this is the type for the class, use the class name if cls_type_and_name and tp == cls_type_and_name[0]: @@ -665,6 +668,13 @@ def save_object(self, obj: object) -> PyObject: return cast("PyObject", RuntimeExpr(self._mod_decls, typed_expr_decl)) +class GraphvizKwargs(TypedDict, total=False): + max_functions: Optional[int] + max_calls_per_function: Optional[int] + n_inline_leaves: int + split_primitive_outputs: bool + + @dataclass class EGraph(_BaseModule): """ @@ -678,7 +688,7 @@ class EGraph(_BaseModule): _decl_stack: list[Declarations] = field(default_factory=list, repr=False) _token: Optional[Token[EGraph]] = None - def __post_init__(self, modules: list[Module], seminaive) -> None: + def __post_init__(self, modules: list[Module], seminaive: bool) -> None: # type: ignore super().__post_init__(modules) self._egraph = bindings.EGraph(seminaive=seminaive) for m in self._flatted_deps: @@ -696,8 +706,38 @@ def _repr_mimebundle_(self, *args, **kwargs): return {"image/svg+xml": self.graphviz().pipe(format="svg", quiet=True, encoding="utf-8")} - def graphviz(self, **kwargs) -> graphviz.Source: - return graphviz.Source(self._egraph.to_graphviz_string(**kwargs)) + def graphviz(self, **kwargs: Unpack[GraphvizKwargs]) -> graphviz.Source: + original = self._egraph.to_graphviz_string(**kwargs) + # Add link to stylesheet to the graph, so that edges light up on hover + # https://gist.github.com/sverweij/93e324f67310f66a8f5da5c2abe94682 + styles = """/* the lines within the edges */ + .edge:active path, + .edge:hover path { + stroke: fuchsia; + stroke-width: 3; + stroke-opacity: 1; + } + /* arrows are typically drawn with a polygon */ + .edge:active polygon, + .edge:hover polygon { + stroke: fuchsia; + stroke-width: 3; + fill: fuchsia; + stroke-opacity: 1; + fill-opacity: 1; + } + /* If you happen to have text and want to color that as well... */ + .edge:active text, + .edge:hover text { + fill: fuchsia; + }""" + p = pathlib.Path(tempfile.gettempdir()) / "graphviz-styles.css" + p.write_text(styles) + with_stylesheet = original.replace("{", f'{{stylesheet="{str(p)}"', 1) + return graphviz.Source(with_stylesheet) + + def graphviz_svg(self, **kwargs: Unpack[GraphvizKwargs]) -> str: + return self.graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8") def _repr_html_(self) -> str: """ @@ -706,15 +746,19 @@ def _repr_html_(self) -> str: until this PR is merged and released https://github.com/sphinx-gallery/sphinx-gallery/pull/1138 """ - return self.graphviz().pipe(format="svg", quiet=True).decode() + return self.graphviz_svg() - def display(self, **kwargs): + def display(self, **kwargs: Unpack[GraphvizKwargs]): """ Displays the e-graph in the notebook. """ - from IPython.display import SVG, display + graphviz = self.graphviz(**kwargs) + if IN_IPYTHON: + from IPython.display import SVG, display - display(SVG(self.graphviz(**kwargs).pipe(format="svg", quiet=True, encoding="utf-8"))) + display(SVG(self.graphviz_svg(**kwargs))) + else: + graphviz.render(view=True, format="svg", quiet=True) @overload def simplify(self, expr: EXPR, limit: int, /, *until: Fact, ruleset: Optional[Ruleset] = None) -> EXPR: @@ -879,6 +923,31 @@ def load_object(self, obj: PyObject) -> object: expr = typed_expr_decl.to_egg(self._mod_decls) return self._egraph.load_object(expr) + def eval_fn(self, fn: Callable) -> PyObjectFunction: + """ + Takes a python callable and maps it to a callable which takes + and returns PyObjects. + + It translates it to a call which uses `py_eval` to call the function, passing in the + args as locals, and using the globals from function. + """ + from .builtins import py_eval + + fn_globals = self.save_object(fn.__globals__) + fn_locals = self.save_object({"__fn": fn}) + + def inner(*__args: PyObject, __fn_locals=fn_locals) -> PyObject: + new_kvs: list[PyObject] = [] + eval_str = "__fn(" + for i, arg in enumerate(__args): + new_kvs.append(self.save_object(f"__arg_{i}")) + new_kvs.append(arg) + eval_str += f"__arg_{i}, " + eval_str += ")" + return py_eval(eval_str, fn_locals.dict_update(*new_kvs), fn_globals) + + return inner + @classmethod def current(cls) -> EGraph: """ @@ -1103,7 +1172,7 @@ def __str__(self) -> str: def _to_egg_action(self, mod_decls: ModuleDeclarations) -> bindings.Set: egg_call = self._call.__egg_typed_expr__.expr.to_egg(mod_decls) if not isinstance(egg_call, bindings.Call): - raise ValueError(f"Can only create a call with a call for the lhs, got {self._call}") + raise ValueError(f"Can only create a set with a call for the lhs, got {self._call}") return bindings.Set( egg_call.name, egg_call.args, @@ -1467,7 +1536,7 @@ def _action_like(action_like: ActionLike) -> Action: return action_like -FactLike = Union[Fact, Unit] +FactLike = Union[Fact, Expr] def _fact_likes(fact_likes: Iterable[FactLike]) -> tuple[Fact, ...]: diff --git a/python/egglog/examples/lambda.py b/python/egglog/examples/lambda.py index a5888b12..a89e87ea 100644 --- a/python/egglog/examples/lambda.py +++ b/python/egglog/examples/lambda.py @@ -1,8 +1,9 @@ +# mypy: disable-error-code="empty-body" + """ Lambda Calculus =============== """ -# mypy: disable-error-code=empty-body from __future__ import annotations from typing import Callable, ClassVar diff --git a/python/egglog/examples/ndarrays.py b/python/egglog/examples/ndarrays.py index e1a1c610..4f5b9780 100644 --- a/python/egglog/examples/ndarrays.py +++ b/python/egglog/examples/ndarrays.py @@ -1,10 +1,11 @@ +# mypy: disable-error-code="empty-body" + """ N-Dimensional Arrays ==================== Example of building NDarray in the vein of Mathemetics of Arrays. """ -# mypy: disable-error-code=empty-body from __future__ import annotations from egglog import * diff --git a/python/egglog/exp/array_api.py b/python/egglog/exp/array_api.py index f5c5acec..8631c8dd 100644 --- a/python/egglog/exp/array_api.py +++ b/python/egglog/exp/array_api.py @@ -1,4 +1,5 @@ -# mypy: disable-error-code=empty-body +# mypy: disable-error-code="empty-body" + from __future__ import annotations import itertools @@ -11,9 +12,10 @@ import numpy as np from egglog import * from egglog.bindings import EggSmolError -from egglog.egraph import Action from egglog.runtime import RuntimeExpr +from .program_gen import * + # Pretend that exprs are numbers b/c sklearn does isinstance checks numbers.Integral.register(RuntimeExpr) @@ -697,7 +699,7 @@ class NDArray(Expr): def __init__(self, py_array: PyObject) -> None: ... - @array_api_module.method(cost=100) + @array_api_module.method(cost=200) @classmethod def var(cls, name: StringLike) -> NDArray: ... @@ -1485,88 +1487,68 @@ def _size(x: NDArray): # Depends on `np` as a global variable. ## -array_api_module_string = Module([array_api_module]) - - -@array_api_module_string.function(merge=lambda old, new: new, default=i64(0)) -def gensym() -> i64: - ... - - -gensym_var = join("_", gensym().to_string()) - - -def add_line(*v: StringLike) -> Action: - return set_(statements()).to(join(" ", *v, "\n")) - - -incr_gensym = set_(gensym()).to(gensym() + 1) - - -@array_api_module_string.function(merge=lambda old, new: join(old, new), default=String("")) -def statements() -> String: - ... +array_api_module_string = Module([array_api_module, program_gen_module]) @array_api_module_string.function() -def ndarray_expr(x: NDArray) -> String: +def ndarray_program(x: NDArray) -> Program: ... @array_api_module_string.function() -def dtype_expr(x: DType) -> String: +def dtype_program(x: DType) -> Program: ... @array_api_module_string.function() -def tuple_int_expr(x: TupleInt) -> String: +def tuple_int_program(x: TupleInt) -> Program: ... @array_api_module_string.function() -def int_expr(x: Int) -> String: +def int_program(x: Int) -> Program: ... @array_api_module_string.function() -def tuple_value_expr(x: TupleValue) -> String: +def tuple_value_program(x: TupleValue) -> Program: ... @array_api_module_string.function() -def value_expr(x: Value) -> String: +def value_program(x: Value) -> Program: ... array_api_module_string.register( - set_(dtype_expr(DType.float64)).to(String("np.float64")), - set_(dtype_expr(DType.int64)).to(String("np.int64")), + union(dtype_program(DType.float64)).with_(Program("np.float64")), + union(dtype_program(DType.int64)).with_(Program("np.int64")), ) @array_api_module_string.function -def bool_expr(x: Bool) -> String: +def bool_program(x: Bool) -> Program: ... array_api_module_string.register( - set_(bool_expr(TRUE)).to(String("True")), - set_(bool_expr(FALSE)).to(String("False")), + union(bool_program(TRUE)).with_(Program("True")), + union(bool_program(FALSE)).with_(Program("False")), ) @array_api_module_string.function -def float_expr(x: Float) -> String: +def float_program(x: Float) -> Program: ... @array_api_module_string.function -def tuple_ndarray_expr(x: TupleNDArray) -> String: +def tuple_ndarray_program(x: TupleNDArray) -> Program: ... @array_api_module_string.function -def optional_dtype_expr(x: OptionalDType) -> String: +def optional_dtype_program(x: OptionalDType) -> Program: ... @@ -1610,258 +1592,94 @@ def _py_expr( optional_dtype_: OptionalDType, ): # Var - yield rule( - eq(x).to(NDArray.var(s)), - ).then( - set_(lhs=ndarray_expr(x)).to(s), - ) + yield rewrite(ndarray_program(NDArray.var(s))).to(Program(s)) # Asssume dtype z_assumed_dtype = copy(z) - assume_dtype(z_assumed_dtype, dtype=dtype) - yield rule( - eq(x).to(z_assumed_dtype), - eq(z_str).to(ndarray_expr(z)), - eq(dtype_str).to(dtype_expr(dtype)), - ).then( - set_(ndarray_expr(x)).to(z_str), - add_line("assert ", z_str, ".dtype == ", dtype_str), + assume_dtype(z_assumed_dtype, dtype) + z_program = ndarray_program(z) + yield rewrite(ndarray_program(z_assumed_dtype)).to( + z_program.statement(Program("assert ") + z_program + ".dtype == " + dtype_program(dtype)) ) - # assume shape z_assumed_shape = copy(z) assume_shape(z_assumed_shape, ti) - yield rule( - eq(x).to(z_assumed_shape), - eq(z_str).to(ndarray_expr(z)), - eq(ti_str).to(tuple_int_expr(ti)), - ).then( - set_(ndarray_expr(x)).to(z_str), - add_line("assert ", z_str, ".shape == ", ti_str), + yield rewrite(ndarray_program(z_assumed_shape)).to( + z_program.statement(Program("assert ") + z_program + ".shape == " + tuple_int_program(ti)) ) # tuple int - yield rule( - eq(ti).to(ti1 + ti2), - eq(ti_str1).to(tuple_int_expr(ti1)), - eq(ti_str2).to(tuple_int_expr(ti2)), - ).then( - set_(tuple_int_expr(ti)).to(join(ti_str1, " + ", ti_str2)), - ) - yield rule( - eq(ti).to(TupleInt(i)), - eq(i_str).to(int_expr(i)), - ).then( - set_(tuple_int_expr(ti)).to(join("(", i_str, ",)")), - ) + yield rewrite(tuple_int_program(ti1 + ti2)).to(tuple_int_program(ti1) + " + " + tuple_int_program(ti2)) + yield rewrite(tuple_int_program(TupleInt(i))).to(Program("(") + int_program(i) + ",)") # Int - yield rule( - eq(i).to(Int(i64_)), - ).then( - set_(int_expr(i)).to(i64_.to_string()), - ) + yield rewrite(int_program(Int(i64_))).to(Program(i64_.to_string())) # assume isfinite z_assumed_isfinite = copy(z) assume_isfinite(z_assumed_isfinite) - yield rule( - eq(x).to(z_assumed_isfinite), - eq(z_str).to(ndarray_expr(z)), - ).then( - set_(ndarray_expr(x)).to(z_str), - add_line("assert np.all(np.isfinite(", z_str, "))"), + yield rewrite(ndarray_program(z_assumed_isfinite)).to( + z_program.statement(Program("assert np.all(np.isfinite(") + z_program + "))") ) # Assume value_one_of z_assumed_value_one_of = copy(z) assume_value_one_of(z_assumed_value_one_of, tv) - yield rule( - eq(x).to(z_assumed_value_one_of), - # not_traversed(x), - eq(z_str).to(ndarray_expr(z)), - eq(tv_str).to(tuple_value_expr(tv)), - ).then( - set_(ndarray_expr(x)).to(z_str), - # traverse(x), - add_line("assert set(", z_str, ".flatten()) == set(", tv_str, ")"), + yield rewrite(ndarray_program(z_assumed_value_one_of)).to( + z_program.statement(Program("assert set(") + z_program + ".flatten()) == set(" + tuple_value_program(tv) + ")") ) - # print(r._to_egg_command(array_api_module_string._mod_decls)) - # yield r + # tuple values - yield rule( - eq(tv).to(tv1 + tv2), - eq(tv1_str).to(tuple_value_expr(tv1)), - eq(tv2_str).to(tuple_value_expr(tv2)), - ).then( - set_(tuple_value_expr(tv)).to(join(tv1_str, " + ", tv2_str)), - ) - yield rule( - eq(tv).to(TupleValue(v)), - eq(v_str).to(value_expr(v)), - ).then( - set_(tuple_value_expr(tv)).to(join("(", v_str, ",)")), - ) + yield rewrite(tuple_value_program(tv1 + tv2)).to(tuple_value_program(tv1) + " + " + tuple_value_program(tv2)) + yield rewrite(tuple_value_program(TupleValue(v))).to(Program("(") + value_program(v) + ",)") # Value - yield rule( - eq(v).to(Value.int(i)), - eq(i_str).to(int_expr(i)), - ).then( - set_(value_expr(v)).to(i_str), - ) - yield rule( - eq(v).to(Value.bool(b)), - eq(b_str).to(bool_expr(b)), - ).then( - set_(value_expr(v)).to(b_str), - ) - yield rule( - eq(v).to(Value.float(f)), - eq(f_str).to(float_expr(f)), - ).then( - set_(value_expr(v)).to(f_str), - ) + yield rewrite(value_program(Value.int(i))).to(int_program(i)) + yield rewrite(value_program(Value.bool(b))).to(bool_program(b)) + yield rewrite(value_program(Value.float(f))).to(float_program(f)) # Float - yield rule( - eq(f).to(Float(f64_)), - ).then( - set_(float_expr(f)).to(f64_.to_string()), - ) + yield rewrite(float_program(Float(f64_))).to(Program(f64_.to_string())) # reshape (don't include copy, since not present in numpy) - yield rule( - eq(x).to(reshape(y, ti, ob)), - eq(y_str).to(ndarray_expr(y)), - eq(ti_str).to(tuple_int_expr(ti)), - ).then( - set_(ndarray_expr(x)).to(gensym_var), - add_line(gensym_var, " = ", y_str, ".reshape(", ti_str, ")"), - incr_gensym, + yield rewrite(ndarray_program(reshape(y, ti, ob))).to( + (ndarray_program(y) + ".reshape(" + tuple_int_program(ti) + ")").assign() ) # astype - yield rule( - eq(x).to(astype(y, dtype)), - eq(y_str).to(ndarray_expr(y)), - eq(dtype_str).to(dtype_expr(dtype)), - ).then( - set_(ndarray_expr(x)).to(gensym_var), - add_line(gensym_var, " = ", y_str, ".astype(", dtype_str, ")"), - incr_gensym, + yield rewrite(ndarray_program(astype(y, dtype))).to( + (ndarray_program(y) + ".astype(" + dtype_program(dtype) + ")").assign() ) # unique_counts(x) => unique(x, return_counts=True) - yield rule( - eq(tnd).to(unique_counts(y)), - eq(y_str).to(ndarray_expr(y)), - ).then( - set_(tuple_ndarray_expr(tnd)).to(gensym_var), - add_line(gensym_var, " = np.unique(", y_str, ", return_counts=True)"), - incr_gensym, + yield rewrite(tuple_ndarray_program(unique_counts(x))).to( + (Program("np.unique(") + ndarray_program(x) + ", return_counts=True)").assign() ) + # Tuple ndarray indexing - yield rule( - eq(x).to(tnd[i]), - eq(tnd_str).to(tuple_ndarray_expr(tnd)), - eq(i_str).to(int_expr(i)), - ).then( - set_(ndarray_expr(x)).to(join(tnd_str, "[", i_str, "]")), - ) + yield rewrite(ndarray_program(tnd[i])).to(tuple_ndarray_program(tnd) + "[" + int_program(i) + "]") # ndarray scalar # TODO: Use dtype and shape and indexing instead? - yield rule( - eq(x).to(NDArray.scalar(v)), - eq(v_str).to(value_expr(v)), - ).then( - set_(ndarray_expr(x)).to(gensym_var), - add_line(gensym_var, " = np.array(", v_str, ")"), - incr_gensym, - ) + # TODO: SPecify dtype? + yield rewrite(ndarray_program(NDArray.scalar(v))).to(Program("np.array(") + value_program(v) + ")") # zeros - yield rule( - eq(x).to(zeros(ti, optional_dtype_, optional_device_)), - eq(ti_str).to(tuple_int_expr(ti)), - eq(dtype_str).to(optional_dtype_expr(optional_dtype_)), - ).then( - set_(ndarray_expr(x)).to(gensym_var), - add_line(gensym_var, " = np.zeros(", ti_str, ", dtype=", dtype_str, ")"), - incr_gensym, + yield rewrite(ndarray_program(zeros(ti, optional_dtype_, optional_device_))).to( + ( + Program("np.zeros(") + tuple_int_program(ti) + ", dtype=" + optional_dtype_program(optional_dtype_) + ")" + ).assign() ) # Optional dtype - yield rule( - eq(optional_dtype_).to(OptionalDType.none), - ).then( - set_(optional_dtype_expr(optional_dtype_)).to(String("None")), - ) - yield rule( - eq(optional_dtype_).to(OptionalDType.some(dtype)), - eq(dtype_str).to(dtype_expr(dtype)), - ).then( - set_(optional_dtype_expr(optional_dtype_)).to(dtype_str), - ) + yield rewrite(optional_dtype_program(OptionalDType.none)).to(Program("None")) + yield rewrite(optional_dtype_program(OptionalDType.some(dtype))).to(dtype_program(dtype)) # unique_values - yield rule( - eq(x).to(unique_values(y)), - eq(y_str).to(ndarray_expr(y)), - ).then( - set_(ndarray_expr(x)).to(gensym_var), - add_line(gensym_var, " = np.unique(", y_str, ")"), - incr_gensym, - ) + yield rewrite(ndarray_program(unique_values(x))).to((Program("np.unique(") + ndarray_program(x) + ")").assign()) # reshape # NDARRAy ops - - yield rule( - eq(x).to(y + z), - eq(y_str).to(ndarray_expr(y)), - eq(z_str).to(ndarray_expr(z)), - ).then( - set_(ndarray_expr(x)).to(gensym_var), - add_line(gensym_var, " = ", y_str, " + ", z_str), - incr_gensym, - ) - - yield rule( - eq(x).to(y / z), - eq(y_str).to(ndarray_expr(y)), - eq(z_str).to(ndarray_expr(z)), - ).then( - set_(ndarray_expr(x)).to(gensym_var), - add_line(gensym_var, " = ", y_str, " / ", z_str), - incr_gensym, - ) - - -@array_api_module_string.class_ -class FunctionExprTwo(Expr): - """ - Python expression that takes two NDArrays as arguments and returns an NDArray. - """ - - def __init__(self, name: StringLike, res: NDArray, arg_1: NDArray, arg_2: NDArray) -> None: - ... - - @property - def source(self) -> String: - ... - - -fn_ruleset = array_api_module_string.ruleset("fn") - - -@array_api_module_string.register -def _function_expr(name: String, res: NDArray, arg1: String, arg2: String, f: FunctionExprTwo, s: String): - yield rule( - eq(f).to(FunctionExprTwo(name, res, NDArray.var(arg1), NDArray.var(arg2))), - ruleset=fn_ruleset, - ).then( - set_(f.source).to( - join("def ", name, "(", arg1, ", ", arg2, "):\n", statements(), " return ", ndarray_expr(res), "\n") - ), - ) + yield rewrite(ndarray_program(x + y)).to((ndarray_program(x) + " + " + ndarray_program(y)).assign()) + yield rewrite(ndarray_program(x - y)).to((ndarray_program(x) + " - " + ndarray_program(y)).assign()) + yield rewrite(ndarray_program(x * y)).to((ndarray_program(x) + " * " + ndarray_program(y)).assign()) + yield rewrite(ndarray_program(x / y)).to((ndarray_program(x) + " / " + ndarray_program(y)).assign()) diff --git a/python/egglog/exp/program_gen.py b/python/egglog/exp/program_gen.py new file mode 100644 index 00000000..07dca96a --- /dev/null +++ b/python/egglog/exp/program_gen.py @@ -0,0 +1,343 @@ +# mypy: disable-error-code="empty-body" +""" +Builds up imperative string expressions from a functional expression. +""" +from __future__ import annotations + +from typing import Union + +from egglog import * + +program_gen_module = Module() + +ProgramLike = Union["Program", StringLike] + + +@program_gen_module.class_ +class Program(Expr): + """ + Semanticallly represents an expression with a number of ordered statements that it depends on to run. + + The expression and statements are all represented as strings. + """ + + def __init__(self, expr: StringLike) -> None: + """ + Create a program based on a string expression. + """ + ... + + def __add__(self, other: ProgramLike) -> Program: + """ + Concats the strings of the two expressions and also the statements. + """ + ... + + def statement(self, statement: ProgramLike) -> Program: + """ + Uses the expression of the statement and adds that as a statement to the program. + """ + ... + + def assign(self) -> Program: + """ + Returns a new program with the expression assigned to a gensym. + """ + ... + + def function_two(self, arg1: ProgramLike, arg2: ProgramLike, name: StringLike = String("__fn")) -> Program: + """ + Returns a new program defining a function with two arguments. + """ + ... + + def expr_to_statement(self) -> Program: + """ + Returns a new program with the expression as a statement and the new expression empty. + """ + ... + + @property + def expr(self) -> String: + """ + Returns the expression of the program, if it's been compiled + """ + ... + + @property + def statements(self) -> String: + """ + Returns the statements of the program, if it's been compiled + """ + ... + + @property + def next_sym(self) -> i64: + """ + Returns the next gensym to use. This is set after calling `compile(i)` on a program. + """ + ... + + @program_gen_module.method(default=Unit()) + def compile(self, next_sym: i64 = i64(0)) -> Unit: + """ + Triggers compilation of the program. + """ + + @program_gen_module.method(merge=lambda old, new: old, cost=1000) # type: ignore[misc] + @property + def parent(self) -> Program: + """ + Returns the parent of the program, if it's been compiled into the parent. + + Only keeps the original parent, not any additional ones, so that each set of statements is only added once. + """ + ... + + @program_gen_module.method(default=Unit()) + def eval_py_object(self, globals: PyObject) -> Unit: + """ + Evaluates the program and saves as the py_object + """ + + # Only allow it to be set once, b/c hash of functions not stable + @program_gen_module.method(merge=lambda old, new: old) # type: ignore[misc] + @property + def py_object(self) -> PyObject: + """ + Returns the python object of the program, if it's been evaluated. + """ + ... + + +converter(String, Program, Program) + + +@program_gen_module.register +def _py_object(p: Program, expr: String, statements: String, g: PyObject): + # When we evaluate a program, we first want to compile to a string + yield rule(p.eval_py_object(g)).then(p.compile()) + # Then we want to evaluate the statements/expr + yield rule(p.eval_py_object(g), eq(p.statements).to(statements), eq(p.expr).to(expr)).then( + set_(p.py_object).to( + py_eval( + "l['___res']", + PyObject.dict(PyObject.from_string("l"), py_exec(join(statements, "\n", "___res = ", expr), g)), + ) + ) + ) + + +@program_gen_module.register +def _compile( + s: String, + s1: String, + s2: String, + s3: String, + s4: String, + s5: String, + p: Program, + p1: Program, + p2: Program, + p3: Program, + # c: Compiler, + statements: Program, + expr: Program, + i: i64, + m: Map[Program, Program], +): + # Combining two strings is just joining them + # yield rewrite(Program(s1) + Program(s2)).to(Program(join(s1, s2))) + + # Compiling a string just gives that string + program_expr = Program(s) + yield rule(program_expr.compile(i)).then( + set_(program_expr.expr).to(s), + set_(program_expr.statements).to(String("")), + set_(program_expr.next_sym).to(i), + ) + + ## + # Statement + ## + # Compiling a statement means that we should use the expression of the statement as a statement and use the expression of the first + yield rewrite(p1.statement(p2)).to(p1 + p2.expr_to_statement()) + + ## + # Expr to statement + ## + stmt = p1.expr_to_statement() + # 1. Set parent + yield rule(eq(p).to(stmt), p.compile(i)).then(set_(p1.parent).to(p)) + # 2. Compile p1 if parent set + yield rule(eq(p).to(stmt), p.compile(i), eq(p1.parent).to(stmt)).then(p1.compile(i)) + # 3.a. If parent not set, set statements to expr + yield rule( + eq(p).to(stmt), + p.compile(i), + p1.parent != p, + eq(s1).to(p1.expr), + ).then( + set_(p.statements).to(join(s1, "\n")), + set_(p.next_sym).to(i), + set_(p.expr).to(String("")), + ) + # 3.b. If parent set, set statements to expr + statements + yield rule( + eq(p).to(stmt), + eq(p1.parent).to(stmt), + eq(s1).to(p1.expr), + eq(s2).to(p1.statements), + eq(i).to(p1.next_sym), + ).then( + set_(p.statements).to(join(s2, s1, "\n")), + set_(p.next_sym).to(i), + set_(p.expr).to(String("")), + ) + + ## + # Addition + ## + + # Compiling an addition is the same as compiling one, then the other, then setting the expression as the addition + # of the two + program_add = p1 + p2 + + # Set parent of p1 + yield rule(eq(p).to(program_add), p.compile(i)).then(set_(p1.parent).to(p)) + + # Compile p1, if p1 parent equal + yield rule(eq(p).to(program_add), p.compile(i), eq(p1.parent).to(program_add)).then(p1.compile(i)) + + # Set parent of p2, once p1 compiled + yield rule(eq(p).to(program_add), p1.next_sym).then(set_(p2.parent).to(p)) + + # Compile p2, if p1 parent not equal, but p2 parent equal + yield rule(eq(p).to(program_add), p.compile(i), p1.parent != p, eq(p2.parent).to(p)).then(p2.compile(i)) + + # Compile p2, if p1 parent eqal + yield rule(eq(p).to(program_add), eq(p1.parent).to(program_add), eq(i).to(p1.next_sym), eq(p2.parent).to(p)).then( + p2.compile(i) + ) + + # Set p expr to join of p1 and p2 + yield rule( + eq(p).to(program_add), + eq(s1).to(p1.expr), + eq(s2).to(p2.expr), + ).then( + set_(p.expr).to(join(s1, s2)), + ) + # Set p statements to join and next sym to p2 if both parents set + yield rule( + eq(p).to(program_add), + eq(p1.parent).to(p), + eq(p2.parent).to(p), + eq(s1).to(p1.statements), + eq(s2).to(p2.statements), + eq(i).to(p2.next_sym), + ).then( + set_(p.statements).to(join(s1, s2)), + set_(p.next_sym).to(i), + ) + # Set p statements to empty and next sym to i if neither parents set + yield rule( + eq(p).to(program_add), + p.compile(i), + p1.parent != p, + p2.parent != p, + ).then( + set_(p.statements).to(String("")), + set_(p.next_sym).to(i), + ) + # Set p statements to p1 and next sym to p1 if p1 parent set and p2 parent not set + yield rule( + eq(p).to(program_add), + eq(p1.parent).to(p), + p2.parent != p, + eq(s1).to(p1.statements), + eq(i).to(p1.next_sym), + ).then( + set_(p.statements).to(s1), + set_(p.next_sym).to(i), + ) + # Set p statements to p2 and next sym to p2 if p2 parent set and p1 parent not set + yield rule( + eq(p).to(program_add), + eq(p2.parent).to(p), + p1.parent != p, + eq(s2).to(p2.statements), + eq(i).to(p2.next_sym), + ).then( + set_(p.statements).to(s2), + set_(p.next_sym).to(i), + ) + + ## + # Assign + ## + + # Compiling an assign is the same as compiling the expression, adding an assign statement, then setting the + # expression as the gensym + program_assign = p1.assign() + # Set parent + yield rule(eq(p).to(program_assign), p.compile(i)).then(set_(p1.parent).to(p)) + # If parent set, compile the expression + yield rule(eq(p).to(program_assign), p.compile(i), eq(p1.parent).to(program_assign)).then(p1.compile(i)) + + # If p1 parent is p, then use statements of p, next sym of p + symbol = join(String("_"), i.to_string()) + yield rule( + eq(p).to(program_assign), + eq(p1.parent).to(p), + eq(s1).to(p1.statements), + eq(i).to(p1.next_sym), + eq(s2).to(p1.expr), + ).then( + set_(p.statements).to(join(s1, symbol, " = ", s2, "\n")), + set_(p.expr).to(symbol), + set_(p.next_sym).to(i + 1), + ) + # If p1 parent is not p, then just use assign as statement, next sym of i + yield rule( + eq(p).to(program_assign), + p1.parent != p, + p.compile(i), + eq(s2).to(p1.expr), + ).then( + set_(p.statements).to(join(symbol, " = ", s2, "\n")), + set_(p.expr).to(symbol), + set_(p.next_sym).to(i + 1), + ) + + ## + # Function two + + # When compiling a function, the two args, p2 and p3, should get compiled when we compile p1, and should just be vars. + fn_two = p1.function_two(p2, p3, s1) + # 1. Set parents of both args to p and compile them + # Assumes that this if the first thing to compile, so no need to check, and assumes that compiling args doesn't result in any + # change in the next sym + yield rule(eq(p).to(fn_two), p.compile(i)).then( + set_(p2.parent).to(p), + set_(p3.parent).to(p), + set_(p1.parent).to(p), + p2.compile(i), + p3.compile(i), + p1.compile(i), + ) + # 2. Set statements to function body and the next sym to i + yield rule( + eq(p).to(fn_two), + p.compile(i), + eq(s2).to(p1.expr), + eq(s3).to(p1.statements), + eq(s4).to(p2.expr), + eq(s5).to(p3.expr), + ).then( + set_(p.statements).to( + join("def ", s1, "(", s4, ", ", s5, "):\n ", s3.replace("\n", "\n "), "return ", s2, "\n") + ), + set_(p.next_sym).to(i), + set_(p.expr).to(s1), + ) diff --git a/python/egglog/ipython_magic.py b/python/egglog/ipython_magic.py index a080e5e8..cb7472a6 100644 --- a/python/egglog/ipython_magic.py +++ b/python/egglog/ipython_magic.py @@ -4,11 +4,11 @@ try: get_ipython() # type: ignore[name-defined] - in_ipython = True + IN_IPYTHON = True except NameError: - in_ipython = False + IN_IPYTHON = False -if in_ipython: +if IN_IPYTHON: import graphviz from IPython.core.magic import needs_local_scope, register_cell_magic diff --git a/python/tests/__snapshots__/test_array_api/test_sklearn_lda.py b/python/tests/__snapshots__/test_array_api/test_sklearn_lda.py index 8263e212..26c6da6c 100644 --- a/python/tests/__snapshots__/test_array_api/test_sklearn_lda.py +++ b/python/tests/__snapshots__/test_array_api/test_sklearn_lda.py @@ -28,7 +28,7 @@ _NDArray_7 = std(_NDArray_6, OptionalIntOrTuple.int(Int(0))) _NDArray_7[ndarray_index(std(_NDArray_6, OptionalIntOrTuple.int(Int(0))) == NDArray.scalar(Value.int(Int(0))))] = NDArray.scalar(Value.float(Float(1.0))) _TupleNDArray_1 = svd(sqrt(NDArray.scalar(Value.int(NDArray.scalar(Value.float(Float(1.0))).to_int() / Int(147)))) * (_NDArray_6 / _NDArray_7), FALSE) -_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(sum(astype(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001))), DType.int32)).to_int())) +_Slice_1 = Slice(OptionalInt.none, OptionalInt.some(astype(sum(_TupleNDArray_1[Int(1)] > NDArray.scalar(Value.float(Float(0.0001)))), DType.int32).to_int())) _NDArray_8 = (_TupleNDArray_1[Int(2)][IndexKey.multi_axis(MultiAxisIndexKey(MultiAxisIndexKeyItem.slice(_Slice_1)) + _MultiAxisIndexKey_1)] / _NDArray_7).T / _TupleNDArray_1[ Int(1) ][IndexKey.slice(_Slice_1)] @@ -49,8 +49,8 @@ Slice( OptionalInt.none, OptionalInt.some( - sum( - astype(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))]), DType.int32) + astype( + sum(_TupleNDArray_2[Int(1)] > (NDArray.scalar(Value.float(Float(0.0001))) * _TupleNDArray_2[Int(1)][IndexKey.int(Int(0))])), DType.int32 ).to_int() ), ) diff --git a/python/tests/__snapshots__/test_array_api/test_to_source.py b/python/tests/__snapshots__/test_array_api/test_to_source.py index a867c69f..68d6db71 100644 --- a/python/tests/__snapshots__/test_array_api/test_to_source.py +++ b/python/tests/__snapshots__/test_array_api/test_to_source.py @@ -1,27 +1,12 @@ -def my_fn(X, y): +def __fn(X, y): assert y.dtype == np.int64 - assert X.dtype == np.float64 - assert y.dtype == np.int64 - assert X.dtype == np.float64 - _0 = np.array(150.0) - assert y.shape == (150,) - assert y.shape == (150,) - assert X.shape == (150,) + (4,) - assert X.shape == (150,) + (4,) assert y.shape == (150,) - assert y.shape == (150,) - assert set(y.flatten()) == set((0,) + (1,) + (2,)) assert set(y.flatten()) == set((0,) + (1,) + (2,)) - _1 = y.reshape((-1,)) - _1 = y.reshape((-1,)) - _2 = np.unique(_1, return_counts=True) - _2 = np.unique(_1, return_counts=True) - _3 = np.unique(_1) - _4 = _2[1].astype(np.float64) - _5 = _4 / _0 - _6 = np.zeros((3,) + (4,), dtype=np.float64) - _6 = np.zeros((3,) + (4,), dtype=np.float64) - _7 = _5 + X - _8 = _7 + _6 - _8 = _7 + _6 - return _8 + _0 = y.reshape((-1,)) + _1 = np.zeros((3,) + (4,), dtype=np.float64) + _2 = _0 + _1 + _3 = np.unique(_0, return_counts=True) + _4 = _3[1].astype(np.float64) + _5 = _4 / np.array(150.0) + _6 = _2 + _5 + return _6 diff --git a/python/tests/__snapshots__/test_program_gen/test_to_string.py b/python/tests/__snapshots__/test_program_gen/test_to_string.py new file mode 100644 index 00000000..fe466ad7 --- /dev/null +++ b/python/tests/__snapshots__/test_program_gen/test_to_string.py @@ -0,0 +1,7 @@ +def my_fn(x, y): + _0 = -x + assert _0 > 0 + _1 = _0 + y + _2 = _1 + 2 + _3 = _2 + _1 + return _3 diff --git a/python/tests/test_array_api.py b/python/tests/test_array_api.py index 9cd25e55..057b3e73 100644 --- a/python/tests/test_array_api.py +++ b/python/tests/test_array_api.py @@ -1,4 +1,3 @@ -import pytest from egglog.exp.array_api import * @@ -28,6 +27,8 @@ def test_tuple_value_includes(): def test_to_source(snapshot_py): + import numpy + _NDArray_1 = NDArray.var("X") X_orig = copy(_NDArray_1) assume_dtype(_NDArray_1, DType.float64) @@ -47,26 +48,23 @@ def test_to_source(snapshot_py): OptionalDType.some(DType.float64), OptionalDevice.some(_NDArray_1.device), ) - res = _NDArray_4 + _NDArray_1 + _NDArray_5 - + res = _NDArray_3 + _NDArray_5 + _NDArray_4 egraph = EGraph([array_api_module_string]) - fn = egraph.let("fn", FunctionExprTwo("my_fn", res, X_orig, Y_orig)) - - egraph.run(20) - # while egraph.run((run())).updated: - # print(egraph.load_object(egraph.extract(PyObject.from_string(statements())))) - # egraph.graphviz().render(view=True) - # egraph.graphviz(n_inline_leaves=3).render("inlined", view=True) - - egraph.run(run(fn_ruleset)) - fn_source = egraph.load_object(egraph.extract(PyObject.from_string(fn.source))) + with egraph: + egraph.register(res) + egraph.run(10000) + res = egraph.extract(res) + fn = ndarray_program(res).function_two(ndarray_program(X_orig), ndarray_program(Y_orig)) + with egraph: + egraph.register(fn.eval_py_object(egraph.save_object({"np": numpy}))) + egraph.run(10000) + fn = egraph.extract(fn) + # egraph.display(n_inline_leaves=0, split_primitive_outputs=True) + fn_source = egraph.load_object(egraph.extract(PyObject.from_string(fn.statements))) assert fn_source == snapshot_py - locals_: dict[str, object] = {} - exec(fn_source, {"np": np}, locals_) # type: ignore - fn: object = locals_["my_fn"] -@pytest.mark.xfail(reason="unstable output") +# @pytest.mark.xfail(raises=TODO) def test_sklearn_lda(snapshot_py): from sklearn import config_context from sklearn.discriminant_analysis import LinearDiscriminantAnalysis @@ -88,8 +86,7 @@ def test_sklearn_lda(snapshot_py): with EGraph([array_api_module]) as egraph: egraph.register(X_r2) - egraph.run((run() * 10)) - # egraph.run((run() * 10).saturate()) + egraph.run((run() * 10).saturate()) # egraph.graphviz(n_inline_leaves=3).render("3", view=True) res = egraph.extract(X_r2) diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 19095e15..fc1fb304 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -312,6 +312,16 @@ def test_eval_local(self): res_simpl = egraph.simplify(res, 1) assert egraph.load_object(res_simpl) == "hithere" + def test_exec(self): + egraph = EGraph() + res = egraph.simplify(py_exec("x = 10"), 1) + assert egraph.load_object(res) == {"x": 10} + + def test_exec_globals(self): + egraph = EGraph() + res = egraph.simplify(py_exec("x = y + 1", egraph.save_object({"y": 10})), 1) + assert egraph.load_object(res) == {"x": 11} + def my_add(a, b): return a + b @@ -430,81 +440,3 @@ def __radd__(self, other: Math) -> Math: JustTypeRef("Math"), CallDecl(MethodRef("Math", "__add__"), (expr_parts(Math(i64(10))), expr_parts(Math(i64(5))))), ) - - -@pytest.mark.xfail(reason="https://github.com/egraphs-good/egglog/issues/229") -def test_imperative(): - egraph = EGraph(seminaive=False) - - @egraph.function(merge=lambda old, new: join(old, new), default=String("")) - def statements() -> String: - ... - - @egraph.function(merge=lambda old, new: old + new, default=i64(0)) - def gensym() -> i64: - ... - - gensym_var = join("_", gensym().to_string()) - - @egraph.class_ - class Math(Expr): - @egraph.method(egg_fn="Num") - def __init__(self, value: i64Like) -> None: - ... - - @egraph.method(egg_fn="Var") - @classmethod - def var(cls, v: StringLike) -> Math: - ... - - @egraph.method(egg_fn="Add") - def __add__(self, other: Math) -> Math: - ... - - @egraph.method(egg_fn="Mul") - def __mul__(self, other: Math) -> Math: - ... - - @egraph.method(egg_fn="expr") # type: ignore[misc] - @property - def expr(self) -> String: - ... - - @egraph.register - def _rules(s: String, y_expr: String, z_expr: String, x: Math, i: i64, y: Math, z: Math): - yield rule( - eq(x).to(Math.var(s)), - ).then( - set_(x.expr).to(s), - ) - - yield rule( - eq(x).to(Math(i)), - ).then( - set_(x.expr).to(i.to_string()), - ) - - yield rule( - eq(x).to(y + z), - eq(y_expr).to(y.expr), - eq(z_expr).to(z.expr), - ).then( - set_(x.expr).to(gensym_var), - set_(statements()).to(join(gensym_var, " = ", y_expr, " + ", z_expr, "\n")), - set_(gensym()).to(i64(1)), - ) - yield rule( - eq(x).to(y * z), - eq(y_expr).to(y.expr), - eq(z_expr).to(z.expr), - ).then( - set_(x.expr).to(gensym_var), - set_(statements()).to(join(gensym_var, " = ", y_expr, " * ", z_expr, "\n")), - set_(gensym()).to(i64(1)), - ) - - y = egraph.let("y", Math(2) * (Math.var("x") + Math(3))) - - egraph.run(3) - egraph.check(eq(y.expr).to(String("_1"))) - egraph.check(eq(statements()).to(String("_0 = x + 3\n_1 = 2 * _0\n"))) diff --git a/python/tests/test_program_gen.py b/python/tests/test_program_gen.py new file mode 100644 index 00000000..e7cc001b --- /dev/null +++ b/python/tests/test_program_gen.py @@ -0,0 +1,85 @@ +# mypy: disable-error-code="empty-body" +from __future__ import annotations + +from egglog import * +from egglog.exp.program_gen import * + +egraph = EGraph([program_gen_module]) + + +@egraph.class_ +class Math(Expr): + def __init__(self, value: i64Like) -> None: + ... + + @classmethod + def var(cls, v: StringLike) -> Math: + ... + + def __add__(self, other: Math) -> Math: + ... + + def __mul__(self, other: Math) -> Math: + ... + + def __neg__(self) -> Math: + ... + + @egraph.method(cost=1000) # type: ignore + @property + def program(self) -> Program: + ... + + +@egraph.function +def assume_pos(x: Math) -> Math: + ... + + +@egraph.register +def _rules( + s: String, + y_expr: String, + z_expr: String, + old_statements: String, + x: Math, + i: i64, + y: Math, + z: Math, + old_gensym: i64, +): + yield rewrite(Math.var(s).program).to(Program(s)) + yield rewrite(Math(i).program).to(Program(i.to_string())) + yield rewrite((y + z).program).to((y.program + " + " + z.program).assign()) + yield rewrite((y * z).program).to((y.program + " * " + z.program).assign()) + yield rewrite((-y).program).to(Program("-") + y.program) + assigned_x = x.program.assign() + yield rewrite(assume_pos(x).program).to(assigned_x.statement(Program("assert ") + assigned_x + " > 0")) + + +def test_to_string(snapshot_py) -> None: + first = assume_pos(-Math.var("x")) + Math.var("y") + fn = (first + Math(2) + first).program.function_two(Math.var("x").program, Math.var("y").program, "my_fn") + with egraph: + egraph.register(fn) + egraph.run(200) + fn = egraph.extract(fn) + egraph.register(fn) + egraph.register(fn.compile()) + egraph.run(200) + # egraph.display(n_inline_leaves=1) + expr = egraph.load_object(egraph.extract(PyObject.from_string(fn.expr))) + assert expr == "my_fn" # type: ignore + stmts = egraph.load_object(egraph.extract(PyObject.from_string(fn.statements))) + assert stmts == snapshot_py # type: ignore + + +def test_py_object(): + x = Math.var("x") + y = Math.var("y") + z = Math.var("z") + fn = (x + y + z).program.function_two(x.program, y.program) + egraph.register(fn.eval_py_object(egraph.save_object({"z": 10}))) + egraph.run(100) + res = egraph.load_object(egraph.extract(fn.py_object)) + assert res(1, 2) == 13 # type: ignore diff --git a/src/egraph.rs b/src/egraph.rs index aae4c1bc..4f805c23 100644 --- a/src/egraph.rs +++ b/src/egraph.rs @@ -83,14 +83,15 @@ impl EGraph { /// Returns the EGraph as graphviz string. #[pyo3( - signature = (*, max_functions=None, max_calls_per_function=None, n_inline_leaves=0), - text_signature = "(self, *, max_functions=None, max_calls_per_function=None, n_inline_leaves=0)" + signature = (*, max_functions=None, max_calls_per_function=None, n_inline_leaves=0, split_primitive_outputs=false), + text_signature = "(self, *, max_functions=None, max_calls_per_function=None, n_inline_leaves=0, split_primitive_outputs=False)" )] fn to_graphviz_string( &self, max_functions: Option, max_calls_per_function: Option, n_inline_leaves: usize, + split_primitive_outputs: bool, ) -> String { info!("Getting graphviz"); // TODO: Expose full serialized e-graph in the future @@ -98,6 +99,7 @@ impl EGraph { max_functions, max_calls_per_function, include_temporary_functions: false, + split_primitive_outputs, }); for _ in 0..n_inline_leaves { serialized.inline_leaves(); diff --git a/src/py_object_sort.rs b/src/py_object_sort.rs index b71a2cc7..7ec9ccd9 100644 --- a/src/py_object_sort.rs +++ b/src/py_object_sort.rs @@ -90,6 +90,15 @@ impl Sort for PyObjectSort { py_object: self.clone(), string: typeinfo.get_sort(), }); + typeinfo.add_primitive(Exec { + name: "py-exec".into(), + py_object: self.clone(), + string: typeinfo.get_sort(), + }); + typeinfo.add_primitive(Dict { + name: "py-dict".into(), + py_object: self.clone(), + }); typeinfo.add_primitive(DictUpdate { name: "py-dict-update".into(), py_object: self.clone(), @@ -202,6 +211,86 @@ impl PrimitiveLike for Eval { } } +/// Copies the locals, execs the Python string, then returns the copied version of the locals with any updates +/// (py-exec ) +struct Exec { + name: Symbol, + py_object: Arc, + string: Arc, +} + +impl PrimitiveLike for Exec { + fn name(&self) -> Symbol { + self.name + } + + fn accept(&self, types: &[ArcSort]) -> Option { + match types { + [str, locals, globals] + if str.name() == self.string.name() + && locals.name() == self.py_object.name() + && globals.name() == self.py_object.name() => + { + Some(self.py_object.clone()) + } + _ => None, + } + } + + fn apply(&self, values: &[Value]) -> Option { + let code: Symbol = Symbol::load(self.string.as_ref(), &values[0]); + let locals: PyObject = Python::with_gil(|py| { + let (_, globals) = self.py_object.load(&values[1]); + let globals = globals.downcast::(py).unwrap(); + let (_, locals) = self.py_object.load(&values[2]); + let locals = locals.downcast::(py).unwrap().copy().unwrap(); + py.run(code.into(), Some(globals), Some(locals)).unwrap(); + locals.into() + }); + Some(self.py_object.store(locals)) + } +} + +/// (py-dict [ ]*) +struct Dict { + name: Symbol, + py_object: Arc, +} + +impl PrimitiveLike for Dict { + fn name(&self) -> Symbol { + self.name + } + + fn accept(&self, types: &[ArcSort]) -> Option { + // Should have an even number of args + if types.len() % 2 != 0 { + return None; + } + for tp in types.iter() { + // All tps should be object + if tp.name() != self.py_object.name() { + return None; + } + } + Some(self.py_object.clone()) + } + + fn apply(&self, values: &[Value]) -> Option { + let dict: PyObject = Python::with_gil(|py| { + let dict = PyDict::new(py); + // Update the dict with the key-value pairs + for i in values.chunks_exact(2) { + let key = self.py_object.load(&i[0]).1; + let value = self.py_object.load(&i[1]).1; + dict.set_item(key, value).unwrap(); + } + dict.into() + }); + Some(self.py_object.store(dict)) + } +} + /// Supports calling (py-dict-update [ ]*) struct DictUpdate { name: Symbol,