diff --git a/docs/changelog.md b/docs/changelog.md index 2d1ccf1..8b938fd 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -6,6 +6,7 @@ _This project uses semantic versioning_ - Fix pretty printing of lambda functions - Add support for subsuming rewrite generated by default function and method definitions +- Add better error message when using @function in class (thanks @shinawy) ## 8.0.1 (2024-10-24) diff --git a/python/egglog/egraph.py b/python/egglog/egraph.py index 1cc4135..c237d36 100644 --- a/python/egglog/egraph.py +++ b/python/egglog/egraph.py @@ -577,23 +577,25 @@ def _generate_class_decls( # noqa: C901,PLR0912 decl = FunctionDecl(special_function_name, builtin=True, egg_name=egg_fn) decls.set_function_decl(ref, decl) continue - - _, add_rewrite = _fn_decl( - decls, - egg_fn, - ref, - fn, - locals, - default, - cost, - merge, - on_merge, - mutates, - builtin, - ruleset=ruleset, - unextractable=unextractable, - subsume=subsume, - ) + try: + _, add_rewrite = _fn_decl( + decls, + egg_fn, + ref, + fn, + locals, + default, + cost, + merge, + on_merge, + mutates, + builtin, + ruleset=ruleset, + unextractable=unextractable, + subsume=subsume, + ) + except ValueError as e: + raise ValueError(f"Error processing {cls_name}.{method_name}: {e}") from e if not builtin and not isinstance(ref, InitRef) and not mutates: add_default_funcs.append(add_rewrite) @@ -721,6 +723,9 @@ def _fn_decl( """ Sets the function decl for the function object and returns the ref as well as a thunk that sets the default callable. """ + if isinstance(fn, RuntimeFunction): + msg = "Inside of classes, wrap methods with the `method` decorator, not `function`" + raise ValueError(msg) # noqa: TRY004 if not isinstance(fn, FunctionType): raise NotImplementedError(f"Can only generate function decls for functions not {fn} {type(fn)}") diff --git a/python/tests/test_high_level.py b/python/tests/test_high_level.py index 5edda45..c22fb99 100644 --- a/python/tests/test_high_level.py +++ b/python/tests/test_high_level.py @@ -762,3 +762,15 @@ def test_inserting_map(self): def test_creating_map(self): EGraph().simplify(Map[String, i64].empty(), 1) + + +def test_helpful_error_function_class(): + class E(Expr): + @function(cost=10) + def __init__(self) -> None: ... + + with pytest.raises( + ValueError, + match="Error processing E.__init__: Inside of classes, wrap methods with the `method` decorator, not `function`", + ): + E()