Skip to content

Commit

Permalink
running new ruff version
Browse files Browse the repository at this point in the history
  • Loading branch information
romanc committed Jan 23, 2025
1 parent c0d5fb7 commit 0e0af6e
Show file tree
Hide file tree
Showing 47 changed files with 295 additions and 279 deletions.
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/cuir/cuir_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def maybe_const(s):
decl = symtable[node.name]
if isinstance(decl, cuir.Temporary) and decl.data_dims:
data_index_str = "+".join(
f"{index}*{int(np.prod(decl.data_dims[i + 1:], initial=1))}"
f"{index}*{int(np.prod(decl.data_dims[i + 1 :], initial=1))}"
for i, index in enumerate(data_index)
)
return f"{name}({offset})[{data_index_str}]"
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/cartesian/gtc/dace/expansion_specification.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,9 +254,9 @@ def _populate_gpu_schedules(expansion_specification: List[ExpansionItem]):
tiled = True
break
if not tiled:
assert any(
isinstance(item, Map) for item in expansion_specification
), "needs at least one map to avoid dereferencing on CPU"
assert any(isinstance(item, Map) for item in expansion_specification), (
"needs at least one map to avoid dereferencing on CPU"
)
for es in expansion_specification:
if isinstance(es, Map):
if es.schedule is None:
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/dace/symbol_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def get_axis_bound_diff_str(axis_bound1, axis_bound2, var_name: str):
var = var_name
else:
var = ""
return f"{sign}({var}{axis_bound1.offset-axis_bound2.offset:+d})"
return f"{sign}({var}{axis_bound1.offset - axis_bound2.offset:+d})"


@lru_cache(maxsize=None)
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/cartesian/gtc/gtcpp/gtcpp_codegen.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def visit_AccessorRef(
temp = temp_decls[accessor_ref.name]
data_index = "+".join(
[
f"{self.visit(index, in_data_index=True, **kwargs)}*{int(np.prod(temp.data_dims[i+1:], initial=1))}"
f"{self.visit(index, in_data_index=True, **kwargs)}*{int(np.prod(temp.data_dims[i + 1 :], initial=1))}"
for i, index in enumerate(accessor_ref.data_index)
]
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def visit_FieldAccess(
) -> Union[oir.FieldAccess, oir.ScalarAccess]:
offsets = node.offset.to_dict()
if node.name in tmps_name_map:
assert (
offsets["i"] == offsets["j"] == offsets["k"] == 0
), "Non-zero offset in temporary that is replaced?!"
assert offsets["i"] == offsets["j"] == offsets["k"] == 0, (
"Non-zero offset in temporary that is replaced?!"
)
return oir.ScalarAccess(name=tmps_name_map[node.name], dtype=node.dtype)
return self.generic_visit(node, tmps_name_map=tmps_name_map, **kwargs)

Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/cartesian/stencil_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,9 +502,9 @@ def _normalize_origins(
if field_origin is not None:
field_origin_ndim = len(field_origin)
if field_origin_ndim != field_info.ndim:
assert (
field_origin_ndim == field_info.domain_ndim
), f"Invalid origin specification ({field_origin}) for '{name}' field."
assert field_origin_ndim == field_info.domain_ndim, (
f"Invalid origin specification ({field_origin}) for '{name}' field."
)
origin[name] = (*field_origin, *((0,) * len(field_info.data_dims)))

elif all_origin is not None:
Expand Down
17 changes: 10 additions & 7 deletions src/gt4py/cartesian/testing/suites.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,14 +268,14 @@ def _validate_new_args(cls, cls_name, cls_dict):
assert isinstance(cls_dict["symbols"], collections.abc.Mapping), "Invalid 'symbols' mapping"

# Check domain and ndims
assert 1 <= len(domain_range) <= 3 and all(
len(d) == 2 for d in domain_range
), "Invalid 'domain_range' definition"
assert 1 <= len(domain_range) <= 3 and all(len(d) == 2 for d in domain_range), (
"Invalid 'domain_range' definition"
)

if any(cls_name.endswith(suffix) for suffix in ("1D", "2D", "3D")):
assert cls_dict["ndims"] == int(
cls_name[-2:-1]
), "Suite name does not match the actual 'ndims'"
assert cls_dict["ndims"] == int(cls_name[-2:-1]), (
"Suite name does not match the actual 'ndims'"
)

# Check dtypes
assert isinstance(
Expand Down Expand Up @@ -386,7 +386,10 @@ class StencilTestSuite(metaclass=SuiteMeta):
.. code-block:: python
{"float_symbols": (np.float32, np.float64), "int_symbols": (int, np.int_, np.int64)}
{
"float_symbols": (np.float32, np.float64),
"int_symbols": (int, np.int_, np.int64),
}
domain_range : `Sequence` of pairs like `((int, int), (int, int) ... )`
Required class attribute.
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/cartesian/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def standardize_dtype_dict(dtypes):
dtypes as 1-tuples)
"""
assert isinstance(dtypes, collections.abc.Mapping)
assert all(
(isinstance(k, str) or gt_utils.is_iterable_of(k, str)) for k in dtypes.keys()
), "Invalid key in 'dtypes'."
assert all((isinstance(k, str) or gt_utils.is_iterable_of(k, str)) for k in dtypes.keys()), (
"Invalid key in 'dtypes'."
)
assert all(
(
isinstance(k, (type, np.dtype))
Expand Down
5 changes: 3 additions & 2 deletions src/gt4py/cartesian/utils/attrib.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,9 @@ def _make_attrs_class_wrapper(cls):
for name, member in extra_members.items():
if name in cls.__dict__.keys():
raise ValueError(
"Name clashing with a existing '{name}' member"
" of the decorated class ".format(name=name)
"Name clashing with a existing '{name}' member of the decorated class ".format(
name=name
)
)
setattr(cls, name, member)

Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/eve/datamodels/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,7 +212,7 @@ def _field_type_validator_factory(type_annotation: TypeAnnotation, name: str) ->
else:
simple_validator = factory(type_annotation, name, required=True)
return ValidatorAdapter(
simple_validator, f"{getattr(simple_validator,'__name__', 'TypeValidator')}"
simple_validator, f"{getattr(simple_validator, '__name__', 'TypeValidator')}"
)

return _field_type_validator_factory
Expand Down Expand Up @@ -915,9 +915,9 @@ def __attrs_post_init__(self: DataModel) -> None:
return __attrs_post_init__


def _make_devtools_pretty() -> (
Callable[[DataModel, Callable[[Any], Any]], Generator[Any, None, None]]
):
def _make_devtools_pretty() -> Callable[
[DataModel, Callable[[Any], Any]], Generator[Any, None, None]
]:
def __pretty__(
self: DataModel, fmt: Callable[[Any], Any], **kwargs: Any
) -> Generator[Any, None, None]:
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/eve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,9 @@ def register_subclasses(*subclasses: Type) -> Callable[[Type], Type]:
>>> @register_subclasses(MyVirtualSubclassA, MyVirtualSubclassB)
... class MyBaseClass(abc.ABC):
... pass
>>> issubclass(MyVirtualSubclassA, MyBaseClass) and issubclass(MyVirtualSubclassB, MyBaseClass)
>>> issubclass(MyVirtualSubclassA, MyBaseClass) and issubclass(
... MyVirtualSubclassB, MyBaseClass
... )
True
"""
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/next/constructors.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,9 @@ def empty(
>>> from gt4py._core import definitions as core_defs
>>> JDim = gtx.Dimension("J")
>>> b = gtx.empty({IDim: 3, JDim: 3}, int, device=core_defs.Device(core_defs.DeviceType.CPU, 0))
>>> b = gtx.empty(
... {IDim: 3, JDim: 3}, int, device=core_defs.Device(core_defs.DeviceType.CPU, 0)
... )
>>> b.shape
(3, 3)
"""
Expand Down
5 changes: 4 additions & 1 deletion src/gt4py/next/embedded/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,10 @@ def restrict_to_intersection(
... common.domain({I: (1, 3), J: (0, 3)}),
... ignore_dims=J,
... )
>>> assert res == (common.domain({I: (1, 3), J: (1, 2)}), common.domain({I: (1, 3), J: (0, 3)}))
>>> assert res == (
... common.domain({I: (1, 3), J: (1, 2)}),
... common.domain({I: (1, 3), J: (0, 3)}),
... )
"""
ignore_dims_tuple = ignore_dims if isinstance(ignore_dims, tuple) else (ignore_dims,)
intersection_without_ignore_dims = domain_intersection(
Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/ffront/foast_passes/type_deduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs: Any) -> foast.IfStmt:
if not isinstance(new_node.condition.type, ts.ScalarType):
raise errors.DSLError(
node.location,
"Condition for 'if' must be scalar, " f"got '{new_node.condition.type}' instead.",
f"Condition for 'if' must be scalar, got '{new_node.condition.type}' instead.",
)

if new_node.condition.type.kind != ts.ScalarKind.BOOL:
Expand Down
3 changes: 1 addition & 2 deletions src/gt4py/next/ffront/foast_to_gtir.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,8 +275,7 @@ def _visit_shift(self, node: foast.Call, **kwargs: Any) -> itir.Expr:
)(current_expr)
# `field(Dim + idx)`
case foast.BinOp(
op=dialect_ast_enums.BinaryOperator.ADD
| dialect_ast_enums.BinaryOperator.SUB,
op=dialect_ast_enums.BinaryOperator.ADD | dialect_ast_enums.BinaryOperator.SUB,
left=foast.Name(id=dimension), # TODO(tehrengruber): use type of lhs
right=foast.Constant(value=offset_index),
):
Expand Down
4 changes: 3 additions & 1 deletion src/gt4py/next/ffront/foast_to_past.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,9 @@ class OperatorToProgram(workflow.Workflow[AOT_FOP, AOT_PRG]):
... column_axis=None,
... )
>>> copy_program = op_to_prog(toolchain.CompilableProgram(copy.foast_stage, compile_time_args))
>>> copy_program = op_to_prog(
... toolchain.CompilableProgram(copy.foast_stage, compile_time_args)
... )
>>> print(copy_program.data.past_node.id)
__field_operator_copy
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/ffront/func_to_foast.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,9 +317,9 @@ def visit_AnnAssign(self, node: ast.AnnAssign, **kwargs: Any) -> foast.Assign:
raise errors.DSLError(self.get_location(node), "Can only assign to names.")

if node.annotation is not None:
assert isinstance(
node.annotation, ast.Constant
), "Annotations should be ast.Constant(string). Use StringifyAnnotationsPass"
assert isinstance(node.annotation, ast.Constant), (
"Annotations should be ast.Constant(string). Use StringifyAnnotationsPass"
)
context = {**fbuiltins.BUILTINS, **self.closure_vars}
annotation = eval(node.annotation.value, context)
target_type = type_translation.from_type_hint(annotation, globalns=context)
Expand Down
10 changes: 6 additions & 4 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,9 @@ def past_to_gtir(inp: AOT_PRG) -> stages.CompilableProgram:
... return a
>>> @gtx.program
... def copy_program(a: gtx.Field[[IDim], gtx.float32], out: gtx.Field[[IDim], gtx.float32]):
... def copy_program(
... a: gtx.Field[[IDim], gtx.float32], out: gtx.Field[[IDim], gtx.float32]
... ):
... copy(a, out=out)
>>> compile_time_args = arguments.CompileTimeArgs(
Expand Down Expand Up @@ -460,9 +462,9 @@ def _visit_stencil_call_out_arg(

field_slice = None
if isinstance(first_field, past.Subscript):
assert all(
isinstance(field, past.Subscript) for field in flattened
), "Incompatible field in tuple: either all fields or no field must be sliced."
assert all(isinstance(field, past.Subscript) for field in flattened), (
"Incompatible field in tuple: either all fields or no field must be sliced."
)
assert all(
concepts.eq_nonlocated(
first_field.slice_,
Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/next/iterator/builtins.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
# Please, refer to the LICENSE file in the root directory.
# SPDX-License-Identifier: BSD-3-Clause

# ruff: noqa: A005 Module `builtins` shadows a Python standard-library module

from gt4py.next.iterator.dispatcher import Dispatcher


Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,9 +642,9 @@ def _is_list_of_complete_offsets(
def group_offsets(*offsets: OffsetPart) -> list[CompleteOffset]:
assert len(offsets) % 2 == 0
complete_offsets = [*zip(offsets[::2], offsets[1::2])]
assert _is_list_of_complete_offsets(
complete_offsets
), f"Invalid sequence of offset parts: {offsets}"
assert _is_list_of_complete_offsets(complete_offsets), (
f"Invalid sequence of offset parts: {offsets}"
)
return complete_offsets


Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/transforms/collapse_tuple.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,9 +284,9 @@ def transform_collapse_tuple_get_make_tuple(
assert type_info.is_integer(node.args[0].type)
make_tuple_call = node.args[1]
idx = int(node.args[0].value)
assert idx < len(
make_tuple_call.args
), f"Index {idx} is out of bounds for tuple of size {len(make_tuple_call.args)}"
assert idx < len(make_tuple_call.args), (
f"Index {idx} is out of bounds for tuple of size {len(make_tuple_call.args)}"
)
return node.args[1].args[idx]
return None

Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/transforms/cse.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,9 +418,9 @@ def apply(
assert within_stencil is None
within_stencil = False
else:
assert (
within_stencil is not None
), "The expression's context must be specified using `within_stencil`."
assert within_stencil is not None, (
"The expression's context must be specified using `within_stencil`."
)

offset_provider_type = offset_provider_type or {}
node = itir_type_inference.infer(
Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/transforms/infer_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,9 +466,9 @@ def infer_program(
See :func:`infer_expr` for more details.
"""
assert (
not program.function_definitions
), "Domain propagation does not support function definitions."
assert not program.function_definitions, (
"Domain propagation does not support function definitions."
)

return itir.Program(
id=program.id,
Expand Down
7 changes: 4 additions & 3 deletions src/gt4py/next/iterator/transforms/power_unrolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,18 +53,19 @@ def visit_FunCall(self, node: ir.FunCall):
remainder = exponent

# Build target expression
ret = im.ref(f"power_{2 ** pow_max}")
ret = im.ref(f"power_{2**pow_max}")
remainder -= 2**pow_cur
while remainder > 0:
pow_cur = _compute_integer_power_of_two(remainder)
remainder -= 2**pow_cur

ret = im.multiplies_(ret, f"power_{2 ** pow_cur}")
ret = im.multiplies_(ret, f"power_{2**pow_cur}")

# Nest target expression to avoid multiple redundant evaluations
for i in range(pow_max, 0, -1):
ret = im.let(
f"power_{2 ** i}", im.multiplies_(f"power_{2**(i-1)}", f"power_{2**(i-1)}")
f"power_{2**i}",
im.multiplies_(f"power_{2 ** (i - 1)}", f"power_{2 ** (i - 1)}"),
)(ret)
ret = im.let("power_1", base)(ret)

Expand Down
6 changes: 3 additions & 3 deletions src/gt4py/next/iterator/transforms/remap_symbols.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@ def visit_Lambda(self, node: ir.Lambda, *, symbol_map: Dict[str, ir.Node]):
return ir.Lambda(params=node.params, expr=self.visit(node.expr, symbol_map=new_symbol_map))

def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override]
assert isinstance(node, SymbolTableTrait) == isinstance(
node, ir.Lambda
), "found unexpected new symbol scope"
assert isinstance(node, SymbolTableTrait) == isinstance(node, ir.Lambda), (
"found unexpected new symbol scope"
)
return super().generic_visit(node, **kwargs)


Expand Down
8 changes: 4 additions & 4 deletions src/gt4py/next/iterator/type_system/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,9 @@ def __call__(
*args: type_synthesizer.TypeOrTypeSynthesizer,
offset_provider_type: common.OffsetProviderType,
) -> Union[ts.TypeSpec, ObservableTypeSynthesizer]:
assert all(
isinstance(arg, (ts.TypeSpec, ObservableTypeSynthesizer)) for arg in args
), "ObservableTypeSynthesizer can only be used with arguments that are TypeSpec or ObservableTypeSynthesizer"
assert all(isinstance(arg, (ts.TypeSpec, ObservableTypeSynthesizer)) for arg in args), (
"ObservableTypeSynthesizer can only be used with arguments that are TypeSpec or ObservableTypeSynthesizer"
)

return_type_or_synthesizer = self.type_synthesizer(
*args, offset_provider_type=offset_provider_type
Expand Down Expand Up @@ -644,7 +644,7 @@ def visit_FunCall(
return result

def visit_Node(self, node: itir.Node, **kwargs):
raise NotImplementedError(f"No type rule for nodes of type " f"'{type(node).__name__}'.")
raise NotImplementedError(f"No type rule for nodes of type '{type(node).__name__}'.")


infer = ITIRTypeInference.apply
Expand Down
10 changes: 4 additions & 6 deletions src/gt4py/next/otf/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,12 +98,10 @@ def jit_to_aot_args(
return CompileTimeArgs.from_concrete_no_size(*inp.args, **inp.kwargs)


def adapted_jit_to_aot_args_factory() -> (
workflow.Workflow[
toolchain.CompilableProgram[DATA_T, JITArgs],
toolchain.CompilableProgram[DATA_T, CompileTimeArgs],
]
):
def adapted_jit_to_aot_args_factory() -> workflow.Workflow[
toolchain.CompilableProgram[DATA_T, JITArgs],
toolchain.CompilableProgram[DATA_T, CompileTimeArgs],
]:
"""Wrap `jit_to_aot` into a workflow adapter to fit into backend transform workflows."""
return toolchain.ArgsOnlyAdapter(jit_to_aot_args)

Expand Down
2 changes: 1 addition & 1 deletion src/gt4py/next/otf/binding/cpp_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def render_function_declaration(function: interface.Function, body: str) -> str:
}}"""
if template_params:
return f"""
template <{', '.join(template_params)}>
template <{", ".join(template_params)}>
{rendered_decl}
""".strip()
return rendered_decl
Expand Down
Loading

0 comments on commit 0e0af6e

Please sign in to comment.