diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3259a74f38..862aa46d66 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -62,7 +62,7 @@ repos: ## version = re.search('black==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: '{version}' # version from constraints.txt") ##]]] - rev: '23.12.1' # version from constraints.txt + rev: '24.1.1' # version from constraints.txt ##[[[end]]] hooks: - id: black @@ -162,7 +162,7 @@ repos: ##]]] - astunparse==1.6.3 - attrs==23.2.0 - - black==23.12.1 + - black==24.1.1 - boltons==23.1.1 - cached-property==1.5.2 - click==8.1.7 @@ -175,7 +175,7 @@ repos: - importlib-resources==6.1.1 - jinja2==3.1.3 - lark==1.1.9 - - mako==1.3.1 + - mako==1.3.2 - nanobind==1.8.0 - ninja==1.11.1.1 - numpy==1.24.4 diff --git a/constraints.txt b/constraints.txt index 343615b421..1aa47d8340 100644 --- a/constraints.txt +++ b/constraints.txt @@ -11,7 +11,7 @@ astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.2.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing babel==2.14.0 # via sphinx backcall==0.2.0 # via ipython -black==23.12.1 # via gt4py (pyproject.toml) +black==24.1.1 # via gt4py (pyproject.toml) blinker==1.7.0 # via flask boltons==23.1.1 # via gt4py (pyproject.toml) build==1.0.3 # via pip-tools @@ -29,8 +29,8 @@ cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox comm==0.2.1 # via ipykernel contourpy==1.1.1 # via matplotlib -coverage==7.4.0 # via -r requirements-dev.in, coverage, pytest-cov -cryptography==42.0.1 # via types-paramiko, types-pyopenssl, types-redis +coverage==7.4.1 # via -r requirements-dev.in, coverage, pytest-cov +cryptography==42.0.2 # via types-paramiko, types-pyopenssl, types-redis cycler==0.12.1 # via matplotlib cytoolz==0.12.3 # via gt4py (pyproject.toml) dace==0.15.1 # via gt4py (pyproject.toml) @@ -39,7 +39,7 @@ debugpy==1.8.0 # via ipykernel decorator==5.1.1 # via ipython deepdiff==6.7.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) -dill==0.3.7 # via dace +dill==0.3.8 # via dace distlib==0.3.8 # via virtualenv docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme eradicate==2.3.0 # via flake8-eradicate @@ -47,7 +47,7 @@ exceptiongroup==1.2.0 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist executing==2.0.1 # via devtools, stack-data factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==22.5.1 # via factory-boy +faker==22.6.0 # via factory-boy fastjsonschema==2.19.1 # via nbformat filelock==3.13.1 # via tox, virtualenv flake8==7.0.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings @@ -64,7 +64,7 @@ fonttools==4.47.2 # via matplotlib fparser==0.1.3 # via dace frozendict==2.4.0 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.97.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +hypothesis==6.97.3 # via -r requirements-dev.in, gt4py (pyproject.toml) identify==2.5.33 # via pre-commit idna==3.6 # via requests imagesize==1.4.1 # via sphinx @@ -85,7 +85,7 @@ jupyter-core==5.7.1 # via ipykernel, jupyter-client, nbformat jupytext==1.16.1 # via -r requirements-dev.in kiwisolver==1.4.5 # via matplotlib lark==1.1.9 # via gt4py (pyproject.toml) -mako==1.3.1 # via gt4py (pyproject.toml) +mako==1.3.2 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins markupsafe==2.1.4 # via jinja2, mako, werkzeug matplotlib==3.7.4 # via -r requirements-dev.in @@ -99,7 +99,7 @@ mypy-extensions==1.0.0 # via black, mypy nanobind==1.8.0 # via gt4py (pyproject.toml) nbclient==0.6.8 # via nbmake nbformat==5.9.2 # via jupytext, nbclient, nbmake -nbmake==1.4.6 # via -r requirements-dev.in +nbmake==1.5.0 # via -r requirements-dev.in nest-asyncio==1.6.0 # via ipykernel, nbclient networkx==3.1 # via dace ninja==1.11.1.1 # via gt4py (pyproject.toml) @@ -115,7 +115,7 @@ pillow==10.2.0 # via matplotlib pip-tools==7.3.0 # via -r requirements-dev.in pipdeptree==2.13.2 # via -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==4.1.0 # via black, jupyter-core, tox, virtualenv +platformdirs==4.2.0 # via black, jupyter-core, tox, virtualenv pluggy==1.4.0 # via pytest, tox ply==3.11 # via dace pre-commit==3.5.0 # via -r requirements-dev.in @@ -132,20 +132,20 @@ pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-doc pyparsing==3.1.1 # via matplotlib pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.4 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==8.0.0 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in pytest-factoryboy==2.6.0 # via -r requirements-dev.in pytest-xdist==3.5.0 # via -r requirements-dev.in, pytest-xdist python-dateutil==2.8.2 # via faker, jupyter-client, matplotlib -pytz==2023.3.post1 # via babel +pytz==2023.4 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit pyzmq==25.1.2 # via ipykernel, jupyter-client -referencing==0.32.1 # via jsonschema, jsonschema-specifications +referencing==0.33.0 # via jsonschema, jsonschema-specifications requests==2.31.0 # via dace, sphinx restructuredtext-lint==1.4.0 # via flake8-rst-docstrings rpds-py==0.17.1 # via jsonschema, referencing -ruff==0.1.14 # via -r requirements-dev.in +ruff==0.1.15 # via -r requirements-dev.in setuptools-scm==8.0.4 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx @@ -193,7 +193,7 @@ types-datetimerange==2.0.0.6 # via types-all types-decorator==5.1.8.20240106 # via types-all types-deprecated==1.2.9.20240106 # via types-all types-docopt==0.6.11.4 # via types-all -types-docutils==0.20.0.20240125 # via types-all +types-docutils==0.20.0.20240126 # via types-all types-emoji==2.1.0.3 # via types-all types-enum34==1.1.8 # via types-all types-fb303==1.0.0 # via types-all, types-scribe @@ -209,7 +209,7 @@ types-itsdangerous==1.1.6 # via types-all types-jack-client==0.5.10.20240106 # via types-all types-jinja2==2.11.9 # via types-all, types-flask types-kazoo==0.1.3 # via types-all -types-markdown==3.5.0.20240106 # via types-all +types-markdown==3.5.0.20240129 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 types-mock==5.1.0.20240106 # via types-all @@ -222,20 +222,20 @@ types-pathlib2==2.3.0 # via types-all types-pillow==10.2.0.20240125 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.20240115 # via types-all -types-protobuf==4.24.0.20240106 # via types-all +types-protobuf==4.24.0.20240129 # via types-all types-pyaudio==0.2.16.20240106 # via types-all types-pycurl==7.45.2.20240106 # via types-all types-pyfarmhash==0.3.1.2 # via types-all types-pyjwt==1.7.1 # via types-all types-pymssql==2.1.0 # via types-all types-pymysql==1.1.0.1 # via types-all -types-pyopenssl==23.3.0.20240106 # via types-redis +types-pyopenssl==24.0.0.20240130 # via types-redis types-pyrfc3339==1.1.1.5 # via types-all types-pysftp==0.2.17.20240106 # via types-all types-python-dateutil==2.8.19.20240106 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all -types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.1.1 # via types-all, types-tzlocal +types-python-slugify==8.0.2.20240127 # via types-all +types-pytz==2023.4.0.20240130 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all types-pyyaml==6.0.12.12 # via types-all types-redis==4.6.0.20240106 # via types-all @@ -258,7 +258,7 @@ types-waitress==2.1.4.20240106 # via types-all types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), ipython, mypy, pytest-factoryboy, setuptools-scm -urllib3==2.1.0 # via requests, types-requests +urllib3==2.2.0 # via requests, types-requests virtualenv==20.25.0 # via pre-commit, tox wcwidth==0.2.13 # via prompt-toolkit websockets==12.0 # via dace diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 3c6cd3d9ff..7200018616 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -61,7 +61,7 @@ pipdeptree==2.3 pre-commit==2.17 psutil==5.0 pybind11==2.5 -pygments==2.7 +pygments==2.7.3 pytest-cache==1.0 pytest-cov==2.8 pytest-factoryboy==2.0.3 diff --git a/min-requirements-test.txt b/min-requirements-test.txt index d2ebaba331..259663ffc4 100644 --- a/min-requirements-test.txt +++ b/min-requirements-test.txt @@ -58,7 +58,7 @@ pipdeptree==2.3 pre-commit==2.17 psutil==5.0 pybind11==2.5 -pygments==2.7 +pygments==2.7.3 pytest-cache==1.0 pytest-cov==2.8 pytest-factoryboy==2.0.3 diff --git a/pyproject.toml b/pyproject.toml index 51cfc267d5..5a1618fc49 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -162,7 +162,9 @@ ignore = [ 'DAR', # Disable dargling errors by default 'E203', # Whitespace before ':' (black formatter breaks this sometimes) 'E501', # Line too long (using Bugbear's B950 warning) - 'W503' # Line break occurred before a binary operator + 'W503', # Line break occurred before a binary operator + 'E701', # Multiple statements on one line, see https://github.com/psf/black/issues/3887 + 'E704' # Multiple statements on one line, see https://github.com/psf/black/issues/3887 ] max-complexity = 15 max-line-length = 100 # It should be the same as in `tool.black.line-length` above diff --git a/requirements-dev.in b/requirements-dev.in index 59ddb733d0..4bb05ecbc5 100644 --- a/requirements-dev.in +++ b/requirements-dev.in @@ -32,7 +32,7 @@ pipdeptree>=2.3 pip-tools>=6.10 pre-commit>=2.17 psutil>=5.0 -pygments>=2.7 +pygments>=2.7.3 pytest-cache>=1.0 pytest-cov>=2.8 pytest-factoryboy>=2.0.3 diff --git a/requirements-dev.txt b/requirements-dev.txt index abfa99a2ae..e54e56ad62 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -11,7 +11,7 @@ astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.2.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing babel==2.14.0 # via sphinx backcall==0.2.0 # via ipython -black==23.12.1 # via gt4py (pyproject.toml) +black==24.1.1 # via gt4py (pyproject.toml) blinker==1.7.0 # via flask boltons==23.1.1 # via gt4py (pyproject.toml) build==1.0.3 # via pip-tools @@ -29,8 +29,8 @@ cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox comm==0.2.1 # via ipykernel contourpy==1.1.1 # via matplotlib -coverage[toml]==7.4.0 # via -r requirements-dev.in, coverage, pytest-cov -cryptography==42.0.1 # via types-paramiko, types-pyopenssl, types-redis +coverage[toml]==7.4.1 # via -r requirements-dev.in, coverage, pytest-cov +cryptography==42.0.2 # via types-paramiko, types-pyopenssl, types-redis cycler==0.12.1 # via matplotlib cytoolz==0.12.3 # via gt4py (pyproject.toml) dace==0.15.1 # via gt4py (pyproject.toml) @@ -39,7 +39,7 @@ debugpy==1.8.0 # via ipykernel decorator==5.1.1 # via ipython deepdiff==6.7.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) -dill==0.3.7 # via dace +dill==0.3.8 # via dace distlib==0.3.8 # via virtualenv docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme eradicate==2.3.0 # via flake8-eradicate @@ -47,7 +47,7 @@ exceptiongroup==1.2.0 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist executing==2.0.1 # via devtools, stack-data factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==22.5.1 # via factory-boy +faker==22.6.0 # via factory-boy fastjsonschema==2.19.1 # via nbformat filelock==3.13.1 # via tox, virtualenv flake8==7.0.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings @@ -64,7 +64,7 @@ fonttools==4.47.2 # via matplotlib fparser==0.1.3 # via dace frozendict==2.4.0 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.97.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +hypothesis==6.97.3 # via -r requirements-dev.in, gt4py (pyproject.toml) identify==2.5.33 # via pre-commit idna==3.6 # via requests imagesize==1.4.1 # via sphinx @@ -85,7 +85,7 @@ jupyter-core==5.7.1 # via ipykernel, jupyter-client, nbformat jupytext==1.16.1 # via -r requirements-dev.in kiwisolver==1.4.5 # via matplotlib lark==1.1.9 # via gt4py (pyproject.toml) -mako==1.3.1 # via gt4py (pyproject.toml) +mako==1.3.2 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins markupsafe==2.1.4 # via jinja2, mako, werkzeug matplotlib==3.7.4 # via -r requirements-dev.in @@ -99,7 +99,7 @@ mypy-extensions==1.0.0 # via black, mypy nanobind==1.8.0 # via gt4py (pyproject.toml) nbclient==0.6.8 # via nbmake nbformat==5.9.2 # via jupytext, nbclient, nbmake -nbmake==1.4.6 # via -r requirements-dev.in +nbmake==1.5.0 # via -r requirements-dev.in nest-asyncio==1.6.0 # via ipykernel, nbclient networkx==3.1 # via dace ninja==1.11.1.1 # via gt4py (pyproject.toml) @@ -115,7 +115,7 @@ pillow==10.2.0 # via matplotlib pip-tools==7.3.0 # via -r requirements-dev.in pipdeptree==2.13.2 # via -r requirements-dev.in pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==4.1.0 # via black, jupyter-core, tox, virtualenv +platformdirs==4.2.0 # via black, jupyter-core, tox, virtualenv pluggy==1.4.0 # via pytest, tox ply==3.11 # via dace pre-commit==3.5.0 # via -r requirements-dev.in @@ -132,20 +132,20 @@ pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-doc pyparsing==3.1.1 # via matplotlib pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.4 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==8.0.0 # via -r requirements-dev.in, gt4py (pyproject.toml), nbmake, pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in pytest-factoryboy==2.6.0 # via -r requirements-dev.in pytest-xdist[psutil]==3.5.0 # via -r requirements-dev.in, pytest-xdist python-dateutil==2.8.2 # via faker, jupyter-client, matplotlib -pytz==2023.3.post1 # via babel +pytz==2023.4 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit pyzmq==25.1.2 # via ipykernel, jupyter-client -referencing==0.32.1 # via jsonschema, jsonschema-specifications +referencing==0.33.0 # via jsonschema, jsonschema-specifications requests==2.31.0 # via dace, sphinx restructuredtext-lint==1.4.0 # via flake8-rst-docstrings rpds-py==0.17.1 # via jsonschema, referencing -ruff==0.1.14 # via -r requirements-dev.in +ruff==0.1.15 # via -r requirements-dev.in setuptools-scm==8.0.4 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx @@ -193,7 +193,7 @@ types-datetimerange==2.0.0.6 # via types-all types-decorator==5.1.8.20240106 # via types-all types-deprecated==1.2.9.20240106 # via types-all types-docopt==0.6.11.4 # via types-all -types-docutils==0.20.0.20240125 # via types-all +types-docutils==0.20.0.20240126 # via types-all types-emoji==2.1.0.3 # via types-all types-enum34==1.1.8 # via types-all types-fb303==1.0.0 # via types-all, types-scribe @@ -209,7 +209,7 @@ types-itsdangerous==1.1.6 # via types-all types-jack-client==0.5.10.20240106 # via types-all types-jinja2==2.11.9 # via types-all, types-flask types-kazoo==0.1.3 # via types-all -types-markdown==3.5.0.20240106 # via types-all +types-markdown==3.5.0.20240129 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 types-mock==5.1.0.20240106 # via types-all @@ -222,20 +222,20 @@ types-pathlib2==2.3.0 # via types-all types-pillow==10.2.0.20240125 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.20240115 # via types-all -types-protobuf==4.24.0.20240106 # via types-all +types-protobuf==4.24.0.20240129 # via types-all types-pyaudio==0.2.16.20240106 # via types-all types-pycurl==7.45.2.20240106 # via types-all types-pyfarmhash==0.3.1.2 # via types-all types-pyjwt==1.7.1 # via types-all types-pymssql==2.1.0 # via types-all types-pymysql==1.1.0.1 # via types-all -types-pyopenssl==23.3.0.20240106 # via types-redis +types-pyopenssl==24.0.0.20240130 # via types-redis types-pyrfc3339==1.1.1.5 # via types-all types-pysftp==0.2.17.20240106 # via types-all types-python-dateutil==2.8.19.20240106 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all -types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.1.1 # via types-all, types-tzlocal +types-python-slugify==8.0.2.20240127 # via types-all +types-pytz==2023.4.0.20240130 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all types-pyyaml==6.0.12.12 # via types-all types-redis==4.6.0.20240106 # via types-all @@ -258,7 +258,7 @@ types-waitress==2.1.4.20240106 # via types-all types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), ipython, mypy, pytest-factoryboy, setuptools-scm -urllib3==2.1.0 # via requests, types-requests +urllib3==2.2.0 # via requests, types-requests virtualenv==20.25.0 # via pre-commit, tox wcwidth==0.2.13 # via prompt-toolkit websockets==12.0 # via dace diff --git a/src/gt4py/_core/definitions.py b/src/gt4py/_core/definitions.py index 091fa77e3f..6237704f69 100644 --- a/src/gt4py/_core/definitions.py +++ b/src/gt4py/_core/definitions.py @@ -165,28 +165,23 @@ class DTypeKind(eve.StrEnum): @overload -def dtype_kind(sc_type: Type[BoolT]) -> Literal[DTypeKind.BOOL]: - ... +def dtype_kind(sc_type: Type[BoolT]) -> Literal[DTypeKind.BOOL]: ... @overload -def dtype_kind(sc_type: Type[IntT]) -> Literal[DTypeKind.INT]: - ... +def dtype_kind(sc_type: Type[IntT]) -> Literal[DTypeKind.INT]: ... @overload -def dtype_kind(sc_type: Type[UnsignedIntT]) -> Literal[DTypeKind.UINT]: - ... +def dtype_kind(sc_type: Type[UnsignedIntT]) -> Literal[DTypeKind.UINT]: ... @overload -def dtype_kind(sc_type: Type[FloatingT]) -> Literal[DTypeKind.FLOAT]: - ... +def dtype_kind(sc_type: Type[FloatingT]) -> Literal[DTypeKind.FLOAT]: ... @overload -def dtype_kind(sc_type: Type[ScalarT]) -> DTypeKind: - ... +def dtype_kind(sc_type: Type[ScalarT]) -> DTypeKind: ... def dtype_kind(sc_type: Type[ScalarT]) -> DTypeKind: @@ -360,8 +355,7 @@ class GTDimsInterface(Protocol): """ @property - def __gt_dims__(self) -> Tuple[str, ...]: - ... + def __gt_dims__(self) -> Tuple[str, ...]: ... class GTOriginInterface(Protocol): @@ -372,8 +366,7 @@ class GTOriginInterface(Protocol): """ @property - def __gt_origin__(self) -> Tuple[int, ...]: - ... + def __gt_origin__(self) -> Tuple[int, ...]: ... # -- Device representation -- @@ -443,61 +436,43 @@ def __iter__(self) -> Iterator[DeviceTypeT | int]: class NDArrayObject(Protocol): @property - def ndim(self) -> int: - ... + def ndim(self) -> int: ... @property - def shape(self) -> tuple[int, ...]: - ... + def shape(self) -> tuple[int, ...]: ... @property - def dtype(self) -> Any: - ... + def dtype(self) -> Any: ... - def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: - ... + def astype(self, dtype: npt.DTypeLike) -> NDArrayObject: ... - def __getitem__(self, item: Any) -> NDArrayObject: - ... + def __getitem__(self, item: Any) -> NDArrayObject: ... - def __abs__(self) -> NDArrayObject: - ... + def __abs__(self) -> NDArrayObject: ... - def __neg__(self) -> NDArrayObject: - ... + def __neg__(self) -> NDArrayObject: ... - def __add__(self, other: NDArrayObject | Scalar) -> NDArrayObject: - ... + def __add__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __radd__(self, other: Any) -> NDArrayObject: - ... + def __radd__(self, other: Any) -> NDArrayObject: ... - def __sub__(self, other: NDArrayObject | Scalar) -> NDArrayObject: - ... + def __sub__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __rsub__(self, other: Any) -> NDArrayObject: - ... + def __rsub__(self, other: Any) -> NDArrayObject: ... - def __mul__(self, other: NDArrayObject | Scalar) -> NDArrayObject: - ... + def __mul__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __rmul__(self, other: Any) -> NDArrayObject: - ... + def __rmul__(self, other: Any) -> NDArrayObject: ... - def __floordiv__(self, other: NDArrayObject | Scalar) -> NDArrayObject: - ... + def __floordiv__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __rfloordiv__(self, other: Any) -> NDArrayObject: - ... + def __rfloordiv__(self, other: Any) -> NDArrayObject: ... - def __truediv__(self, other: NDArrayObject | Scalar) -> NDArrayObject: - ... + def __truediv__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __rtruediv__(self, other: Any) -> NDArrayObject: - ... + def __rtruediv__(self, other: Any) -> NDArrayObject: ... - def __pow__(self, other: NDArrayObject | Scalar) -> NDArrayObject: - ... + def __pow__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... def __eq__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[override] # mypy wants to return `bool` ... @@ -517,11 +492,8 @@ def __lt__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignor def __le__(self, other: NDArrayObject | Scalar) -> NDArrayObject: # type: ignore[misc] # Forward operator is not callable ... - def __and__(self, other: NDArrayObject | Scalar) -> NDArrayObject: - ... + def __and__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject: - ... + def __or__(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... - def __xor(self, other: NDArrayObject | Scalar) -> NDArrayObject: - ... + def __xor(self, other: NDArrayObject | Scalar) -> NDArrayObject: ... diff --git a/src/gt4py/cartesian/backend/base.py b/src/gt4py/cartesian/backend/base.py index 62e36de721..669110161e 100644 --- a/src/gt4py/cartesian/backend/base.py +++ b/src/gt4py/cartesian/backend/base.py @@ -305,8 +305,7 @@ def make_module_source(self, *, args_data: Optional[ModuleData] = None, **kwargs class MakeModuleSourceCallable(Protocol): - def __call__(self, *, args_data: Optional[ModuleData] = None, **kwargs: Any) -> str: - ... + def __call__(self, *, args_data: Optional[ModuleData] = None, **kwargs: Any) -> str: ... class PurePythonBackendCLIMixin(CLIBackendMixin): diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index 5dae025acb..b02c765ad7 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -684,14 +684,14 @@ def generate_entry_params(self, stencil_ir: gtir.Stencil, sdfg: dace.SDFG) -> Li if name in sdfg.arrays: data = sdfg.arrays[name] assert isinstance(data, dace.data.Array) - res[ - name - ] = "py::{pybind_type} {name}, std::array {name}_origin".format( - pybind_type="object" - if self.backend.storage_info["device"] == "gpu" - else "buffer", - name=name, - ndim=len(data.shape), + res[name] = ( + "py::{pybind_type} {name}, std::array {name}_origin".format( + pybind_type=( + "object" if self.backend.storage_info["device"] == "gpu" else "buffer" + ), + name=name, + ndim=len(data.shape), + ) ) elif name in sdfg.symbols and not name.startswith("__"): assert name in sdfg.symbols diff --git a/src/gt4py/cartesian/backend/gtcpp_backend.py b/src/gt4py/cartesian/backend/gtcpp_backend.py index e0f982b8be..c69b5b5088 100644 --- a/src/gt4py/cartesian/backend/gtcpp_backend.py +++ b/src/gt4py/cartesian/backend/gtcpp_backend.py @@ -88,9 +88,9 @@ def visit_FieldDecl(self, node: gtcpp.FieldDecl, **kwargs): sid_ndim = domain_ndim + data_ndim if kwargs["external_arg"]: return "py::{pybind_type} {name}, std::array {name}_origin".format( - pybind_type="object" - if self.backend.storage_info["device"] == "gpu" - else "buffer", + pybind_type=( + "object" if self.backend.storage_info["device"] == "gpu" else "buffer" + ), name=node.name, sid_ndim=sid_ndim, ) diff --git a/src/gt4py/cartesian/backend/pyext_builder.py b/src/gt4py/cartesian/backend/pyext_builder.py index e12669ae0f..1ffa5a412d 100644 --- a/src/gt4py/cartesian/backend/pyext_builder.py +++ b/src/gt4py/cartesian/backend/pyext_builder.py @@ -179,8 +179,7 @@ def build_pybind_ext( build_path: str, target_path: str, **kwargs: str, -) -> Tuple[str, str]: - ... +) -> Tuple[str, str]: ... @overload @@ -198,8 +197,7 @@ def build_pybind_ext( build_ext_class: Type = None, verbose: bool = False, clean: bool = False, -) -> Tuple[str, str]: - ... +) -> Tuple[str, str]: ... def build_pybind_ext( diff --git a/src/gt4py/cartesian/frontend/defir_to_gtir.py b/src/gt4py/cartesian/frontend/defir_to_gtir.py index f2ee544900..eb53e49ac5 100644 --- a/src/gt4py/cartesian/frontend/defir_to_gtir.py +++ b/src/gt4py/cartesian/frontend/defir_to_gtir.py @@ -489,18 +489,18 @@ def visit_If(self, node: If) -> Union[gtir.FieldIfStmt, gtir.ScalarIfStmt]: return gtir.FieldIfStmt( cond=cond, true_branch=gtir.BlockStmt(body=self.visit(node.main_body)), - false_branch=gtir.BlockStmt(body=self.visit(node.else_body)) - if node.else_body - else None, + false_branch=( + gtir.BlockStmt(body=self.visit(node.else_body)) if node.else_body else None + ), loc=location_to_source_location(node.loc), ) else: return gtir.ScalarIfStmt( cond=cond, true_branch=gtir.BlockStmt(body=self.visit(node.main_body)), - false_branch=gtir.BlockStmt(body=self.visit(node.else_body)) - if node.else_body - else None, + false_branch=( + gtir.BlockStmt(body=self.visit(node.else_body)) if node.else_body else None + ), loc=location_to_source_location(node.loc), ) diff --git a/src/gt4py/cartesian/frontend/gtscript_frontend.py b/src/gt4py/cartesian/frontend/gtscript_frontend.py index f665410b30..2df8c106ce 100644 --- a/src/gt4py/cartesian/frontend/gtscript_frontend.py +++ b/src/gt4py/cartesian/frontend/gtscript_frontend.py @@ -1154,9 +1154,11 @@ def visit_Subscript(self, node: ast.Subscript): result.offset = {axis: value for axis, value in zip(field_axes, index)} elif isinstance(node.value, ast.Subscript): result.data_index = [ - nodes.ScalarLiteral(value=value, data_type=nodes.DataType.INT32) - if isinstance(value, numbers.Integral) - else value + ( + nodes.ScalarLiteral(value=value, data_type=nodes.DataType.INT32) + if isinstance(value, numbers.Integral) + else value + ) for value in index ] if len(result.data_index) != len(self.fields[result.name].data_dims): @@ -1321,9 +1323,11 @@ def visit_If(self, node: ast.If) -> list: condition=self.visit(node.test), loc=nodes.Location.from_ast_node(node), main_body=nodes.BlockStmt(stmts=main_stmts, loc=nodes.Location.from_ast_node(node)), - else_body=nodes.BlockStmt(stmts=else_stmts, loc=nodes.Location.from_ast_node(node)) - if else_stmts - else None, + else_body=( + nodes.BlockStmt(stmts=else_stmts, loc=nodes.Location.from_ast_node(node)) + if else_stmts + else None + ), ) ) diff --git a/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py b/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py index 567d128c29..de1ca93557 100644 --- a/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py +++ b/src/gt4py/cartesian/gtc/cuir/oir_to_cuir.py @@ -29,8 +29,7 @@ class SymbolNameCreator(Protocol): - def __call__(self, name: str) -> str: - ... + def __call__(self, name: str) -> str: ... def _make_axis_offset_expr(bound: common.AxisBound, axis_index: int) -> cuir.Expr: diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index 48b129fa87..9a214441ad 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -118,18 +118,26 @@ def all_regions_same(scope_nodes): len( set( ( - None - if mask.intervals[axis.to_idx()].start is None - else mask.intervals[axis.to_idx()].start.level, - None - if mask.intervals[axis.to_idx()].start is None - else mask.intervals[axis.to_idx()].start.offset, - None - if mask.intervals[axis.to_idx()].end is None - else mask.intervals[axis.to_idx()].end.level, - None - if mask.intervals[axis.to_idx()].end is None - else mask.intervals[axis.to_idx()].end.offset, + ( + None + if mask.intervals[axis.to_idx()].start is None + else mask.intervals[axis.to_idx()].start.level + ), + ( + None + if mask.intervals[axis.to_idx()].start is None + else mask.intervals[axis.to_idx()].start.offset + ), + ( + None + if mask.intervals[axis.to_idx()].end is None + else mask.intervals[axis.to_idx()].end.level + ), + ( + None + if mask.intervals[axis.to_idx()].end is None + else mask.intervals[axis.to_idx()].end.offset + ), ) for mask in eve.walk_values(scope_nodes).if_isinstance(common.HorizontalMask) ) diff --git a/src/gt4py/cartesian/gtc/dace/expansion_specification.py b/src/gt4py/cartesian/gtc/dace/expansion_specification.py index 57146ef2a8..7c99146426 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion_specification.py +++ b/src/gt4py/cartesian/gtc/dace/expansion_specification.py @@ -173,9 +173,11 @@ def _order_as_spec(computation_node, expansion_order): expansion_specification.append( Loop( axis=axis, - stride=-1 - if computation_node.oir_node.loop_order == common.LoopOrder.BACKWARD - else 1, + stride=( + -1 + if computation_node.oir_node.loop_order == common.LoopOrder.BACKWARD + else 1 + ), ) ) elif item == "Sections": diff --git a/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py b/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py index 82991af1d4..58cddffd5f 100644 --- a/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py +++ b/src/gt4py/cartesian/gtc/gtcpp/oir_to_gtcpp.py @@ -94,8 +94,7 @@ def _make_axis_offset_expr( class SymbolNameCreator(Protocol): - def __call__(self, name: str) -> str: - ... + def __call__(self, name: str) -> str: ... class OIRToGTCpp(eve.NodeTranslator, eve.VisitorWithSymbolTableTrait): diff --git a/src/gt4py/cartesian/testing/input_strategies.py b/src/gt4py/cartesian/testing/input_strategies.py index 5f3ed32572..008b859929 100644 --- a/src/gt4py/cartesian/testing/input_strategies.py +++ b/src/gt4py/cartesian/testing/input_strategies.py @@ -142,7 +142,11 @@ def scalar_value_st(dtype, min_value, max_value, allow_nan=False): """Hypothesis strategy for `dtype` scalar values in range [min_value, max_value].""" allow_infinity = not (np.isfinite(min_value) and np.isfinite(max_value)) - if issubclass(dtype.type, numbers.Real): + if issubclass(dtype.type, numbers.Integral): + value_st = hyp_st.integers(min_value, max_value) + elif issubclass( + dtype.type, numbers.Real + ): # after numbers.Integral because np.int32 is a subclass of numbers.Real value_st = hyp_st.floats( min_value, max_value, @@ -150,8 +154,6 @@ def scalar_value_st(dtype, min_value, max_value, allow_nan=False): allow_nan=allow_nan, width=dtype.itemsize * 8, ) - elif issubclass(dtype.type, numbers.Integral): - value_st = hyp_st.integers(min_value, max_value) return value_st.map(dtype.type) diff --git a/src/gt4py/cartesian/type_hints.py b/src/gt4py/cartesian/type_hints.py index a1af6b93d1..3a776ba847 100644 --- a/src/gt4py/cartesian/type_hints.py +++ b/src/gt4py/cartesian/type_hints.py @@ -21,8 +21,7 @@ class StencilFunc(Protocol): __name__: str __module__: str - def __call__(self, *args: Any, **kwargs: Dict[str, Any]) -> None: - ... + def __call__(self, *args: Any, **kwargs: Dict[str, Any]) -> None: ... class AnnotatedStencilFunc(StencilFunc, Protocol): diff --git a/src/gt4py/cartesian/utils/attrib.py b/src/gt4py/cartesian/utils/attrib.py index f2f77769ec..da53e5c128 100644 --- a/src/gt4py/cartesian/utils/attrib.py +++ b/src/gt4py/cartesian/utils/attrib.py @@ -240,16 +240,13 @@ def attribute(of, optional=False, **kwargs): class AttributeClassLike: - def validate(self): - ... + def validate(self): ... @property - def attributes(self): - ... + def attributes(self): ... @property - def as_dict(self): - ... + def as_dict(self): ... def attribclass(cls_or_none=None, **kwargs): diff --git a/src/gt4py/eve/codegen.py b/src/gt4py/eve/codegen.py index 3a964c92a9..72f0e8858f 100644 --- a/src/gt4py/eve/codegen.py +++ b/src/gt4py/eve/codegen.py @@ -641,15 +641,13 @@ def __init_subclass__(cls, *, inherit_templates: bool = True, **kwargs: Any) -> @overload @classmethod - def apply(cls, root: LeafNode, **kwargs: Any) -> str: - ... + def apply(cls, root: LeafNode, **kwargs: Any) -> str: ... @overload @classmethod def apply( # noqa: F811 # redefinition of symbol cls, root: CollectionNode, **kwargs: Any - ) -> Collection[str]: - ... + ) -> Collection[str]: ... @classmethod def apply( # noqa: F811 # redefinition of symbol diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index bc744b3ccc..11ad824aab 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -84,8 +84,7 @@ class _AttrsClassTP(Protocol): class DataModelTP(_AttrsClassTP, xtyping.DevToolsPrettyPrintable, Protocol): - def __init__(self, *args: Any, **kwargs: Any) -> None: - ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... __datamodel_fields__: ClassVar[utils.FrozenNamespace[Attribute]] = cast( utils.FrozenNamespace[Attribute], None @@ -116,8 +115,7 @@ class GenericDataModelTP(DataModelTP, Protocol): @classmethod def __class_getitem__( cls: Type[GenericDataModelTP], args: Union[Type, Tuple[Type, ...]] - ) -> Union[DataModelTP, GenericDataModelTP]: - ... + ) -> Union[DataModelTP, GenericDataModelTP]: ... _DM = TypeVar("_DM", bound="DataModel") @@ -280,8 +278,7 @@ def datamodel( coerce: bool = _COERCE_DEFAULT, generic: bool = _GENERIC_DEFAULT, type_validation_factory: Optional[FieldTypeValidatorFactory] = DefaultFieldTypeValidatorFactory, -) -> Callable[[Type[_T]], Type[_T]]: - ... +) -> Callable[[Type[_T]], Type[_T]]: ... @overload @@ -300,8 +297,7 @@ def datamodel( # noqa: F811 # redefinion of unused symbol coerce: bool = _COERCE_DEFAULT, generic: bool = _GENERIC_DEFAULT, type_validation_factory: Optional[FieldTypeValidatorFactory] = DefaultFieldTypeValidatorFactory, -) -> Type[_T]: - ... +) -> Type[_T]: ... # TODO(egparedes): Use @dataclass_transform(eq_default=True, field_specifiers=("field",)) @@ -410,8 +406,7 @@ def __call__( type_validation_factory: Optional[ FieldTypeValidatorFactory ] = DefaultFieldTypeValidatorFactory, - ) -> Union[Type[_T], Callable[[Type[_T]], Type[_T]]]: - ... + ) -> Union[Type[_T], Callable[[Type[_T]], Type[_T]]]: ... frozenmodel: _DataModelDecoratorTP = functools.partial(datamodel, frozen=True) @@ -424,13 +419,11 @@ def __call__( if xtyping.TYPE_CHECKING: class DataModel(DataModelTP): - def __init__(self, *args: Any, **kwargs: Any) -> None: - ... + def __init__(self, *args: Any, **kwargs: Any) -> None: ... def __pretty__( self, fmt: Callable[[Any], Any], **kwargs: Any - ) -> Generator[Any, None, None]: - ... + ) -> Generator[Any, None, None]: ... else: # TODO(egparedes): use @dataclass_transform(eq_default=True, field_specifiers=("field",)) @@ -453,9 +446,9 @@ def __init_subclass__( cls, /, *, - repr: bool # noqa: A002 # shadowing 'repr' python builtin - | None - | Literal["inherited"] = "inherited", + repr: ( # noqa: A002 # shadowing 'repr' python builtin + bool | None | Literal["inherited"] + ) = "inherited", eq: bool | None | Literal["inherited"] = "inherited", order: bool | None | Literal["inherited"] = "inherited", unsafe_hash: bool | None | Literal["inherited"] = "inherited", @@ -463,8 +456,9 @@ def __init_subclass__( match_args: bool | Literal["inherited"] = "inherited", kw_only: bool | Literal["inherited"] = "inherited", coerce: bool | Literal["inherited"] = "inherited", - type_validation_factory: Optional[FieldTypeValidatorFactory] - | Literal["inherited"] = "inherited", + type_validation_factory: ( + Optional[FieldTypeValidatorFactory] | Literal["inherited"] + ) = "inherited", **kwargs: Any, ) -> None: dm_opts = kwargs.pop(_DM_OPTS, []) @@ -519,10 +513,9 @@ def field( metadata: Optional[Mapping[Any, Any]] = None, kw_only: bool = _KW_ONLY_DEFAULT, converter: Callable[[Any], Any] | Literal["coerce"] | None = None, - validator: AttrsValidator - | FieldValidator - | Sequence[AttrsValidator | FieldValidator] - | None = None, + validator: ( + AttrsValidator | FieldValidator | Sequence[AttrsValidator | FieldValidator] | None + ) = None, ) -> Any: # attr.s lies in some typings """Define a new attribute on a class with advanced options. @@ -1373,8 +1366,7 @@ class GenericDataModel(GenericDataModelTP): @classmethod def __class_getitem__( cls: Type[GenericDataModelTP], args: Union[Type, Tuple[Type, ...]] - ) -> Union[DataModelTP, GenericDataModelTP]: - ... + ) -> Union[DataModelTP, GenericDataModelTP]: ... else: diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index 3ee447ca6c..82076d1a9c 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -177,19 +177,16 @@ class NonDataDescriptor(Protocol[_C, _V]): @overload def __get__( self, _instance: Literal[None], _owner_type: Optional[Type[_C]] = None - ) -> NonDataDescriptor[_C, _V]: - ... + ) -> NonDataDescriptor[_C, _V]: ... @overload def __get__( # noqa: F811 # redefinion of unused member self, _instance: _C, _owner_type: Optional[Type[_C]] = None - ) -> _V: - ... + ) -> _V: ... def __get__( # noqa: F811 # redefinion of unused member self, _instance: Optional[_C], _owner_type: Optional[Type[_C]] = None - ) -> _V | NonDataDescriptor[_C, _V]: - ... + ) -> _V | NonDataDescriptor[_C, _V]: ... class DataDescriptor(NonDataDescriptor[_C, _V], Protocol): @@ -198,11 +195,9 @@ class DataDescriptor(NonDataDescriptor[_C, _V], Protocol): See https://docs.python.org/3/howto/descriptor.html for further information. """ - def __set__(self, _instance: _C, _value: _V) -> None: - ... + def __set__(self, _instance: _C, _value: _V) -> None: ... - def __delete__(self, _instance: _C) -> None: - ... + def __delete__(self, _instance: _C) -> None: ... # -- Based on typeshed definitions -- @@ -220,26 +215,20 @@ class HashlibAlgorithm(Protocol): block_size: int name: str - def __init__(self, data: ReadableBuffer = ...) -> None: - ... + def __init__(self, data: ReadableBuffer = ...) -> None: ... - def copy(self) -> HashlibAlgorithm: - ... + def copy(self) -> HashlibAlgorithm: ... - def update(self, data: ReadableBuffer) -> None: - ... + def update(self, data: ReadableBuffer) -> None: ... - def digest(self) -> bytes: - ... + def digest(self) -> bytes: ... - def hexdigest(self) -> str: - ... + def hexdigest(self) -> str: ... # -- Third party protocols -- class SupportsArray(Protocol): - def __array__(self, dtype: Optional[npt.DTypeLike] = None, /) -> npt.NDArray[Any]: - ... + def __array__(self, dtype: Optional[npt.DTypeLike] = None, /) -> npt.NDArray[Any]: ... def supports_array(value: Any) -> TypeGuard[SupportsArray]: @@ -248,8 +237,7 @@ def supports_array(value: Any) -> TypeGuard[SupportsArray]: class ArrayInterface(Protocol): @property - def __array_interface__(self) -> Dict[str, Any]: - ... + def __array_interface__(self) -> Dict[str, Any]: ... class ArrayInterfaceTypedDict(TypedDict): @@ -265,8 +253,7 @@ class ArrayInterfaceTypedDict(TypedDict): class StrictArrayInterface(Protocol): @property - def __array_interface__(self) -> ArrayInterfaceTypedDict: - ... + def __array_interface__(self) -> ArrayInterfaceTypedDict: ... def supports_array_interface(value: Any) -> TypeGuard[ArrayInterface]: @@ -275,8 +262,7 @@ def supports_array_interface(value: Any) -> TypeGuard[ArrayInterface]: class CUDAArrayInterface(Protocol): @property - def __cuda_array_interface__(self) -> Dict[str, Any]: - ... + def __cuda_array_interface__(self) -> Dict[str, Any]: ... class CUDAArrayInterfaceTypedDict(TypedDict): @@ -292,8 +278,7 @@ class CUDAArrayInterfaceTypedDict(TypedDict): class StrictCUDAArrayInterface(Protocol): @property - def __cuda_array_interface__(self) -> CUDAArrayInterfaceTypedDict: - ... + def __cuda_array_interface__(self) -> CUDAArrayInterfaceTypedDict: ... def supports_cuda_array_interface(value: Any) -> TypeGuard[CUDAArrayInterface]: @@ -305,19 +290,15 @@ def supports_cuda_array_interface(value: Any) -> TypeGuard[CUDAArrayInterface]: class MultiStreamDLPackBuffer(Protocol): - def __dlpack__(self, *, stream: Optional[int] = None) -> Any: - ... + def __dlpack__(self, *, stream: Optional[int] = None) -> Any: ... - def __dlpack_device__(self) -> DLPackDevice: - ... + def __dlpack_device__(self) -> DLPackDevice: ... class SingleStreamDLPackBuffer(Protocol): - def __dlpack__(self, *, stream: None = None) -> Any: - ... + def __dlpack__(self, *, stream: None = None) -> Any: ... - def __dlpack_device__(self) -> DLPackDevice: - ... + def __dlpack_device__(self) -> DLPackDevice: ... DLPackBuffer: TypeAlias = Union[MultiStreamDLPackBuffer, SingleStreamDLPackBuffer] @@ -333,8 +314,9 @@ def supports_dlpack(value: Any) -> TypeGuard[DLPackBuffer]: class DevToolsPrettyPrintable(Protocol): """Used by python-devtools (https://python-devtools.helpmanual.io/).""" - def __pretty__(self, fmt: Callable[[Any], Any], **kwargs: Any) -> Generator[Any, None, None]: - ... + def __pretty__( + self, fmt: Callable[[Any], Any], **kwargs: Any + ) -> Generator[Any, None, None]: ... # -- Added functionality -- @@ -357,8 +339,7 @@ def extended_runtime_checkable( *, instance_check_shortcut: bool = True, subclass_check_with_data_members: bool = False, -) -> Callable[[_ProtoT], _ProtoT]: - ... +) -> Callable[[_ProtoT], _ProtoT]: ... @overload @@ -367,8 +348,7 @@ def extended_runtime_checkable( *, instance_check_shortcut: bool = True, subclass_check_with_data_members: bool = False, -) -> _ProtoT: - ... +) -> _ProtoT: ... def extended_runtime_checkable( # noqa: C901 # too complex but unavoidable @@ -414,9 +394,11 @@ def _decorator(cls: _ProtoT) -> _ProtoT: _allow_reckless_class_checks = getattr( _typing, - "_allow_reckless_class_checks" - if hasattr(_typing, "_allow_reckless_class_checks") - else "_allow_reckless_class_cheks", # There is a typo in 3.8 and 3.9 + ( + "_allow_reckless_class_checks" + if hasattr(_typing, "_allow_reckless_class_checks") + else "_allow_reckless_class_cheks" + ), # There is a typo in 3.8 and 3.9 ) _get_protocol_attrs = ( diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py index 74c5bd41bb..7bfd22cdf7 100644 --- a/src/gt4py/eve/trees.py +++ b/src/gt4py/eve/trees.py @@ -62,12 +62,10 @@ class TreeLike(abc.ABC): # noqa: B024 class Tree(Protocol): @abc.abstractmethod - def iter_children_values(self) -> Iterable: - ... + def iter_children_values(self) -> Iterable: ... @abc.abstractmethod - def iter_children_items(self) -> Iterable[Tuple[TreeKey, Any]]: - ... + def iter_children_items(self) -> Iterable[Tuple[TreeKey, Any]]: ... TreeLike.register(Tree) diff --git a/src/gt4py/eve/type_validation.py b/src/gt4py/eve/type_validation.py index 65f492ebfe..124957fa20 100644 --- a/src/gt4py/eve/type_validation.py +++ b/src/gt4py/eve/type_validation.py @@ -110,8 +110,7 @@ def __call__( globalns: Optional[Dict[str, Any]] = None, localns: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> FixedTypeValidator: - ... + ) -> FixedTypeValidator: ... @overload def __call__( # noqa: F811 # redefinion of unused member @@ -123,8 +122,7 @@ def __call__( # noqa: F811 # redefinion of unused member globalns: Optional[Dict[str, Any]] = None, localns: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> Optional[FixedTypeValidator]: - ... + ) -> Optional[FixedTypeValidator]: ... @abc.abstractmethod def __call__( # noqa: F811 # redefinion of unused member @@ -169,8 +167,7 @@ def __call__( globalns: Optional[Dict[str, Any]] = None, localns: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> FixedTypeValidator: - ... + ) -> FixedTypeValidator: ... @overload def __call__( # noqa: F811 # redefinion of unused member @@ -182,8 +179,7 @@ def __call__( # noqa: F811 # redefinion of unused member globalns: Optional[Dict[str, Any]] = None, localns: Optional[Dict[str, Any]] = None, **kwargs: Any, - ) -> Optional[FixedTypeValidator]: - ... + ) -> Optional[FixedTypeValidator]: ... def __call__( # noqa: F811,C901 # redefinion of unused member / complex but well organized in cases self, diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 624407f319..8e634c4b11 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -241,15 +241,13 @@ def partial(self, *args: Any, **kwargs: Any) -> fluid_partial: @overload def with_fluid_partial( func: Literal[None] = None, *args: Any, **kwargs: Any -) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: - ... +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ... @overload def with_fluid_partial( # noqa: F811 # redefinition of unused function func: Callable[_P, _T], *args: Any, **kwargs: Any -) -> Callable[_P, _T]: - ... +) -> Callable[_P, _T]: ... def with_fluid_partial( # noqa: F811 # redefinition of unused function @@ -286,15 +284,13 @@ def _decorator(func: Callable[..., Any]) -> Callable[..., Any]: @overload def optional_lru_cache( func: Literal[None] = None, *, maxsize: Optional[int] = 128, typed: bool = False -) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: - ... +) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ... @overload def optional_lru_cache( # noqa: F811 # redefinition of unused function func: Callable[_P, _T], *, maxsize: Optional[int] = 128, typed: bool = False -) -> Callable[_P, _T]: - ... +) -> Callable[_P, _T]: ... def optional_lru_cache( # noqa: F811 # redefinition of unused function @@ -1228,12 +1224,10 @@ def unzip(self) -> XIterable[Tuple[Any, ...]]: return XIterable(zip(*self.iterator)) @typing.overload - def islice(self, __stop: int) -> XIterable[T]: - ... + def islice(self, __stop: int) -> XIterable[T]: ... @typing.overload - def islice(self, __start: int, __stop: int, __step: int = 1) -> XIterable[T]: - ... + def islice(self, __start: int, __stop: int, __step: int = 1) -> XIterable[T]: ... def islice( self, @@ -1315,18 +1309,17 @@ def unique(self, *, key: Union[NOTHING, Callable] = NOTHING) -> XIterable[T]: @typing.overload def groupby( self, key: str, *other_keys: str, as_dict: bool = False - ) -> XIterable[Tuple[Any, List[T]]]: - ... + ) -> XIterable[Tuple[Any, List[T]]]: ... @typing.overload - def groupby(self, key: List[Any], *, as_dict: bool = False) -> XIterable[Tuple[Any, List[T]]]: - ... + def groupby( + self, key: List[Any], *, as_dict: bool = False + ) -> XIterable[Tuple[Any, List[T]]]: ... @typing.overload def groupby( self, key: Callable[[T], Any], *, as_dict: bool = False - ) -> XIterable[Tuple[Any, List[T]]]: - ... + ) -> XIterable[Tuple[Any, List[T]]]: ... def groupby( self, @@ -1454,8 +1447,7 @@ def reduceby( *, as_dict: Literal[False], init: Union[S, NothingType], - ) -> XIterable[Tuple[str, S]]: - ... + ) -> XIterable[Tuple[str, S]]: ... @typing.overload def reduceby( @@ -1466,8 +1458,7 @@ def reduceby( *attr_keys: str, as_dict: Literal[False], init: Union[S, NothingType], - ) -> XIterable[Tuple[Tuple[str, ...], S]]: - ... + ) -> XIterable[Tuple[Tuple[str, ...], S]]: ... @typing.overload def reduceby( @@ -1477,8 +1468,7 @@ def reduceby( *, as_dict: Literal[True], init: Union[S, NothingType], - ) -> Dict[str, S]: - ... + ) -> Dict[str, S]: ... @typing.overload def reduceby( @@ -1489,8 +1479,7 @@ def reduceby( *attr_keys: str, as_dict: Literal[True], init: Union[S, NothingType], - ) -> Dict[Tuple[str, ...], S]: - ... + ) -> Dict[Tuple[str, ...], S]: ... @typing.overload def reduceby( @@ -1500,8 +1489,7 @@ def reduceby( *, as_dict: Literal[False], init: Union[S, NothingType], - ) -> XIterable[Tuple[K, S]]: - ... + ) -> XIterable[Tuple[K, S]]: ... @typing.overload def reduceby( @@ -1511,8 +1499,7 @@ def reduceby( *, as_dict: Literal[True], init: Union[S, NothingType], - ) -> Dict[K, S]: - ... + ) -> Dict[K, S]: ... @typing.overload def reduceby( @@ -1522,8 +1509,7 @@ def reduceby( *, as_dict: Literal[False], init: Union[S, NothingType], - ) -> XIterable[Tuple[K, S]]: - ... + ) -> XIterable[Tuple[K, S]]: ... @typing.overload def reduceby( @@ -1533,8 +1519,7 @@ def reduceby( *, as_dict: Literal[True], init: Union[S, NothingType], - ) -> Dict[K, S]: - ... + ) -> Dict[K, S]: ... def reduceby( self, diff --git a/src/gt4py/next/allocators.py b/src/gt4py/next/allocators.py index 44203bf6d8..559e78eb3e 100644 --- a/src/gt4py/next/allocators.py +++ b/src/gt4py/next/allocators.py @@ -58,8 +58,7 @@ class FieldBufferAllocatorProtocol(Protocol[core_defs.DeviceTypeT]): @property @abc.abstractmethod - def __gt_device_type__(self) -> core_defs.DeviceTypeT: - ... + def __gt_device_type__(self) -> core_defs.DeviceTypeT: ... @abc.abstractmethod def __gt_allocate__( @@ -68,8 +67,7 @@ def __gt_allocate__( dtype: core_defs.DType[core_defs.ScalarT], device_id: int = 0, aligned_index: Optional[Sequence[common.NamedIndex]] = None, # absolute position - ) -> core_allocators.TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: - ... + ) -> core_allocators.TensorBuffer[core_defs.DeviceTypeT, core_defs.ScalarT]: ... def is_field_allocator(obj: Any) -> TypeGuard[FieldBufferAllocatorProtocol]: @@ -87,8 +85,7 @@ class FieldBufferAllocatorFactoryProtocol(Protocol[core_defs.DeviceTypeT]): @property @abc.abstractmethod - def __gt_allocator__(self) -> FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]: - ... + def __gt_allocator__(self) -> FieldBufferAllocatorProtocol[core_defs.DeviceTypeT]: ... def is_field_allocator_factory(obj: Any) -> TypeGuard[FieldBufferAllocatorFactoryProtocol]: @@ -178,9 +175,9 @@ def __gt_allocate__( if TYPE_CHECKING: - __TensorFieldAllocatorAsFieldAllocatorInterfaceT: type[ - FieldBufferAllocatorProtocol - ] = BaseFieldBufferAllocator + __TensorFieldAllocatorAsFieldAllocatorInterfaceT: type[FieldBufferAllocatorProtocol] = ( + BaseFieldBufferAllocator + ) def horizontal_first_layout_mapper( diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 33a0591813..d8ffc2057b 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -160,8 +160,7 @@ def __repr__(self) -> str: return f"UnitRange({self.start}, {self.stop})" @overload - def __getitem__(self, index: int) -> int: - ... + def __getitem__(self, index: int) -> int: ... @overload def __getitem__(self, index: slice) -> UnitRange: # noqa: F811 # redefine unused @@ -414,8 +413,7 @@ def is_finite(cls, obj: Domain) -> TypeGuard[FiniteDomain]: return all(UnitRange.is_finite(rng) for rng in obj.ranges) @overload - def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: - ... + def __getitem__(self, index: int) -> tuple[Dimension, _Rng]: ... @overload def __getitem__(self, index: slice) -> Self: # noqa: F811 # redefine unused @@ -424,8 +422,7 @@ def __getitem__(self, index: slice) -> Self: # noqa: F811 # redefine unused @overload def __getitem__( # noqa: F811 # redefine unused self, index: Dimension - ) -> tuple[Dimension, _Rng]: - ... + ) -> tuple[Dimension, _Rng]: ... def __getitem__( # noqa: F811 # redefine unused self, index: int | slice | Dimension @@ -571,8 +568,7 @@ def _broadcast_ranges( _R = TypeVar("_R", _Value, tuple[_Value, ...]) class GTBuiltInFuncDispatcher(Protocol): - def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _R]: - ... + def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _R]: ... # TODO(havogt): we need to describe when this interface should be used instead of the `Field` protocol. @@ -601,56 +597,45 @@ class Field(GTFieldInterface, Protocol[DimsT, core_defs.ScalarT]): __gt_builtin_func__: ClassVar[GTBuiltInFuncDispatcher] @property - def domain(self) -> Domain: - ... + def domain(self) -> Domain: ... @property def __gt_domain__(self) -> Domain: return self.domain @property - def codomain(self) -> type[core_defs.ScalarT] | Dimension: - ... + def codomain(self) -> type[core_defs.ScalarT] | Dimension: ... @property - def dtype(self) -> core_defs.DType[core_defs.ScalarT]: - ... + def dtype(self) -> core_defs.DType[core_defs.ScalarT]: ... @property - def ndarray(self) -> core_defs.NDArrayObject: - ... + def ndarray(self) -> core_defs.NDArrayObject: ... def __str__(self) -> str: return f"⟨{self.domain!s} → {self.dtype}⟩" @abc.abstractmethod - def asnumpy(self) -> np.ndarray: - ... + def asnumpy(self) -> np.ndarray: ... @abc.abstractmethod - def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: - ... + def remap(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod - def restrict(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: - ... + def restrict(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ... # Operators @abc.abstractmethod - def __call__(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: - ... + def __call__(self, index_field: ConnectivityField | fbuiltins.FieldOffset) -> Field: ... @abc.abstractmethod - def __getitem__(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: - ... + def __getitem__(self, item: AnyIndexSpec) -> Field | core_defs.ScalarT: ... @abc.abstractmethod - def __abs__(self) -> Field: - ... + def __abs__(self) -> Field: ... @abc.abstractmethod - def __neg__(self) -> Field: - ... + def __neg__(self) -> Field: ... @abc.abstractmethod def __invert__(self) -> Field: @@ -665,48 +650,37 @@ def __ne__(self, other: Any) -> Field: # type: ignore[override] # mypy wants re ... @abc.abstractmethod - def __add__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __add__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __radd__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __radd__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __sub__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __sub__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __rsub__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __rsub__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __mul__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __mul__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __rmul__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __rmul__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __floordiv__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __floordiv__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __rfloordiv__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __rfloordiv__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __truediv__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __truediv__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __rtruediv__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __rtruediv__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod - def __pow__(self, other: Field | core_defs.ScalarT) -> Field: - ... + def __pow__(self, other: Field | core_defs.ScalarT) -> Field: ... @abc.abstractmethod def __and__(self, other: Field | core_defs.ScalarT) -> Field: @@ -734,8 +708,7 @@ def is_field( @extended_runtime_checkable class MutableField(Field[DimsT, core_defs.ScalarT], Protocol[DimsT, core_defs.ScalarT]): @abc.abstractmethod - def __setitem__(self, index: AnyIndexSpec, value: Field | core_defs.ScalarT) -> None: - ... + def __setitem__(self, index: AnyIndexSpec, value: Field | core_defs.ScalarT) -> None: ... def is_mutable_field( @@ -759,8 +732,7 @@ class ConnectivityKind(enum.Flag): class ConnectivityField(Field[DimsT, core_defs.IntegralScalar], Protocol[DimsT, DimT]): @property @abc.abstractmethod - def codomain(self) -> DimT: - ... + def codomain(self) -> DimT: ... @property def kind(self) -> ConnectivityKind: @@ -771,8 +743,7 @@ def kind(self) -> ConnectivityKind: ) @abc.abstractmethod - def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: - ... + def inverse_image(self, image_range: UnitRange | NamedRange) -> Sequence[NamedRange]: ... # Operators def __abs__(self) -> Never: @@ -1076,9 +1047,9 @@ class FieldBuiltinFuncRegistry: dispatching (via ChainMap) to its parent's registries. """ - _builtin_func_map: collections.ChainMap[ - fbuiltins.BuiltInFunction, Callable - ] = collections.ChainMap() + _builtin_func_map: collections.ChainMap[fbuiltins.BuiltInFunction, Callable] = ( + collections.ChainMap() + ) def __init_subclass__(cls, **kwargs): cls._builtin_func_map = collections.ChainMap( diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index 52a61b40bb..3a22df1032 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -131,8 +131,9 @@ def dtype(self) -> core_defs.DType[core_defs.ScalarT]: @classmethod def from_array( cls, - data: npt.ArrayLike - | core_defs.NDArrayObject, # TODO: NDArrayObject should be part of ArrayLike + data: ( + npt.ArrayLike | core_defs.NDArrayObject + ), # TODO: NDArrayObject should be part of ArrayLike /, *, domain: common.DomainLike, @@ -476,9 +477,10 @@ def restrict(self, index: common.AnyIndexSpec) -> common.Field | core_defs.Integ NdArrayField.register_builtin_func(fbuiltins.where, _make_builtin("where", "where")) -def _make_reduction( - builtin_name: str, array_builtin_name: str -) -> Callable[..., NdArrayField[common.DimsT, core_defs.ScalarT],]: +def _make_reduction(builtin_name: str, array_builtin_name: str) -> Callable[ + ..., + NdArrayField[common.DimsT, core_defs.ScalarT], +]: def _builtin_op( field: NdArrayField[common.DimsT, core_defs.ScalarT], axis: common.Dimension ) -> NdArrayField[common.DimsT, core_defs.ScalarT]: diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 9f8537f59b..6510be560e 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -490,13 +490,13 @@ def itir(self): @typing.overload -def program(definition: types.FunctionType) -> Program: - ... +def program(definition: types.FunctionType) -> Program: ... @typing.overload -def program(*, backend: Optional[ppi.ProgramExecutor]) -> Callable[[types.FunctionType], Program]: - ... +def program( + *, backend: Optional[ppi.ProgramExecutor] +) -> Callable[[types.FunctionType], Program]: ... def program( @@ -748,15 +748,13 @@ def __call__( @typing.overload def field_operator( definition: types.FunctionType, *, backend: Optional[ppi.ProgramExecutor] -) -> FieldOperator[foast.FieldOperator]: - ... +) -> FieldOperator[foast.FieldOperator]: ... @typing.overload def field_operator( *, backend: Optional[ppi.ProgramExecutor] -) -> Callable[[types.FunctionType], FieldOperator[foast.FieldOperator]]: - ... +) -> Callable[[types.FunctionType], FieldOperator[foast.FieldOperator]]: ... def field_operator(definition=None, *, backend=eve.NOTHING, grid_type=None): @@ -793,8 +791,7 @@ def scan_operator( init: core_defs.Scalar, backend: Optional[str], grid_type: GridType, -) -> FieldOperator[foast.ScanOperator]: - ... +) -> FieldOperator[foast.ScanOperator]: ... @typing.overload @@ -805,8 +802,7 @@ def scan_operator( init: core_defs.Scalar, backend: Optional[str], grid_type: GridType, -) -> Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]]: - ... +) -> Callable[[types.FunctionType], FieldOperator[foast.ScanOperator]]: ... def scan_operator( diff --git a/src/gt4py/next/ffront/dialect_parser.py b/src/gt4py/next/ffront/dialect_parser.py index c04e978e51..07490db27c 100644 --- a/src/gt4py/next/ffront/dialect_parser.py +++ b/src/gt4py/next/ffront/dialect_parser.py @@ -39,12 +39,16 @@ def parse_source_definition(source_definition: SourceDefinition) -> ast.AST: line=err.lineno + source_definition.line_offset, column=err.offset + source_definition.column_offset, filename=source_definition.filename, - end_line=err.end_lineno + source_definition.line_offset - if err.end_lineno is not None - else None, - end_column=err.end_offset + source_definition.column_offset - if err.end_offset is not None - else None, + end_line=( + err.end_lineno + source_definition.line_offset + if err.end_lineno is not None + else None + ), + end_column=( + err.end_offset + source_definition.column_offset + if err.end_offset is not None + else None + ), ) raise errors.DSLError(loc, err.msg).with_traceback(err.__traceback__) diff --git a/src/gt4py/next/ffront/field_operator_ast.py b/src/gt4py/next/ffront/field_operator_ast.py index 6b772227b2..322a6df2e0 100644 --- a/src/gt4py/next/ffront/field_operator_ast.py +++ b/src/gt4py/next/ffront/field_operator_ast.py @@ -153,8 +153,7 @@ class Call(Expr): kwargs: dict[str, Expr] -class Stmt(LocatedNode): - ... +class Stmt(LocatedNode): ... class Starred(Expr): diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 5e289af664..64fea7935c 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -438,9 +438,9 @@ def visit_IfStmt(self, node: foast.IfStmt, **kwargs) -> foast.IfStmt: f"got types '{true_type}' and '{false_type}.", ) # TODO: properly patch symtable (new node?) - symtable[sym].type = new_node.annex.propagated_symbols[ - sym - ].type = new_true_branch.annex.symtable[sym].type + symtable[sym].type = new_node.annex.propagated_symbols[sym].type = ( + new_true_branch.annex.symtable[sym].type + ) return new_node diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index 0c9ab4ab27..c0e618a42d 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -430,5 +430,4 @@ def _process_elements( return self._map(im.lambda_("expr")(process_func(current_el_expr)), obj) -class FieldOperatorLoweringError(Exception): - ... +class FieldOperatorLoweringError(Exception): ... diff --git a/src/gt4py/next/ffront/program_ast.py b/src/gt4py/next/ffront/program_ast.py index 14151fc243..4ff8265f70 100644 --- a/src/gt4py/next/ffront/program_ast.py +++ b/src/gt4py/next/ffront/program_ast.py @@ -93,8 +93,7 @@ class Slice(Expr): step: Literal[None] -class Stmt(LocatedNode): - ... +class Stmt(LocatedNode): ... class Program(LocatedNode, SymbolTableTrait): diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 6d610fd136..6985aea853 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -80,8 +80,7 @@ ) -class SparseTag(Tag): - ... +class SparseTag(Tag): ... class NeighborTableOffsetProvider: @@ -156,14 +155,11 @@ class ItIterator(Protocol): `ItIterator` to avoid name clashes with `Iterator` from `typing` and `collections.abc`. """ - def shift(self, *offsets: OffsetPart) -> ItIterator: - ... + def shift(self, *offsets: OffsetPart) -> ItIterator: ... - def can_deref(self) -> bool: - ... + def can_deref(self) -> bool: ... - def deref(self) -> Any: - ... + def deref(self) -> Any: ... @runtime_checkable @@ -172,13 +168,11 @@ class LocatedField(Protocol): @property @abc.abstractmethod - def __gt_domain__(self) -> common.Domain: - ... + def __gt_domain__(self) -> common.Domain: ... # TODO(havogt): define generic Protocol to provide a concrete return type @abc.abstractmethod - def field_getitem(self, indices: NamedFieldIndices) -> Any: - ... + def field_getitem(self, indices: NamedFieldIndices) -> Any: ... @property def __gt_origin__(self) -> tuple[int, ...]: @@ -191,8 +185,7 @@ class MutableLocatedField(LocatedField, Protocol): # TODO(havogt): define generic Protocol to provide a concrete return type @abc.abstractmethod - def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: - ... + def field_setitem(self, indices: NamedFieldIndices, value: Any) -> None: ... #: Column range used in column mode (`column_axis != None`) in the current closure execution context. @@ -705,8 +698,7 @@ def _make_tuple( named_indices: NamedFieldIndices, *, column_axis: Tag, -) -> tuple[tuple | Column, ...]: - ... +) -> tuple[tuple | Column, ...]: ... @overload @@ -722,8 +714,7 @@ def _make_tuple( @overload def _make_tuple( field_or_tuple: LocatedField, named_indices: NamedFieldIndices, *, column_axis: Tag -) -> Column: - ... +) -> Column: ... @overload @@ -732,8 +723,7 @@ def _make_tuple( named_indices: NamedFieldIndices, *, column_axis: Literal[None] = None, -) -> npt.DTypeLike | Undefined: - ... +) -> npt.DTypeLike | Undefined: ... def _make_tuple( @@ -974,13 +964,11 @@ def get_ordered_indices(axes: Iterable[Axis], pos: NamedFieldIndices) -> tuple[F @overload -def _shift_range(range_or_index: range, offset: int) -> slice: - ... +def _shift_range(range_or_index: range, offset: int) -> slice: ... @overload -def _shift_range(range_or_index: common.IntIndex, offset: int) -> common.IntIndex: - ... +def _shift_range(range_or_index: common.IntIndex, offset: int) -> common.IntIndex: ... def _shift_range(range_or_index: range | common.IntIndex, offset: int) -> ArrayIndex: @@ -994,13 +982,11 @@ def _shift_range(range_or_index: range | common.IntIndex, offset: int) -> ArrayI @overload -def _range2slice(r: range) -> slice: - ... +def _range2slice(r: range) -> slice: ... @overload -def _range2slice(r: common.IntIndex) -> common.IntIndex: - ... +def _range2slice(r: common.IntIndex) -> common.IntIndex: ... def _range2slice(r: range | common.IntIndex) -> slice | common.IntIndex: @@ -1288,8 +1274,7 @@ def impl(it: ItIterator) -> ItIterator: DT = TypeVar("DT") -class _List(tuple, Generic[DT]): - ... +class _List(tuple, Generic[DT]): ... @dataclasses.dataclass(frozen=True) @@ -1424,8 +1409,7 @@ def is_tuple_of_field(field) -> bool: ) -class TupleFieldMeta(type): - ... +class TupleFieldMeta(type): ... class TupleField(metaclass=TupleFieldMeta): diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index 37abbec9e7..10caecc591 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -45,9 +45,9 @@ class Sym(Node): # helper # TODO(tehrengruber): Revisit. Using strings is a workaround to avoid coupling with the # type inference. kind: typing.Literal["Iterator", "Value", None] = None - dtype: Optional[ - tuple[str, bool] - ] = None # format: name of primitive type, boolean indicating if it is a list + dtype: Optional[tuple[str, bool]] = ( + None # format: name of primitive type, boolean indicating if it is a list + ) @datamodels.validator("kind") def _kind_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribute, value: str): @@ -63,8 +63,7 @@ def _dtype_validator(self: datamodels.DataModelTP, attribute: datamodels.Attribu @noninstantiable -class Expr(Node): - ... +class Expr(Node): ... class Literal(Expr): diff --git a/src/gt4py/next/iterator/runtime.py b/src/gt4py/next/iterator/runtime.py index e12ae84dbc..5de4839b55 100644 --- a/src/gt4py/next/iterator/runtime.py +++ b/src/gt4py/next/iterator/runtime.py @@ -43,12 +43,10 @@ def offset(value): return Offset(value) -class CartesianDomain(dict): - ... +class CartesianDomain(dict): ... -class UnstructuredDomain(dict): - ... +class UnstructuredDomain(dict): ... # dependency inversion, register fendef for embedded execution or for tracing/parsing here diff --git a/src/gt4py/next/otf/compilation/compiler.py b/src/gt4py/next/otf/compilation/compiler.py index 9fd20b16e2..29541a3ae5 100644 --- a/src/gt4py/next/otf/compilation/compiler.py +++ b/src/gt4py/next/otf/compilation/compiler.py @@ -41,8 +41,7 @@ def __call__( self, source: stages.CompilableSource[SrcL, LS, TgtL], cache_strategy: cache.Strategy, - ) -> stages.BuildSystemProject[SrcL, LS, TgtL]: - ... + ) -> stages.BuildSystemProject[SrcL, LS, TgtL]: ... @dataclasses.dataclass(frozen=True) @@ -88,5 +87,4 @@ def __call__( ) -class CompilationError(RuntimeError): - ... +class CompilationError(RuntimeError): ... diff --git a/src/gt4py/next/otf/languages.py b/src/gt4py/next/otf/languages.py index b0d01d91ab..2397878271 100644 --- a/src/gt4py/next/otf/languages.py +++ b/src/gt4py/next/otf/languages.py @@ -57,8 +57,7 @@ class Python(LanguageTag): ... -class NanobindSrcL(LanguageTag): - ... +class NanobindSrcL(LanguageTag): ... class Cpp(NanobindSrcL): diff --git a/src/gt4py/next/otf/stages.py b/src/gt4py/next/otf/stages.py index a21bc83c0b..bd7f59e7aa 100644 --- a/src/gt4py/next/otf/stages.py +++ b/src/gt4py/next/otf/stages.py @@ -107,15 +107,13 @@ class BuildSystemProject(Protocol[SrcL_co, SettingT_co, TgtL_co]): and is not responsible for importing the results into Python. """ - def build(self) -> None: - ... + def build(self) -> None: ... class CompiledProgram(Protocol): """Executable python representation of a program.""" - def __call__(self, *args, **kwargs) -> None: - ... + def __call__(self, *args, **kwargs) -> None: ... def _unique_libs(*args: interface.LibraryDependency) -> tuple[interface.LibraryDependency, ...]: diff --git a/src/gt4py/next/otf/step_types.py b/src/gt4py/next/otf/step_types.py index 5eeb5c495b..43def259ab 100644 --- a/src/gt4py/next/otf/step_types.py +++ b/src/gt4py/next/otf/step_types.py @@ -46,8 +46,7 @@ class BindingStep(Protocol[SrcL, LS, TgtL]): def __call__( self, program_source: stages.ProgramSource[SrcL, LS] - ) -> stages.CompilableSource[SrcL, LS, TgtL]: - ... + ) -> stages.CompilableSource[SrcL, LS, TgtL]: ... class CompilationStep( @@ -56,5 +55,6 @@ class CompilationStep( ): """Compile program source code and bindings into a python callable (CompilableSource -> CompiledProgram).""" - def __call__(self, source: stages.CompilableSource[SrcL, LS, TgtL]) -> stages.CompiledProgram: - ... + def __call__( + self, source: stages.CompilableSource[SrcL, LS, TgtL] + ) -> stages.CompiledProgram: ... diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index 3a82f9c738..4bdb4bbb41 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -61,8 +61,7 @@ class Workflow(Protocol[StartT_contra, EndT_co]): - take a single input argument """ - def __call__(self, inp: StartT_contra) -> EndT_co: - ... + def __call__(self, inp: StartT_contra) -> EndT_co: ... class ReplaceEnabledWorkflowMixin(Workflow[StartT_contra, EndT_co], Protocol): diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py index f0843919fe..a62f50fc44 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_im_ir.py @@ -21,8 +21,7 @@ from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_common import Expr, Sym, SymRef -class Stmt(Node): - ... +class Stmt(Node): ... class AssignStmt(Stmt): @@ -35,8 +34,7 @@ class InitStmt(AssignStmt): init_type: str = "auto" -class EmptyListInitializer(Expr): - ... +class EmptyListInitializer(Expr): ... class Conditional(Stmt): diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_common.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_common.py index 79d4c18828..cb9aeffb90 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_common.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_ir_common.py @@ -25,8 +25,7 @@ class Sym(Node): # helper id: Coerced[SymbolName] # noqa: A003 -class Expr(Node): - ... +class Expr(Node): ... class SymRef(Expr): diff --git a/src/gt4py/next/program_processors/processor_interface.py b/src/gt4py/next/program_processors/processor_interface.py index 95d3d2ca35..0c280202b8 100644 --- a/src/gt4py/next/program_processors/processor_interface.py +++ b/src/gt4py/next/program_processors/processor_interface.py @@ -40,14 +40,12 @@ class ProgramProcessorCallable(Protocol[OutputT]): - def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> OutputT: - ... + def __call__(self, program: itir.FencilDefinition, *args, **kwargs) -> OutputT: ... class ProgramProcessor(ProgramProcessorCallable[OutputT], Protocol[OutputT, ProcessorKindT]): @property - def kind(self) -> type[ProcessorKindT]: - ... + def kind(self) -> type[ProcessorKindT]: ... class ProgramFormatter(ProgramProcessor[str, "ProgramFormatter"], Protocol): @@ -234,8 +232,7 @@ class ProgramBackend( ProgramProcessor[None, "ProgramExecutor"], next_allocators.FieldBufferAllocatorFactoryProtocol[core_defs.DeviceTypeT], Protocol[core_defs.DeviceTypeT], -): - ... +): ... def is_program_backend(obj: Callable) -> TypeGuard[ProgramBackend]: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index 525a5c694e..073c856d86 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -113,9 +113,11 @@ def _make_array_shape_and_strides( dtype = dace.int64 sorted_dims = [dim for _, dim in get_sorted_dims(dims)] if sort_dims else dims shape = [ - neighbor_tables[dim.value].max_neighbors - if dim.kind == DimensionKind.LOCAL - else dace.symbol(unique_name(f"{name}_shape{i}"), dtype) + ( + neighbor_tables[dim.value].max_neighbors + if dim.kind == DimensionKind.LOCAL + else dace.symbol(unique_name(f"{name}_shape{i}"), dtype) + ) for i, dim in enumerate(sorted_dims) ] strides = [dace.symbol(unique_name(f"{name}_stride{i}"), dtype) for i, _ in enumerate(shape)] @@ -348,11 +350,15 @@ def visit_StencilClosure( # Map SDFG tasklet arguments to parameters input_local_names = [ - input_transients_mapping[input_name] - if input_name in input_transients_mapping - else input_name - if input_name in input_field_names - else cast(ValueExpr, program_arg_syms[input_name]).value.data + ( + input_transients_mapping[input_name] + if input_name in input_transients_mapping + else ( + input_name + if input_name in input_field_names + else cast(ValueExpr, program_arg_syms[input_name]).value.data + ) + ) for input_name in input_names ] input_memlets = [ @@ -380,9 +386,11 @@ def visit_StencilClosure( create_memlet_at( output_name, tuple( - f"i_{dim}" - if f"i_{dim}" in map_ranges - else f"0:{closure_sdfg.arrays[output_name].shape[scan_dim_index]}" + ( + f"i_{dim}" + if f"i_{dim}" in map_ranges + else f"0:{closure_sdfg.arrays[output_name].shape[scan_dim_index]}" + ) for dim, _ in closure_domain ), ) diff --git a/src/gt4py/next/type_inference.py b/src/gt4py/next/type_inference.py index 9b5d9070e3..10ae524451 100644 --- a/src/gt4py/next/type_inference.py +++ b/src/gt4py/next/type_inference.py @@ -94,13 +94,11 @@ def visit_TypeVar(self, node: V, *, index_map: dict[int, int]) -> V: @typing.overload -def freshen(dtypes: list[T]) -> list[T]: - ... +def freshen(dtypes: list[T]) -> list[T]: ... @typing.overload -def freshen(dtypes: T) -> T: - ... +def freshen(dtypes: T) -> T: ... def freshen(dtypes: list[T] | T) -> list[T] | T: @@ -325,15 +323,13 @@ def _handle_constraint(self, constraint: tuple[_Box, _Box]) -> bool: @typing.overload def unify( dtypes: list[Type], constraints: set[tuple[Type, Type]] -) -> tuple[list[Type], list[tuple[Type, Type]]]: - ... +) -> tuple[list[Type], list[tuple[Type, Type]]]: ... @typing.overload def unify( dtypes: Type, constraints: set[tuple[Type, Type]] -) -> tuple[Type, list[tuple[Type, Type]]]: - ... +) -> tuple[Type, list[tuple[Type, Type]]]: ... def unify( diff --git a/src/gt4py/next/type_system/type_info.py b/src/gt4py/next/type_system/type_info.py index 20fa8bd791..7c4c8e6e23 100644 --- a/src/gt4py/next/type_system/type_info.py +++ b/src/gt4py/next/type_system/type_info.py @@ -118,8 +118,9 @@ def constituents_yielder(symbol_type: ts.TypeSpec): def apply_to_primitive_constituents( symbol_type: ts.TypeSpec, - fun: Callable[[ts.TypeSpec], ts.TypeSpec] - | Callable[[ts.TypeSpec, tuple[int, ...]], ts.TypeSpec], + fun: ( + Callable[[ts.TypeSpec], ts.TypeSpec] | Callable[[ts.TypeSpec, tuple[int, ...]], ts.TypeSpec] + ), with_path_arg=False, _path=(), ): diff --git a/src/gt4py/storage/allocators.py b/src/gt4py/storage/allocators.py index 061f79f146..0482ec1e65 100644 --- a/src/gt4py/storage/allocators.py +++ b/src/gt4py/storage/allocators.py @@ -156,8 +156,7 @@ class BufferAllocator(Protocol[core_defs.DeviceTypeT]): """Protocol for buffer allocators.""" @property - def device_type(self) -> core_defs.DeviceTypeT: - ... + def device_type(self) -> core_defs.DeviceTypeT: ... def allocate( self, @@ -321,20 +320,17 @@ class _NumPyLibStridesModule(Protocol): @staticmethod def as_strided( ndarray: core_defs.NDArrayObject, **kwargs: Any - ) -> core_defs.NDArrayObject: - ... + ) -> core_defs.NDArrayObject: ... stride_tricks: _NumPyLibStridesModule lib: _NumPyLibModule @staticmethod - def empty(shape: Tuple[int, ...], dtype: Any) -> _NDBuffer: - ... + def empty(shape: Tuple[int, ...], dtype: Any) -> _NDBuffer: ... @staticmethod - def byte_bounds(ndarray: _NDBuffer) -> Tuple[int, int]: - ... + def byte_bounds(ndarray: _NDBuffer) -> Tuple[int, int]: ... def is_valid_nplike_allocation_ns(obj: Any) -> TypeGuard[ValidNumPyLikeAllocationNS]: diff --git a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py index 79056c2914..4ac239fdd2 100644 --- a/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py +++ b/tests/cartesian_tests/integration_tests/multi_feature_tests/stencil_definitions.py @@ -142,11 +142,7 @@ def native_functions(field_a: Field3D, field_b: Field3D): field_b = ( trunc_res if isfinite(trunc_res) - else field_a - if isinf(trunc_res) - else field_b - if isnan(trunc_res) - else 0.0 + else field_a if isinf(trunc_res) else field_b if isnan(trunc_res) else 0.0 ) diff --git a/tests/eve_tests/unit_tests/test_extended_typing.py b/tests/eve_tests/unit_tests/test_extended_typing.py index 733e12577c..d90a577bf9 100644 --- a/tests/eve_tests/unit_tests/test_extended_typing.py +++ b/tests/eve_tests/unit_tests/test_extended_typing.py @@ -413,12 +413,10 @@ class B: def test_is_protocol(): class AProtocol(typing.Protocol): - def do_something(self, value: int) -> int: - ... + def do_something(self, value: int) -> int: ... class NotProtocol(AProtocol): - def do_something_else(self, value: float) -> float: - ... + def do_something_else(self, value: float) -> float: ... class AXProtocol(xtyping.Protocol): A = 1 @@ -427,8 +425,7 @@ class NotXProtocol(AXProtocol): A = 1 class AgainProtocol(AProtocol, xtyping.Protocol): - def do_something_else(self, value: float) -> float: - ... + def do_something_else(self, value: float) -> float: ... assert xtyping.is_protocol(AProtocol) assert xtyping.is_protocol(AXProtocol) @@ -440,16 +437,13 @@ def do_something_else(self, value: float) -> float: def test_get_partial_type_hints(): - def f1(a: int) -> float: - ... + def f1(a: int) -> float: ... assert xtyping.get_partial_type_hints(f1) == {"a": int, "return": float} - class MissingRef: - ... + class MissingRef: ... - def f_partial(a: int) -> MissingRef: - ... + def f_partial(a: int) -> MissingRef: ... # This is expected behavior because this test file uses # 'from __future__ import annotations' and therefore local @@ -467,8 +461,7 @@ def f_partial(a: int) -> MissingRef: "return": int, } - def f_nested_partial(a: int) -> Dict[str, MissingRef]: - ... + def f_nested_partial(a: int) -> Dict[str, MissingRef]: ... assert xtyping.get_partial_type_hints(f_nested_partial) == { "a": int, @@ -500,8 +493,7 @@ def test_eval_forward_ref(): == Dict[str, Tuple[int, float]] ) - class MissingRef: - ... + class MissingRef: ... assert ( xtyping.eval_forward_ref("Callable[[int], MissingRef]", localns={"MissingRef": MissingRef}) @@ -559,19 +551,16 @@ def test_infer_type(): assert xtyping.infer_type(str) == Type[str] - class A: - ... + class A: ... assert xtyping.infer_type(A()) == A assert xtyping.infer_type(A) == Type[A] - def f1(): - ... + def f1(): ... assert xtyping.infer_type(f1) == Callable[[], Any] - def f2(a: int, b: float) -> None: - ... + def f2(a: int, b: float) -> None: ... assert xtyping.infer_type(f2) == Callable[[int, float], type(None)] @@ -579,8 +568,7 @@ def f3( a: Dict[Tuple[str, ...], List[int]], b: List[Callable[[List[int]], Set[Set[int]]]], c: Type[List[int]], - ) -> Any: - ... + ) -> Any: ... assert ( xtyping.infer_type(f3) @@ -594,8 +582,7 @@ def f3( ] ) - def f4(a: int, b: float, *, foo: Tuple[str, ...] = ()) -> None: - ... + def f4(a: int, b: float, *, foo: Tuple[str, ...] = ()) -> None: ... assert xtyping.infer_type(f4) == Callable[[int, float], type(None)] assert ( diff --git a/tests/next_tests/definitions.py b/tests/next_tests/definitions.py index dbb2366f47..70a0e7d090 100644 --- a/tests/next_tests/definitions.py +++ b/tests/next_tests/definitions.py @@ -61,12 +61,10 @@ class ProgramBackendId(_PythonObjectIdMixin, str, enum.Enum): class ExecutionAndAllocatorDescriptor(Protocol): # Used for test infrastructure, consider implementing this in gt4py when refactoring otf @property - def executor(self) -> Optional[ppi.ProgramExecutor]: - ... + def executor(self) -> Optional[ppi.ProgramExecutor]: ... @property - def allocator(self) -> next_allocators.FieldBufferAllocatorProtocol: - ... + def allocator(self) -> next_allocators.FieldBufferAllocatorProtocol: ... @dataclasses.dataclass(frozen=True) diff --git a/tests/next_tests/integration_tests/cases.py b/tests/next_tests/integration_tests/cases.py index 03a0a9f5a7..7d55e26118 100644 --- a/tests/next_tests/integration_tests/cases.py +++ b/tests/next_tests/integration_tests/cases.py @@ -95,8 +95,7 @@ class DataInitializer(Protocol): @property - def scalar_value(self) -> ScalarValue: - ... + def scalar_value(self) -> ScalarValue: ... def scalar(self, dtype: np.typing.DTypeLike) -> ScalarValue: # some unlikely numpy dtypes are picky about arguments @@ -107,8 +106,7 @@ def field( allocator: next_allocators.FieldBufferAllocatorProtocol, sizes: dict[gtx.Dimension, int], dtype: np.typing.DTypeLike, - ) -> FieldValue: - ... + ) -> FieldValue: ... def from_case( self: Self, @@ -249,22 +247,19 @@ def __getattr__(self, name: str) -> Any: @typing.overload -def make_builder(*args: Callable) -> Callable[..., Builder]: - ... +def make_builder(*args: Callable) -> Callable[..., Builder]: ... @typing.overload def make_builder( *args: Literal[None], **kwargs: dict[str, Any] -) -> Callable[[Callable], Callable[..., Builder]]: - ... +) -> Callable[[Callable], Callable[..., Builder]]: ... @typing.overload def make_builder( *args: Optional[Callable], **kwargs: dict[str, Any] -) -> Callable[[Callable], Callable[..., Builder]] | Callable[..., Builder]: - ... +) -> Callable[[Callable], Callable[..., Builder]] | Callable[..., Builder]: ... # TODO(ricoh): Think about improving the type hints using `typing.ParamSpec`. @@ -305,8 +300,7 @@ def setter(self: Builder) -> Builder: argspec = inspect.getfullargspec(func) @dataclasses.dataclass(frozen=True) - class NewBuilder(Builder): - ... + class NewBuilder(Builder): ... for argname in argspec.args + argspec.kwonlyargs: setattr(NewBuilder, argname, make_setter(argname)) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 9482860d13..6b7737df67 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -1099,7 +1099,12 @@ def test_tuple_unpacking(cartesian_case): @gtx.field_operator def unpack( inp: cases.IField, - ) -> tuple[cases.IField, cases.IField, cases.IField, cases.IField,]: + ) -> tuple[ + cases.IField, + cases.IField, + cases.IField, + cases.IField, + ]: a, b, c, d = (inp + 2, inp + 3, inp + 5, inp + 7) return a, b, c, d diff --git a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py index 2dd4b91c48..bc92efc02c 100644 --- a/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py +++ b/tests/next_tests/unit_tests/ffront_tests/test_foast_to_itir.py @@ -526,15 +526,13 @@ def reduction(e1: gtx.Field[[Edge], float64], e2: gtx.Field[[Vertex, V2EDim], fl def test_builtin_int_constructors(): - def int_constrs() -> ( - tuple[ - int32, - int32, - int64, - int32, - int64, - ] - ): + def int_constrs() -> tuple[ + int32, + int32, + int64, + int32, + int64, + ]: return 1, int32(1), int64(1), int32("1"), int64("1") parsed = FieldOperatorParser.apply_to_function(int_constrs) @@ -552,17 +550,15 @@ def int_constrs() -> ( def test_builtin_float_constructors(): - def float_constrs() -> ( - tuple[ - float, - float, - float32, - float64, - float, - float32, - float64, - ] - ): + def float_constrs() -> tuple[ + float, + float, + float32, + float64, + float, + float32, + float64, + ]: return ( 0.1, float(0.1), diff --git a/tests/next_tests/unit_tests/test_type_inference.py b/tests/next_tests/unit_tests/test_type_inference.py index 74178e7548..3db67320f1 100644 --- a/tests/next_tests/unit_tests/test_type_inference.py +++ b/tests/next_tests/unit_tests/test_type_inference.py @@ -20,8 +20,7 @@ class Foo(ti.Type): bar: ti.Type baz: ti.Type - class Bar(ti.Type): - ... + class Bar(ti.Type): ... r = ti._Renamer() actual = [