Skip to content

Commit

Permalink
fix[next][dace]: remove unused connectivities (#1797)
Browse files Browse the repository at this point in the history
By design, the arrays for connectivity tables are initially created as
transient and marked as non-transient during the lowering when they are
used. At the end of lowering to SDFG, the unused connectivities (still
transient arrays) should be removed.
  • Loading branch information
edopao authored Jan 15, 2025
1 parent 21e7f64 commit 33bb68b
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
16 changes: 15 additions & 1 deletion src/gt4py/next/program_processors/runners/dace_common/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,11 @@
from gt4py.next.type_system import type_specifications as ts


# arrays for connectivity tables use the following prefix
CONNECTIVITY_INDENTIFIER_PREFIX: Final[str] = "connectivity_"
CONNECTIVITY_INDENTIFIER_RE: Final[re.Pattern] = re.compile(r"^connectivity_(.+)$")


# regex to match the symbols for field shape and strides
FIELD_SYMBOL_RE: Final[re.Pattern] = re.compile(r"__.+_(size|stride)_\d+")

Expand Down Expand Up @@ -48,7 +53,16 @@ def as_itir_type(dtype: dace.typeclass) -> ts.ScalarType:


def connectivity_identifier(name: str) -> str:
return f"connectivity_{name}"
return f"{CONNECTIVITY_INDENTIFIER_PREFIX}{name}"


def is_connectivity_identifier(
name: str, offset_provider_type: gtx_common.OffsetProviderType
) -> bool:
m = CONNECTIVITY_INDENTIFIER_RE.match(name)
if m is None:
return False
return m[1] in offset_provider_type


def field_symbol_name(field_name: str, axis: int, sym: Literal["size", "stride"]) -> str:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -401,7 +401,7 @@ def _construct_local_view(self, field: MemletExpr | ValueExpr) -> ValueExpr:
view_shape = tuple(desc.shape[i] for i in local_dim_indices)
view_strides = tuple(desc.strides[i] for i in local_dim_indices)
view, _ = self.sdfg.add_view(
f"{field.dc_node.data}_view",
f"view_{field.dc_node.data}",
view_shape,
desc.dtype,
strides=view_strides,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,18 @@ def visit_Program(self, node: gtir.Program) -> dace.SDFG:
head_state._debuginfo = dace_utils.debug_info(stmt, default=sdfg.debuginfo)
head_state = self.visit(stmt, sdfg=sdfg, state=head_state)

# remove unused connectivity tables (by design, arrays are marked as non-transient when they are used)
for nsdfg in sdfg.all_sdfgs_recursive():
unused_connectivities = [
data
for data, datadesc in nsdfg.arrays.items()
if dace_utils.is_connectivity_identifier(data, self.offset_provider_type)
and datadesc.transient
]
for data in unused_connectivities:
assert isinstance(nsdfg.arrays[data], dace.data.Array)
nsdfg.arrays.pop(data)

# Create the call signature for the SDFG.
# Only the arguments required by the GT4Py program, i.e. `node.params`, are added
# as positional arguments. The implicit arguments, such as the offset providers or
Expand Down

0 comments on commit 33bb68b

Please sign in to comment.