Skip to content

Commit

Permalink
stages hasher: avoid recursing into non-stage dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
DropD committed Apr 19, 2024
1 parent 63121fd commit 04e407b
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions src/gt4py/next/ffront/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,17 +111,26 @@ def cache_key(obj: Any, algorithm: Optional[str | Hasher_T] = None) -> str:

@functools.singledispatch
def update_cache_key(obj: Any, hasher: Hasher_T) -> None:
if dataclasses.is_dataclass(obj):
update_cache_key(obj.__class__, hasher)
for field in dataclasses.fields(obj):
update_cache_key(getattr(obj, field.name), hasher)
# the following is to avoid circular dependencies
elif hasattr(obj, "backend"): # assume it is a decorator wrapper
if hasattr(obj, "backend"): # assume it is a decorator wrapper
update_cache_key_fielop(obj, hasher)
else:
hasher.update(str(obj).encode())


@update_cache_key.register(FieldOperatorDefinition)
@update_cache_key.register(FoastOperatorDefinition)
@update_cache_key.register(FoastWithTypes)
@update_cache_key.register(FoastClosure)
@update_cache_key.register(ProgramDefinition)
@update_cache_key.register(PastProgramDefinition)
@update_cache_key.register(PastClosure)
def update_cache_key_stages(obj: Any, hasher: Hasher_T) -> None:
update_cache_key(obj.__class__, hasher)
for field in dataclasses.fields(obj):
update_cache_key(getattr(obj, field.name), hasher)


@update_cache_key.register
def update_cache_key_str(obj: str, hasher: Hasher_T) -> None:
hasher.update(str(obj).encode())
Expand Down

0 comments on commit 04e407b

Please sign in to comment.