Skip to content

Commit

Permalink
Minor edit
Browse files Browse the repository at this point in the history
  • Loading branch information
edopao committed Apr 30, 2024
1 parent a10b614 commit aef4265
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -33,17 +33,9 @@
from gt4py.next.program_processors.runners.dace_fieldview.gtir_tasklet_codegen import (
GtirTaskletCodegen,
)
from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type
from gt4py.next.type_system import type_specifications as ts


def unique_name(prefix: str) -> str:
unique_id = getattr(unique_name, "_unique_id", 0) # static variable
setattr(unique_name, "_unique_id", unique_id + 1) # noqa: B010 [set-attr-with-constant]

return f"{prefix}_{unique_id}"


class GtirDataflowBuilder(eve.NodeVisitor):
"""Translates a GTIR `ir.Stmt` node to a dataflow graph."""

Expand All @@ -58,24 +50,6 @@ def __init__(
self._sdfg = sdfg
self._data_types = data_types

def _add_local_storage(
self, type_: ts.DataType, shape: list[str]
) -> tuple[str, dace.data.Data]:
name = unique_name("var")
if isinstance(type_, ts.FieldType):
dtype = as_dace_type(type_.dtype)
assert len(type_.dims) == len(shape)
# TODO: for now we let DaCe decide the array strides, evaluate if symblic strides should be used
name, data = self._sdfg.add_array(
name, shape, dtype, find_new_name=True, transient=True
)
else:
assert isinstance(type_, ts.ScalarType)
assert len(shape) == 0
dtype = as_dace_type(type_)
name, data = self._sdfg.add_scalar(name, dtype, find_new_name=True, transient=True)
return name, data

def visit_domain(self, node: itir.Expr) -> Sequence[tuple[Dimension, str, str]]:
domain = []
assert cpm.is_call_to(node, ("cartesian_domain", "unstructured_domain"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from gt4py.eve import codegen
from gt4py.next.iterator import ir as itir
from gt4py.next.iterator.ir_utils import common_pattern_matcher as cpm
from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type
from gt4py.next.program_processors.runners.dace_fieldview.utility import as_dace_type, unique_name
from gt4py.next.type_system import type_specifications as ts


Expand Down Expand Up @@ -112,14 +112,15 @@ def __call__(
return self._build(), self._state

@final
def _add_local_storage(self, data_type: ts.DataType, shape: list[str]) -> dace.nodes.AccessNode:
name = f"{self._state.label}_var"
def _add_local_storage(
self, data_type: ts.FieldType | ts.ScalarType, shape: list[str]
) -> dace.nodes.AccessNode:
name = unique_name("var")
if isinstance(data_type, ts.FieldType):
assert len(data_type.dims) == len(shape)
dtype = as_dace_type(data_type.dtype)
name, _ = self._sdfg.add_array(name, shape, dtype, find_new_name=True, transient=True)
else:
assert isinstance(data_type, ts.ScalarType)
assert len(shape) == 0
dtype = as_dace_type(data_type)
name, _ = self._sdfg.add_scalar(name, dtype, find_new_name=True, transient=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,10 @@ def filter_connectivities(offset_provider: Mapping[str, Any]) -> dict[str, Conne
for offset, table in offset_provider.items()
if isinstance(table, Connectivity)
}


def unique_name(prefix: str) -> str:
unique_id = getattr(unique_name, "_unique_id", 0) # static variable
setattr(unique_name, "_unique_id", unique_id + 1) # noqa: B010 [set-attr-with-constant]

return f"{prefix}_{unique_id}"

0 comments on commit aef4265

Please sign in to comment.