Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix[next][dace]: remove temporary arrays with runtime shape on the output of a mapped nested SDFG #1877

Merged
merged 6 commits into from
Feb 20, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 30 additions & 9 deletions src/gt4py/next/program_processors/runners/dace/gtir_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,15 +232,36 @@ def connect(
dest: dace.nodes.AccessNode,
subset: dace_subsets.Range,
) -> None:
# retrieve the node which writes the result
last_node = self.state.in_edges(self.result.dc_node)[0].src
if isinstance(last_node, dace.nodes.Tasklet):
# the last transient node can be deleted
# Note that it could also be applied when `last_node` is a NestedSDFG,
# but an exception would be when the inner write to global data is a
# WCR memlet, because that prevents fusion of the outer map. This case
# happens for the reduce with skip values, which uses a map with WCR.
last_node_connector = self.state.in_edges(self.result.dc_node)[0].src_conn
write_edge = self.state.in_edges(self.result.dc_node)[0]
write_size = write_edge.data.dst_subset.num_elements()
# check the kind of node which writes the result
if isinstance(write_edge.src, dace.nodes.Tasklet):
# The temporary data written by a tasklet can be safely deleted
assert write_size.is_constant()
remove_last_node = True
elif isinstance(write_edge.src, dace.nodes.NestedSDFG):
if write_size.is_constant():
# Temporary data with compile-time size is allocated on the stack
# and therefore is safe to keep. We decide to keep it as a workaround
# for a dace issue with memlet propagation in combination with
# nested SDFGs containing conditional blocks. The output memlet
# of such blocks will be marked as dynamic because dace is not able
# to detect the exact size of a conditional branch dataflow, even
# in case of if-else expressions with exact same output data.
remove_last_node = False
else:
# In case the output data has runtime size it is necessary to remove
# it in order to avoid dynamic memory allocation inside a parallel
# map scope. Otherwise, the memory allocation will for sure lead
# to performance degradation, and eventually illegal memory issues
# when the gpu runs out of local memory.
remove_last_node = True
else:
remove_last_node = False

if remove_last_node:
last_node = write_edge.src
last_node_connector = write_edge.src_conn
self.state.remove_node(self.result.dc_node)
else:
last_node = self.result.dc_node
Expand Down