From 6a5ae7ef79b877e6f461560c3f78171aad6e283e Mon Sep 17 00:00:00 2001 From: Hannes Vogt Date: Tue, 23 Apr 2024 17:29:26 +0200 Subject: [PATCH] refactor[next]: move lift_mode itir test fixture into program_processor (#1533) Currently we have a mix of specifying the backend which already comes with a lift_mode default and separately the lift_mode fixture in some tests. The default was not overwritten in the roundtrip backend. Now we remove the separate `lift_mode` and add extra backends with the lift_mode set. Note: we don't run double_roundtrip with temporaries. Longer term we should refactor all itir tests to use the ffront test infrastructure. --------- Co-authored-by: Till Ehrengruber --- docs/user/next/QuickstartGuide.md | 2 +- src/gt4py/next/__init__.py | 2 +- src/gt4py/next/constructors.py | 12 ++--- .../next/iterator/transforms/pass_manager.py | 4 -- .../codegens/gtfn/gtfn_module.py | 20 +------- .../program_processors/formatters/gtfn.py | 1 - .../runners/dace_iterator/__init__.py | 3 -- .../runners/dace_iterator/workflow.py | 17 +------ .../runners/double_roundtrip.py | 6 +-- .../next/program_processors/runners/gtfn.py | 2 - .../program_processors/runners/roundtrip.py | 23 ++++++--- tests/next_tests/definitions.py | 8 +++- .../feature_tests/iterator_tests/test_scan.py | 5 +- .../iterator_tests/test_trivial.py | 14 ++---- .../ffront_tests/test_ffront_fvm_nabla.py | 6 ++- .../iterator_tests/test_anton_toy.py | 15 ++---- .../iterator_tests/test_column_stencil.py | 17 +++---- .../iterator_tests/test_fvm_nabla.py | 29 +++++------- .../iterator_tests/test_hdiff.py | 18 ++----- .../iterator_tests/test_vertical_advection.py | 25 ++++------ .../test_with_toy_connectivity.py | 47 ++++++------------- tests/next_tests/unit_tests/conftest.py | 13 +---- 22 files changed, 92 insertions(+), 197 deletions(-) diff --git a/docs/user/next/QuickstartGuide.md b/docs/user/next/QuickstartGuide.md index f8ff64a980..81604c7770 100644 --- a/docs/user/next/QuickstartGuide.md +++ b/docs/user/next/QuickstartGuide.md @@ -77,7 +77,7 @@ array_of_ones_numpy = np.ones((grid_shape[0], grid_shape[1])) field_of_ones = gtx.ones( domain={I: range(grid_shape[0]), J: range(grid_shape[0])}, dtype=np.float64, - allocator=gtx.program_processors.runners.roundtrip.backend + allocator=gtx.itir_python ) ``` diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index e79e2f5517..f33b9c5127 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -53,7 +53,7 @@ run_gtfn_cached as gtfn_cpu, run_gtfn_gpu_cached as gtfn_gpu, ) -from .program_processors.runners.roundtrip import backend as itir_python +from .program_processors.runners.roundtrip import default as itir_python __all__ = [ diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 43d9cb81b9..18a89ec07a 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -67,9 +67,8 @@ def empty( Initialize a field in one dimension with a backend and a range domain: >>> from gt4py import next as gtx - >>> from gt4py.next.program_processors.runners import roundtrip >>> IDim = gtx.Dimension("I") - >>> a = gtx.empty({IDim: range(3, 10)}, allocator=roundtrip.backend) + >>> a = gtx.empty({IDim: range(3, 10)}, allocator=gtx.itir_python) >>> a.shape (7,) @@ -109,9 +108,8 @@ def zeros( Examples: >>> from gt4py import next as gtx - >>> from gt4py.next.program_processors.runners import roundtrip >>> IDim = gtx.Dimension("I") - >>> gtx.zeros({IDim: range(3, 10)}, allocator=roundtrip.backend).ndarray + >>> gtx.zeros({IDim: range(3, 10)}, allocator=gtx.itir_python).ndarray array([0., 0., 0., 0., 0., 0., 0.]) """ field = empty( @@ -137,9 +135,8 @@ def ones( Examples: >>> from gt4py import next as gtx - >>> from gt4py.next.program_processors.runners import roundtrip >>> IDim = gtx.Dimension("I") - >>> gtx.ones({IDim: range(3, 10)}, allocator=roundtrip.backend).ndarray + >>> gtx.ones({IDim: range(3, 10)}, allocator=gtx.itir_python).ndarray array([1., 1., 1., 1., 1., 1., 1.]) """ field = empty( @@ -171,9 +168,8 @@ def full( Examples: >>> from gt4py import next as gtx - >>> from gt4py.next.program_processors.runners import roundtrip >>> IDim = gtx.Dimension("I") - >>> gtx.full({IDim: 3}, 5, allocator=roundtrip.backend).ndarray + >>> gtx.full({IDim: 3}, 5, allocator=gtx.itir_python).ndarray array([5, 5, 5]) """ field = empty( diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 8a19203275..32b42f8d2b 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -17,7 +17,6 @@ from gt4py.eve import utils as eve_utils from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import simple_inline_heuristic from gt4py.next.iterator.transforms.collapse_list_get import CollapseListGet from gt4py.next.iterator.transforms.collapse_tuple import CollapseTuple from gt4py.next.iterator.transforms.constant_folding import ConstantFolding @@ -41,14 +40,11 @@ class LiftMode(enum.Enum): FORCE_INLINE = enum.auto() USE_TEMPORARIES = enum.auto() - SIMPLE_HEURISTIC = enum.auto() def _inline_lifts(ir, lift_mode): if lift_mode == LiftMode.FORCE_INLINE: return InlineLifts().visit(ir) - elif lift_mode == LiftMode.SIMPLE_HEURISTIC: - return InlineLifts(simple_inline_heuristic.is_eligible_for_inlining).visit(ir) elif lift_mode == LiftMode.USE_TEMPORARIES: return InlineLifts( flags=InlineLifts.Flag.INLINE_TRIVIAL_DEREF_LIFT diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index d28e513093..52b9376b82 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -16,7 +16,6 @@ import dataclasses import functools -import warnings from typing import Any, Callable, Final, Optional import factory @@ -182,26 +181,13 @@ def _preprocess_program( self, program: itir.FencilDefinition, offset_provider: dict[str, Connectivity | Dimension], - runtime_lift_mode: Optional[LiftMode], ) -> itir.FencilDefinition | global_tmps.FencilWithTemporaries: - # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added - # to the interface of all (or at least all of concern) backends, but instead should be - # configured in the backend itself (like it is here), until then we respect the argument - # here and warn the user if it differs from the one configured. - lift_mode = runtime_lift_mode or self.lift_mode - if runtime_lift_mode and runtime_lift_mode != self.lift_mode: - warnings.warn( - f"GTFN Backend was configured for LiftMode `{self.lift_mode!s}`, but " - f"overriden to be {runtime_lift_mode!s} at runtime.", - stacklevel=2, - ) - if not self.enable_itir_transforms: return program apply_common_transforms = functools.partial( pass_manager.apply_common_transforms, - lift_mode=lift_mode, + lift_mode=self.lift_mode, offset_provider=offset_provider, # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements unconditionally_collapse_tuples=True, @@ -228,9 +214,8 @@ def generate_stencil_source( program: itir.FencilDefinition, offset_provider: dict[str, Connectivity | Dimension], column_axis: Optional[common.Dimension], - runtime_lift_mode: Optional[LiftMode] = None, ) -> str: - new_program = self._preprocess_program(program, offset_provider, runtime_lift_mode) + new_program = self._preprocess_program(program, offset_provider) program_itir = fencil_to_program.FencilToProgram().apply( new_program ) # TODO(havogt): should be removed after refactoring to combined IR @@ -278,7 +263,6 @@ def __call__( program, inp.kwargs["offset_provider"], inp.kwargs.get("column_axis", None), - inp.kwargs.get("lift_mode", None), ) source_code = interface.format_source( self._language_settings(), diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index 6c8d4478c2..632974a787 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -29,5 +29,4 @@ def format_cpp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str program, offset_provider=kwargs.get("offset_provider", None), column_axis=kwargs.get("column_axis", None), - runtime_lift_mode=kwargs.get("lift_mode", None), ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 2b56dc0420..2bbf068d53 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -248,9 +248,6 @@ def build_sdfg_from_itir( load_sdfg_from_file: Allows to read the SDFG from file, instead of generating it, for debug only. save_sdfg: If `True`, the default the SDFG is stored as a file and can be loaded, this allows to skip the lowering step, requires `load_sdfg_from_file` set to `True`. use_field_canonical_representation: If `True`, assume that the fields dimensions are sorted alphabetically. - - Notes: - Currently only the `FORCE_INLINE` liftmode is supported and the value of `lift_mode` is ignored. """ sdfg_filename = f"_dacegraphs/gt4py/{program.id}.sdfg" diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py index 96a40e7450..1a7a36b8c5 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/workflow.py @@ -15,7 +15,6 @@ from __future__ import annotations import dataclasses -import warnings from typing import Callable, Optional, cast import dace @@ -62,22 +61,9 @@ def generate_sdfg( arg_types: list[ts.TypeSpec], offset_provider: dict[str, common.Dimension | common.Connectivity], column_axis: Optional[common.Dimension], - runtime_lift_mode: Optional[LiftMode] = None, ) -> dace.SDFG: on_gpu = True if self.device_type == core_defs.DeviceType.CUDA else False - # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added - # to the interface of all (or at least all of concern) backends, but instead should be - # configured in the backend itself (like it is here), until then we respect the argument - # here and warn the user if it differs from the one configured. - lift_mode = runtime_lift_mode or self.lift_mode - if runtime_lift_mode and runtime_lift_mode != self.lift_mode: - warnings.warn( - f"DaCe Backend was configured for LiftMode `{self.lift_mode!s}`, but " - f"overriden to be {runtime_lift_mode!s} at runtime.", - stacklevel=2, - ) - return build_sdfg_from_itir( program, arg_types, @@ -85,7 +71,7 @@ def generate_sdfg( auto_optimize=self.auto_optimize, on_gpu=on_gpu, column_axis=column_axis, - lift_mode=lift_mode, + lift_mode=self.lift_mode, symbolic_domain_sizes=self.symbolic_domain_sizes, temporary_extraction_heuristics=self.temporary_extraction_heuristics, load_sdfg_from_file=False, @@ -105,7 +91,6 @@ def __call__( arg_types, inp.kwargs["offset_provider"], inp.kwargs.get("column_axis", None), - inp.kwargs.get("lift_mode", None), ) param_types = tuple( diff --git a/src/gt4py/next/program_processors/runners/double_roundtrip.py b/src/gt4py/next/program_processors/runners/double_roundtrip.py index e6220ea879..0b0f71c2f7 100644 --- a/src/gt4py/next/program_processors/runners/double_roundtrip.py +++ b/src/gt4py/next/program_processors/runners/double_roundtrip.py @@ -19,8 +19,6 @@ backend = next_backend.Backend( - executor=roundtrip.RoundtripExecutorFactory( - dispatch_backend=roundtrip.RoundtripExecutorFactory() - ), - allocator=roundtrip.backend.allocator, + executor=roundtrip.RoundtripExecutorFactory(dispatch_backend=roundtrip.default.executor), + allocator=roundtrip.default.allocator, ) diff --git a/src/gt4py/next/program_processors/runners/gtfn.py b/src/gt4py/next/program_processors/runners/gtfn.py index 39ec607323..f88dcc5825 100644 --- a/src/gt4py/next/program_processors/runners/gtfn.py +++ b/src/gt4py/next/program_processors/runners/gtfn.py @@ -109,8 +109,6 @@ def compilation_hash(otf_closure: stages.ProgramCall) -> int: content_hash(tuple(from_value(arg) for arg in otf_closure.args)), id(offset_provider) if offset_provider else None, otf_closure.kwargs.get("column_axis", None), - # TODO(tehrengruber): Remove `lift_mode` from call interface. - otf_closure.kwargs.get("lift_mode", None), ) ) diff --git a/src/gt4py/next/program_processors/runners/roundtrip.py b/src/gt4py/next/program_processors/runners/roundtrip.py index aba2dbf71b..d90e7e5c8f 100644 --- a/src/gt4py/next/program_processors/runners/roundtrip.py +++ b/src/gt4py/next/program_processors/runners/roundtrip.py @@ -26,7 +26,7 @@ from gt4py.eve import codegen from gt4py.eve.codegen import FormatTemplate as as_fmt, MakoTemplate as as_mako -from gt4py.next import allocators as next_allocators, backend as next_backend, common +from gt4py.next import allocators as next_allocators, backend as next_backend, common, config from gt4py.next.iterator import embedded, ir as itir, transforms as itir_transforms from gt4py.next.iterator.transforms import global_tmps as gtmps_transform from gt4py.next.otf import stages, workflow @@ -103,8 +103,6 @@ def visit_Temporary(self, node: itir.Temporary, **kwargs: Any) -> str: return f"{node.id} = {_create_tmp(axes, origin, shape, node.dtype)}" -_BACKEND_NAME = "roundtrip" - _FENCIL_CACHE: dict[int, Callable] = {} @@ -131,6 +129,8 @@ def fencil_generator( # caching mechanism cache_key = hash((ir, lift_mode, debug, use_embedded, tuple(offset_provider.items()))) if cache_key in _FENCIL_CACHE: + if debug: + print(f"Using cached fencil for key {cache_key}") return _FENCIL_CACHE[cache_key] ir = itir_transforms.apply_common_transforms( @@ -209,10 +209,11 @@ def execute_roundtrip( *args: Any, column_axis: Optional[common.Dimension] = None, offset_provider: dict[str, embedded.NeighborTableOffsetProvider], - debug: bool = False, + debug: Optional[bool] = None, lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE, dispatch_backend: Optional[ppi.ProgramExecutor] = None, ) -> None: + debug = debug if debug is not None else config.DEBUG fencil = fencil_generator( ir, offset_provider=offset_provider, @@ -230,15 +231,16 @@ def execute_roundtrip( @dataclasses.dataclass(frozen=True) class Roundtrip(workflow.Workflow[stages.ProgramCall, stages.CompiledProgram]): - debug: bool = False + debug: Optional[bool] = None lift_mode: itir_transforms.LiftMode = itir_transforms.LiftMode.FORCE_INLINE use_embedded: bool = True def __call__(self, inp: stages.ProgramCall) -> stages.CompiledProgram: + debug = config.DEBUG if self.debug is None else self.debug return fencil_generator( inp.program, offset_provider=inp.kwargs.get("offset_provider", None), - debug=self.debug, + debug=debug, lift_mode=self.lift_mode, use_embedded=self.use_embedded, ) @@ -275,7 +277,14 @@ class Params: executor = RoundtripExecutorFactory(name="roundtrip") +executor_with_temporaries = RoundtripExecutorFactory( + name="roundtrip_with_temporaries", + roundtrip_workflow=RoundtripFactory(lift_mode=itir_transforms.LiftMode.USE_TEMPORARIES), +) -backend = next_backend.Backend( +default = next_backend.Backend( executor=executor, allocator=next_allocators.StandardCPUFieldBufferAllocator() ) +with_temporaries = next_backend.Backend( + executor=executor_with_temporaries, allocator=next_allocators.StandardCPUFieldBufferAllocator() +) diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index 86eac69712..c7573aa8f3 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -54,7 +54,8 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): "gt4py.next.program_processors.runners.gtfn.run_gtfn_with_temporaries" ) GTFN_GPU = "gt4py.next.program_processors.runners.gtfn.run_gtfn_gpu" - ROUNDTRIP = "gt4py.next.program_processors.runners.roundtrip.backend" + ROUNDTRIP = "gt4py.next.program_processors.runners.roundtrip.default" + ROUNDTRIP_WITH_TEMPORARIES = "gt4py.next.program_processors.runners.roundtrip.with_temporaries" DOUBLE_ROUNDTRIP = "gt4py.next.program_processors.runners.double_roundtrip.backend" @@ -192,4 +193,9 @@ class ProgramFormatterId(_PythonObjectIdMixin, str, enum.Enum): (USES_REDUCTION_WITH_ONLY_SPARSE_FIELDS, XFAIL, REDUCTION_WITH_ONLY_SPARSE_FIELDS_MESSAGE) ], ProgramBackendId.ROUNDTRIP: [(USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE)], + ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES: [ + (USES_SPARSE_FIELDS_AS_OUTPUT, XFAIL, UNSUPPORTED_MESSAGE), + (USES_DYNAMIC_OFFSETS, XFAIL, UNSUPPORTED_MESSAGE), + (USES_STRIDED_NEIGHBOR_OFFSET, XFAIL, UNSUPPORTED_MESSAGE), + ], } diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py index fce1aa3960..ef38e23e60 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_scan.py @@ -20,11 +20,11 @@ from gt4py.next.iterator.runtime import fundef, offset from next_tests.integration_tests.cases import IDim, KDim -from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor, run_processor @pytest.mark.uses_index_fields -def test_scan_in_stencil(program_processor, lift_mode): +def test_scan_in_stencil(program_processor): program_processor, validate = program_processor isize = 1 @@ -54,7 +54,6 @@ def wrapped(inp): program_processor, inp, out=out, - lift_mode=lift_mode, offset_provider={"Koff": KDim}, column_axis=KDim, ) diff --git a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py index f85b9b4035..7bb023aabb 100644 --- a/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py +++ b/tests/next_tests/integration_tests/feature_tests/iterator_tests/test_trivial.py @@ -21,7 +21,7 @@ from gt4py.next.iterator.runtime import closure, fendef, fundef, offset from next_tests.integration_tests.cases import IDim, JDim, KDim -from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor, run_processor I = offset("I") @@ -44,7 +44,7 @@ def baz(baz_inp): return deref(lift(bar)(baz_inp)) -def test_trivial(program_processor, lift_mode): +def test_trivial(program_processor): program_processor, validate = program_processor rng = np.random.default_rng() @@ -60,7 +60,6 @@ def test_trivial(program_processor, lift_mode): program_processor, inp_s, out=out_s, - lift_mode=lift_mode, offset_provider={"I": IDim, "J": JDim}, ) @@ -73,12 +72,9 @@ def stencil_shifted_arg_to_lift(inp): return deref(lift(deref)(shift(I, -1)(inp))) -def test_shifted_arg_to_lift(program_processor, lift_mode): +def test_shifted_arg_to_lift(program_processor): program_processor, validate = program_processor - if lift_mode != transforms.LiftMode.FORCE_INLINE: - pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE") - rng = np.random.default_rng() inp = rng.uniform(size=(5, 7)) out = np.zeros_like(inp) @@ -95,7 +91,6 @@ def test_shifted_arg_to_lift(program_processor, lift_mode): program_processor, inp_s, out=out_s, - lift_mode=lift_mode, offset_provider={"I": IDim, "J": JDim}, ) @@ -113,7 +108,7 @@ def fen_direct_deref(i_size, j_size, out, inp): ) -def test_direct_deref(program_processor, lift_mode): +def test_direct_deref(program_processor): program_processor, validate = program_processor rng = np.random.default_rng() @@ -129,7 +124,6 @@ def test_direct_deref(program_processor, lift_mode): *out.shape, out_s, inp_s, - lift_mode=lift_mode, offset_provider=dict(), ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py index aeed607b01..bdb50f27ff 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_ffront_fvm_nabla.py @@ -87,7 +87,9 @@ def test_ffront_compute_zavgS(exec_alloc_descriptor): atlas_utils.AtlasTable(setup.edges2node_connectivity).asnumpy(), Edge, Vertex, 2, False ) - compute_zavgS.with_backend(executor)(pp, S_M[0], out=zavgS, offset_provider={"E2V": e2v}) + compute_zavgS.with_backend(exec_alloc_descriptor)( + pp, S_M[0], out=zavgS, offset_provider={"E2V": e2v} + ) assert_close(-199755464.25741270, np.min(zavgS.asnumpy())) assert_close(388241977.58389181, np.max(zavgS.asnumpy())) @@ -113,7 +115,7 @@ def test_ffront_nabla(exec_alloc_descriptor): atlas_utils.AtlasTable(setup.nodes2edge_connectivity).asnumpy(), Vertex, Edge, 7 ) - pnabla.with_backend(executor)( + pnabla.with_backend(exec_alloc_descriptor)( pp, S_M, sign, vol, out=(pnabla_MXX, pnabla_MYY), offset_provider={"E2V": e2v, "V2E": v2e} ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py index bcea9e0901..445255f391 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_anton_toy.py @@ -20,7 +20,7 @@ from gt4py.next.iterator.runtime import closure, fendef, fundef, offset from gt4py.next.program_processors.runners import gtfn -from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor, run_processor @fundef @@ -79,18 +79,13 @@ def naive_lap(inp): @pytest.mark.uses_origin -def test_anton_toy(program_processor, lift_mode): +def test_anton_toy(program_processor): program_processor, validate = program_processor if program_processor in [ - gtfn.run_gtfn.executor, - gtfn.run_gtfn_imperative.executor, gtfn.run_gtfn_with_temporaries.executor, ]: - from gt4py.next.iterator import transforms - - if lift_mode != transforms.LiftMode.FORCE_INLINE: - pytest.xfail("TODO: issue with temporaries that crashes the application") + pytest.xfail("TODO: issue with temporaries that crashes the application") shape = [5, 7, 9] rng = np.random.default_rng() @@ -102,9 +97,7 @@ def test_anton_toy(program_processor, lift_mode): out = gtx.as_field([IDim, JDim, KDim], np.zeros(shape)) ref = naive_lap(inp) - run_processor( - fencil, program_processor, shape[0], shape[1], shape[2], out, inp, lift_mode=lift_mode - ) + run_processor(fencil, program_processor, shape[0], shape[1], shape[2], out, inp) if validate: assert np.allclose(out.asnumpy(), ref) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py index c7c8cf6c57..7f6caa9de0 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_column_stencil.py @@ -21,7 +21,7 @@ from gt4py.next.iterator.runtime import closure, fendef, fundef, offset from next_tests.integration_tests.cases import IDim, KDim -from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor, run_processor I = offset("I") @@ -78,7 +78,7 @@ def basic_stencils(request): @pytest.mark.uses_origin -def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): +def test_basic_column_stencils(program_processor, basic_stencils): program_processor, validate = program_processor stencil, ref_fun, inp_fun = basic_stencils @@ -99,7 +99,6 @@ def test_basic_column_stencils(program_processor, lift_mode, basic_stencils): out=out, offset_provider={"I": IDim, "K": KDim}, column_axis=KDim, - lift_mode=lift_mode, ) if validate: @@ -153,7 +152,7 @@ def k_level_condition_upper_tuple(k_idx, k_level): ], ) @pytest.mark.uses_tuple_args -def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_function, ref_function): +def test_k_level_condition(program_processor, fun, k_level, inp_function, ref_function): program_processor, validate = program_processor k_size = 5 @@ -170,7 +169,6 @@ def test_k_level_condition(program_processor, lift_mode, fun, k_level, inp_funct out=out, offset_provider={"K": KDim}, column_axis=KDim, - lift_mode=lift_mode, ) if validate: @@ -201,7 +199,7 @@ def ksum_fencil(i_size, k_start, k_end, inp, out): "kstart, reference", [(0, np.asarray([[0, 1, 3, 6, 10, 15, 21]])), (2, np.asarray([[0, 0, 2, 5, 9, 14, 20]]))], ) -def test_ksum_scan(program_processor, lift_mode, kstart, reference): +def test_ksum_scan(program_processor, kstart, reference): program_processor, validate = program_processor shape = [1, 7] inp = gtx.as_field([IDim, KDim], np.array(np.broadcast_to(np.arange(0.0, 7.0), shape))) @@ -216,7 +214,6 @@ def test_ksum_scan(program_processor, lift_mode, kstart, reference): inp, out, offset_provider={"I": IDim, "K": KDim}, - lift_mode=lift_mode, ) if validate: @@ -238,7 +235,7 @@ def ksum_back_fencil(i_size, k_size, inp, out): ) -def test_ksum_back_scan(program_processor, lift_mode): +def test_ksum_back_scan(program_processor): program_processor, validate = program_processor shape = [1, 7] inp = gtx.as_field([IDim, KDim], np.array(np.broadcast_to(np.arange(0.0, 7.0), shape))) @@ -254,7 +251,6 @@ def test_ksum_back_scan(program_processor, lift_mode): inp, out, offset_provider={"I": IDim, "K": KDim}, - lift_mode=lift_mode, ) if validate: @@ -300,7 +296,7 @@ def kdoublesum_fencil(i_size, k_start, k_end, inp0, inp1, out): ), ], ) -def test_kdoublesum_scan(program_processor, lift_mode, kstart, reference): +def test_kdoublesum_scan(program_processor, kstart, reference): program_processor, validate = program_processor pytest.xfail("structured dtype input/output currently unsupported") shape = [1, 7] @@ -321,7 +317,6 @@ def test_kdoublesum_scan(program_processor, lift_mode, kstart, reference): inp1, out, offset_provider={"I": IDim, "K": KDim}, - lift_mode=lift_mode, ) if validate: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py index d29ef68d4e..f83430ef9f 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_fvm_nabla.py @@ -42,7 +42,7 @@ assert_close, nabla_setup, ) -from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor, run_processor Vertex = gtx.Dimension("Vertex") @@ -116,7 +116,7 @@ def nabla(n_nodes, out, pp, S_MXX, S_MYY, sign, vol): @pytest.mark.requires_atlas -def test_compute_zavgS(program_processor, lift_mode): +def test_compute_zavgS(program_processor): program_processor, validate = program_processor setup = nabla_setup() @@ -137,7 +137,6 @@ def test_compute_zavgS(program_processor, lift_mode): pp, S_MXX, offset_provider={"E2V": e2v}, - lift_mode=lift_mode, ) if validate: @@ -152,7 +151,6 @@ def test_compute_zavgS(program_processor, lift_mode): pp, S_MYY, offset_provider={"E2V": e2v}, - lift_mode=lift_mode, ) if validate: assert_close(-1000788897.3202186, np.min(zavgS.asnumpy())) @@ -165,7 +163,7 @@ def compute_zavgS2_fencil(n_edges, out, pp, S_M): @pytest.mark.requires_atlas -def test_compute_zavgS2(program_processor, lift_mode): +def test_compute_zavgS2(program_processor): program_processor, validate = program_processor setup = nabla_setup() @@ -190,7 +188,6 @@ def test_compute_zavgS2(program_processor, lift_mode): pp, S, offset_provider={"E2V": e2v}, - lift_mode=lift_mode, ) if validate: @@ -202,10 +199,9 @@ def test_compute_zavgS2(program_processor, lift_mode): @pytest.mark.requires_atlas -def test_nabla(program_processor, lift_mode): +def test_nabla(program_processor): program_processor, validate = program_processor - if lift_mode != LiftMode.FORCE_INLINE: - pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE") + setup = nabla_setup() sign = gtx.as_field([Vertex, V2EDim], setup.sign_field) @@ -234,7 +230,6 @@ def test_nabla(program_processor, lift_mode): sign, vol, offset_provider={"E2V": e2v, "V2E": v2e}, - lift_mode=lift_mode, ) if validate: @@ -255,7 +250,7 @@ def nabla2(n_nodes, out, pp, S, sign, vol): @pytest.mark.requires_atlas -def test_nabla2(program_processor, lift_mode): +def test_nabla2(program_processor): program_processor, validate = program_processor setup = nabla_setup() @@ -274,7 +269,9 @@ def test_nabla2(program_processor, lift_mode): AtlasTable(setup.nodes2edge_connectivity), Vertex, Edge, 7 ) - nabla2( + run_processor( + nabla2, + program_processor, setup.nodes_size, (pnabla_MXX, pnabla_MYY), pp, @@ -282,8 +279,6 @@ def test_nabla2(program_processor, lift_mode): sign, vol, offset_provider={"E2V": e2v, "V2E": v2e}, - program_processor=program_processor, - lift_mode=lift_mode, ) if validate: @@ -334,10 +329,9 @@ def nabla_sign(n_nodes, out_MXX, out_MYY, pp, S_MXX, S_MYY, vol, node_index, is_ @pytest.mark.requires_atlas -def test_nabla_sign(program_processor, lift_mode): +def test_nabla_sign(program_processor): program_processor, validate = program_processor - if lift_mode != LiftMode.FORCE_INLINE: - pytest.xfail("test is broken due to bad lift semantics in iterator IR") + setup = nabla_setup() is_pole_edge = gtx.as_field([Edge], setup.is_pole_edge_field) @@ -368,7 +362,6 @@ def test_nabla_sign(program_processor, lift_mode): gtx.index_field(Vertex), is_pole_edge, offset_provider={"E2V": e2v, "V2E": v2e}, - lift_mode=lift_mode, ) if validate: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py index 5d369c3a8f..3abdd7cd5a 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_hdiff.py @@ -24,7 +24,7 @@ from next_tests.integration_tests.multi_feature_tests.iterator_tests.hdiff_reference import ( hdiff_reference, ) -from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor, run_processor I = offset("I") @@ -72,18 +72,8 @@ def hdiff(inp, coeff, out, x, y): @pytest.mark.uses_origin -def test_hdiff(hdiff_reference, program_processor, lift_mode): +def test_hdiff(hdiff_reference, program_processor): program_processor, validate = program_processor - if program_processor in [ - gtfn.run_gtfn.executor, - gtfn.run_gtfn_imperative.executor, - gtfn.run_gtfn_with_temporaries.executor, - ]: - # TODO(tehrengruber): check if still true - from gt4py.next.iterator import transforms - - if lift_mode != transforms.LiftMode.FORCE_INLINE: - pytest.xfail("Temporaries are not compatible with origins.") inp, coeff, out = hdiff_reference shape = (out.shape[0], out.shape[1]) @@ -92,9 +82,7 @@ def test_hdiff(hdiff_reference, program_processor, lift_mode): coeff_s = gtx.as_field([IDim, JDim], coeff[:, :, 0]) out_s = gtx.as_field([IDim, JDim], np.zeros_like(coeff[:, :, 0])) - run_processor( - hdiff, program_processor, inp_s, coeff_s, out_s, shape[0], shape[1], lift_mode=lift_mode - ) + run_processor(hdiff, program_processor, inp_s, coeff_s, out_s, shape[0], shape[1]) if validate: assert np.allclose(out[:, :, 0], out_s.asnumpy()) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py index 820e9415bc..921d1ae116 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_vertical_advection.py @@ -23,7 +23,7 @@ from gt4py.next.program_processors.runners import gtfn from next_tests.integration_tests.cases import IDim, JDim, KDim -from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor, run_processor @fundef @@ -110,23 +110,15 @@ def fen_solve_tridiag2(i_size, j_size, k_size, a, b, c, d, x): @pytest.mark.parametrize("fencil", [fen_solve_tridiag, fen_solve_tridiag2]) @pytest.mark.uses_lift_expressions -def test_tridiag(fencil, tridiag_reference, program_processor, lift_mode): +def test_tridiag(fencil, tridiag_reference, program_processor): program_processor, validate = program_processor - if ( - program_processor - in [ - gtfn.run_gtfn.executor, - gtfn.run_gtfn_imperative.executor, - gtfn.run_gtfn_with_temporaries.executor, - gtfn_formatters.format_cpp, - ] - and lift_mode == LiftMode.FORCE_INLINE - ): + if program_processor in [ + gtfn.run_gtfn.executor, + gtfn.run_gtfn_imperative.executor, + gtfn_formatters.format_cpp, + ]: pytest.skip("gtfn does only support lifted scans when using temporaries") - if ( - program_processor == gtfn.run_gtfn_with_temporaries.executor - or lift_mode == LiftMode.USE_TEMPORARIES - ): + if program_processor == gtfn.run_gtfn_with_temporaries.executor: pytest.xfail("tuple_get on columns not supported.") a, b, c, d, x = tridiag_reference shape = a.shape @@ -150,7 +142,6 @@ def test_tridiag(fencil, tridiag_reference, program_processor, lift_mode): x_s, offset_provider={}, column_axis=KDim, - lift_mode=lift_mode, ) if validate: diff --git a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py index 714e568b8f..dfd54debb6 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/iterator_tests/test_with_toy_connectivity.py @@ -48,7 +48,7 @@ v2e_arr, v2v_arr, ) -from next_tests.unit_tests.conftest import lift_mode, program_processor, run_processor +from next_tests.unit_tests.conftest import program_processor, run_processor def edge_index_field(): # TODO replace by gtx.index_field once supported in bindings @@ -84,7 +84,7 @@ def sum_edges_to_vertices_reduce(in_edges): "stencil", [sum_edges_to_vertices, sum_edges_to_vertices_list_get_neighbors, sum_edges_to_vertices_reduce], ) -def test_sum_edges_to_vertices(program_processor, lift_mode, stencil): +def test_sum_edges_to_vertices(program_processor, stencil): program_processor, validate = program_processor inp = edge_index_field() out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) @@ -96,7 +96,6 @@ def test_sum_edges_to_vertices(program_processor, lift_mode, stencil): inp, out=out, offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, - lift_mode=lift_mode, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -107,7 +106,7 @@ def map_neighbors(in_edges): return reduce(plus, 0)(map_(plus)(neighbors(V2E, in_edges), neighbors(V2E, in_edges))) -def test_map_neighbors(program_processor, lift_mode): +def test_map_neighbors(program_processor): program_processor, validate = program_processor inp = edge_index_field() out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) @@ -119,7 +118,6 @@ def test_map_neighbors(program_processor, lift_mode): inp, out=out, offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, - lift_mode=lift_mode, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -131,7 +129,7 @@ def map_make_const_list(in_edges): @pytest.mark.uses_constant_fields -def test_map_make_const_list(program_processor, lift_mode): +def test_map_make_const_list(program_processor): program_processor, validate = program_processor inp = edge_index_field() out = gtx.as_field([Vertex], np.zeros([9], inp.dtype)) @@ -143,7 +141,6 @@ def test_map_make_const_list(program_processor, lift_mode): inp, out=out, offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, - lift_mode=lift_mode, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -154,7 +151,7 @@ def first_vertex_neigh_of_first_edge_neigh_of_cells(in_vertices): return deref(shift(E2V, 0)(shift(C2E, 0)(in_vertices))) -def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processor, lift_mode): +def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processor): program_processor, validate = program_processor inp = vertex_index_field() out = gtx.as_field([Cell], np.zeros([9], dtype=inp.dtype)) @@ -169,7 +166,6 @@ def test_first_vertex_neigh_of_first_edge_neigh_of_cells_fencil(program_processo "E2V": gtx.NeighborTableOffsetProvider(e2v_arr, Edge, Vertex, 2), "C2E": gtx.NeighborTableOffsetProvider(c2e_arr, Cell, Edge, 4), }, - lift_mode=lift_mode, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -180,7 +176,7 @@ def sparse_stencil(non_sparse, inp): return reduce(lambda a, b, c: a + c, 0)(neighbors(V2E, non_sparse), deref(inp)) -def test_sparse_input_field(program_processor, lift_mode): +def test_sparse_input_field(program_processor): program_processor, validate = program_processor non_sparse = gtx.as_field([Edge], np.zeros(18, dtype=np.int32)) @@ -196,14 +192,13 @@ def test_sparse_input_field(program_processor, lift_mode): inp, out=out, offset_provider={"V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4)}, - lift_mode=lift_mode, ) if validate: assert np.allclose(out.asnumpy(), ref) -def test_sparse_input_field_v2v(program_processor, lift_mode): +def test_sparse_input_field_v2v(program_processor): program_processor, validate = program_processor non_sparse = gtx.as_field([Edge], np.zeros(18, dtype=np.int32)) @@ -222,7 +217,6 @@ def test_sparse_input_field_v2v(program_processor, lift_mode): "V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4), "V2E": gtx.NeighborTableOffsetProvider(v2e_arr, Vertex, Edge, 4), }, - lift_mode=lift_mode, ) if validate: @@ -235,7 +229,7 @@ def slice_sparse_stencil(sparse): @pytest.mark.uses_sparse_fields -def test_slice_sparse(program_processor, lift_mode): +def test_slice_sparse(program_processor): program_processor, validate = program_processor inp = gtx.as_field([Vertex, V2VDim], v2v_arr) out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) @@ -248,7 +242,6 @@ def test_slice_sparse(program_processor, lift_mode): inp, out=out, offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, - lift_mode=lift_mode, ) if validate: @@ -261,7 +254,7 @@ def slice_twice_sparse_stencil(sparse): @pytest.mark.xfail(reason="Field with more than one sparse dimension is not implemented.") -def test_slice_twice_sparse(program_processor, lift_mode): +def test_slice_twice_sparse(program_processor): program_processor, validate = program_processor inp = gtx.as_field([Vertex, V2VDim, V2VDim], v2v_arr[v2v_arr]) out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) @@ -273,7 +266,6 @@ def test_slice_twice_sparse(program_processor, lift_mode): inp, out=out, offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, - lift_mode=lift_mode, ) if validate: @@ -286,7 +278,7 @@ def shift_sliced_sparse_stencil(sparse): @pytest.mark.uses_sparse_fields -def test_shift_sliced_sparse(program_processor, lift_mode): +def test_shift_sliced_sparse(program_processor): program_processor, validate = program_processor inp = gtx.as_field([Vertex, V2VDim], v2v_arr) out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) @@ -299,7 +291,6 @@ def test_shift_sliced_sparse(program_processor, lift_mode): inp, out=out, offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, - lift_mode=lift_mode, ) if validate: @@ -312,7 +303,7 @@ def slice_shifted_sparse_stencil(sparse): @pytest.mark.uses_sparse_fields -def test_slice_shifted_sparse(program_processor, lift_mode): +def test_slice_shifted_sparse(program_processor): program_processor, validate = program_processor inp = gtx.as_field([Vertex, V2VDim], v2v_arr) out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) @@ -325,7 +316,6 @@ def test_slice_shifted_sparse(program_processor, lift_mode): inp, out=out, offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, - lift_mode=lift_mode, ) if validate: @@ -342,7 +332,7 @@ def lift_stencil(inp): return deref(shift(V2V, 2)(lift(deref_stencil)(inp))) -def test_lift(program_processor, lift_mode): +def test_lift(program_processor): program_processor, validate = program_processor inp = vertex_index_field() out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) @@ -354,7 +344,6 @@ def test_lift(program_processor, lift_mode): inp, out=out, offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, - lift_mode=lift_mode, ) if validate: assert np.allclose(out.asnumpy(), ref) @@ -366,7 +355,7 @@ def sparse_shifted_stencil(inp): @pytest.mark.uses_sparse_fields -def test_shift_sparse_input_field(program_processor, lift_mode): +def test_shift_sparse_input_field(program_processor): program_processor, validate = program_processor inp = gtx.as_field([Vertex, V2VDim], v2v_arr) out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) @@ -378,7 +367,6 @@ def test_shift_sparse_input_field(program_processor, lift_mode): inp, out=out, offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, - lift_mode=lift_mode, ) if validate: @@ -396,7 +384,7 @@ def shift_sparse_stencil2(inp): @pytest.mark.uses_sparse_fields -def test_shift_sparse_input_field2(program_processor, lift_mode): +def test_shift_sparse_input_field2(program_processor): program_processor, validate = program_processor if program_processor in [ gtfn.run_gtfn, @@ -423,7 +411,6 @@ def test_shift_sparse_input_field2(program_processor, lift_mode): inp, out=out1, offset_provider=offset_provider, - lift_mode=lift_mode, ) run_processor( shift_sparse_stencil2[domain], @@ -431,7 +418,6 @@ def test_shift_sparse_input_field2(program_processor, lift_mode): inp_sparse, out=out2, offset_provider=offset_provider, - lift_mode=lift_mode, ) if validate: @@ -448,10 +434,8 @@ def sum_(a, b): @pytest.mark.uses_sparse_fields @pytest.mark.uses_reduction_with_only_sparse_fields -def test_sparse_shifted_stencil_reduce(program_processor, lift_mode): +def test_sparse_shifted_stencil_reduce(program_processor): program_processor, validate = program_processor - if lift_mode != transforms.LiftMode.FORCE_INLINE: - pytest.xfail("shifted input arguments not supported for lift_mode != LiftMode.FORCE_INLINE") inp = gtx.as_field([Vertex, V2VDim], v2v_arr) out = gtx.as_field([Vertex], np.zeros([9], dtype=inp.dtype)) @@ -472,7 +456,6 @@ def test_sparse_shifted_stencil_reduce(program_processor, lift_mode): inp, out=out, offset_provider={"V2V": gtx.NeighborTableOffsetProvider(v2v_arr, Vertex, Vertex, 4)}, - lift_mode=lift_mode, ) if validate: diff --git a/tests/next_tests/unit_tests/conftest.py b/tests/next_tests/unit_tests/conftest.py index c9406884e6..84a2d459e5 100644 --- a/tests/next_tests/unit_tests/conftest.py +++ b/tests/next_tests/unit_tests/conftest.py @@ -35,18 +35,6 @@ import next_tests -@pytest.fixture( - params=[ - transforms.LiftMode.FORCE_INLINE, - transforms.LiftMode.USE_TEMPORARIES, - transforms.LiftMode.SIMPLE_HEURISTIC, - ], - ids=lambda p: f"lift_mode={p.name}", -) -def lift_mode(request): - return request.param - - OPTIONAL_PROCESSORS = [] if dace_iterator: OPTIONAL_PROCESSORS.append((next_tests.definitions.OptionalProgramBackendId.DACE_CPU, True)) @@ -62,6 +50,7 @@ def lift_mode(request): params=[ (None, True), (next_tests.definitions.ProgramBackendId.ROUNDTRIP, True), + (next_tests.definitions.ProgramBackendId.ROUNDTRIP_WITH_TEMPORARIES, True), (next_tests.definitions.ProgramBackendId.DOUBLE_ROUNDTRIP, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU, True), (next_tests.definitions.ProgramBackendId.GTFN_CPU_IMPERATIVE, True),