diff --git a/src/gt4py/cartesian/frontend/base.py b/src/gt4py/cartesian/frontend/base.py index 3ba54f3356..5e542cd36d 100644 --- a/src/gt4py/cartesian/frontend/base.py +++ b/src/gt4py/cartesian/frontend/base.py @@ -74,6 +74,7 @@ def generate( externals: Dict[str, Any], dtypes: Dict[Type, Type], options: BuildOptions, + backend_name: str, ) -> gtir.Stencil: """ Generate a StencilDefinition from a stencil Python function. diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index d21aba674c..962d175eb1 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -708,6 +708,7 @@ def __init__( fields: dict, parameters: dict, local_symbols: dict, + backend_name: str, *, domain: nodes.Domain, temp_decls: Optional[Dict[str, nodes.FieldDecl]] = None, @@ -721,6 +722,7 @@ def __init__( isinstance(value, (type, np.dtype)) for value in local_symbols.values() ) + self.backend_name = backend_name self.fields = fields self.parameters = parameters self.local_symbols = local_symbols @@ -1432,11 +1434,26 @@ def visit_Assign(self, node: ast.Assign) -> list: for t in node.targets[0].elts if isinstance(node.targets[0], ast.Tuple) else node.targets: name, spatial_offset, data_index = self._parse_assign_target(t) if spatial_offset: - if any(offset != 0 for offset in spatial_offset): + if spatial_offset[0] != 0 or spatial_offset[1] != 0: raise GTScriptSyntaxError( - message="Assignment to non-zero offsets is not supported.", + message="Assignment to non-zero offsets is not supported in IJ.", loc=nodes.Location.from_ast_node(t), ) + # Case of K-offset + if len(spatial_offset) == 3 and spatial_offset[2] != 0: + if self.iteration_order == nodes.IterationOrder.PARALLEL: + raise GTScriptSyntaxError( + message="Assignment to non-zero offsets in K is not available in PARALLEL. Choose FORWARD or BACKWARD.", + loc=nodes.Location.from_ast_node(t), + ) + if self.backend_name in ["gt:gpu", "dace:gpu"]: + import cupy as cp + + if cp.cuda.runtime.runtimeGetVersion() < 12000: + raise GTScriptSyntaxError( + message=f"Assignment to non-zero offsets in K is not available in {self.backend_name} for CUDA<12. Please update CUDA.", + loc=nodes.Location.from_ast_node(t), + ) if not self._is_known(name): if name in self.temp_decls: @@ -1997,7 +2014,7 @@ def extract_arg_descriptors(self): return api_signature, fields_decls, parameter_decls - def run(self): + def run(self, backend_name: str): assert ( isinstance(self.ast_root, ast.Module) and "body" in self.ast_root._fields @@ -2055,6 +2072,7 @@ def run(self): fields=fields_decls, parameters=parameter_decls, local_symbols={}, # Not used + backend_name=backend_name, domain=domain, temp_decls=temp_decls, dtypes=self.dtypes, @@ -2110,14 +2128,14 @@ def prepare_stencil_definition(cls, definition, externals): return GTScriptParser.annotate_definition(definition, externals) @classmethod - def generate(cls, definition, externals, dtypes, options): + def generate(cls, definition, externals, dtypes, options, backend_name): if options.build_info is not None: start_time = time.perf_counter() if not hasattr(definition, "_gtscript_"): cls.prepare_stencil_definition(definition, externals) translator = GTScriptParser(definition, externals=externals, dtypes=dtypes, options=options) - definition_ir = translator.run() + definition_ir = translator.run(backend_name) # GTIR only supports LatLonGrids if definition_ir.domain != nodes.Domain.LatLonGrid(): diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index d5b1c91466..a8a3a3cb54 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -321,17 +321,29 @@ def visit_FieldAccess( is_target: bool, targets: Set[eve.SymbolRef], var_offset_fields: Set[eve.SymbolRef], + K_write_with_offset: Set[eve.SymbolRef], **kwargs: Any, ) -> Union[dcir.IndexAccess, dcir.ScalarAccess]: + """Generate the relevant accessor to match the memlet that was previously setup. + + When a Field is written in K, we force the usage of the OUT memlet throughout the stencil + to make sure all side effects are being properly resolved. Frontend checks ensure that no + parallel code issues sips here. + """ + res: Union[dcir.IndexAccess, dcir.ScalarAccess] - if node.name in var_offset_fields: + if node.name in var_offset_fields.union(K_write_with_offset): + # If write in K, we consider the variable to always be a target + is_target = is_target or node.name in targets or node.name in K_write_with_offset + name = get_tasklet_symbol(node.name, node.offset, is_target=is_target) res = dcir.IndexAccess( - name=node.name + "__", + name=name, offset=self.visit( node.offset, - is_target=False, + is_target=is_target, targets=targets, var_offset_fields=var_offset_fields, + K_write_with_offset=K_write_with_offset, **kwargs, ), data_index=node.data_index, @@ -799,11 +811,23 @@ def visit_VerticalLoop( ) ) + # Variable offsets var_offset_fields = { acc.name for acc in node.walk_values().if_isinstance(oir.FieldAccess) if isinstance(acc.offset, oir.VariableKOffset) } + + # We book keep - all write offset to K + K_write_with_offset = set() + for assign_node in node.walk_values().if_isinstance(oir.AssignStmt): + if isinstance(assign_node.left, oir.FieldAccess): + if ( + isinstance(assign_node.left.offset, common.CartesianOffset) + and assign_node.left.offset.k != 0 + ): + K_write_with_offset.add(assign_node.left.name) + sections_idx = next( idx for idx, item in enumerate(global_ctx.library_node.expansion_specification) @@ -821,6 +845,7 @@ def visit_VerticalLoop( iteration_ctx=iteration_ctx, symbol_collector=symbol_collector, var_offset_fields=var_offset_fields, + K_write_with_offset=K_write_with_offset, **kwargs, ) ) diff --git a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py index c219667a4a..696dc27387 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/tasklet_codegen.py @@ -81,7 +81,12 @@ def visit_IndexAccess( # if this node is not a target, it will still use the symbol of the write memlet if the # field was previously written in the same memlet. memlets = kwargs["read_memlets"] + kwargs["write_memlets"] - memlet = next(mem for mem in memlets if mem.connector == node.name) + try: + memlet = next(mem for mem in memlets if mem.connector == node.name) + except StopIteration: + raise ValueError( + "Memlet connector and tasklet variable mismatch, DaCe IR error." + ) from None index_strs = [] if node.offset is not None: diff --git a/src/gt4py/cartesian/gtc/dace/utils.py b/src/gt4py/cartesian/gtc/dace/utils.py index 9be2e9a07d..b5c23d2735 100644 --- a/src/gt4py/cartesian/gtc/dace/utils.py +++ b/src/gt4py/cartesian/gtc/dace/utils.py @@ -61,9 +61,9 @@ def get_tasklet_symbol( name: eve.SymbolRef, offset: Union[CartesianOffset, VariableKOffset], is_target: bool ): if is_target: - return f"__{name}" + return f"gtOUT__{name}" - acc_name = name + "__" + acc_name = f"gtIN__{name}" if offset is not None: offset_strs = [] for axis in dcir.Axis.dims_3d(): @@ -230,9 +230,12 @@ def _make_access_info( region, he_grid, grid_subset, + is_write, ) -> dcir.FieldAccessInfo: + # Check we have expression offsets in K + # OR write offsets in K offset = [offset_node.to_dict()[k] for k in "ijk"] - if isinstance(offset_node, oir.VariableKOffset): + if isinstance(offset_node, oir.VariableKOffset) or (offset[2] != 0 and is_write): variable_offset_axes = [dcir.Axis.K] else: variable_offset_axes = [] @@ -291,6 +294,7 @@ def visit_FieldAccess( region=region, he_grid=he_grid, grid_subset=grid_subset, + is_write=is_write, ) ctx.access_infos[node.name] = access_info.union( ctx.access_infos.get(node.name, access_info) diff --git a/src/gt4py/cartesian/stencil_builder.py b/src/gt4py/cartesian/stencil_builder.py index 07d58f25f5..c0f58c0bc9 100644 --- a/src/gt4py/cartesian/stencil_builder.py +++ b/src/gt4py/cartesian/stencil_builder.py @@ -277,7 +277,9 @@ def gtir_pipeline(self) -> GtirPipeline: return self._build_data.get("gtir_pipeline") or self._build_data.setdefault( "gtir_pipeline", GtirPipeline( - self.frontend.generate(self.definition, self.externals, self.dtypes, self.options), + self.frontend.generate( + self.definition, self.externals, self.dtypes, self.options, self.backend.name + ), self.stencil_id, ), ) diff --git a/tests/cartesian_tests/definitions.py b/tests/cartesian_tests/definitions.py index 9ed4e3dfb3..7499ad4a95 100644 --- a/tests/cartesian_tests/definitions.py +++ b/tests/cartesian_tests/definitions.py @@ -15,6 +15,7 @@ import datetime +import numpy as np import pytest from gt4py import cartesian as gt4pyc @@ -54,3 +55,14 @@ def _get_backends_with_storage_info(storage_info_kind: str): @pytest.fixture() def id_version(): return gt_utils.shashed_id(str(datetime.datetime.now())) + + +def get_array_library(backend: str): + """Return device ready array maker library""" + backend_cls = gt4pyc.backend.from_name(backend) + assert backend_cls is not None + if backend_cls.storage_info["device"] == "gpu": + assert cp is not None + return cp + else: + return np diff --git a/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py b/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py index d032e16419..c1b4e58f97 100644 --- a/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py +++ b/tests/cartesian_tests/integration_tests/feature_tests/test_field_layouts.py @@ -6,13 +6,12 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause -import numpy as np import pytest from gt4py import cartesian as gt4pyc, storage as gt_storage from gt4py.cartesian import gtscript -from cartesian_tests.definitions import ALL_BACKENDS, PERFORMANCE_BACKENDS +from cartesian_tests.definitions import ALL_BACKENDS, PERFORMANCE_BACKENDS, get_array_library from cartesian_tests.integration_tests.multi_feature_tests.stencil_definitions import copy_stencil @@ -22,20 +21,10 @@ cp = None -def _get_array_library(backend: str): - backend_cls = gt4pyc.backend.from_name(backend) - assert backend_cls is not None - if backend_cls.storage_info["device"] == "gpu": - assert cp is not None - return cp - else: - return np - - @pytest.mark.parametrize("backend", ALL_BACKENDS) @pytest.mark.parametrize("order", ["C", "F"]) def test_numpy_allocators(backend, order): - xp = _get_array_library(backend) + xp = get_array_library(backend) shape = (20, 10, 5) inp = xp.array(xp.random.randn(*shape), order=order, dtype=xp.float_) outp = xp.zeros(shape=shape, order=order, dtype=xp.float_) @@ -48,7 +37,7 @@ def test_numpy_allocators(backend, order): @pytest.mark.parametrize("backend", PERFORMANCE_BACKENDS) def test_bad_layout_warns(backend): - xp = _get_array_library(backend) + xp = get_array_library(backend) backend_cls = gt4pyc.backend.from_name(backend) assert backend_cls is not None diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py index cb8bb8c5d9..976f9a89af 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_code_generation.py @@ -27,7 +27,7 @@ ) from gt4py.storage.cartesian import utils as storage_utils -from cartesian_tests.definitions import ALL_BACKENDS, CPU_BACKENDS +from cartesian_tests.definitions import ALL_BACKENDS, CPU_BACKENDS, get_array_library from cartesian_tests.integration_tests.multi_feature_tests.stencil_definitions import ( EXTERNALS_REGISTRY as externals_registry, REGISTRY as stencil_definitions, @@ -190,12 +190,20 @@ def stencil( assert field_3d.shape == full_shape[:] field_2d = gt_storage.zeros( - full_shape[:-1], dtype, backend=backend, aligned_index=aligned_index[:-1], dimensions="IJ" + full_shape[:-1], + dtype, + backend=backend, + aligned_index=aligned_index[:-1], + dimensions="IJ", ) assert field_2d.shape == full_shape[:-1] field_1d = gt_storage.ones( - full_shape[-1:], dtype, backend=backend, aligned_index=(aligned_index[-1],), dimensions="K" + full_shape[-1:], + dtype, + backend=backend, + aligned_index=(aligned_index[-1],), + dimensions="K", ) assert list(field_1d.shape) == [full_shape[-1]] @@ -273,7 +281,8 @@ def copy_2to3( def test_lower_dimensional_inputs_2d_to_3d_forward(backend): @gtscript.stencil(backend=backend) def copy_2to3( - inp: gtscript.Field[gtscript.IJ, np.float_], outp: gtscript.Field[gtscript.IJK, np.float_] + inp: gtscript.Field[gtscript.IJ, np.float_], + outp: gtscript.Field[gtscript.IJK, np.float_], ): with computation(FORWARD), interval(...): outp[0, 0, 0] = inp @@ -574,6 +583,125 @@ def test(out: Field[np.float64], inp: Field[np.float64]): test(out, inp) +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_K_offset_write(backend): + # Cuda generates bad code for the K offset + if backend == "cuda": + pytest.skip("cuda K-offset write generates bad code") + if backend in ["gt:gpu", "dace:gpu"]: + import cupy as cp + + if cp.cuda.runtime.runtimeGetVersion() < 12000: + pytest.skip( + f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" + ) + + arraylib = get_array_library(backend) + array_shape = (1, 1, 4) + K_values = arraylib.arange(start=40, stop=44) + + # Simple case of writing ot an offset. + # A is untouched + # B is written in K+1 and should have K_values, except for the first element (FORWARD) + @gtscript.stencil(backend=backend) + def simple(A: Field[np.float64], B: Field[np.float64]): + with computation(FORWARD), interval(...): + B[0, 0, 1] = A + + A = gt_storage.zeros( + backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64 + ) + A[:, :, :] = K_values + B = gt_storage.zeros( + backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64 + ) + simple(A, B) + assert (B[:, :, 0] == 0).all() + assert (B[:, :, 1:3] == K_values[0:2]).all() + + # Order of operations: FORWARD with negative offset + # means while A is update B will have non-updated values of A + # Because of the interval, value of B[0] is 0 + @gtscript.stencil(backend=backend) + def forward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): + with computation(FORWARD), interval(1, None): + A[0, 0, -1] = scalar + B[0, 0, 0] = A + + A = gt_storage.zeros( + backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64 + ) + A[:, :, :] = K_values + B = gt_storage.zeros( + backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64 + ) + forward(A, B, 2.0) + assert (A[:, :, :3] == 2.0).all() + assert (A[:, :, 3] == K_values[3]).all() + assert (B[:, :, 0] == 0).all() + assert (B[:, :, 1:] == K_values[1:]).all() + + # Order of operations: BACKWARD with negative offset + # means A is update B will get the updated values of A + @gtscript.stencil(backend=backend) + def backward(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): + with computation(BACKWARD), interval(1, None): + A[0, 0, -1] = scalar + B[0, 0, 0] = A + + A = gt_storage.zeros( + backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64 + ) + A[:, :, :] = K_values + B = gt_storage.empty( + backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64 + ) + backward(A, B, 2.0) + assert (A[:, :, :3] == 2.0).all() + assert (A[:, :, 3] == K_values[3]).all() + assert (B[:, :, :] == A[:, :, :]).all() + + +@pytest.mark.parametrize("backend", ALL_BACKENDS) +def test_K_offset_write_conditional(backend): + if backend == "cuda": + pytest.skip("Cuda backend is not capable of K offset write") + if backend in ["gt:gpu", "dace:gpu"]: + import cupy as cp + + if cp.cuda.runtime.runtimeGetVersion() < 12000: + pytest.skip( + f"{backend} backend with CUDA 11 and/or GCC 10.3 is not capable of K offset write, update CUDA/GCC if possible" + ) + + arraylib = get_array_library(backend) + array_shape = (1, 1, 4) + K_values = arraylib.arange(start=40, stop=44) + + @gtscript.stencil(backend=backend) + def column_physics_conditional(A: Field[np.float64], B: Field[np.float64], scalar: np.float64): + with computation(BACKWARD), interval(1, None): + if A > 0 and B > 0: + A[0, 0, -1] = scalar + B[0, 0, 1] = A + lev = 1 + while A >= 0 and B >= 0: + A[0, 0, lev] = -1 + B = -1 + lev = lev + 1 + + A = gt_storage.zeros( + backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64 + ) + A[:, :, :] = K_values + B = gt_storage.ones( + backend=backend, aligned_index=(0, 0, 0), shape=array_shape, dtype=np.float64 + ) + column_physics_conditional(A, B, 2.0) + assert (A[0, 0, :] == arraylib.array([2, 2, -1, -1])).all() + assert (B[0, 0, :] == arraylib.array([1, -1, 2, 42])).all() + + @pytest.mark.parametrize("backend", ALL_BACKENDS) def test_direct_datadims_index(backend): F64_VEC4 = (np.float64, (2, 2, 2, 2)) diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py index 1034176789..e62f878746 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_gtscript_frontend.py @@ -19,6 +19,7 @@ from gt4py.cartesian.frontend import gtscript_frontend as gt_frontend, nodes from gt4py.cartesian.gtscript import ( __INLINED, + FORWARD, IJ, IJK, PARALLEL, @@ -62,7 +63,7 @@ def parse_definition( ) definition_ir = gt_frontend.GTScriptParser( definition_func, externals=externals or {}, options=build_options, dtypes=dtypes - ).run() + ).run("numpy") setattr(definition_func, "__annotations__", original_annotations) @@ -108,7 +109,9 @@ def definition_func(inout_field: gtscript.Field[float]): with pytest.raises(gt_frontend.GTScriptSymbolError, match=r".*MISSING_CONSTANT.*"): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def definition_func(inout_field: gtscript.Field[float]): @@ -116,10 +119,13 @@ def definition_func(inout_field: gtscript.Field[float]): inout_field = inout_field[0, 0, 0] + GLOBAL_NESTED_CONSTANTS.missing with pytest.raises( - gt_frontend.GTScriptDefinitionError, match=r".*GLOBAL_NESTED_CONSTANTS.missing.*" + gt_frontend.GTScriptDefinitionError, + match=r".*GLOBAL_NESTED_CONSTANTS.missing.*", ): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_recursive_function_imports(self): @@ -200,7 +206,9 @@ def definition_func(inout_field: gtscript.Field[float]): with pytest.raises(gt_frontend.GTScriptDefinitionError, match=r".*WRONG_VALUE_CONSTANT.*"): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) @@ -215,7 +223,9 @@ def definition_func(inout_field: gtscript.Field[float]): with pytest.raises(TypeError, match=r"func is not a gtscript function"): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_use_in_expr(self): @@ -290,7 +300,9 @@ def definition_func(inout_field: gtscript.Field[float]): "Please assign the function results to symbols first.", ): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_use_in_call_arg_multiple_return(self): @@ -316,7 +328,9 @@ def definition_func(inout_field: gtscript.Field[float]): "Please assign the function results to symbols first.", ): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_recursive_function_call_two_externals(self): @@ -412,7 +426,9 @@ def definition_func(in_field: gtscript.Field[float], out_field: gtscript.Field[f with pytest.raises(gt_frontend.GTScriptSyntaxError): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_bad_dup_add(self): @@ -422,7 +438,9 @@ def definition_func(in_field: gtscript.Field[float], out_field: gtscript.Field[f with pytest.raises(gt_frontend.GTScriptSyntaxError): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_bad_dup_axis(self): @@ -432,7 +450,9 @@ def definition_func(in_field: gtscript.Field[float], out_field: gtscript.Field[f with pytest.raises(gt_frontend.GTScriptSyntaxError): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_bad_out_of_order(self): @@ -442,7 +462,9 @@ def definition_func(in_field: gtscript.Field[float], out_field: gtscript.Field[f with pytest.raises(gt_frontend.GTScriptSyntaxError): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) @@ -495,7 +517,9 @@ def definition_func(inout_field: gtscript.Field[float]): with pytest.raises(gt_frontend.GTScriptDefinitionError, match=r".*MISSING_CONSTANT.*"): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def definition_func(inout_field: gtscript.Field[float]): @@ -613,10 +637,13 @@ def definition_func(field: gtscript.Field[float]): field = 0 with pytest.raises( - gt_frontend.GTScriptSyntaxError, match="Invalid interval range specification" + gt_frontend.GTScriptSyntaxError, + match="Invalid interval range specification", ): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_error_do_not_mix(self): @@ -626,7 +653,9 @@ def definition_func(field: gtscript.Field[float]): with pytest.raises(gt_frontend.GTScriptSyntaxError, match="Two-argument syntax"): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_reversed_interval(self): @@ -635,10 +664,13 @@ def definition_func(field: gtscript.Field[float]): field = 0 with pytest.raises( - gt_frontend.GTScriptSyntaxError, match="Invalid interval range specification" + gt_frontend.GTScriptSyntaxError, + match="Invalid interval range specification", ): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_overlapping_intervals_none(self): @@ -651,7 +683,9 @@ def definition_func(field: gtscript.Field[float]): with pytest.raises(gt_frontend.GTScriptSyntaxError, match="Overlapping intervals"): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_overlapping_intervals(self): @@ -664,7 +698,9 @@ def definition_func(field: gtscript.Field[float]): with pytest.raises(gt_frontend.GTScriptSyntaxError, match="Overlapping intervals"): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_nonoverlapping_intervals(self): @@ -806,7 +842,10 @@ def _stage_laplacian_y(dy, phi): @gtscript.function def _stage_laplacian(dx, dy, phi): - from gt4py.cartesian.__externals__ import stage_laplacian_x, stage_laplacian_y + from gt4py.cartesian.__externals__ import ( + stage_laplacian_x, + stage_laplacian_y, + ) lap_x = stage_laplacian_x(dx=dx, phi=phi) lap_y = stage_laplacian_y(dy=dy, phi=phi) @@ -876,10 +915,13 @@ def definition_func(phi: gtscript.Field[np.float64]): phi = test_no_return(phi) with pytest.raises( - gt_frontend.GTScriptSyntaxError, match="should have a single return statement" + gt_frontend.GTScriptSyntaxError, + match="should have a single return statement", ): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_number_return_args(self): @@ -896,7 +938,9 @@ def definition_func(phi: gtscript.Field[np.float64]): match="Number of returns values does not match arguments on left side", ): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_multiple_return(self): @@ -910,10 +954,13 @@ def definition_func(phi: gtscript.Field[np.float64]): phi = test_multiple_return(phi) with pytest.raises( - gt_frontend.GTScriptSyntaxError, match="should have a single return statement" + gt_frontend.GTScriptSyntaxError, + match="should have a single return statement", ): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_conditional_return(self): @@ -1217,7 +1264,9 @@ def definition_func(inout_field: gtscript.Field[float]): with pytest.raises(gt_frontend.GTScriptSyntaxError): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) @@ -1271,7 +1320,8 @@ def definition_func( ) @pytest.mark.parametrize( - "id_case,test_dtype", list(enumerate([str, np.uint32, np.uint64, dict, map, bytes])) + "id_case,test_dtype", + list(enumerate([str, np.uint32, np.uint64, dict, map, bytes])), ) def test_invalid_inlined_dtypes(self, id_case, test_dtype): with pytest.raises(ValueError, match=r".*data type descriptor.*"): @@ -1285,11 +1335,14 @@ def definition_func( out_field = in_field + param @pytest.mark.parametrize( - "id_case,test_dtype", list(enumerate([str, np.uint32, np.uint64, dict, map, bytes])) + "id_case,test_dtype", + list(enumerate([str, np.uint32, np.uint64, dict, map, bytes])), ) def test_invalid_external_dtypes(self, id_case, test_dtype): def definition_func( - in_field: gtscript.Field["dtype"], out_field: gtscript.Field["dtype"], param: "dtype" + in_field: gtscript.Field["dtype"], + out_field: gtscript.Field["dtype"], + param: "dtype", ): with computation(PARALLEL), interval(...): out_field = in_field + param @@ -1343,7 +1396,10 @@ def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float with pytest.raises(gt_frontend.GTScriptSyntaxError): - def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float_]): + def func( + in_field: gtscript.Field[np.float_], + out_field: gtscript.Field[np.float_], + ): with computation(PARALLEL), interval(...): out_field[0, 0, 1] = in_field @@ -1364,7 +1420,7 @@ def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float with pytest.raises( gt_frontend.GTScriptSyntaxError, - match="Assignment to non-zero offsets is not supported.", + match="Assignment to non-zero offsets in K is not available in PARALLEL. Choose FORWARD or BACKWARD.", ): parse_definition( func, @@ -1403,16 +1459,21 @@ def definition_func( with pytest.raises( gt_frontend.GTScriptSyntaxError, - match="Assignment to non-zero offsets is not supported.", + match="Assignment to non-zero offsets in K is not available in PARALLEL. Choose FORWARD or BACKWARD.", ): parse_definition( - definition_func, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition_func, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) def test_slice(self): with pytest.raises(gt_frontend.GTScriptSyntaxError): - def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float_]): + def func( + in_field: gtscript.Field[np.float_], + out_field: gtscript.Field[np.float_], + ): with computation(PARALLEL), interval(...): out_field[:, :, :] = in_field @@ -1421,7 +1482,10 @@ def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float def test_string(self): with pytest.raises(gt_frontend.GTScriptSyntaxError): - def func(in_field: gtscript.Field[np.float_], out_field: gtscript.Field[np.float_]): + def func( + in_field: gtscript.Field[np.float_], + out_field: gtscript.Field[np.float_], + ): with computation(PARALLEL), interval(...): out_field["a_key"] = in_field @@ -1437,6 +1501,24 @@ def func(in_field: gtscript.Field[np.float_]): parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) + def test_K_offset_write(self): + def func(out: gtscript.Field[np.float64], inp: gtscript.Field[np.float64]): + with computation(FORWARD), interval(...): + out[0, 0, 1] = inp + + parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) + + with pytest.raises( + gt_frontend.GTScriptSyntaxError, + match=r"(.*?)Assignment to non-zero offsets in K is not available in PARALLEL. Choose FORWARD or BACKWARD.(.*)", + ): + + def func(out: gtscript.Field[np.float64], inp: gtscript.Field[np.float64]): + with computation(PARALLEL), interval(...): + out[0, 0, 1] = inp + + parse_definition(func, name=inspect.stack()[0][3], module=self.__class__.__name__) + def test_datadims_direct_access(self): # Check classic data dimensions are working def data_dims( @@ -1499,7 +1581,9 @@ def data_dims_with_at( out_field = global_field.A[1, 0, 2] parse_definition( - data_dims_with_at, name=inspect.stack()[0][3], module=self.__class__.__name__ + data_dims_with_at, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) @@ -1541,7 +1625,9 @@ def definition_bw( match=r"(.*?)Intervals must be specified in order of execution(.*)", ): parse_definition( - definition, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) @@ -1653,7 +1739,11 @@ def sumdiff_defs( @pytest.mark.parametrize("dtype_scalar", [int, np.float32, np.float64]) def test_set_arg_dtypes(self, dtype_in, dtype_out, dtype_scalar): definition = self.sumdiff_defs - dtypes = {"dtype_in": dtype_in, "dtype_out": dtype_out, "dtype_scalar": dtype_scalar} + dtypes = { + "dtype_in": dtype_in, + "dtype_out": dtype_out, + "dtype_scalar": dtype_scalar, + } original_annotations = gtscript._set_arg_dtypes(definition, dtypes) @@ -1701,10 +1791,17 @@ def test_set_arg_dtypes(self, dtype_in, dtype_out, dtype_scalar): @pytest.mark.parametrize("dtype_scalar", [int, np.float32, np.float64]) def test_parsing(self, dtype_in, dtype_out, dtype_scalar): definition = self.sumdiff_defs - dtypes = {"dtype_in": dtype_in, "dtype_out": dtype_out, "dtype_scalar": dtype_scalar} + dtypes = { + "dtype_in": dtype_in, + "dtype_out": dtype_out, + "dtype_scalar": dtype_scalar, + } parse_definition( - definition, dtypes=dtypes, name=inspect.stack()[0][3], module=self.__class__.__name__ + definition, + dtypes=dtypes, + name=inspect.stack()[0][3], + module=self.__class__.__name__, ) annotations = getattr(definition, "__annotations__", {}) diff --git a/tests/cartesian_tests/unit_tests/frontend_tests/test_ir_maker.py b/tests/cartesian_tests/unit_tests/frontend_tests/test_ir_maker.py index e054b7a715..7857ae3e00 100644 --- a/tests/cartesian_tests/unit_tests/frontend_tests/test_ir_maker.py +++ b/tests/cartesian_tests/unit_tests/frontend_tests/test_ir_maker.py @@ -13,7 +13,7 @@ def test_AugAssign(): - ir_maker = IRMaker(None, None, None, domain=None) + ir_maker = IRMaker(None, None, None, None, domain=None) aug_assign = ast.parse("a += 1", feature_version=PYTHON_AST_VERSION).body[0] _, result = ir_maker.visit_AugAssign(aug_assign)