From 0e0af6e7fb64e9e6150b84e1ce794df1dc53dfd2 Mon Sep 17 00:00:00 2001 From: Roman Cattaneo <1116746+romanc@users.noreply.github.com> Date: Thu, 23 Jan 2025 18:13:39 +0100 Subject: [PATCH] running new ruff version --- src/gt4py/cartesian/gtc/cuir/cuir_codegen.py | 2 +- .../gtc/dace/expansion_specification.py | 6 +- src/gt4py/cartesian/gtc/dace/symbol_utils.py | 2 +- .../cartesian/gtc/gtcpp/gtcpp_codegen.py | 2 +- .../passes/oir_optimizations/temporaries.py | 6 +- src/gt4py/cartesian/stencil_object.py | 6 +- src/gt4py/cartesian/testing/suites.py | 17 +- src/gt4py/cartesian/testing/utils.py | 6 +- src/gt4py/cartesian/utils/attrib.py | 5 +- src/gt4py/eve/datamodels/core.py | 8 +- src/gt4py/eve/utils.py | 4 +- src/gt4py/next/constructors.py | 4 +- src/gt4py/next/embedded/common.py | 5 +- .../ffront/foast_passes/type_deduction.py | 2 +- src/gt4py/next/ffront/foast_to_gtir.py | 3 +- src/gt4py/next/ffront/foast_to_past.py | 4 +- src/gt4py/next/ffront/func_to_foast.py | 6 +- src/gt4py/next/ffront/past_to_itir.py | 10 +- src/gt4py/next/iterator/builtins.py | 2 + src/gt4py/next/iterator/embedded.py | 6 +- .../iterator/transforms/collapse_tuple.py | 6 +- src/gt4py/next/iterator/transforms/cse.py | 6 +- .../next/iterator/transforms/infer_domain.py | 6 +- .../iterator/transforms/power_unrolling.py | 7 +- .../next/iterator/transforms/remap_symbols.py | 6 +- .../next/iterator/type_system/inference.py | 8 +- src/gt4py/next/otf/arguments.py | 10 +- src/gt4py/next/otf/binding/cpp_interface.py | 2 +- src/gt4py/next/otf/compilation/cache.py | 4 +- .../codegens/gtfn/gtfn_module.py | 3 +- .../runners/dace/transformations/gpu_utils.py | 6 +- .../transformations/local_double_buffering.py | 12 +- .../dace/transformations/map_fusion_serial.py | 6 +- .../runners/dace/transformations/strides.py | 12 +- .../runners/dace/workflow/decoration.py | 6 +- src/gt4py/next/type_system/type_info.py | 2 +- .../multi_feature_tests/test_dace_parsing.py | 2 +- .../unit_tests/test_caching.py | 2 +- .../test_oir_optimizations/test_caches.py | 234 +++++++++--------- .../transforms_tests/test_domain_inference.py | 2 +- .../test_distributed_buffer_relocator.py | 12 +- .../test_global_self_copy_elimination.py | 12 +- .../transformation_tests/test_gpu_utils.py | 28 ++- .../test_loop_blocking.py | 6 +- .../transformation_tests/test_map_order.py | 18 +- .../transformation_tests/test_strides.py | 44 ++-- tests/next_tests/unit_tests/test_common.py | 6 +- 47 files changed, 295 insertions(+), 279 deletions(-) diff --git a/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py b/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py index 96149a1723..4e34dc0360 100644 --- a/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py +++ b/src/gt4py/cartesian/gtc/cuir/cuir_codegen.py @@ -70,7 +70,7 @@ def maybe_const(s): decl = symtable[node.name] if isinstance(decl, cuir.Temporary) and decl.data_dims: data_index_str = "+".join( - f"{index}*{int(np.prod(decl.data_dims[i + 1:], initial=1))}" + f"{index}*{int(np.prod(decl.data_dims[i + 1 :], initial=1))}" for i, index in enumerate(data_index) ) return f"{name}({offset})[{data_index_str}]" diff --git a/src/gt4py/cartesian/gtc/dace/expansion_specification.py b/src/gt4py/cartesian/gtc/dace/expansion_specification.py index af9a814843..c3c107dcb6 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion_specification.py +++ b/src/gt4py/cartesian/gtc/dace/expansion_specification.py @@ -254,9 +254,9 @@ def _populate_gpu_schedules(expansion_specification: List[ExpansionItem]): tiled = True break if not tiled: - assert any( - isinstance(item, Map) for item in expansion_specification - ), "needs at least one map to avoid dereferencing on CPU" + assert any(isinstance(item, Map) for item in expansion_specification), ( + "needs at least one map to avoid dereferencing on CPU" + ) for es in expansion_specification: if isinstance(es, Map): if es.schedule is None: diff --git a/src/gt4py/cartesian/gtc/dace/symbol_utils.py b/src/gt4py/cartesian/gtc/dace/symbol_utils.py index b9b6a49ce0..c2144d5837 100644 --- a/src/gt4py/cartesian/gtc/dace/symbol_utils.py +++ b/src/gt4py/cartesian/gtc/dace/symbol_utils.py @@ -61,7 +61,7 @@ def get_axis_bound_diff_str(axis_bound1, axis_bound2, var_name: str): var = var_name else: var = "" - return f"{sign}({var}{axis_bound1.offset-axis_bound2.offset:+d})" + return f"{sign}({var}{axis_bound1.offset - axis_bound2.offset:+d})" @lru_cache(maxsize=None) diff --git a/src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py b/src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py index 3105f4a8cb..1571b31d45 100644 --- a/src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py +++ b/src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py @@ -100,7 +100,7 @@ def visit_AccessorRef( temp = temp_decls[accessor_ref.name] data_index = "+".join( [ - f"{self.visit(index, in_data_index=True, **kwargs)}*{int(np.prod(temp.data_dims[i+1:], initial=1))}" + f"{self.visit(index, in_data_index=True, **kwargs)}*{int(np.prod(temp.data_dims[i + 1 :], initial=1))}" for i, index in enumerate(accessor_ref.data_index) ] ) diff --git a/src/gt4py/cartesian/gtc/passes/oir_optimizations/temporaries.py b/src/gt4py/cartesian/gtc/passes/oir_optimizations/temporaries.py index 2498dd0278..c07b2544a7 100644 --- a/src/gt4py/cartesian/gtc/passes/oir_optimizations/temporaries.py +++ b/src/gt4py/cartesian/gtc/passes/oir_optimizations/temporaries.py @@ -21,9 +21,9 @@ def visit_FieldAccess( ) -> Union[oir.FieldAccess, oir.ScalarAccess]: offsets = node.offset.to_dict() if node.name in tmps_name_map: - assert ( - offsets["i"] == offsets["j"] == offsets["k"] == 0 - ), "Non-zero offset in temporary that is replaced?!" + assert offsets["i"] == offsets["j"] == offsets["k"] == 0, ( + "Non-zero offset in temporary that is replaced?!" + ) return oir.ScalarAccess(name=tmps_name_map[node.name], dtype=node.dtype) return self.generic_visit(node, tmps_name_map=tmps_name_map, **kwargs) diff --git a/src/gt4py/cartesian/stencil_object.py b/src/gt4py/cartesian/stencil_object.py index b76415e17f..becaadba2c 100644 --- a/src/gt4py/cartesian/stencil_object.py +++ b/src/gt4py/cartesian/stencil_object.py @@ -502,9 +502,9 @@ def _normalize_origins( if field_origin is not None: field_origin_ndim = len(field_origin) if field_origin_ndim != field_info.ndim: - assert ( - field_origin_ndim == field_info.domain_ndim - ), f"Invalid origin specification ({field_origin}) for '{name}' field." + assert field_origin_ndim == field_info.domain_ndim, ( + f"Invalid origin specification ({field_origin}) for '{name}' field." + ) origin[name] = (*field_origin, *((0,) * len(field_info.data_dims))) elif all_origin is not None: diff --git a/src/gt4py/cartesian/testing/suites.py b/src/gt4py/cartesian/testing/suites.py index f680a1dbef..e0202df66d 100644 --- a/src/gt4py/cartesian/testing/suites.py +++ b/src/gt4py/cartesian/testing/suites.py @@ -268,14 +268,14 @@ def _validate_new_args(cls, cls_name, cls_dict): assert isinstance(cls_dict["symbols"], collections.abc.Mapping), "Invalid 'symbols' mapping" # Check domain and ndims - assert 1 <= len(domain_range) <= 3 and all( - len(d) == 2 for d in domain_range - ), "Invalid 'domain_range' definition" + assert 1 <= len(domain_range) <= 3 and all(len(d) == 2 for d in domain_range), ( + "Invalid 'domain_range' definition" + ) if any(cls_name.endswith(suffix) for suffix in ("1D", "2D", "3D")): - assert cls_dict["ndims"] == int( - cls_name[-2:-1] - ), "Suite name does not match the actual 'ndims'" + assert cls_dict["ndims"] == int(cls_name[-2:-1]), ( + "Suite name does not match the actual 'ndims'" + ) # Check dtypes assert isinstance( @@ -386,7 +386,10 @@ class StencilTestSuite(metaclass=SuiteMeta): .. code-block:: python - {"float_symbols": (np.float32, np.float64), "int_symbols": (int, np.int_, np.int64)} + { + "float_symbols": (np.float32, np.float64), + "int_symbols": (int, np.int_, np.int64), + } domain_range : `Sequence` of pairs like `((int, int), (int, int) ... )` Required class attribute. diff --git a/src/gt4py/cartesian/testing/utils.py b/src/gt4py/cartesian/testing/utils.py index ad8d82eebd..c41301c464 100644 --- a/src/gt4py/cartesian/testing/utils.py +++ b/src/gt4py/cartesian/testing/utils.py @@ -38,9 +38,9 @@ def standardize_dtype_dict(dtypes): dtypes as 1-tuples) """ assert isinstance(dtypes, collections.abc.Mapping) - assert all( - (isinstance(k, str) or gt_utils.is_iterable_of(k, str)) for k in dtypes.keys() - ), "Invalid key in 'dtypes'." + assert all((isinstance(k, str) or gt_utils.is_iterable_of(k, str)) for k in dtypes.keys()), ( + "Invalid key in 'dtypes'." + ) assert all( ( isinstance(k, (type, np.dtype)) diff --git a/src/gt4py/cartesian/utils/attrib.py b/src/gt4py/cartesian/utils/attrib.py index 2c1b0f3b87..cbcec19de6 100644 --- a/src/gt4py/cartesian/utils/attrib.py +++ b/src/gt4py/cartesian/utils/attrib.py @@ -263,8 +263,9 @@ def _make_attrs_class_wrapper(cls): for name, member in extra_members.items(): if name in cls.__dict__.keys(): raise ValueError( - "Name clashing with a existing '{name}' member" - " of the decorated class ".format(name=name) + "Name clashing with a existing '{name}' member of the decorated class ".format( + name=name + ) ) setattr(cls, name, member) diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index 31e63bdf9f..3c4ccfc587 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -212,7 +212,7 @@ def _field_type_validator_factory(type_annotation: TypeAnnotation, name: str) -> else: simple_validator = factory(type_annotation, name, required=True) return ValidatorAdapter( - simple_validator, f"{getattr(simple_validator,'__name__', 'TypeValidator')}" + simple_validator, f"{getattr(simple_validator, '__name__', 'TypeValidator')}" ) return _field_type_validator_factory @@ -915,9 +915,9 @@ def __attrs_post_init__(self: DataModel) -> None: return __attrs_post_init__ -def _make_devtools_pretty() -> ( - Callable[[DataModel, Callable[[Any], Any]], Generator[Any, None, None]] -): +def _make_devtools_pretty() -> Callable[ + [DataModel, Callable[[Any], Any]], Generator[Any, None, None] +]: def __pretty__( self: DataModel, fmt: Callable[[Any], Any], **kwargs: Any ) -> Generator[Any, None, None]: diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 96e41a7bd8..5c41cb99ba 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -374,7 +374,9 @@ def register_subclasses(*subclasses: Type) -> Callable[[Type], Type]: >>> @register_subclasses(MyVirtualSubclassA, MyVirtualSubclassB) ... class MyBaseClass(abc.ABC): ... pass - >>> issubclass(MyVirtualSubclassA, MyBaseClass) and issubclass(MyVirtualSubclassB, MyBaseClass) + >>> issubclass(MyVirtualSubclassA, MyBaseClass) and issubclass( + ... MyVirtualSubclassB, MyBaseClass + ... ) True """ diff --git a/src/gt4py/next/constructors.py b/src/gt4py/next/constructors.py index 7b39511674..668f675f0c 100644 --- a/src/gt4py/next/constructors.py +++ b/src/gt4py/next/constructors.py @@ -70,7 +70,9 @@ def empty( >>> from gt4py._core import definitions as core_defs >>> JDim = gtx.Dimension("J") - >>> b = gtx.empty({IDim: 3, JDim: 3}, int, device=core_defs.Device(core_defs.DeviceType.CPU, 0)) + >>> b = gtx.empty( + ... {IDim: 3, JDim: 3}, int, device=core_defs.Device(core_defs.DeviceType.CPU, 0) + ... ) >>> b.shape (3, 3) """ diff --git a/src/gt4py/next/embedded/common.py b/src/gt4py/next/embedded/common.py index 4c0f9ce6a6..2510aee8b4 100644 --- a/src/gt4py/next/embedded/common.py +++ b/src/gt4py/next/embedded/common.py @@ -121,7 +121,10 @@ def restrict_to_intersection( ... common.domain({I: (1, 3), J: (0, 3)}), ... ignore_dims=J, ... ) - >>> assert res == (common.domain({I: (1, 3), J: (1, 2)}), common.domain({I: (1, 3), J: (0, 3)})) + >>> assert res == ( + ... common.domain({I: (1, 3), J: (1, 2)}), + ... common.domain({I: (1, 3), J: (0, 3)}), + ... ) """ ignore_dims_tuple = ignore_dims if isinstance(ignore_dims, tuple) else (ignore_dims,) intersection_without_ignore_dims = domain_intersection( diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 26bcadaef1..a391bbf934 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -428,7 +428,7 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs: Any) -> foast.IfStmt: if not isinstance(new_node.condition.type, ts.ScalarType): raise errors.DSLError( node.location, - "Condition for 'if' must be scalar, " f"got '{new_node.condition.type}' instead.", + f"Condition for 'if' must be scalar, got '{new_node.condition.type}' instead.", ) if new_node.condition.type.kind != ts.ScalarKind.BOOL: diff --git a/src/gt4py/next/ffront/foast_to_gtir.py b/src/gt4py/next/ffront/foast_to_gtir.py index 4519b4e571..4e33b11666 100644 --- a/src/gt4py/next/ffront/foast_to_gtir.py +++ b/src/gt4py/next/ffront/foast_to_gtir.py @@ -275,8 +275,7 @@ def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr: )(current_expr) # `field(Dim + idx)` case foast.BinOp( - op=dialect_ast_enums.BinaryOperator.ADD - | dialect_ast_enums.BinaryOperator.SUB, + op=dialect_ast_enums.BinaryOperator.ADD | dialect_ast_enums.BinaryOperator.SUB, left=foast.Name(id=dimension), # TODO(tehrengruber): use type of lhs right=foast.Constant(value=offset_index), ): diff --git a/src/gt4py/next/ffront/foast_to_past.py b/src/gt4py/next/ffront/foast_to_past.py index 330bc79809..6244621dcb 100644 --- a/src/gt4py/next/ffront/foast_to_past.py +++ b/src/gt4py/next/ffront/foast_to_past.py @@ -77,7 +77,9 @@ class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]): ... column_axis=None, ... ) - >>> copy_program = op_to_prog(toolchain.CompilableProgram(copy.foast_stage, compile_time_args)) + >>> copy_program = op_to_prog( + ... toolchain.CompilableProgram(copy.foast_stage, compile_time_args) + ... ) >>> print(copy_program.data.past_node.id) __field_operator_copy diff --git a/src/gt4py/next/ffront/func_to_foast.py b/src/gt4py/next/ffront/func_to_foast.py index ebe12d3a8b..57852d9079 100644 --- a/src/gt4py/next/ffront/func_to_foast.py +++ b/src/gt4py/next/ffront/func_to_foast.py @@ -317,9 +317,9 @@ def visit_AnnAssign(self, node: ast.AnnAssign, **kwargs: Any) -> foast.Assign: raise errors.DSLError(self.get_location(node), "Can only assign to names.") if node.annotation is not None: - assert isinstance( - node.annotation, ast.Constant - ), "Annotations should be ast.Constant(string). Use StringifyAnnotationsPass" + assert isinstance(node.annotation, ast.Constant), ( + "Annotations should be ast.Constant(string). Use StringifyAnnotationsPass" + ) context = {**fbuiltins.BUILTINS, **self.closure_vars} annotation = eval(node.annotation.value, context) target_type = type_translation.from_type_hint(annotation, globalns=context) diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 4bc1dfb2f8..53b1c49f86 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -46,7 +46,9 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram: ... return a >>> @gtx.program - ... def copy_program(a: gtx.Field[[IDim], gtx.float32], out: gtx.Field[[IDim], gtx.float32]): + ... def copy_program( + ... a: gtx.Field[[IDim], gtx.float32], out: gtx.Field[[IDim], gtx.float32] + ... ): ... copy(a, out=out) >>> compile_time_args = arguments.CompileTimeArgs( @@ -460,9 +462,9 @@ def _visit_stencil_call_out_arg( field_slice = None if isinstance(first_field, past.Subscript): - assert all( - isinstance(field, past.Subscript) for field in flattened - ), "Incompatible field in tuple: either all fields or no field must be sliced." + assert all(isinstance(field, past.Subscript) for field in flattened), ( + "Incompatible field in tuple: either all fields or no field must be sliced." + ) assert all( concepts.eq_nonlocated( first_field.slice_, diff --git a/src/gt4py/next/iterator/builtins.py b/src/gt4py/next/iterator/builtins.py index 959f451e01..e0802c3687 100644 --- a/src/gt4py/next/iterator/builtins.py +++ b/src/gt4py/next/iterator/builtins.py @@ -6,6 +6,8 @@ # Please, refer to the LICENSE file in the root directory. # SPDX-License-Identifier: BSD-3-Clause +# ruff: noqa: A005 Module `builtins` shadows a Python standard-library module + from gt4py.next.iterator.dispatcher import Dispatcher diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 970e88e8c5..33562ed5a4 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -642,9 +642,9 @@ def _is_list_of_complete_offsets( def group_offsets(*offsets: OffsetPart) -> list[CompleteOffset]: assert len(offsets) % 2 == 0 complete_offsets = [*zip(offsets[::2], offsets[1::2])] - assert _is_list_of_complete_offsets( - complete_offsets - ), f"Invalid sequence of offset parts: {offsets}" + assert _is_list_of_complete_offsets(complete_offsets), ( + f"Invalid sequence of offset parts: {offsets}" + ) return complete_offsets diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 0a0cf6d37e..9f287050ce 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -284,9 +284,9 @@ def transform_collapse_tuple_get_make_tuple( assert type_info.is_integer(node.args[0].type) make_tuple_call = node.args[1] idx = int(node.args[0].value) - assert idx < len( - make_tuple_call.args - ), f"Index {idx} is out of bounds for tuple of size {len(make_tuple_call.args)}" + assert idx < len(make_tuple_call.args), ( + f"Index {idx} is out of bounds for tuple of size {len(make_tuple_call.args)}" + ) return node.args[1].args[idx] return None diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index ccaaf563f5..e103171069 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -418,9 +418,9 @@ def apply( assert within_stencil is None within_stencil = False else: - assert ( - within_stencil is not None - ), "The expression's context must be specified using `within_stencil`." + assert within_stencil is not None, ( + "The expression's context must be specified using `within_stencil`." + ) offset_provider_type = offset_provider_type or {} node = itir_type_inference.infer( diff --git a/src/gt4py/next/iterator/transforms/infer_domain.py b/src/gt4py/next/iterator/transforms/infer_domain.py index f3c3185225..f2c612526e 100644 --- a/src/gt4py/next/iterator/transforms/infer_domain.py +++ b/src/gt4py/next/iterator/transforms/infer_domain.py @@ -466,9 +466,9 @@ def infer_program( See :func:`infer_expr` for more details. """ - assert ( - not program.function_definitions - ), "Domain propagation does not support function definitions." + assert not program.function_definitions, ( + "Domain propagation does not support function definitions." + ) return itir.Program( id=program.id, diff --git a/src/gt4py/next/iterator/transforms/power_unrolling.py b/src/gt4py/next/iterator/transforms/power_unrolling.py index f4ef1521ac..1bc1844127 100644 --- a/src/gt4py/next/iterator/transforms/power_unrolling.py +++ b/src/gt4py/next/iterator/transforms/power_unrolling.py @@ -53,18 +53,19 @@ def visit_FunCall(self, node: ir.FunCall): remainder = exponent # Build target expression - ret = im.ref(f"power_{2 ** pow_max}") + ret = im.ref(f"power_{2**pow_max}") remainder -= 2**pow_cur while remainder > 0: pow_cur = _compute_integer_power_of_two(remainder) remainder -= 2**pow_cur - ret = im.multiplies_(ret, f"power_{2 ** pow_cur}") + ret = im.multiplies_(ret, f"power_{2**pow_cur}") # Nest target expression to avoid multiple redundant evaluations for i in range(pow_max, 0, -1): ret = im.let( - f"power_{2 ** i}", im.multiplies_(f"power_{2**(i-1)}", f"power_{2**(i-1)}") + f"power_{2**i}", + im.multiplies_(f"power_{2 ** (i - 1)}", f"power_{2 ** (i - 1)}"), )(ret) ret = im.let("power_1", base)(ret) diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index fb909dc5d0..5495f63ae1 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -26,9 +26,9 @@ def visit_Lambda(self, node: ir.Lambda, *, symbol_map: Dict[str, ir.Node]): return ir.Lambda(params=node.params, expr=self.visit(node.expr, symbol_map=new_symbol_map)) def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override] - assert isinstance(node, SymbolTableTrait) == isinstance( - node, ir.Lambda - ), "found unexpected new symbol scope" + assert isinstance(node, SymbolTableTrait) == isinstance(node, ir.Lambda), ( + "found unexpected new symbol scope" + ) return super().generic_visit(node, **kwargs) diff --git a/src/gt4py/next/iterator/type_system/inference.py b/src/gt4py/next/iterator/type_system/inference.py index 901cb103da..a28d7b349b 100644 --- a/src/gt4py/next/iterator/type_system/inference.py +++ b/src/gt4py/next/iterator/type_system/inference.py @@ -230,9 +230,9 @@ def __call__( *args: type_synthesizer.TypeOrTypeSynthesizer, offset_provider_type: common.OffsetProviderType, ) -> Union[ts.TypeSpec, ObservableTypeSynthesizer]: - assert all( - isinstance(arg, (ts.TypeSpec, ObservableTypeSynthesizer)) for arg in args - ), "ObservableTypeSynthesizer can only be used with arguments that are TypeSpec or ObservableTypeSynthesizer" + assert all(isinstance(arg, (ts.TypeSpec, ObservableTypeSynthesizer)) for arg in args), ( + "ObservableTypeSynthesizer can only be used with arguments that are TypeSpec or ObservableTypeSynthesizer" + ) return_type_or_synthesizer = self.type_synthesizer( *args, offset_provider_type=offset_provider_type @@ -644,7 +644,7 @@ def visit_FunCall( return result def visit_Node(self, node: itir.Node, **kwargs): - raise NotImplementedError(f"No type rule for nodes of type " f"'{type(node).__name__}'.") + raise NotImplementedError(f"No type rule for nodes of type '{type(node).__name__}'.") infer = ITIRTypeInference.apply diff --git a/src/gt4py/next/otf/arguments.py b/src/gt4py/next/otf/arguments.py index c4235eaa9a..b6ec05235f 100644 --- a/src/gt4py/next/otf/arguments.py +++ b/src/gt4py/next/otf/arguments.py @@ -98,12 +98,10 @@ def jit_to_aot_args( return CompileTimeArgs.from_concrete_no_size(*inp.args, **inp.kwargs) -def adapted_jit_to_aot_args_factory() -> ( - workflow.Workflow[ - toolchain.CompilableProgram[DATA_T, JITArgs], - toolchain.CompilableProgram[DATA_T, CompileTimeArgs], - ] -): +def adapted_jit_to_aot_args_factory() -> workflow.Workflow[ + toolchain.CompilableProgram[DATA_T, JITArgs], + toolchain.CompilableProgram[DATA_T, CompileTimeArgs], +]: """Wrap `jit_to_aot` into a workflow adapter to fit into backend transform workflows.""" return toolchain.ArgsOnlyAdapter(jit_to_aot_args) diff --git a/src/gt4py/next/otf/binding/cpp_interface.py b/src/gt4py/next/otf/binding/cpp_interface.py index 17eee4d5c6..b9058350a3 100644 --- a/src/gt4py/next/otf/binding/cpp_interface.py +++ b/src/gt4py/next/otf/binding/cpp_interface.py @@ -36,7 +36,7 @@ def render_function_declaration(function: interface.Function, body: str) -> str: }}""" if template_params: return f""" - template <{', '.join(template_params)}> + template <{", ".join(template_params)}> {rendered_decl} """.strip() return rendered_decl diff --git a/src/gt4py/next/otf/compilation/cache.py b/src/gt4py/next/otf/compilation/cache.py index 8907cd81c0..03f73190f4 100644 --- a/src/gt4py/next/otf/compilation/cache.py +++ b/src/gt4py/next/otf/compilation/cache.py @@ -36,8 +36,8 @@ def _serialize_source(source: stages.ProgramSource) -> str: return f"""\ language: {source.language} name: {source.entry_point.name} - params: {', '.join(parameters)} - deps: {', '.join(dependencies)} + params: {", ".join(parameters)} + deps: {", ".join(dependencies)} src: {source.source_code} implicit_domain: {source.implicit_domain} """ diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 48f15acffb..c5fcfe4cea 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -310,8 +310,7 @@ def _library_name(self) -> str: def _not_implemented_for_device_type(self) -> NotImplementedError: return NotImplementedError( - f"{self.__class__.__name__} is not implemented for " - f"device type {self.device_type.name}" + f"{self.__class__.__name__} is not implemented for device type {self.device_type.name}" ) diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py index 8bae56cd88..a2d7692e95 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/gpu_utils.py @@ -70,9 +70,9 @@ def gt_gpu_transformation( - Solve the fusing problem. - Currently only one block size for all maps is given, add more options. """ - assert ( - len(kwargs) == 0 - ), f"gt_gpu_transformation(): found unknown arguments: {', '.join(arg for arg in kwargs.keys())}" + assert len(kwargs) == 0, ( + f"gt_gpu_transformation(): found unknown arguments: {', '.join(arg for arg in kwargs.keys())}" + ) # Turn all global arrays (which we identify as input) into GPU memory. # This way the GPU transformation will not create this copying stuff. diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/local_double_buffering.py b/src/gt4py/next/program_processors/runners/dace/transformations/local_double_buffering.py index 02ecbe28e6..5be3169254 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/local_double_buffering.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/local_double_buffering.py @@ -342,12 +342,12 @@ def _check_if_map_must_be_handled( input_node, output_node = inout_datas[inout_data_name] input_edges = state.edges_between(input_node, map_entry) output_edges = state.edges_between(map_exit, output_node) - assert ( - len(input_edges) == 1 - ), f"Expected a single connection between input node and map entry, but found {len(input_edges)}." - assert ( - len(output_edges) == 1 - ), f"Expected a single connection between map exit and write back node, but found {len(output_edges)}." + assert len(input_edges) == 1, ( + f"Expected a single connection between input node and map entry, but found {len(input_edges)}." + ) + assert len(output_edges) == 1, ( + f"Expected a single connection between map exit and write back node, but found {len(output_edges)}." + ) # If there is only one edge on the inside of the map, that goes into an # AccessNode, then we assume it is double buffered. diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py index 2cdcc455d4..58dbd6e756 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/map_fusion_serial.py @@ -593,9 +593,9 @@ def partition_first_outputs( # TODO(phimuell): Is this restriction necessary, I am not sure. return None consumer_subsets.append(inner_consumer_edge.data.src_subset) - assert ( - found_second_map - ), f"Found '{intermediate_node}' which looked like a pure node, but is not one." + assert found_second_map, ( + f"Found '{intermediate_node}' which looked like a pure node, but is not one." + ) assert len(consumer_subsets) != 0 # The consumer still uses the original symbols of the second map, so we must rename them. diff --git a/src/gt4py/next/program_processors/runners/dace/transformations/strides.py b/src/gt4py/next/program_processors/runners/dace/transformations/strides.py index 9af76e5b57..6f8ecebd9c 100644 --- a/src/gt4py/next/program_processors/runners/dace/transformations/strides.py +++ b/src/gt4py/next/program_processors/runners/dace/transformations/strides.py @@ -639,9 +639,9 @@ def _gt_find_toplevel_data_accesses( # We also check if it was ever found on the top level, this should # not happen, as everything should go through Maps. But some strange # DaCe transformation might do it. - assert ( - data not in top_level_data - ), f"Found {data} on the top level and inside a scope." + assert data not in top_level_data, ( + f"Found {data} on the top level and inside a scope." + ) not_top_level_data.add(data) continue @@ -656,9 +656,9 @@ def _gt_find_toplevel_data_accesses( continue # We have found a new data node that is on the top node and is unknown. - assert ( - data not in not_top_level_data - ), f"Found {data} on the top level and inside a scope." + assert data not in not_top_level_data, ( + f"Found {data} on the top level and inside a scope." + ) desc: dace_data.Data = dnode.desc(sdfg) # Check if we only accept arrays diff --git a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py index 2ee99f5fa4..70197186a0 100644 --- a/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py +++ b/src/gt4py/next/program_processors/runners/dace/workflow/decoration.py @@ -73,9 +73,9 @@ def decorated_program( last_call_args[i] = actype(kwargs[arg_name]) else: # shape and strides of arrays are supposed not to change, and can therefore be omitted - assert gtx_dace_utils.is_field_symbol( - arg_name - ), f"argument '{arg_name}' not found." + assert gtx_dace_utils.is_field_symbol(arg_name), ( + f"argument '{arg_name}' not found." + ) if use_fast_call: return inp.fast_call() diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 26373c647f..7ba79ee817 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -534,7 +534,7 @@ def return_type( with_kwargs: dict[str, ts.TypeSpec], ) -> ts.TypeSpec: raise NotImplementedError( - f"Return type deduction of type " f"'{type(callable_type).__name__}' not implemented." + f"Return type deduction of type '{type(callable_type).__name__}' not implemented." ) diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py index 9fafc27c85..3d91a0a205 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/test_dace_parsing.py @@ -425,7 +425,7 @@ def call_stencil(): with pytest.raises( TypeError, match=re.escape( - "Only dace backends are supported in DaCe-orchestrated programs." ' (found "numpy")' + 'Only dace backends are supported in DaCe-orchestrated programs. (found "numpy")' ), ): call_stencil() diff --git a/tests/cartesian_tests/unit_tests/test_caching.py b/tests/cartesian_tests/unit_tests/test_caching.py index ebabdbe9c6..86339d6c60 100644 --- a/tests/cartesian_tests/unit_tests/test_caching.py +++ b/tests/cartesian_tests/unit_tests/test_caching.py @@ -100,7 +100,7 @@ def test_jit_version(builder): withdoc.backend.generate() assert could_load_stencil_from_cache(withdoc) - original.definition.__doc__ = "Added docstring." "" + original.definition.__doc__ = "Added docstring." assert not could_load_stencil_from_cache(original, catch_exceptions=True) assert not could_load_stencil_from_cache(duplicate, catch_exceptions=True) # fingerprint has changed and with it the file paths, new cache_info file does not exist. diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_caches.py b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_caches.py index 0ce9812660..374ce666ab 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_caches.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_passes/test_oir_optimizations/test_caches.py @@ -309,41 +309,41 @@ def test_fill_to_local_k_caches_basic_forward(): assert len(vertical_loop.sections) == 2, "number of vertical sections has changed" - assert ( - len(vertical_loop.sections[0].horizontal_executions[0].body) == 2 - ), "no or too many fill stmts introduced?" - assert ( - vertical_loop.sections[0].horizontal_executions[0].body[0].left.name == cache_name - ), "wrong fill destination" - assert ( - vertical_loop.sections[0].horizontal_executions[0].body[0].right.name == "foo" - ), "wrong fill source" + assert len(vertical_loop.sections[0].horizontal_executions[0].body) == 2, ( + "no or too many fill stmts introduced?" + ) + assert vertical_loop.sections[0].horizontal_executions[0].body[0].left.name == cache_name, ( + "wrong fill destination" + ) + assert vertical_loop.sections[0].horizontal_executions[0].body[0].right.name == "foo", ( + "wrong fill source" + ) assert ( vertical_loop.sections[0].horizontal_executions[0].body[0].left.offset.k == vertical_loop.sections[0].horizontal_executions[0].body[0].right.offset.k == 1 ), "wrong fill offset" - assert ( - vertical_loop.sections[0].horizontal_executions[0].body[1].left.name == cache_name - ), "wrong field name in cache access" - assert ( - vertical_loop.sections[0].horizontal_executions[0].body[1].right.name == cache_name - ), "wrong field name in cache access" - assert ( - vertical_loop.sections[0].horizontal_executions[0].body[1].right.offset.k == 1 - ), "wrong offset in cache access" - assert ( - len(vertical_loop.sections[1].horizontal_executions[0].body) == 1 - ), "too many fill stmts introduced?" - assert ( - vertical_loop.sections[1].horizontal_executions[0].body[0].left.name == cache_name - ), "wrong field name in cache access" - assert ( - vertical_loop.sections[1].horizontal_executions[0].body[0].right.name == cache_name - ), "wrong field name in cache access" - assert ( - vertical_loop.sections[1].horizontal_executions[0].body[0].right.offset.k == 0 - ), "wrong offset in cache access" + assert vertical_loop.sections[0].horizontal_executions[0].body[1].left.name == cache_name, ( + "wrong field name in cache access" + ) + assert vertical_loop.sections[0].horizontal_executions[0].body[1].right.name == cache_name, ( + "wrong field name in cache access" + ) + assert vertical_loop.sections[0].horizontal_executions[0].body[1].right.offset.k == 1, ( + "wrong offset in cache access" + ) + assert len(vertical_loop.sections[1].horizontal_executions[0].body) == 1, ( + "too many fill stmts introduced?" + ) + assert vertical_loop.sections[1].horizontal_executions[0].body[0].left.name == cache_name, ( + "wrong field name in cache access" + ) + assert vertical_loop.sections[1].horizontal_executions[0].body[0].right.name == cache_name, ( + "wrong field name in cache access" + ) + assert vertical_loop.sections[1].horizontal_executions[0].body[0].right.offset.k == 0, ( + "wrong offset in cache access" + ) def test_fill_to_local_k_caches_basic_backward(): @@ -384,41 +384,41 @@ def test_fill_to_local_k_caches_basic_backward(): assert len(vertical_loop.sections) == 2, "number of vertical sections has changed" - assert ( - len(vertical_loop.sections[0].horizontal_executions[0].body) == 2 - ), "no or too many fill stmts introduced?" - assert ( - vertical_loop.sections[0].horizontal_executions[0].body[0].left.name == cache_name - ), "wrong fill destination" - assert ( - vertical_loop.sections[0].horizontal_executions[0].body[0].right.name == "foo" - ), "wrong fill source" + assert len(vertical_loop.sections[0].horizontal_executions[0].body) == 2, ( + "no or too many fill stmts introduced?" + ) + assert vertical_loop.sections[0].horizontal_executions[0].body[0].left.name == cache_name, ( + "wrong fill destination" + ) + assert vertical_loop.sections[0].horizontal_executions[0].body[0].right.name == "foo", ( + "wrong fill source" + ) assert ( vertical_loop.sections[0].horizontal_executions[0].body[0].left.offset.k == vertical_loop.sections[0].horizontal_executions[0].body[0].right.offset.k == -1 ), "wrong fill offset" - assert ( - vertical_loop.sections[0].horizontal_executions[0].body[1].left.name == cache_name - ), "wrong field name in cache access" - assert ( - vertical_loop.sections[0].horizontal_executions[0].body[1].right.name == cache_name - ), "wrong field name in cache access" - assert ( - vertical_loop.sections[0].horizontal_executions[0].body[1].right.offset.k == -1 - ), "wrong offset in cache access" - assert ( - len(vertical_loop.sections[1].horizontal_executions[0].body) == 1 - ), "too many fill stmts introduced?" - assert ( - vertical_loop.sections[1].horizontal_executions[0].body[0].left.name == cache_name - ), "wrong field name in cache access" - assert ( - vertical_loop.sections[1].horizontal_executions[0].body[0].right.name == cache_name - ), "wrong field name in cache access" - assert ( - vertical_loop.sections[1].horizontal_executions[0].body[0].right.offset.k == 0 - ), "wrong offset in cache access" + assert vertical_loop.sections[0].horizontal_executions[0].body[1].left.name == cache_name, ( + "wrong field name in cache access" + ) + assert vertical_loop.sections[0].horizontal_executions[0].body[1].right.name == cache_name, ( + "wrong field name in cache access" + ) + assert vertical_loop.sections[0].horizontal_executions[0].body[1].right.offset.k == -1, ( + "wrong offset in cache access" + ) + assert len(vertical_loop.sections[1].horizontal_executions[0].body) == 1, ( + "too many fill stmts introduced?" + ) + assert vertical_loop.sections[1].horizontal_executions[0].body[0].left.name == cache_name, ( + "wrong field name in cache access" + ) + assert vertical_loop.sections[1].horizontal_executions[0].body[0].right.name == cache_name, ( + "wrong field name in cache access" + ) + assert vertical_loop.sections[1].horizontal_executions[0].body[0].right.offset.k == 0, ( + "wrong offset in cache access" + ) def test_fill_to_local_k_caches_section_splitting_forward(): @@ -477,15 +477,15 @@ def test_fill_to_local_k_caches_section_splitting_forward(): == -1 and vertical_loop.sections[2].interval.end.offset == 0 ), "wrong interval offsets in split sections" - assert ( - len(vertical_loop.sections[0].horizontal_executions[0].body) == 4 - ), "wrong number of fill stmts" - assert ( - len(vertical_loop.sections[1].horizontal_executions[0].body) == 3 - ), "wrong number of fill stmts" - assert ( - len(vertical_loop.sections[2].horizontal_executions[0].body) == 1 - ), "wrong number of fill stmts" + assert len(vertical_loop.sections[0].horizontal_executions[0].body) == 4, ( + "wrong number of fill stmts" + ) + assert len(vertical_loop.sections[1].horizontal_executions[0].body) == 3, ( + "wrong number of fill stmts" + ) + assert len(vertical_loop.sections[2].horizontal_executions[0].body) == 1, ( + "wrong number of fill stmts" + ) def test_fill_to_local_k_caches_section_splitting_backward(): @@ -541,15 +541,15 @@ def test_fill_to_local_k_caches_section_splitting_backward(): == 1 and vertical_loop.sections[2].interval.start.offset == 0 ), "wrong interval offsets in split sections" - assert ( - len(vertical_loop.sections[0].horizontal_executions[0].body) == 4 - ), "wrong number of fill stmts" - assert ( - len(vertical_loop.sections[1].horizontal_executions[0].body) == 3 - ), "wrong number of fill stmts" - assert ( - len(vertical_loop.sections[2].horizontal_executions[0].body) == 1 - ), "wrong number of fill stmts" + assert len(vertical_loop.sections[0].horizontal_executions[0].body) == 4, ( + "wrong number of fill stmts" + ) + assert len(vertical_loop.sections[1].horizontal_executions[0].body) == 3, ( + "wrong number of fill stmts" + ) + assert len(vertical_loop.sections[2].horizontal_executions[0].body) == 1, ( + "wrong number of fill stmts" + ) def test_flush_to_local_k_caches_basic(): @@ -590,50 +590,50 @@ def test_flush_to_local_k_caches_basic(): assert len(vertical_loop.sections) == 2, "number of vertical sections has changed" - assert ( - len(vertical_loop.sections[0].horizontal_executions[0].body) == 2 - ), "no or too many flush stmts introduced?" - assert ( - vertical_loop.sections[0].horizontal_executions[0].body[0].left.name == cache_name - ), "wrong field name in cache access" - assert ( - vertical_loop.sections[0].horizontal_executions[0].body[0].right.name == cache_name - ), "wrong field name in cache access" - assert ( - vertical_loop.sections[0].horizontal_executions[0].body[0].right.offset.k == 0 - ), "wrong offset in cache access" - assert ( - vertical_loop.sections[0].horizontal_executions[0].body[1].left.name == "foo" - ), "wrong flush source" - assert ( - vertical_loop.sections[0].horizontal_executions[0].body[1].right.name == cache_name - ), "wrong flush destination" + assert len(vertical_loop.sections[0].horizontal_executions[0].body) == 2, ( + "no or too many flush stmts introduced?" + ) + assert vertical_loop.sections[0].horizontal_executions[0].body[0].left.name == cache_name, ( + "wrong field name in cache access" + ) + assert vertical_loop.sections[0].horizontal_executions[0].body[0].right.name == cache_name, ( + "wrong field name in cache access" + ) + assert vertical_loop.sections[0].horizontal_executions[0].body[0].right.offset.k == 0, ( + "wrong offset in cache access" + ) + assert vertical_loop.sections[0].horizontal_executions[0].body[1].left.name == "foo", ( + "wrong flush source" + ) + assert vertical_loop.sections[0].horizontal_executions[0].body[1].right.name == cache_name, ( + "wrong flush destination" + ) assert ( vertical_loop.sections[0].horizontal_executions[0].body[1].left.offset.k == vertical_loop.sections[0].horizontal_executions[0].body[1].right.offset.k == 0 ), "wrong flush offset" - assert ( - len(vertical_loop.sections[1].horizontal_executions[0].body) == 2 - ), "no or too many flush stmts introduced?" - assert ( - vertical_loop.sections[1].horizontal_executions[0].body[0].left.name == cache_name - ), "wrong field name in cache access" - assert ( - vertical_loop.sections[1].horizontal_executions[0].body[0].right.name == cache_name - ), "wrong field name in cache access" - assert ( - vertical_loop.sections[1].horizontal_executions[0].body[0].right.offset.k == -1 - ), "wrong offset in cache access" - assert ( - vertical_loop.sections[1].horizontal_executions[0].body[1].left.name == "foo" - ), "wrong flush source" - assert ( - vertical_loop.sections[1].horizontal_executions[0].body[1].right.name == cache_name - ), "wrong flush destination" - assert ( - vertical_loop.sections[1].horizontal_executions[0].body[1].right.offset.k == 0 - ), "wrong flush offset" + assert len(vertical_loop.sections[1].horizontal_executions[0].body) == 2, ( + "no or too many flush stmts introduced?" + ) + assert vertical_loop.sections[1].horizontal_executions[0].body[0].left.name == cache_name, ( + "wrong field name in cache access" + ) + assert vertical_loop.sections[1].horizontal_executions[0].body[0].right.name == cache_name, ( + "wrong field name in cache access" + ) + assert vertical_loop.sections[1].horizontal_executions[0].body[0].right.offset.k == -1, ( + "wrong offset in cache access" + ) + assert vertical_loop.sections[1].horizontal_executions[0].body[1].left.name == "foo", ( + "wrong flush source" + ) + assert vertical_loop.sections[1].horizontal_executions[0].body[1].right.name == cache_name, ( + "wrong flush destination" + ) + assert vertical_loop.sections[1].horizontal_executions[0].body[1].right.offset.k == 0, ( + "wrong flush offset" + ) def test_fill_flush_to_local_k_caches_basic_forward(): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py index 779ab738cb..e3a4b6986e 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_domain_inference.py @@ -66,7 +66,7 @@ def setup_test_as_fieldop( ) -> tuple[itir.FunCall, itir.FunCall]: if refs is None: assert isinstance(stencil, itir.Lambda) - refs = [f"in_field{i+1}" for i in range(0, len(stencil.params))] + refs = [f"in_field{i + 1}" for i in range(0, len(stencil.params))] testee = im.as_fieldop(stencil)(*refs) expected = im.as_fieldop(stencil, domain)(*refs) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py index 9241bae4bf..3bbf94779e 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_distributed_buffer_relocator.py @@ -134,9 +134,9 @@ def test_distributed_buffer_global_memory_data_race(): assert state2.number_of_nodes() == 2 -def _make_distributed_buffer_global_memory_data_race_sdfg2() -> ( - tuple[dace.SDFG, dace.SDFGState, dace.SDFGState] -): +def _make_distributed_buffer_global_memory_data_race_sdfg2() -> tuple[ + dace.SDFG, dace.SDFGState, dace.SDFGState +]: sdfg = dace.SDFG(util.unique_name("distributed_buffer_global_memory_data_race2_sdfg")) arr_names = ["a", "b", "t"] for name in arr_names: @@ -289,9 +289,9 @@ def test_distributed_buffer_global_memory_data_no_rance2(): assert state2.number_of_nodes() == 0 -def _make_distributed_buffer_non_sink_temporary_sdfg() -> ( - tuple[dace.SDFG, dace.SDFGState, dace.SDFGState] -): +def _make_distributed_buffer_non_sink_temporary_sdfg() -> tuple[ + dace.SDFG, dace.SDFGState, dace.SDFGState +]: sdfg = dace.SDFG(util.unique_name("distributed_buffer_non_sink_temporary_sdfg")) state = sdfg.add_state(is_start_block=True) wb_state = sdfg.add_state_after(state) diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py index 1d98fef8c4..251f68668c 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_global_self_copy_elimination.py @@ -56,9 +56,9 @@ def test_global_self_copy_elimination_only_pattern(): assert count != 0 assert sdfg.number_of_nodes() == 1 - assert ( - state.number_of_nodes() == 0 - ), f"Expected that 0 access nodes remained, but {state.number_of_nodes()} were there." + assert state.number_of_nodes() == 0, ( + f"Expected that 0 access nodes remained, but {state.number_of_nodes()} were there." + ) def test_global_self_copy_elimination_g_downstream(): @@ -95,9 +95,9 @@ def test_global_self_copy_elimination_g_downstream(): assert count != 0 assert sdfg.number_of_nodes() == 2 - assert ( - state1.number_of_nodes() == 0 - ), f"Expected that 0 access nodes remained, but {state.number_of_nodes()} were there." + assert state1.number_of_nodes() == 0, ( + f"Expected that 0 access nodes remained, but {state.number_of_nodes()} were there." + ) assert state2.number_of_nodes() == 5 assert util.count_nodes(state2, dace_nodes.AccessNode) == 2 assert util.count_nodes(state2, dace_nodes.MapEntry) == 1 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py index cdc66d4ffd..f0e73fbf48 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_gpu_utils.py @@ -98,26 +98,28 @@ def test_trivial_gpu_map_promoter_1(): validate=True, validate_all=True, ) - assert ( - nb_runs == 1 - ), f"Expected that 'TrivialGPUMapElimination' applies once but it applied {nb_runs}." + assert nb_runs == 1, ( + f"Expected that 'TrivialGPUMapElimination' applies once but it applied {nb_runs}." + ) trivial_map_params = trivial_map_entry.map.params trivial_map_ranges = trivial_map_entry.map.range second_map_params = second_map_entry.map.params second_map_ranges = second_map_entry.map.range - assert ( - second_map_params == org_second_map_params - ), "The transformation modified the parameter of the second map." - assert all( - org_rng == rng for org_rng, rng in zip(org_second_map_ranges, second_map_ranges) - ), "The transformation modified the range of the second map." + assert second_map_params == org_second_map_params, ( + "The transformation modified the parameter of the second map." + ) + assert all(org_rng == rng for org_rng, rng in zip(org_second_map_ranges, second_map_ranges)), ( + "The transformation modified the range of the second map." + ) assert all( t_rng == s_rng for t_rng, s_rng in zip(trivial_map_ranges, second_map_ranges, strict=True) - ), "Expected that the ranges are the same; trivial '{trivial_map_ranges}'; second '{second_map_ranges}'." - assert ( - trivial_map_params == second_map_params - ), f"Expected the trivial map to have parameters '{second_map_params}', but it had '{trivial_map_params}'." + ), ( + "Expected that the ranges are the same; trivial '{trivial_map_ranges}'; second '{second_map_ranges}'." + ) + assert trivial_map_params == second_map_params, ( + f"Expected the trivial map to have parameters '{second_map_params}', but it had '{trivial_map_params}'." + ) assert sdfg.is_valid() diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py index a08cf12a5a..4f52793c40 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_loop_blocking.py @@ -350,9 +350,9 @@ def test_chained_access() -> None: if not isinstance(node, dace_nodes.MapEntry): continue if state.scope_dict()[node] is None: - assert ( - outer_map is None - ), f"Found multiple outer maps, first '{outer_map}', second '{node}'." + assert outer_map is None, ( + f"Found multiple outer maps, first '{outer_map}', second '{node}'." + ) outer_map = node assert outer_map is not None, "Could not found the outer map." assert len(outer_map.map.params) == 2 diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py index 762040e20d..f0426591fd 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_map_order.py @@ -44,19 +44,19 @@ def _perform_reorder_test( new_map_params = map_entry.map.params.copy() if len(expected_order) == 0: - assert ( - apply_count == 0 - ), f"Expected that the transformation was not applied. New map order: {map_entry.map.params}" + assert apply_count == 0, ( + f"Expected that the transformation was not applied. New map order: {map_entry.map.params}" + ) return else: - assert ( - apply_count > 0 - ), f"Expected that the transformation was applied. Old map order: {map_entry.map.params}; Expected order: {expected_order}" + assert apply_count > 0, ( + f"Expected that the transformation was applied. Old map order: {map_entry.map.params}; Expected order: {expected_order}" + ) assert len(expected_order) == len(new_map_params) - assert ( - expected_order == new_map_params - ), f"Expected map order {expected_order} but got {new_map_params} instead." + assert expected_order == new_map_params, ( + f"Expected map order {expected_order} but got {new_map_params} instead." + ) def _make_test_sdfg(map_params: list[str]) -> dace.SDFG: diff --git a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py index c89fe566c0..1e71234867 100644 --- a/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py +++ b/tests/next_tests/unit_tests/program_processor_tests/runners_tests/dace_tests/transformation_tests/test_strides.py @@ -110,9 +110,9 @@ def _make_strides_propagation_level2_sdfg() -> tuple[dace.SDFG, dace_nodes.Neste return sdfg, nsdfg -def _make_strides_propagation_level1_sdfg() -> ( - tuple[dace.SDFG, dace_nodes.NestedSDFG, dace_nodes.NestedSDFG] -): +def _make_strides_propagation_level1_sdfg() -> tuple[ + dace.SDFG, dace_nodes.NestedSDFG, dace_nodes.NestedSDFG +]: """Generates the level 1 SDFG (top) SDFG for `test_strides_propagation()`. Note that the SDFG is valid, but will be indeterminate. The only point of @@ -186,9 +186,9 @@ def test_strides_propagation_use_symbol_mapping(): exp_stride = f"{aname}_stride" actual_stride = adesc.strides[0] assert len(adesc.strides) == 1 - assert ( - str(actual_stride) == exp_stride - ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + assert str(actual_stride) == exp_stride, ( + f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + ) nsdfg = sdfg.parent_nsdfg_node if nsdfg is not None: @@ -218,9 +218,9 @@ def test_strides_propagation_use_symbol_mapping(): assert original_stride in nsdfg.symbol_mapping assert str(nsdfg.symbol_mapping[original_stride]) == target_symbol assert len(adesc.strides) == 1 - assert ( - str(adesc.strides[0]) == original_stride - ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + assert str(adesc.strides[0]) == original_stride, ( + f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + ) # Now we also propagate `c` thus now all data descriptors have the same stride gtx_transformations.gt_propagate_strides_of(sdfg_level1, "c1", ignore_symbol_mapping=False) @@ -229,14 +229,14 @@ def test_strides_propagation_use_symbol_mapping(): for aname, adesc in sdfg.arrays.items(): nsdfg = sdfg.parent_nsdfg_node original_stride = f"{aname}_stride" - target_symbol = f"{aname[0]}{level-1}_stride" + target_symbol = f"{aname[0]}{level - 1}_stride" if nsdfg is not None: assert original_stride in nsdfg.symbol_mapping assert str(nsdfg.symbol_mapping[original_stride]) == target_symbol assert len(adesc.strides) == 1 - assert ( - str(adesc.strides[0]) == original_stride - ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + assert str(adesc.strides[0]) == original_stride, ( + f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + ) def test_strides_propagation_ignore_symbol_mapping(): @@ -249,9 +249,9 @@ def test_strides_propagation_ignore_symbol_mapping(): exp_stride = f"{aname}_stride" actual_stride = adesc.strides[0] assert len(adesc.strides) == 1 - assert ( - str(actual_stride) == exp_stride - ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + assert str(actual_stride) == exp_stride, ( + f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + ) nsdfg = sdfg.parent_nsdfg_node if nsdfg is not None: @@ -275,9 +275,9 @@ def test_strides_propagation_ignore_symbol_mapping(): else: exp_stride = f"{aname[0]}1_stride" assert len(adesc.strides) == 1 - assert ( - str(adesc.strides[0]) == exp_stride - ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + assert str(adesc.strides[0]) == exp_stride, ( + f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + ) nsdfg = sdfg.parent_nsdfg_node if nsdfg is not None: @@ -292,9 +292,9 @@ def test_strides_propagation_ignore_symbol_mapping(): exp_stride = f"{aname[0]}1_stride" original_stride = f"{aname}_stride" assert len(adesc.strides) == 1 - assert ( - str(adesc.strides[0]) == exp_stride - ), f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + assert str(adesc.strides[0]) == exp_stride, ( + f"Expected that '{aname}' has strides '{exp_stride}', but found '{adesc.strides}'." + ) nsdfg = sdfg.parent_nsdfg_node if nsdfg is not None: diff --git a/tests/next_tests/unit_tests/test_common.py b/tests/next_tests/unit_tests/test_common.py index 8f46fc7ce1..6e196ab232 100644 --- a/tests/next_tests/unit_tests/test_common.py +++ b/tests/next_tests/unit_tests/test_common.py @@ -566,9 +566,9 @@ def test_domain_replace(index, named_ranges, domain, expected): assert new_domain == expected -def dimension_promotion_cases() -> ( - list[tuple[list[list[Dimension]], list[Dimension] | None, None | Pattern]] -): +def dimension_promotion_cases() -> list[ + tuple[list[list[Dimension]], list[Dimension] | None, None | Pattern] +]: raw_list = [ # list of list of dimensions, expected result, expected error message ([["I", "J"], ["I"]], ["I", "J"], None),