Skip to content

Commit

Permalink
Added a unit test for the error.
Browse files Browse the repository at this point in the history
  • Loading branch information
philip-paul-mueller committed Feb 13, 2025
1 parent d916ae5 commit 156f546
Showing 1 changed file with 68 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -852,3 +852,71 @@ def test_loop_blocking_no_independent_nodes():
validate_all=True,
)
assert count == 1


import dace


def _make_only_last_two_elements_sdfg() -> dace.SDFG:
sdfg = dace.SDFG(util.unique_name("simple_block_sdfg"))
state = sdfg.add_state("state", is_start_block=True)
sdfg.add_symbol("N", dace.int32)
sdfg.add_symbol("B", dace.int32)
sdfg.add_symbol("M", dace.int32)

for name in "acb":
sdfg.add_array(
name,
shape=(20, 10),
dtype=dace.float64,
)

state.add_mapped_tasklet(
"computation",
map_ranges={"i": "B:N", "k": "(M-2):M"},
inputs={
"__in1": dace.Memlet("a[i, k]"),
"__in2": dace.Memlet("b[i, k]"),
},
code="__out = __in1 + __in2",
outputs={"__out": dace.Memlet("c[i, k]")},
external_edges=True,
)
sdfg.validate()

return sdfg


def test_only_last_two_elements_sdfg():
sdfg = _make_only_last_two_elements_sdfg()

def ref_comp(a, b, c, B, N, M):
for i in range(B, N):
for k in range(M - 2, M):
c[i, k] = a[i, k] + b[i, k]

count = sdfg.apply_transformations_repeated(
gtx_transformations.LoopBlocking(
blocking_size=1,
blocking_parameter="k",
require_independent_nodes=False,
),
validate=True,
validate_all=True,
)
assert count == 1

ref = {
"a": np.array(np.random.rand(20, 10), dtype=np.float64),
"b": np.array(np.random.rand(20, 10), dtype=np.float64),
"c": np.zeros((20, 10), dtype=np.float64),
"B": 0,
"N": 20,
"M": 6,
}
res = copy.deepcopy(ref)

ref_comp(**ref)
sdfg(**res)

assert np.allclose(ref["c"], res["c"])

0 comments on commit 156f546

Please sign in to comment.