Skip to content

Commit

Permalink
Move device storage into OirSDFGBuilder
Browse files Browse the repository at this point in the history
For GPU targets, we have to configure the `storage_type` for transient
arrays. In addition, we have to set the library node's `device` property.
We can do both while building the SDFG instead of separate passes afterwards.
  • Loading branch information
romanc committed Feb 12, 2025
1 parent c0695b4 commit 4ee0bd2
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 17 deletions.
15 changes: 3 additions & 12 deletions src/gt4py/cartesian/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,16 +132,6 @@ def _set_tile_sizes(sdfg: dace.SDFG):
node.tile_sizes_interpretation = "strides"


def _to_device(sdfg: dace.SDFG, device: str) -> None:
"""Update sdfg in place."""
if device == "gpu":
for array in sdfg.arrays.values():
array.storage = dace.StorageType.GPU_Global
for node, _ in sdfg.all_nodes_recursive():
if isinstance(node, StencilComputation):
node.device = dace.DeviceType.GPU


def _pre_expand_transformations(gtir_pipeline: GtirPipeline, sdfg: dace.SDFG, layout_map):
args_data = make_args_data_from_gtir(gtir_pipeline)

Expand Down Expand Up @@ -347,9 +337,10 @@ def _unexpanded_sdfg(self):
"oir_pipeline", DefaultPipeline()
)
oir_node = oir_pipeline.run(base_oir)
sdfg = OirSDFGBuilder().visit(oir_node)
sdfg = OirSDFGBuilder().visit(
oir_node, device=self.builder.backend.storage_info["device"]
)

_to_device(sdfg, self.builder.backend.storage_info["device"])
_pre_expand_transformations(
self.builder.gtir_pipeline,
sdfg,
Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/cartesian/gtc/dace/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
extents: Optional[Dict[int, Extent]] = None,
declarations: Optional[Dict[str, Decl]] = None,
expansion_order=None,
device: Optional[dace.DeviceType] = None,
*args,
**kwargs,
):
Expand All @@ -137,6 +138,7 @@ def __init__(
self.oir_node = typing.cast(PickledDataclassProperty, oir_node)
self.extents = extents_dict # type: ignore
self.declarations = declarations # type: ignore
self.device = device
self.symbol_mapping = {
decl.name: dace.symbol(
decl.name,
Expand Down
30 changes: 25 additions & 5 deletions src/gt4py/cartesian/gtc/dace/oir_to_dace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import Dict
from typing import Dict, Literal

import dace
import dace.properties
Expand All @@ -33,6 +33,17 @@
)


transient_storage_per_device: Dict[Literal["cpu", "gpu"], dace.StorageType] = {
"cpu": dace.StorageType.Default,
"gpu": dace.StorageType.GPU_Global,
}

device_type_per_device: Dict[Literal["cpu", "gpu"], dace.DeviceType] = {
"cpu": dace.DeviceType.CPU,
"gpu": dace.DeviceType.GPU,
}


class OirSDFGBuilder(eve.NodeVisitor):
@dataclass
class SDFGContext:
Expand Down Expand Up @@ -98,7 +109,13 @@ def _make_dace_subset(self, local_access_info, field):
global_access_info, local_access_info, self.decls[field].data_dims
)

def visit_VerticalLoop(self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFGContext):
def visit_VerticalLoop(
self,
node: oir.VerticalLoop,
*,
ctx: OirSDFGBuilder.SDFGContext,
device: Literal["cpu", "gpu"],
) -> None:
declarations = {
acc.name: ctx.decls[acc.name]
for acc in node.walk_values().if_isinstance(oir.FieldAccess, oir.ScalarAccess)
Expand All @@ -109,6 +126,7 @@ def visit_VerticalLoop(self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFG
extents=ctx.block_extents,
declarations=declarations,
oir_node=node,
device=device_type_per_device[device],
)

state = ctx.sdfg.add_state()
Expand Down Expand Up @@ -137,8 +155,8 @@ def visit_VerticalLoop(self, node: oir.VerticalLoop, *, ctx: OirSDFGBuilder.SDFG
library_node, connector_name, access_node, None, dace.Memlet(field, subset=subset)
)

def visit_Stencil(self, node: oir.Stencil):
ctx = OirSDFGBuilder.SDFGContext(stencil=node)
def visit_Stencil(self, node: oir.Stencil, *, device: Literal["cpu", "gpu"]) -> dace.SDFG:
ctx = OirSDFGBuilder.SDFGContext(node)
for param in node.params:
if isinstance(param, oir.FieldDecl):
dim_strs = [d for i, d in enumerate("IJK") if param.dimensions[i]] + [
Expand All @@ -153,6 +171,7 @@ def visit_Stencil(self, node: oir.Stencil):
],
dtype=data_type_to_dace_typeclass(param.dtype),
transient=False,
storage=transient_storage_per_device[device],
debuginfo=get_dace_debuginfo(param),
)
else:
Expand All @@ -172,8 +191,9 @@ def visit_Stencil(self, node: oir.Stencil):
dtype=data_type_to_dace_typeclass(decl.dtype),
transient=True,
lifetime=dace.AllocationLifetime.Persistent,
storage=transient_storage_per_device[device],
debuginfo=get_dace_debuginfo(decl),
)
self.generic_visit(node, ctx=ctx)
self.generic_visit(node, ctx=ctx, device=device)
ctx.sdfg.validate()
return ctx.sdfg

0 comments on commit 4ee0bd2

Please sign in to comment.