From 04e407ba53c671c99dc48b6cbad98820fb5e3125 Mon Sep 17 00:00:00 2001 From: DropD Date: Fri, 19 Apr 2024 14:59:06 +0200 Subject: [PATCH] stages hasher: avoid recursing into non-stage dataclasses --- src/gt4py/next/ffront/stages.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/src/gt4py/next/ffront/stages.py b/src/gt4py/next/ffront/stages.py index 30a38dc118..86279de137 100644 --- a/src/gt4py/next/ffront/stages.py +++ b/src/gt4py/next/ffront/stages.py @@ -111,17 +111,26 @@ def cache_key(obj: Any, algorithm: Optional[str | Hasher_T] = None) -> str: @functools.singledispatch def update_cache_key(obj: Any, hasher: Hasher_T) -> None: - if dataclasses.is_dataclass(obj): - update_cache_key(obj.__class__, hasher) - for field in dataclasses.fields(obj): - update_cache_key(getattr(obj, field.name), hasher) # the following is to avoid circular dependencies - elif hasattr(obj, "backend"): # assume it is a decorator wrapper + if hasattr(obj, "backend"): # assume it is a decorator wrapper update_cache_key_fielop(obj, hasher) else: hasher.update(str(obj).encode()) +@update_cache_key.register(FieldOperatorDefinition) +@update_cache_key.register(FoastOperatorDefinition) +@update_cache_key.register(FoastWithTypes) +@update_cache_key.register(FoastClosure) +@update_cache_key.register(ProgramDefinition) +@update_cache_key.register(PastProgramDefinition) +@update_cache_key.register(PastClosure) +def update_cache_key_stages(obj: Any, hasher: Hasher_T) -> None: + update_cache_key(obj.__class__, hasher) + for field in dataclasses.fields(obj): + update_cache_key(getattr(obj, field.name), hasher) + + @update_cache_key.register def update_cache_key_str(obj: str, hasher: Hasher_T) -> None: hasher.update(str(obj).encode())