Skip to content

Commit

Permalink
This should fix the map fusion fix.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed Feb 4, 2025
1 parent 059fae7 commit 053bc4e
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 93 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -73,11 +73,6 @@ class MapFusionHelper(transformation.SingleStateTransformation):
# `False` then the fusion will be rejected.
_apply_fusion_callback: Optional[FusionCallback]

# Maps SDFGs to the set of data that can not be removed,
# because they transmit data _between states_, such data will be made 'shared'.
# This variable acts as a cache, and is managed by 'is_shared_data()'.
_shared_data: Dict[SDFG, Set[str]]

def __init__(
self,
only_inner_maps: Optional[bool] = None,
Expand Down Expand Up @@ -380,107 +375,94 @@ def rename_map_parameters(
def is_shared_data(
self,
data: nodes.AccessNode,
state: dace.SDFGState,
sdfg: dace.SDFG,
) -> bool:
"""Tests if `data` is interstate data, an can not be removed.
"""Tests if `data` is shared data, i.e. it can not be removed from the SDFG.
Interstate data is used to transmit data between multiple state or by
extension within the state. Thus it must be classified as a shared output.
This function will go through the SDFG to and collect the names of all data
container that should be classified as shared. Note that this is an over
approximation as it does not take the location into account, i.e. "is no longer
used".
Depending on the situation, the function will not perform a scan of the whole SDFG:
1) If `data` is non transient then the function will return `True`, as non transient data
must be reconstructed always.
2) If the AccessNode `data` has more than one outgoing edge or more than one incoming edge
it is classified as shared.
3) If `FindSingleUseData` is in the pipeline it will be used and no scan will be performed.
4) The function will perform a scan.
Args:
transient: The transient that should be checked.
sdfg: The SDFG containing the array.
:param data: The transient that should be checked.
:param state: The state in which the fusion is performed.
:param sdfg: The SDFG in which we want to perform the fusing.
Note:
The function computes the this set once for every SDFG and then caches it.
There is no mechanism to detect if the cache must be evicted. However,
as long as no additional data is added, there is no problem.
"""
if sdfg not in self._shared_data:
self._compute_shared_data(sdfg)
return data.data in self._shared_data[sdfg]

def _compute_shared_data(
# If `data` is non transient then return `True` as the intermediate can not be removed.
if not data.desc(sdfg).transient:
return True

# This means the data is consumed by multiple Maps, through the same AccessNode, in this state
# Note currently multiple incoming edges are not handled, but in the spirit of this function
# we consider such AccessNodes as shared, because we can not remove the intermediate.
if state.out_degree(data) > 1:
return True
if state.in_degree(data) > 1:
return True

# We have to perform the full scan of the SDFG.
return self._scan_sdfg_if_data_is_shared(data=data, state=state, sdfg=sdfg)

def _scan_sdfg_if_data_is_shared(
self,
data: nodes.AccessNode,
state: dace.SDFGState,
sdfg: dace.SDFG,
) -> None:
"""Updates the internal set of shared data/interstate data of `self` for `sdfg`.
"""Scans `sdfg` to determine if `data` is shared.
See the documentation for `self.is_shared_data()` for a description.
Essentially, this function determine, if the intermediate AccessNode `data` is
can be removed or if it has to be restored as output of the Map.
A data descriptor is classified as shared if any of the following is true:
- `data` is non transient data.
- `data` has at most one incoming and/or outgoing edge.
- There are other AccessNodes beside `data` that refer to the same data.
- The data is accessed on an interstate edge.
Args:
sdfg: The SDFG for which the set of shared data should be computed.
This function should not be called directly. Instead it is called indirectly
by `is_shared_data()` if there is no short cut.
:param data: The AccessNode that should checked if it is shared.
:param sdfg: The SDFG for which the set of shared data should be computed.
"""
# Shared data of this SDFG.
shared_data: Set[str] = set()

# All global data can not be removed, so it must always be shared.
for data_name, data_desc in sdfg.arrays.items():
if not data_desc.transient:
shared_data.add(data_name)
elif isinstance(data_desc, dace.data.Scalar):
shared_data.add(data_name)

# We go through all states and classify the nodes/data:
# - Data is referred to in different states.
# - The access node is a view (both have to survive).
# - Transient sink or source node.
# - The access node has output degree larger than 1 (input degrees larger
# than one, will always be partitioned as shared anyway).
prevously_seen_data: Set[str] = set()
interstate_read_symbols: Set[str] = set()
for state in sdfg.nodes():
for access_node in state.data_nodes():
if access_node.data in shared_data:
# The data was already classified to be shared data
pass

elif access_node.data in prevously_seen_data:
# We have seen this data before, either in this state or in
# a previous one, but we did not classifies it as shared back then
shared_data.add(access_node.data)

if state.in_degree(access_node) == 0:
# (Transient) sink nodes are used in other states, or simplify
# will get rid of them.
shared_data.add(access_node.data)

elif (
state.out_degree(access_node) != 1
): # state.out_degree() == 0 or state.out_degree() > 1
# The access node is either a source node (it is shared in another
# state) or the node has a degree larger than one, so it is used
# in this state somewhere else.
shared_data.add(access_node.data)

elif self.is_view(node=access_node, sdfg=sdfg):
# To ensure that the write to the view happens, both have to be shared.
viewed_data: str = self.track_view(
view=access_node, state=state, sdfg=sdfg
).data
shared_data.update([access_node.data, viewed_data])
prevously_seen_data.update([access_node.data, viewed_data])

else:
# The node was not classified as shared data, so we record that
# we saw it. Note that a node that was immediately classified
# as shared node will never be added to this set, but a data
# that was found twice will be inside this list.
prevously_seen_data.add(access_node.data)

# Now we are collecting all symbols that interstate edges read from.
if not data.desc(sdfg).transient:
return True

# See description in `is_shared_data()` for more.
if state.out_degree(data) > 1:
return True
if state.in_degree(data) > 1:
return True

data_name: str = data.data
for state in sdfg.states():
for dnode in state.data_nodes():
if dnode is data:
# We have found the `data` AccessNode, which we must ignore.
continue
if dnode.data == data_name:
# We found a different AccessNode that refers to the same data
# as `data`. Thus `data` is shared.
return True

# Test if the data is referenced in the interstate edges.
for edge in sdfg.edges():
interstate_read_symbols.update(edge.data.read_symbols())
if data_name in edge.data.free_symbols:
# The data is used in the inter state edges. So it is shared.
return True

# We also have to keep everything the edges referrers to and is an array.
shared_data.update(interstate_read_symbols.intersection(prevously_seen_data))
# Test if they are accessed in a condition of a loop or conditional block.
for cfr in sdfg.all_control_flow_regions():
if data_name in cfr.used_symbols(all_symbols=True, with_contents=False):
return True

# Update the internal cache
self._shared_data[sdfg] = shared_data
# The `data` is not used anywhere else, thus `data` is not shared.
return False

def _compute_multi_write_data(
self,
Expand Down Expand Up @@ -522,7 +504,7 @@ def _compute_multi_write_data(

def is_node_reachable_from(
self,
graph: Union[dace.SDFG, dace.SDFGState],
graph: dace.SDFGState,
begin: nodes.Node,
end: nodes.Node,
) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,7 @@ def partition_first_outputs(
# node can be removed (`\mathbb{E}`) or has to be restored (`\mathbb{S}`).
# Note that "removed" here means that it is reconstructed by a new
# output of the second map.
if self.is_shared_data(intermediate_node, sdfg):
if self.is_shared_data(data=intermediate_node, state=state, sdfg=sdfg):
# The intermediate data is used somewhere else, either in this or another state.
shared_outputs.add(out_edge)
else:
Expand Down

0 comments on commit 053bc4e

Please sign in to comment.