Skip to content

Commit

Permalink
Allow K offset writes in dace backends
Browse files Browse the repository at this point in the history
Rename connector for better tracking in SDFG
var_offset_fields -> offset_in_K_fields
  • Loading branch information
FlorianDeconinck committed Apr 30, 2024
1 parent b5f112a commit 6916121
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 29 deletions.
60 changes: 39 additions & 21 deletions src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,23 +326,34 @@ def visit_FieldAccess(
*,
is_target: bool,
targets: Set[eve.SymbolRef],
var_offset_fields: Set[eve.SymbolRef],
offset_in_K_fields: Set[eve.SymbolRef],
**kwargs: Any,
) -> Union[dcir.IndexAccess, dcir.ScalarAccess]:
"""Generate the relevant accessor to match the memlet that was previously setup
This function was written when offset writing was forbidden. It has been refactor
to allow offset write in K with minimum changes.
"""

res: Union[dcir.IndexAccess, dcir.ScalarAccess]
if node.name in var_offset_fields:
res = dcir.IndexAccess(
name=node.name + "__",
offset=self.visit(
node.offset,
is_target=False,
targets=targets,
var_offset_fields=var_offset_fields,
**kwargs,
),
data_index=node.data_index,
dtype=node.dtype,
)
if node.name in offset_in_K_fields:
is_target = is_target or node.name in targets
name = get_tasklet_symbol(node.name, node.offset, is_target=is_target)
if is_target:
res = dcir.IndexAccess(
name=name,
offset=self.visit(
node.offset,
is_target=is_target,
targets=targets,
offset_in_K_fields=offset_in_K_fields,
**kwargs,
),
data_index=node.data_index,
dtype=node.dtype,
)
else:
res = dcir.ScalarAccess(name=name, dtype=node.dtype)
else:
is_target = is_target or (
node.name in targets and node.offset == common.CartesianOffset.zero()
Expand All @@ -354,11 +365,7 @@ def visit_FieldAccess(
)
else:
res = dcir.ScalarAccess(name=name, dtype=node.dtype)
# Because we allow writing in K, we allow targets who have been
# writing in K to be omitted so the original connector can be re-used.
# Previous guardrails in `gtscript` restrict the writes
# to non-parallel K-loop so we don't have issues.
if is_target and node.offset.to_dict()["k"] == 0:
if is_target:
targets.add(node.name)
return res

Expand Down Expand Up @@ -826,11 +833,22 @@ def visit_VerticalLoop(
)
)

var_offset_fields = {
# Offsets in K can be both:
# - read indexed via an expression,
# - write offset in K (both in expression and on scalar)
# We get all indexed via expression
offset_in_K_fields = {
acc.name
for acc in node.walk_values().if_isinstance(oir.FieldAccess)
if isinstance(acc.offset, oir.VariableKOffset)
}
# We add write offset to K
for assign_node in node.walk_values().if_isinstance(oir.AssignStmt):
if isinstance(assign_node.left, oir.FieldAccess):
acc = assign_node.left
if isinstance(acc.offset, common.CartesianOffset) and acc.offset.k != 0:
offset_in_K_fields.add(acc.name)

sections_idx = next(
idx
for idx, item in enumerate(global_ctx.library_node.expansion_specification)
Expand All @@ -847,7 +865,7 @@ def visit_VerticalLoop(
global_ctx=global_ctx,
iteration_ctx=iteration_ctx,
symbol_collector=symbol_collector,
var_offset_fields=var_offset_fields,
offset_in_K_fields=offset_in_K_fields,
**kwargs,
)
)
Expand Down
10 changes: 7 additions & 3 deletions src/gt4py/cartesian/gtc/dace/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,9 @@ def replace_strides(arrays, get_layout_map):

def get_tasklet_symbol(name, offset, is_target):
if is_target:
return f"__{name}"
return f"gtOUT__{name}"

acc_name = name + "__"
acc_name = f"gtIN__{name}"
if offset is not None:
offset_strs = []
for axis in dcir.Axis.dims_3d():
Expand Down Expand Up @@ -231,9 +231,12 @@ def _make_access_info(
region,
he_grid,
grid_subset,
is_write,
) -> "dcir.FieldAccessInfo":
# Check we have expression offsets in K
# OR write offsets in K
offset = [offset_node.to_dict()[k] for k in "ijk"]
if isinstance(offset_node, oir.VariableKOffset):
if isinstance(offset_node, oir.VariableKOffset) or (offset[2] != 0 and is_write):
variable_offset_axes = [dcir.Axis.K]
else:
variable_offset_axes = []
Expand Down Expand Up @@ -292,6 +295,7 @@ def visit_FieldAccess(
region=region,
he_grid=he_grid,
grid_subset=grid_subset,
is_write=is_write,
)
ctx.access_infos[node.name] = access_info.union(
ctx.access_infos.get(node.name, access_info)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -673,11 +673,8 @@ def backward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64):
@pytest.mark.parametrize("backend", ALL_BACKENDS)
def test_K_offset_write_conditional(backend):
# While loop have a bug in `dace:X` backends where
# the read-connector is used which means you can never
# update the field in a while.
# Logged in: https://github.com/GridTools/gt4py/issues/1496
if backend.startswith("dace") or backend == "cuda":
pytest.skip("DaCe backends have a bug when handling while loop.")
if backend == "cuda":
pytest.skip("Cuda backend is not capable of K offset write")

arraylib = _get_array_library(backend)
array_shape = (1, 1, 4)
Expand Down

0 comments on commit 6916121

Please sign in to comment.