From 053bc4e91a31bcd4ff2637eed23fdfa70afe032e Mon Sep 17 00:00:00 2001 From: Philip Mueller Date: Tue, 4 Feb 2025 09:52:30 +0100 Subject: [PATCH] This should fix the map fusion fix. --- .../dace/transformations/map_fusion_helper.py | 166 ++++++++---------- .../dace/transformations/map_fusion_serial.py | 2 +- 2 files changed, 75 insertions(+), 93 deletions(-) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_helper.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_helper.py index eceb07ed82..f0460d2315 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_helper.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_helper.py @@ -73,11 +73,6 @@ class MapFusionHelper(transformation.SingleStateTransformation): # `False` then the fusion will be rejected. _apply_fusion_callback: Optional[FusionCallback] - # Maps SDFGs to the set of data that can not be removed, - # because they transmit data _between states_, such data will be made 'shared'. - # This variable acts as a cache, and is managed by 'is_shared_data()'. - _shared_data: Dict[SDFG, Set[str]] - def __init__( self, only_inner_maps: Optional[bool] = None, @@ -380,107 +375,94 @@ def rename_map_parameters( def is_shared_data( self, data: nodes.AccessNode, + state: dace.SDFGState, sdfg: dace.SDFG, ) -> bool: - """Tests if `data` is interstate data, an can not be removed. + """Tests if `data` is shared data, i.e. it can not be removed from the SDFG. - Interstate data is used to transmit data between multiple state or by - extension within the state. Thus it must be classified as a shared output. - This function will go through the SDFG to and collect the names of all data - container that should be classified as shared. Note that this is an over - approximation as it does not take the location into account, i.e. "is no longer - used". + Depending on the situation, the function will not perform a scan of the whole SDFG: + 1) If `data` is non transient then the function will return `True`, as non transient data + must be reconstructed always. + 2) If the AccessNode `data` has more than one outgoing edge or more than one incoming edge + it is classified as shared. + 3) If `FindSingleUseData` is in the pipeline it will be used and no scan will be performed. + 4) The function will perform a scan. - Args: - transient: The transient that should be checked. - sdfg: The SDFG containing the array. + :param data: The transient that should be checked. + :param state: The state in which the fusion is performed. + :param sdfg: The SDFG in which we want to perform the fusing. - Note: - The function computes the this set once for every SDFG and then caches it. - There is no mechanism to detect if the cache must be evicted. However, - as long as no additional data is added, there is no problem. """ - if sdfg not in self._shared_data: - self._compute_shared_data(sdfg) - return data.data in self._shared_data[sdfg] - - def _compute_shared_data( + # If `data` is non transient then return `True` as the intermediate can not be removed. + if not data.desc(sdfg).transient: + return True + + # This means the data is consumed by multiple Maps, through the same AccessNode, in this state + # Note currently multiple incoming edges are not handled, but in the spirit of this function + # we consider such AccessNodes as shared, because we can not remove the intermediate. + if state.out_degree(data) > 1: + return True + if state.in_degree(data) > 1: + return True + + # We have to perform the full scan of the SDFG. + return self._scan_sdfg_if_data_is_shared(data=data, state=state, sdfg=sdfg) + + def _scan_sdfg_if_data_is_shared( self, + data: nodes.AccessNode, + state: dace.SDFGState, sdfg: dace.SDFG, ) -> None: - """Updates the internal set of shared data/interstate data of `self` for `sdfg`. + """Scans `sdfg` to determine if `data` is shared. - See the documentation for `self.is_shared_data()` for a description. + Essentially, this function determine, if the intermediate AccessNode `data` is + can be removed or if it has to be restored as output of the Map. + A data descriptor is classified as shared if any of the following is true: + - `data` is non transient data. + - `data` has at most one incoming and/or outgoing edge. + - There are other AccessNodes beside `data` that refer to the same data. + - The data is accessed on an interstate edge. - Args: - sdfg: The SDFG for which the set of shared data should be computed. + This function should not be called directly. Instead it is called indirectly + by `is_shared_data()` if there is no short cut. + + :param data: The AccessNode that should checked if it is shared. + :param sdfg: The SDFG for which the set of shared data should be computed. """ - # Shared data of this SDFG. - shared_data: Set[str] = set() - - # All global data can not be removed, so it must always be shared. - for data_name, data_desc in sdfg.arrays.items(): - if not data_desc.transient: - shared_data.add(data_name) - elif isinstance(data_desc, dace.data.Scalar): - shared_data.add(data_name) - - # We go through all states and classify the nodes/data: - # - Data is referred to in different states. - # - The access node is a view (both have to survive). - # - Transient sink or source node. - # - The access node has output degree larger than 1 (input degrees larger - # than one, will always be partitioned as shared anyway). - prevously_seen_data: Set[str] = set() - interstate_read_symbols: Set[str] = set() - for state in sdfg.nodes(): - for access_node in state.data_nodes(): - if access_node.data in shared_data: - # The data was already classified to be shared data - pass - - elif access_node.data in prevously_seen_data: - # We have seen this data before, either in this state or in - # a previous one, but we did not classifies it as shared back then - shared_data.add(access_node.data) - - if state.in_degree(access_node) == 0: - # (Transient) sink nodes are used in other states, or simplify - # will get rid of them. - shared_data.add(access_node.data) - - elif ( - state.out_degree(access_node) != 1 - ): # state.out_degree() == 0 or state.out_degree() > 1 - # The access node is either a source node (it is shared in another - # state) or the node has a degree larger than one, so it is used - # in this state somewhere else. - shared_data.add(access_node.data) - - elif self.is_view(node=access_node, sdfg=sdfg): - # To ensure that the write to the view happens, both have to be shared. - viewed_data: str = self.track_view( - view=access_node, state=state, sdfg=sdfg - ).data - shared_data.update([access_node.data, viewed_data]) - prevously_seen_data.update([access_node.data, viewed_data]) - - else: - # The node was not classified as shared data, so we record that - # we saw it. Note that a node that was immediately classified - # as shared node will never be added to this set, but a data - # that was found twice will be inside this list. - prevously_seen_data.add(access_node.data) - - # Now we are collecting all symbols that interstate edges read from. + if not data.desc(sdfg).transient: + return True + + # See description in `is_shared_data()` for more. + if state.out_degree(data) > 1: + return True + if state.in_degree(data) > 1: + return True + + data_name: str = data.data + for state in sdfg.states(): + for dnode in state.data_nodes(): + if dnode is data: + # We have found the `data` AccessNode, which we must ignore. + continue + if dnode.data == data_name: + # We found a different AccessNode that refers to the same data + # as `data`. Thus `data` is shared. + return True + + # Test if the data is referenced in the interstate edges. for edge in sdfg.edges(): - interstate_read_symbols.update(edge.data.read_symbols()) + if data_name in edge.data.free_symbols: + # The data is used in the inter state edges. So it is shared. + return True - # We also have to keep everything the edges referrers to and is an array. - shared_data.update(interstate_read_symbols.intersection(prevously_seen_data)) + # Test if they are accessed in a condition of a loop or conditional block. + for cfr in sdfg.all_control_flow_regions(): + if data_name in cfr.used_symbols(all_symbols=True, with_contents=False): + return True - # Update the internal cache - self._shared_data[sdfg] = shared_data + # The `data` is not used anywhere else, thus `data` is not shared. + return False def _compute_multi_write_data( self, @@ -522,7 +504,7 @@ def _compute_multi_write_data( def is_node_reachable_from( self, - graph: Union[dace.SDFG, dace.SDFGState], + graph: dace.SDFGState, begin: nodes.Node, end: nodes.Node, ) -> bool: diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py index 2cdcc455d4..445788b73a 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py @@ -621,7 +621,7 @@ def partition_first_outputs( # node can be removed (`\mathbb{E}`) or has to be restored (`\mathbb{S}`). # Note that "removed" here means that it is reconstructed by a new # output of the second map. - if self.is_shared_data(intermediate_node, sdfg): + if self.is_shared_data(data=intermediate_node, state=state, sdfg=sdfg): # The intermediate data is used somewhere else, either in this or another state. shared_outputs.add(out_edge) else: