diff --git a/.bumpversion.cfg b/.bumpversion.cfg index d7a3acaac1..9e65fd9ae0 100644 --- a/.bumpversion.cfg +++ b/.bumpversion.cfg @@ -1,5 +1,5 @@ [bumpversion] -current_version = 1.0.1 +current_version = 1.0.2 parse = (?P\d+)\.(?P\d+)(\.(?P\d+))? serialize = {major}.{minor}.{patch} diff --git a/.github/workflows/daily-ci.yml b/.github/workflows/daily-ci.yml index 77ba39a361..8631390dbb 100644 --- a/.github/workflows/daily-ci.yml +++ b/.github/workflows/daily-ci.yml @@ -14,7 +14,7 @@ jobs: daily-ci: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] tox-module-factor: ["cartesian", "eve", "next", "storage"] os: ["ubuntu-latest"] requirements-file: ["requirements-dev.txt", "min-requirements-test.txt", "min-extra-requirements-test.txt"] diff --git a/.github/workflows/test-cartesian-fallback.yml b/.github/workflows/test-cartesian-fallback.yml index b2eaead47a..7e9a948e9c 100644 --- a/.github/workflows/test-cartesian-fallback.yml +++ b/.github/workflows/test-cartesian-fallback.yml @@ -16,7 +16,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] backends: [internal-cpu, dace-cpu] steps: diff --git a/.github/workflows/test-cartesian.yml b/.github/workflows/test-cartesian.yml index 2c2b97aaa6..ebdc4ce749 100644 --- a/.github/workflows/test-cartesian.yml +++ b/.github/workflows/test-cartesian.yml @@ -23,7 +23,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] backends: [internal-cpu, dace-cpu] steps: - uses: actions/checkout@v2 diff --git a/.github/workflows/test-eve-fallback.yml b/.github/workflows/test-eve-fallback.yml index 93dc308a53..fd7ab5452c 100644 --- a/.github/workflows/test-eve-fallback.yml +++ b/.github/workflows/test-eve-fallback.yml @@ -17,7 +17,7 @@ jobs: test-eve: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] os: ["ubuntu-latest"] runs-on: ${{ matrix.os }} diff --git a/.github/workflows/test-eve.yml b/.github/workflows/test-eve.yml index 1322c573db..222b825f38 100644 --- a/.github/workflows/test-eve.yml +++ b/.github/workflows/test-eve.yml @@ -20,7 +20,7 @@ jobs: test-eve: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] os: ["ubuntu-latest"] fail-fast: false @@ -68,4 +68,3 @@ jobs: # with: # name: info-py${{ matrix.python-version }}-${{ matrix.os }} # path: info.txt - diff --git a/.github/workflows/test-next-fallback.yml b/.github/workflows/test-next-fallback.yml index 8490a3e393..bdcc061db0 100644 --- a/.github/workflows/test-next-fallback.yml +++ b/.github/workflows/test-next-fallback.yml @@ -15,7 +15,7 @@ jobs: test-next: strategy: matrix: - python-version: ["3.10"] + python-version: ["3.10", "3.11"] tox-env-factor: ["nomesh", "atlas"] os: ["ubuntu-latest"] diff --git a/.github/workflows/test-next.yml b/.github/workflows/test-next.yml index 52f8c25386..4282a22da6 100644 --- a/.github/workflows/test-next.yml +++ b/.github/workflows/test-next.yml @@ -18,7 +18,7 @@ jobs: test-next: strategy: matrix: - python-version: ["3.10"] + python-version: ["3.10", "3.11"] tox-env-factor: ["nomesh", "atlas"] os: ["ubuntu-latest"] fail-fast: false diff --git a/.github/workflows/test-storage-fallback.yml b/.github/workflows/test-storage-fallback.yml index 0cbc735564..99e4923de8 100644 --- a/.github/workflows/test-storage-fallback.yml +++ b/.github/workflows/test-storage-fallback.yml @@ -18,7 +18,7 @@ jobs: test-storage: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] backends: [internal-cpu, dace-cpu] os: ["ubuntu-latest"] diff --git a/.github/workflows/test-storage.yml b/.github/workflows/test-storage.yml index 1133353f30..34841ed71c 100644 --- a/.github/workflows/test-storage.yml +++ b/.github/workflows/test-storage.yml @@ -21,7 +21,7 @@ jobs: test-storage: strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.8", "3.9", "3.10", "3.11"] backends: [internal-cpu, dace-cpu] os: ["ubuntu-latest"] fail-fast: false @@ -70,4 +70,3 @@ jobs: # with: # name: info-py${{ matrix.python-version }}-${{ matrix.os }} # path: info.txt - diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index b1092fafd0..d9cfa0ff48 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -62,7 +62,7 @@ repos: ## version = re.search('black==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: '{version}' # version from constraints.txt") ##]]] - rev: '23.9.1' # version from constraints.txt + rev: '23.11.0' # version from constraints.txt ##[[[end]]] hooks: - id: black @@ -73,7 +73,7 @@ repos: ## version = re.search('isort==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"rev: '{version}' # version from constraints.txt") ##]]] - rev: '5.12.0' # version from constraints.txt + rev: '5.13.0' # version from constraints.txt ##[[[end]]] hooks: - id: isort @@ -97,14 +97,14 @@ repos: ## print(f"- {pkg}==" + str(re.search(f'\n{pkg}==([0-9\.]*)', constraints)[1])) ##]]] - darglint==1.8.1 - - flake8-bugbear==23.9.16 - - flake8-builtins==2.1.0 + - flake8-bugbear==23.12.2 + - flake8-builtins==2.2.0 - flake8-debugger==4.1.2 - flake8-docstrings==1.7.0 - flake8-eradicate==1.5.0 - flake8-mutable==1.2.0 - flake8-pyproject==1.2.3 - - pygments==2.16.1 + - pygments==2.17.2 ##[[[end]]] # - flake8-rst-docstrings # Disabled for now due to random false positives exclude: | @@ -146,9 +146,9 @@ repos: ## version = re.search('mypy==([0-9\.]*)', open("constraints.txt").read())[1] ## print(f"#========= FROM constraints.txt: v{version} =========") ##]]] - #========= FROM constraints.txt: v1.5.1 ========= + #========= FROM constraints.txt: v1.7.1 ========= ##[[[end]]] - rev: v1.5.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) + rev: v1.7.1 # MUST match version ^^^^ in constraints.txt (if the mirror is up-to-date) hooks: - id: mypy additional_dependencies: # versions from constraints.txt @@ -162,26 +162,26 @@ repos: ##]]] - astunparse==1.6.3 - attrs==23.1.0 - - black==23.9.1 - - boltons==23.0.0 + - black==23.11.0 + - boltons==23.1.1 - cached-property==1.5.2 - click==8.1.7 - - cmake==3.27.5 + - cmake==3.27.9 - cytoolz==0.12.2 - - deepdiff==6.5.0 + - deepdiff==6.7.1 - devtools==0.12.2 - - frozendict==2.3.8 + - frozendict==2.3.10 - gridtools-cpp==2.3.1 - - importlib-resources==6.0.1 + - importlib-resources==6.1.1 - jinja2==3.1.2 - - lark==1.1.7 - - mako==1.2.4 - - nanobind==1.5.2 - - ninja==1.11.1 + - lark==1.1.8 + - mako==1.3.0 + - nanobind==1.8.0 + - ninja==1.11.1.1 - numpy==1.24.4 - - packaging==23.1 + - packaging==23.2 - pybind11==2.11.1 - - setuptools==68.2.2 + - setuptools==69.0.2 - tabulate==0.9.0 - typing-extensions==4.5.0 - xxhash==3.0.0 diff --git a/AUTHORS.md b/AUTHORS.md index 89aafb9971..6c76e5759e 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -9,6 +9,7 @@ - Deconinck, Florian. SSAI/NASA-GSFC - Ehrengruber, Till. ETH Zurich - CSCS - Elbert, Oliver D. NOAA-GFDL +- Faghih-Naini, Sara. ECMWF - Farabullini, Nicoletta. ETH Zurich - C2SM - George, Rhea. Allen Institute for AI - González Paredes, Enrique. ETH Zurich - CSCS diff --git a/CHANGELOG.md b/CHANGELOG.md index 519f7ff1db..87f3ee9d2b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,23 @@ Notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +## [1.0.2] - 2024-01-24 + +### Cartesian + +- Compatibility of `gt4py.next` Fields with `gt4py.cartesian` computations. +- Fixes for DaCe 0.15.1 compatibility. +- Added `log10` as native function. +- Make `scipy` optional: get `scipy` by installing `gt4py[full]` for best performance with `numpy` backend. + +### Storage + +- Refactored low-level storage allocation. + +### Next + +See commit history. + ## [1.0.1] - 2023-02-20 First version including the experimental `gt4py.next` aka _Declarative GT4Py_. The `gt4py.next` package is excluded from semantic versioning. diff --git a/constraints.txt b/constraints.txt index b334851af1..81abd64c6e 100644 --- a/constraints.txt +++ b/constraints.txt @@ -6,124 +6,136 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx -asttokens==2.4.0 # via devtools +asttokens==2.4.1 # via devtools astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing -babel==2.12.1 # via sphinx -black==23.9.1 # via gt4py (pyproject.toml) -blinker==1.6.2 # via flask -boltons==23.0.0 # via gt4py (pyproject.toml) +babel==2.13.1 # via sphinx +black==23.11.0 # via gt4py (pyproject.toml) +blinker==1.7.0 # via flask +boltons==23.1.1 # via gt4py (pyproject.toml) build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) -cachetools==5.3.1 # via tox -certifi==2023.7.22 # via requests -cffi==1.15.1 # via cryptography +cachetools==5.3.2 # via tox +cerberus==1.3.5 # via plette +certifi==2023.11.17 # via requests +cffi==1.16.0 # via cryptography cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox -charset-normalizer==3.2.0 # via requests -clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) +charset-normalizer==3.3.2 # via requests +clang-format==17.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.5 # via gt4py (pyproject.toml) +cmake==3.27.9 # via dace, gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage==7.3.1 # via -r requirements-dev.in, pytest-cov -cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis +coverage==7.3.2 # via -r requirements-dev.in, pytest-cov +cryptography==41.0.7 # via types-paramiko, types-pyopenssl, types-redis cytoolz==0.12.2 # via gt4py (pyproject.toml) -dace==0.14.4 # via gt4py (pyproject.toml) +dace==0.15.1 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -deepdiff==6.5.0 # via gt4py (pyproject.toml) +deepdiff==6.7.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.7 # via dace -distlib==0.3.7 # via virtualenv -docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme +distlib==0.3.7 # via requirementslib, virtualenv +distro==1.8.0 # via scikit-build +docopt==0.6.2 # via pipreqs +docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme eradicate==2.3.0 # via flake8-eradicate -exceptiongroup==1.1.3 # via hypothesis, pytest +exceptiongroup==1.2.0 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist -executing==1.2.0 # via devtools +executing==2.0.1 # via devtools factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==19.6.1 # via factory-boy -fastjsonschema==2.18.0 # via nbformat -filelock==3.12.4 # via tox, virtualenv +faker==20.1.0 # via factory-boy +fastjsonschema==2.19.0 # via nbformat +filelock==3.13.1 # via tox, virtualenv flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==23.9.16 # via -r requirements-dev.in -flake8-builtins==2.1.0 # via -r requirements-dev.in +flake8-bugbear==23.12.2 # via -r requirements-dev.in +flake8-builtins==2.2.0 # via -r requirements-dev.in flake8-debugger==4.1.2 # via -r requirements-dev.in flake8-docstrings==1.7.0 # via -r requirements-dev.in flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==2.3.3 # via dace -frozendict==2.3.8 # via gt4py (pyproject.toml) +flask==3.0.0 # via dace +fparser==0.1.3 # via dace +frozendict==2.3.10 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.5.29 # via pre-commit -idna==3.4 # via requests +hypothesis==6.92.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.5.33 # via pre-commit +idna==3.6 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 # via build, flask, sphinx -importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications +importlib-metadata==7.0.0 # via build, flask, fparser, sphinx +importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest -isort==5.12.0 # via -r requirements-dev.in +isort==5.13.0 # via -r requirements-dev.in itsdangerous==2.1.2 # via flask jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx -jsonschema==4.19.0 # via nbformat -jsonschema-specifications==2023.7.1 # via jsonschema -jupyter-core==5.3.1 # via nbformat -jupytext==1.15.2 # via -r requirements-dev.in -lark==1.1.7 # via gt4py (pyproject.toml) -mako==1.2.4 # via gt4py (pyproject.toml) +jsonschema==4.20.0 # via nbformat +jsonschema-specifications==2023.11.2 # via jsonschema +jupyter-core==5.5.0 # via nbformat +jupytext==1.16.0 # via -r requirements-dev.in +lark==1.1.8 # via gt4py (pyproject.toml) +mako==1.3.0 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins markupsafe==2.1.3 # via jinja2, mako, werkzeug mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.5.1 # via -r requirements-dev.in +mypy==1.7.1 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==1.5.2 # via gt4py (pyproject.toml) +nanobind==1.8.0 # via gt4py (pyproject.toml) nbformat==5.9.2 # via jupytext networkx==3.1 # via dace -ninja==1.11.1 # via gt4py (pyproject.toml) +ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.8.0 # via pre-commit numpy==1.24.4 # via dace, gt4py (pyproject.toml), types-jack-client ordered-set==4.1.0 # via deepdiff -packaging==23.1 # via black, build, gt4py (pyproject.toml), pyproject-api, pytest, sphinx, tox -pathspec==0.11.2 # via black +packaging==23.2 # via black, build, gt4py (pyproject.toml), jupytext, pyproject-api, pytest, scikit-build, setuptools-scm, sphinx, tox +pathspec==0.12.1 # via black +pep517==0.13.1 # via requirementslib +pip-api==0.0.30 # via isort pip-tools==7.3.0 # via -r requirements-dev.in -pipdeptree==2.13.0 # via -r requirements-dev.in +pipdeptree==2.13.1 # via -r requirements-dev.in +pipreqs==0.4.13 # via isort pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv +platformdirs==4.1.0 # via black, jupyter-core, requirementslib, tox, virtualenv +plette==0.4.4 # via requirementslib pluggy==1.3.0 # via pytest, tox ply==3.11 # via dace -pre-commit==3.4.0 # via -r requirements-dev.in -psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist +pre-commit==3.5.0 # via -r requirements-dev.in +psutil==5.9.6 # via -r requirements-dev.in, pytest-xdist pybind11==2.11.1 # via gt4py (pyproject.toml) -pycodestyle==2.11.0 # via flake8, flake8-debugger +pycodestyle==2.11.1 # via flake8, flake8-debugger pycparser==2.21 # via cffi +pydantic==1.10.13 # via requirementslib pydocstyle==6.3.0 # via flake8-docstrings pyflakes==3.1.0 # via flake8 -pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx +pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==7.4.3 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in -pytest-factoryboy==2.5.1 # via -r requirements-dev.in -pytest-xdist==3.3.1 # via -r requirements-dev.in +pytest-factoryboy==2.6.0 # via -r requirements-dev.in +pytest-xdist==3.5.0 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker pytz==2023.3.post1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit -referencing==0.30.2 # via jsonschema, jsonschema-specifications -requests==2.31.0 # via dace, sphinx +referencing==0.32.0 # via jsonschema, jsonschema-specifications +requests==2.31.0 # via dace, requirementslib, sphinx, yarg +requirementslib==3.0.0 # via isort restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.10.3 # via jsonschema, referencing -ruff==0.0.290 # via -r requirements-dev.in +rpds-py==0.13.2 # via jsonschema, referencing +ruff==0.1.7 # via -r requirements-dev.in +scikit-build==0.17.6 # via dace +setuptools-scm==8.0.4 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx sortedcontainers==2.4.0 # via hypothesis sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in +sphinx-rtd-theme==2.0.0 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -131,31 +143,32 @@ sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-qthelp==1.0.3 # via sphinx sphinxcontrib-serializinghtml==1.1.5 # via sphinx -sympy==1.12 # via dace, gt4py (pyproject.toml) +sympy==1.9 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext -tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox +tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pep517, pip-tools, pyproject-api, pyproject-hooks, pytest, scikit-build, setuptools-scm, tox +tomlkit==0.12.3 # via plette, requirementslib toolz==0.12.0 # via cytoolz -tox==4.11.3 # via -r requirements-dev.in -traitlets==5.10.0 # via jupyter-core, nbformat +tox==4.11.4 # via -r requirements-dev.in +traitlets==5.14.0 # via jupyter-core, nbformat types-aiofiles==23.2.0.0 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all types-atomicwrites==1.4.5.1 # via types-all types-backports==0.1.3 # via types-all types-backports-abc==0.5.2 # via types-all -types-bleach==6.0.0.4 # via types-all +types-bleach==6.1.0.1 # via types-all types-boto==2.49.18.9 # via types-all -types-cachetools==5.3.0.6 # via types-all +types-cachetools==5.3.0.7 # via types-all types-certifi==2021.10.8.3 # via types-all -types-cffi==1.15.1.15 # via types-jack-client +types-cffi==1.16.0.0 # via types-jack-client types-characteristic==14.3.7 # via types-all types-chardet==5.0.4.6 # via types-all types-click==7.1.8 # via types-all, types-flask -types-click-spinner==0.1.13.5 # via types-all +types-click-spinner==0.1.13.6 # via types-all types-colorama==0.4.15.12 # via types-all types-contextvars==2.4.7.3 # via types-all -types-croniter==1.4.0.1 # via types-all +types-croniter==2.0.0.0 # via types-all types-cryptography==3.3.23.2 # via types-all, types-openssl-python, types-pyjwt types-dataclasses==0.6.6 # via types-all types-dateparser==1.1.4.10 # via types-all @@ -176,44 +189,44 @@ types-futures==3.3.8 # via types-all types-geoip2==3.0.0 # via types-all types-ipaddress==1.0.8 # via types-all, types-maxminddb types-itsdangerous==1.1.6 # via types-all -types-jack-client==0.5.10.9 # via types-all +types-jack-client==0.5.10.10 # via types-all types-jinja2==2.11.9 # via types-all, types-flask types-kazoo==0.1.3 # via types-all -types-markdown==3.4.2.10 # via types-all +types-markdown==3.5.0.3 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.2 # via types-all +types-mock==5.1.0.3 # via types-all types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all -types-paramiko==3.3.0.0 # via types-all, types-pysftp +types-paramiko==3.3.0.2 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.0.0.3 # via types-all +types-pillow==10.1.0.2 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.1 # via types-all -types-protobuf==4.24.0.1 # via types-all +types-protobuf==4.24.0.4 # via types-all types-pyaudio==0.2.16.7 # via types-all types-pycurl==7.45.2.5 # via types-all types-pyfarmhash==0.3.1.2 # via types-all types-pyjwt==1.7.1 # via types-all types-pymssql==2.1.0 # via types-all types-pymysql==1.1.0.1 # via types-all -types-pyopenssl==23.2.0.2 # via types-redis +types-pyopenssl==23.3.0.0 # via types-redis types-pyrfc3339==1.1.1.5 # via types-all types-pysftp==0.2.17.6 # via types-all types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.1.0 # via types-all, types-tzlocal +types-pytz==2023.3.1.1 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all -types-pyyaml==6.0.12.11 # via types-all -types-redis==4.6.0.6 # via types-all -types-requests==2.31.0.2 # via types-all +types-pyyaml==6.0.12.12 # via types-all +types-redis==4.6.0.11 # via types-all +types-requests==2.31.0.10 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==68.2.0.0 # via types-cffi +types-setuptools==69.0.0.0 # via types-cffi types-simplejson==3.19.0.2 # via types-all types-singledispatch==4.1.0.0 # via types-all types-six==1.16.21.9 # via types-all @@ -222,21 +235,21 @@ types-termcolor==1.1.6.2 # via types-all types-toml==0.10.8.7 # via types-all types-tornado==5.1.1 # via types-all types-typed-ast==1.5.8.7 # via types-all -types-tzlocal==5.0.1.1 # via types-all +types-tzlocal==5.1.0.1 # via types-all types-ujson==5.8.0.1 # via types-all -types-urllib3==1.26.25.14 # via types-requests types-waitress==2.1.4.9 # via types-all types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all -typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy -urllib3==2.0.4 # via requests -virtualenv==20.24.5 # via pre-commit, tox -websockets==11.0.3 # via dace -werkzeug==2.3.7 # via flask -wheel==0.41.2 # via astunparse, pip-tools +typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pydantic, pytest-factoryboy, setuptools-scm +urllib3==2.1.0 # via requests, types-requests +virtualenv==20.25.0 # via pre-commit, tox +websockets==12.0 # via dace +werkzeug==3.0.1 # via flask +wheel==0.42.0 # via astunparse, pip-tools, scikit-build xxhash==3.0.0 # via gt4py (pyproject.toml) -zipp==3.16.2 # via importlib-metadata, importlib-resources +yarg==0.1.9 # via pipreqs +zipp==3.17.0 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: -pip==23.2.1 # via pip-tools -setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools +pip==23.3.1 # via pip-api, pip-tools, requirementslib +setuptools==69.0.2 # via gt4py (pyproject.toml), nodeenv, pip-tools, requirementslib, scikit-build, setuptools-scm diff --git a/examples/lap_cartesian_vs_next.ipynb b/examples/lap_cartesian_vs_next.ipynb new file mode 100644 index 0000000000..cb80122570 --- /dev/null +++ b/examples/lap_cartesian_vs_next.ipynb @@ -0,0 +1,189 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "GT4Py - GridTools for Python\n", + "\n", + "Copyright (c) 2014-2023, ETH Zurich\n", + "All rights reserved.\n", + "\n", + "This file is part the GT4Py project and the GridTools framework.\n", + "GT4Py is free software: you can redistribute it and/or modify it under\n", + "the terms of the GNU General Public License as published by the\n", + "Free Software Foundation, either version 3 of the License, or any later\n", + "version. See the LICENSE.txt file at the top-level directory of this\n", + "distribution for a copy of the license or check .\n", + "\n", + "SPDX-License-Identifier: GPL-3.0-or-later" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Demonstrates gt4py.cartesian with gt4py.next compatibility" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Imports" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "nx = 32\n", + "ny = 32\n", + "nz = 1\n", + "dtype = np.float64" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Storages\n", + "--\n", + "\n", + "We create fields using the gt4py.next constructors. These fields are compatible with gt4py.cartesian when we use \"I\", \"J\", \"K\" as the dimension names." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "text/html": [ + "\n", + "\n" + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import gt4py.next as gtx\n", + "\n", + "allocator = gtx.itir_embedded # should match the executor\n", + "# allocator = gtx.gtfn_cpu\n", + "# allocator = gtx.gtfn_gpu\n", + "\n", + "# Note: for gt4py.next, names don't matter, for gt4py.cartesian they have to be \"I\", \"J\", \"K\"\n", + "I = gtx.Dimension(\"I\")\n", + "J = gtx.Dimension(\"J\")\n", + "K = gtx.Dimension(\"K\", kind=gtx.DimensionKind.VERTICAL)\n", + "\n", + "domain = gtx.domain({I: nx, J: ny, K: nz})\n", + "\n", + "inp = gtx.as_field(domain, np.fromfunction(lambda x, y, z: x**2+y**2, shape=(nx, ny, nz)), dtype, allocator=allocator)\n", + "out_cartesian = gtx.zeros(domain, dtype, allocator=allocator)\n", + "out_next = gtx.zeros(domain, dtype, allocator=allocator)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "gt4py.cartesian\n", + "--" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import gt4py.cartesian.gtscript as gtscript\n", + "\n", + "cartesian_backend = \"numpy\"\n", + "# cartesian_backend = \"gt:cpu_ifirst\"\n", + "# cartesian_backend = \"gt:gpu\"\n", + "\n", + "@gtscript.stencil(backend=cartesian_backend)\n", + "def lap_cartesian(\n", + " inp: gtscript.Field[dtype],\n", + " out: gtscript.Field[dtype],\n", + "):\n", + " with computation(PARALLEL), interval(...):\n", + " out = -4.0 * inp[0, 0, 0] + inp[-1, 0, 0] + inp[1, 0, 0] + inp[0, -1, 0] + inp[0, 1, 0]\n", + "\n", + "lap_cartesian(inp=inp, out=out_cartesian, origin=(1, 1, 0), domain=(nx-2, ny-2, nz))" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "from gt4py.next import Field\n", + "\n", + "next_backend = gtx.itir_embedded\n", + "# next_backend = gtx.gtfn_cpu\n", + "# next_backend = gtx.gtfn_gpu\n", + "\n", + "Ioff = gtx.FieldOffset(\"I\", source=I, target=(I,))\n", + "Joff = gtx.FieldOffset(\"J\", source=J, target=(J,))\n", + "\n", + "@gtx.field_operator\n", + "def lap_next(inp: Field[[I, J, K], dtype]) -> Field[[I, J, K], dtype]:\n", + " return -4.0 * inp + inp(Ioff[-1]) + inp(Ioff[1]) + inp(Joff[-1]) + inp(Joff[1])\n", + "\n", + "@gtx.program(backend=next_backend)\n", + "def lap_next_program(inp: Field[[I, J, K], dtype], out: Field[[I, J, K], dtype]):\n", + " lap_next(inp, out=out[1:-1, 1:-1, :])\n", + "\n", + "lap_next_program(inp, out_next, offset_provider={\"Ioff\": I, \"Joff\": J})" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "assert np.allclose(out_cartesian.asnumpy(), out_next.asnumpy())" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/min-extra-requirements-test.txt b/min-extra-requirements-test.txt index 17709206a0..fd7724bac9 100644 --- a/min-extra-requirements-test.txt +++ b/min-extra-requirements-test.txt @@ -25,7 +25,7 @@ cmake==3.22 cogapp==3.3 coverage[toml]==5.0 cytoolz==0.12.0 -dace==0.14.2 +dace==0.15.1 darglint==1.6 deepdiff==5.6.0 devtools==0.6 @@ -70,7 +70,7 @@ scipy==1.7.2 setuptools==65.5.0 sphinx==4.4 sphinx_rtd_theme==1.0 -sympy==1.7 +sympy==1.9 tabulate==0.8.10 tomli==2.0.1 tox==3.2.0 diff --git a/pyproject.toml b/pyproject.toml index 7d499b1f3f..87934b11fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,15 +69,15 @@ requires-python = '>=3.8' cuda = ['cupy>=12.0'] cuda11x = ['cupy-cuda11x>=12.0'] cuda12x = ['cupy-cuda12x>=12.0'] -dace = ['dace>=0.14.2,<0.15', 'sympy>=1.7'] +dace = ['dace>=0.15.1,<0.16', 'sympy>=1.9'] formatting = ['clang-format>=9.0'] # Always add all extra packages to 'full' for a simple full gt4py installation full = [ 'clang-format>=9.0', - 'dace>=0.14.2,<0.15', + 'dace>=0.15.1,<0.16', 'hypothesis>=6.0.0', 'pytest>=7.0', - 'sympy>=1.7', + 'sympy>=1.9', 'scipy>=1.7.2', 'jax[cpu]>=0.4.13' ] diff --git a/requirements-dev.txt b/requirements-dev.txt index d6dcc12d21..0fa523866f 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -6,124 +6,136 @@ # aenum==3.1.15 # via dace alabaster==0.7.13 # via sphinx -asttokens==2.4.0 # via devtools +asttokens==2.4.1 # via devtools astunparse==1.6.3 ; python_version < "3.9" # via dace, gt4py (pyproject.toml) attrs==23.1.0 # via flake8-bugbear, flake8-eradicate, gt4py (pyproject.toml), hypothesis, jsonschema, referencing -babel==2.12.1 # via sphinx -black==23.9.1 # via gt4py (pyproject.toml) -blinker==1.6.2 # via flask -boltons==23.0.0 # via gt4py (pyproject.toml) +babel==2.13.1 # via sphinx +black==23.11.0 # via gt4py (pyproject.toml) +blinker==1.7.0 # via flask +boltons==23.1.1 # via gt4py (pyproject.toml) build==1.0.3 # via pip-tools cached-property==1.5.2 # via gt4py (pyproject.toml) -cachetools==5.3.1 # via tox -certifi==2023.7.22 # via requests -cffi==1.15.1 # via cryptography +cachetools==5.3.2 # via tox +cerberus==1.3.5 # via plette +certifi==2023.11.17 # via requests +cffi==1.16.0 # via cryptography cfgv==3.4.0 # via pre-commit chardet==5.2.0 # via tox -charset-normalizer==3.2.0 # via requests -clang-format==16.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) +charset-normalizer==3.3.2 # via requests +clang-format==17.0.6 # via -r requirements-dev.in, gt4py (pyproject.toml) click==8.1.7 # via black, flask, gt4py (pyproject.toml), pip-tools -cmake==3.27.5 # via gt4py (pyproject.toml) +cmake==3.27.9 # via dace, gt4py (pyproject.toml) cogapp==3.3.0 # via -r requirements-dev.in colorama==0.4.6 # via tox -coverage[toml]==7.3.1 # via -r requirements-dev.in, pytest-cov -cryptography==41.0.3 # via types-paramiko, types-pyopenssl, types-redis +coverage[toml]==7.3.2 # via -r requirements-dev.in, pytest-cov +cryptography==41.0.7 # via types-paramiko, types-pyopenssl, types-redis cytoolz==0.12.2 # via gt4py (pyproject.toml) -dace==0.14.4 # via gt4py (pyproject.toml) +dace==0.15.1 # via gt4py (pyproject.toml) darglint==1.8.1 # via -r requirements-dev.in -deepdiff==6.5.0 # via gt4py (pyproject.toml) +deepdiff==6.7.1 # via gt4py (pyproject.toml) devtools==0.12.2 # via gt4py (pyproject.toml) dill==0.3.7 # via dace -distlib==0.3.7 # via virtualenv -docutils==0.18.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme +distlib==0.3.7 # via requirementslib, virtualenv +distro==1.8.0 # via scikit-build +docopt==0.6.2 # via pipreqs +docutils==0.20.1 # via restructuredtext-lint, sphinx, sphinx-rtd-theme eradicate==2.3.0 # via flake8-eradicate -exceptiongroup==1.1.3 # via hypothesis, pytest +exceptiongroup==1.2.0 # via hypothesis, pytest execnet==2.0.2 # via pytest-cache, pytest-xdist -executing==1.2.0 # via devtools +executing==2.0.1 # via devtools factory-boy==3.3.0 # via -r requirements-dev.in, pytest-factoryboy -faker==19.6.1 # via factory-boy -fastjsonschema==2.18.0 # via nbformat -filelock==3.12.4 # via tox, virtualenv +faker==20.1.0 # via factory-boy +fastjsonschema==2.19.0 # via nbformat +filelock==3.13.1 # via tox, virtualenv flake8==6.1.0 # via -r requirements-dev.in, flake8-bugbear, flake8-builtins, flake8-debugger, flake8-docstrings, flake8-eradicate, flake8-mutable, flake8-pyproject, flake8-rst-docstrings -flake8-bugbear==23.9.16 # via -r requirements-dev.in -flake8-builtins==2.1.0 # via -r requirements-dev.in +flake8-bugbear==23.12.2 # via -r requirements-dev.in +flake8-builtins==2.2.0 # via -r requirements-dev.in flake8-debugger==4.1.2 # via -r requirements-dev.in flake8-docstrings==1.7.0 # via -r requirements-dev.in flake8-eradicate==1.5.0 # via -r requirements-dev.in flake8-mutable==1.2.0 # via -r requirements-dev.in flake8-pyproject==1.2.3 # via -r requirements-dev.in flake8-rst-docstrings==0.3.0 # via -r requirements-dev.in -flask==2.3.3 # via dace -frozendict==2.3.8 # via gt4py (pyproject.toml) +flask==3.0.0 # via dace +fparser==0.1.3 # via dace +frozendict==2.3.10 # via gt4py (pyproject.toml) gridtools-cpp==2.3.1 # via gt4py (pyproject.toml) -hypothesis==6.86.1 # via -r requirements-dev.in, gt4py (pyproject.toml) -identify==2.5.29 # via pre-commit -idna==3.4 # via requests +hypothesis==6.92.0 # via -r requirements-dev.in, gt4py (pyproject.toml) +identify==2.5.33 # via pre-commit +idna==3.6 # via requests imagesize==1.4.1 # via sphinx -importlib-metadata==6.8.0 # via build, flask, sphinx -importlib-resources==6.0.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications +importlib-metadata==7.0.0 # via build, flask, fparser, sphinx +importlib-resources==6.1.1 ; python_version < "3.9" # via gt4py (pyproject.toml), jsonschema, jsonschema-specifications inflection==0.5.1 # via pytest-factoryboy iniconfig==2.0.0 # via pytest -isort==5.12.0 # via -r requirements-dev.in +isort==5.13.0 # via -r requirements-dev.in itsdangerous==2.1.2 # via flask jinja2==3.1.2 # via flask, gt4py (pyproject.toml), sphinx -jsonschema==4.19.0 # via nbformat -jsonschema-specifications==2023.7.1 # via jsonschema -jupyter-core==5.3.1 # via nbformat -jupytext==1.15.2 # via -r requirements-dev.in -lark==1.1.7 # via gt4py (pyproject.toml) -mako==1.2.4 # via gt4py (pyproject.toml) +jsonschema==4.20.0 # via nbformat +jsonschema-specifications==2023.11.2 # via jsonschema +jupyter-core==5.5.0 # via nbformat +jupytext==1.16.0 # via -r requirements-dev.in +lark==1.1.8 # via gt4py (pyproject.toml) +mako==1.3.0 # via gt4py (pyproject.toml) markdown-it-py==3.0.0 # via jupytext, mdit-py-plugins markupsafe==2.1.3 # via jinja2, mako, werkzeug mccabe==0.7.0 # via flake8 mdit-py-plugins==0.4.0 # via jupytext mdurl==0.1.2 # via markdown-it-py mpmath==1.3.0 # via sympy -mypy==1.5.1 # via -r requirements-dev.in +mypy==1.7.1 # via -r requirements-dev.in mypy-extensions==1.0.0 # via black, mypy -nanobind==1.5.2 # via gt4py (pyproject.toml) +nanobind==1.8.0 # via gt4py (pyproject.toml) nbformat==5.9.2 # via jupytext networkx==3.1 # via dace -ninja==1.11.1 # via gt4py (pyproject.toml) +ninja==1.11.1.1 # via gt4py (pyproject.toml) nodeenv==1.8.0 # via pre-commit numpy==1.24.4 # via dace, gt4py (pyproject.toml), types-jack-client ordered-set==4.1.0 # via deepdiff -packaging==23.1 # via black, build, gt4py (pyproject.toml), pyproject-api, pytest, sphinx, tox -pathspec==0.11.2 # via black +packaging==23.2 # via black, build, gt4py (pyproject.toml), jupytext, pyproject-api, pytest, scikit-build, setuptools-scm, sphinx, tox +pathspec==0.12.1 # via black +pep517==0.13.1 # via requirementslib +pip-api==0.0.30 # via isort pip-tools==7.3.0 # via -r requirements-dev.in -pipdeptree==2.13.0 # via -r requirements-dev.in +pipdeptree==2.13.1 # via -r requirements-dev.in +pipreqs==0.4.13 # via isort pkgutil-resolve-name==1.3.10 # via jsonschema -platformdirs==3.10.0 # via black, jupyter-core, tox, virtualenv +platformdirs==4.1.0 # via black, jupyter-core, requirementslib, tox, virtualenv +plette[validation]==0.4.4 # via requirementslib pluggy==1.3.0 # via pytest, tox ply==3.11 # via dace -pre-commit==3.4.0 # via -r requirements-dev.in -psutil==5.9.5 # via -r requirements-dev.in, pytest-xdist +pre-commit==3.5.0 # via -r requirements-dev.in +psutil==5.9.6 # via -r requirements-dev.in, pytest-xdist pybind11==2.11.1 # via gt4py (pyproject.toml) -pycodestyle==2.11.0 # via flake8, flake8-debugger +pycodestyle==2.11.1 # via flake8, flake8-debugger pycparser==2.21 # via cffi +pydantic==1.10.13 # via requirementslib pydocstyle==6.3.0 # via flake8-docstrings pyflakes==3.1.0 # via flake8 -pygments==2.16.1 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx +pygments==2.17.2 # via -r requirements-dev.in, devtools, flake8-rst-docstrings, sphinx pyproject-api==1.6.1 # via tox pyproject-hooks==1.0.0 # via build -pytest==7.4.2 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist +pytest==7.4.3 # via -r requirements-dev.in, gt4py (pyproject.toml), pytest-cache, pytest-cov, pytest-factoryboy, pytest-xdist pytest-cache==1.0 # via -r requirements-dev.in pytest-cov==4.1.0 # via -r requirements-dev.in -pytest-factoryboy==2.5.1 # via -r requirements-dev.in -pytest-xdist[psutil]==3.3.1 # via -r requirements-dev.in +pytest-factoryboy==2.6.0 # via -r requirements-dev.in +pytest-xdist[psutil]==3.5.0 # via -r requirements-dev.in python-dateutil==2.8.2 # via faker pytz==2023.3.post1 # via babel pyyaml==6.0.1 # via dace, jupytext, pre-commit -referencing==0.30.2 # via jsonschema, jsonschema-specifications -requests==2.31.0 # via dace, sphinx +referencing==0.32.0 # via jsonschema, jsonschema-specifications +requests==2.31.0 # via dace, requirementslib, sphinx, yarg +requirementslib==3.0.0 # via isort restructuredtext-lint==1.4.0 # via flake8-rst-docstrings -rpds-py==0.10.3 # via jsonschema, referencing -ruff==0.0.290 # via -r requirements-dev.in +rpds-py==0.13.2 # via jsonschema, referencing +ruff==0.1.7 # via -r requirements-dev.in +scikit-build==0.17.6 # via dace +setuptools-scm==8.0.4 # via fparser six==1.16.0 # via asttokens, astunparse, python-dateutil snowballstemmer==2.2.0 # via pydocstyle, sphinx sortedcontainers==2.4.0 # via hypothesis sphinx==7.1.2 # via -r requirements-dev.in, sphinx-rtd-theme, sphinxcontrib-jquery -sphinx-rtd-theme==1.3.0 # via -r requirements-dev.in +sphinx-rtd-theme==2.0.0 # via -r requirements-dev.in sphinxcontrib-applehelp==1.0.4 # via sphinx sphinxcontrib-devhelp==1.0.2 # via sphinx sphinxcontrib-htmlhelp==2.0.1 # via sphinx @@ -131,31 +143,32 @@ sphinxcontrib-jquery==4.1 # via sphinx-rtd-theme sphinxcontrib-jsmath==1.0.1 # via sphinx sphinxcontrib-qthelp==1.0.3 # via sphinx sphinxcontrib-serializinghtml==1.1.5 # via sphinx -sympy==1.12 # via dace, gt4py (pyproject.toml) +sympy==1.9 # via dace, gt4py (pyproject.toml) tabulate==0.9.0 # via gt4py (pyproject.toml) toml==0.10.2 # via jupytext -tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pip-tools, pyproject-api, pyproject-hooks, pytest, tox +tomli==2.0.1 # via -r requirements-dev.in, black, build, coverage, flake8-pyproject, mypy, pep517, pip-tools, pyproject-api, pyproject-hooks, pytest, scikit-build, setuptools-scm, tox +tomlkit==0.12.3 # via plette, requirementslib toolz==0.12.0 # via cytoolz -tox==4.11.3 # via -r requirements-dev.in -traitlets==5.10.0 # via jupyter-core, nbformat +tox==4.11.4 # via -r requirements-dev.in +traitlets==5.14.0 # via jupyter-core, nbformat types-aiofiles==23.2.0.0 # via types-all types-all==1.0.0 # via -r requirements-dev.in types-annoy==1.17.8.4 # via types-all types-atomicwrites==1.4.5.1 # via types-all types-backports==0.1.3 # via types-all types-backports-abc==0.5.2 # via types-all -types-bleach==6.0.0.4 # via types-all +types-bleach==6.1.0.1 # via types-all types-boto==2.49.18.9 # via types-all -types-cachetools==5.3.0.6 # via types-all +types-cachetools==5.3.0.7 # via types-all types-certifi==2021.10.8.3 # via types-all -types-cffi==1.15.1.15 # via types-jack-client +types-cffi==1.16.0.0 # via types-jack-client types-characteristic==14.3.7 # via types-all types-chardet==5.0.4.6 # via types-all types-click==7.1.8 # via types-all, types-flask -types-click-spinner==0.1.13.5 # via types-all +types-click-spinner==0.1.13.6 # via types-all types-colorama==0.4.15.12 # via types-all types-contextvars==2.4.7.3 # via types-all -types-croniter==1.4.0.1 # via types-all +types-croniter==2.0.0.0 # via types-all types-cryptography==3.3.23.2 # via types-all, types-openssl-python, types-pyjwt types-dataclasses==0.6.6 # via types-all types-dateparser==1.1.4.10 # via types-all @@ -176,44 +189,44 @@ types-futures==3.3.8 # via types-all types-geoip2==3.0.0 # via types-all types-ipaddress==1.0.8 # via types-all, types-maxminddb types-itsdangerous==1.1.6 # via types-all -types-jack-client==0.5.10.9 # via types-all +types-jack-client==0.5.10.10 # via types-all types-jinja2==2.11.9 # via types-all, types-flask types-kazoo==0.1.3 # via types-all -types-markdown==3.4.2.10 # via types-all +types-markdown==3.5.0.3 # via types-all types-markupsafe==1.1.10 # via types-all, types-jinja2 types-maxminddb==1.5.0 # via types-all, types-geoip2 -types-mock==5.1.0.2 # via types-all +types-mock==5.1.0.3 # via types-all types-mypy-extensions==1.0.0.5 # via types-all types-nmap==0.1.6 # via types-all types-openssl-python==0.1.3 # via types-all types-orjson==3.6.2 # via types-all -types-paramiko==3.3.0.0 # via types-all, types-pysftp +types-paramiko==3.3.0.2 # via types-all, types-pysftp types-pathlib2==2.3.0 # via types-all -types-pillow==10.0.0.3 # via types-all +types-pillow==10.1.0.2 # via types-all types-pkg-resources==0.1.3 # via types-all types-polib==1.2.0.1 # via types-all -types-protobuf==4.24.0.1 # via types-all +types-protobuf==4.24.0.4 # via types-all types-pyaudio==0.2.16.7 # via types-all types-pycurl==7.45.2.5 # via types-all types-pyfarmhash==0.3.1.2 # via types-all types-pyjwt==1.7.1 # via types-all types-pymssql==2.1.0 # via types-all types-pymysql==1.1.0.1 # via types-all -types-pyopenssl==23.2.0.2 # via types-redis +types-pyopenssl==23.3.0.0 # via types-redis types-pyrfc3339==1.1.1.5 # via types-all types-pysftp==0.2.17.6 # via types-all types-python-dateutil==2.8.19.14 # via types-all, types-datetimerange types-python-gflags==3.1.7.3 # via types-all types-python-slugify==8.0.0.3 # via types-all -types-pytz==2023.3.1.0 # via types-all, types-tzlocal +types-pytz==2023.3.1.1 # via types-all, types-tzlocal types-pyvmomi==8.0.0.6 # via types-all -types-pyyaml==6.0.12.11 # via types-all -types-redis==4.6.0.6 # via types-all -types-requests==2.31.0.2 # via types-all +types-pyyaml==6.0.12.12 # via types-all +types-redis==4.6.0.11 # via types-all +types-requests==2.31.0.10 # via types-all types-retry==0.9.9.4 # via types-all types-routes==2.5.0 # via types-all types-scribe==2.0.0 # via types-all -types-setuptools==68.2.0.0 # via types-cffi +types-setuptools==69.0.0.0 # via types-cffi types-simplejson==3.19.0.2 # via types-all types-singledispatch==4.1.0.0 # via types-all types-six==1.16.21.9 # via types-all @@ -222,21 +235,21 @@ types-termcolor==1.1.6.2 # via types-all types-toml==0.10.8.7 # via types-all types-tornado==5.1.1 # via types-all types-typed-ast==1.5.8.7 # via types-all -types-tzlocal==5.0.1.1 # via types-all +types-tzlocal==5.1.0.1 # via types-all types-ujson==5.8.0.1 # via types-all -types-urllib3==1.26.25.14 # via types-requests types-waitress==2.1.4.9 # via types-all types-werkzeug==1.0.9 # via types-all, types-flask types-xxhash==3.0.5.2 # via types-all -typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pytest-factoryboy -urllib3==2.0.4 # via requests -virtualenv==20.24.5 # via pre-commit, tox -websockets==11.0.3 # via dace -werkzeug==2.3.7 # via flask -wheel==0.41.2 # via astunparse, pip-tools +typing-extensions==4.5.0 # via black, faker, gt4py (pyproject.toml), mypy, pydantic, pytest-factoryboy, setuptools-scm +urllib3==2.1.0 # via requests, types-requests +virtualenv==20.25.0 # via pre-commit, tox +websockets==12.0 # via dace +werkzeug==3.0.1 # via flask +wheel==0.42.0 # via astunparse, pip-tools, scikit-build xxhash==3.0.0 # via gt4py (pyproject.toml) -zipp==3.16.2 # via importlib-metadata, importlib-resources +yarg==0.1.9 # via pipreqs +zipp==3.17.0 # via importlib-metadata, importlib-resources # The following packages are considered to be unsafe in a requirements file: -pip==23.2.1 # via pip-tools -setuptools==68.2.2 # via gt4py (pyproject.toml), nodeenv, pip-tools +pip==23.3.1 # via pip-api, pip-tools, requirementslib +setuptools==69.0.2 # via gt4py (pyproject.toml), nodeenv, pip-tools, requirementslib, scikit-build, setuptools-scm diff --git a/src/gt4py/__about__.py b/src/gt4py/__about__.py index 57b914f25b..10f4607724 100644 --- a/src/gt4py/__about__.py +++ b/src/gt4py/__about__.py @@ -33,5 +33,5 @@ __license__: Final = "GPL-3.0-or-later" -__version__: Final = "1.0.1" +__version__: Final = "1.0.2" __version_info__: Final = pkg_version.parse(__version__) diff --git a/src/gt4py/__init__.py b/src/gt4py/__init__.py index 7d255de142..c28c5cf2d6 100644 --- a/src/gt4py/__init__.py +++ b/src/gt4py/__init__.py @@ -33,6 +33,6 @@ if _sys.version_info >= (3, 10): - from . import next + from . import next # noqa: A004 __all__ += ["next"] diff --git a/src/gt4py/cartesian/backend/dace_backend.py b/src/gt4py/cartesian/backend/dace_backend.py index b1e559a41e..5dae025acb 100644 --- a/src/gt4py/cartesian/backend/dace_backend.py +++ b/src/gt4py/cartesian/backend/dace_backend.py @@ -562,12 +562,6 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S omp_threads = "" omp_header = "" - # Backward compatible state struct name change in DaCe >=0.15.x - try: - dace_state_suffix = dace.Config.get("compiler.codegen_state_struct_suffix") - except (KeyError, TypeError): - dace_state_suffix = "_t" # old structure name - interface = cls.template.definition.render( name=sdfg.name, backend_specifics=omp_threads, @@ -575,7 +569,7 @@ def apply(cls, stencil_ir: gtir.Stencil, builder: "StencilBuilder", sdfg: dace.S functor_args=self.generate_functor_args(sdfg), tmp_allocs=self.generate_tmp_allocs(sdfg), allocator="gt::cuda_util::cuda_malloc" if is_gpu else "std::make_unique", - state_suffix=dace_state_suffix, + state_suffix=dace.Config.get("compiler.codegen_state_struct_suffix"), ) generated_code = textwrap.dedent( f"""#include diff --git a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py index db276a48b9..48b129fa87 100644 --- a/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py +++ b/src/gt4py/cartesian/gtc/dace/expansion/daceir_builder.py @@ -30,6 +30,7 @@ compute_dcir_access_infos, flatten_list, get_tasklet_symbol, + make_dace_subset, union_inout_memlets, union_node_grid_subsets, untile_memlets, @@ -458,6 +459,40 @@ def visit_HorizontalExecution( write_memlets=write_memlets, ) + for memlet in [*read_memlets, *write_memlets]: + """ + This loop handles the special case of a tasklet performing array access. + The memlet should pass the full array shape (no tiling) and + the tasklet expression for array access should use all explicit indexes. + """ + array_ndims = len(global_ctx.arrays[memlet.field].shape) + field_decl = global_ctx.library_node.field_decls[memlet.field] + # calculate array subset on original memlet + memlet_subset = make_dace_subset( + global_ctx.library_node.access_infos[memlet.field], + memlet.access_info, + field_decl.data_dims, + ) + # select index values for single-point grid access + memlet_data_index = [ + dcir.Literal(value=str(dim_range[0]), dtype=common.DataType.INT32) + for dim_range, dim_size in zip(memlet_subset, memlet_subset.size()) + if dim_size == 1 + ] + if len(memlet_data_index) < array_ndims: + reshape_memlet = False + for access_node in dcir_node.walk_values().if_isinstance(dcir.IndexAccess): + if access_node.data_index and access_node.name == memlet.connector: + access_node.data_index = memlet_data_index + access_node.data_index + assert len(access_node.data_index) == array_ndims + reshape_memlet = True + if reshape_memlet: + # ensure that memlet symbols used for array indexing are defined in context + for sym in memlet.access_info.grid_subset.free_symbols: + symbol_collector.add_symbol(sym) + # set full shape on memlet + memlet.access_info = global_ctx.library_node.access_infos[memlet.field] + for item in reversed(expansion_items): iteration_ctx = iteration_ctx.pop() dcir_node = self._process_iteration_item( diff --git a/src/gt4py/cartesian/gtc/dace/nodes.py b/src/gt4py/cartesian/gtc/dace/nodes.py index ddcb719b5f..bd8c08034c 100644 --- a/src/gt4py/cartesian/gtc/dace/nodes.py +++ b/src/gt4py/cartesian/gtc/dace/nodes.py @@ -121,7 +121,7 @@ def __init__( *args, **kwargs, ): - super().__init__(name=name, *args, **kwargs) + super().__init__(*args, name=name, **kwargs) from gt4py.cartesian.gtc.dace.utils import compute_dcir_access_infos diff --git a/src/gt4py/cartesian/gtc/daceir.py b/src/gt4py/cartesian/gtc/daceir.py index 28ebc8cd8e..0366317360 100644 --- a/src/gt4py/cartesian/gtc/daceir.py +++ b/src/gt4py/cartesian/gtc/daceir.py @@ -536,7 +536,7 @@ def union(self, other): else: assert ( isinstance(interval2, (TileInterval, DomainInterval)) - and isinstance(interval1, IndexWithExtent) + and isinstance(interval1, (IndexWithExtent, DomainInterval)) ) or ( isinstance(interval1, (TileInterval, DomainInterval)) and isinstance(interval2, IndexWithExtent) @@ -573,7 +573,7 @@ def overapproximated_shape(self): def apply_iteration(self, grid_subset: GridSubset): res_intervals = dict(self.grid_subset.intervals) for axis, field_interval in self.grid_subset.intervals.items(): - if axis in grid_subset.intervals: + if axis in grid_subset.intervals and not isinstance(field_interval, DomainInterval): grid_interval = grid_subset.intervals[axis] assert isinstance(field_interval, IndexWithExtent) extent = field_interval.extent diff --git a/src/gt4py/eve/__init__.py b/src/gt4py/eve/__init__.py index 617a889e28..e726db1f1a 100644 --- a/src/gt4py/eve/__init__.py +++ b/src/gt4py/eve/__init__.py @@ -58,7 +58,12 @@ field, frozenmodel, ) -from .traits import SymbolTableTrait, ValidatedSymbolTableTrait, VisitorWithSymbolTableTrait +from .traits import ( + PreserveLocationVisitor, + SymbolTableTrait, + ValidatedSymbolTableTrait, + VisitorWithSymbolTableTrait, +) from .trees import ( bfs_walk_items, bfs_walk_values, @@ -113,6 +118,7 @@ "SymbolTableTrait", "ValidatedSymbolTableTrait", "VisitorWithSymbolTableTrait", + "PreserveLocationVisitor", # trees "bfs_walk_items", "bfs_walk_values", diff --git a/src/gt4py/eve/datamodels/core.py b/src/gt4py/eve/datamodels/core.py index fcd53d1312..bc744b3ccc 100644 --- a/src/gt4py/eve/datamodels/core.py +++ b/src/gt4py/eve/datamodels/core.py @@ -814,7 +814,7 @@ def concretize( """ # noqa: RST301 # doctest conventions confuse RST validator concrete_cls: Type[DataModelT] = _make_concrete_with_cache( - datamodel_cls, *type_args, class_name=class_name, module=module + datamodel_cls, *type_args, class_name=class_name, module=module # type: ignore[arg-type] ) assert isinstance(concrete_cls, type) and is_datamodel(concrete_cls) @@ -883,17 +883,6 @@ def _substitute_typevars( return type_params_map[type_hint], True elif getattr(type_hint, "__parameters__", []): return type_hint[tuple(type_params_map[tp] for tp in type_hint.__parameters__)], True - # TODO(egparedes): WIP fix for partial specialization - # # Type hint is a generic model: replace all the concretized type vars - # noqa: e800 replaced = False - # noqa: e800 new_args = [] - # noqa: e800 for tp in type_hint.__parameters__: - # noqa: e800 if tp in type_params_map: - # noqa: e800 new_args.append(type_params_map[tp]) - # noqa: e800 replaced = True - # noqa: e800 else: - # noqa: e800 new_args.append(type_params_map[tp]) - # noqa: e800 return type_hint[tuple(new_args)], replaced else: return type_hint, False @@ -981,21 +970,14 @@ def __class_getitem__( """ type_args: Tuple[Type] = args if isinstance(args, tuple) else (args,) concrete_cls: Type[DataModelT] = concretize(cls, *type_args) - res = xtyping.StdGenericAliasType(concrete_cls, type_args) - if sys.version_info < (3, 9): - # in Python 3.8, xtyping.StdGenericAliasType (aka typing._GenericAlias) - # does not copy all required `__dict__` entries, so do it manually - for k, v in concrete_cls.__dict__.items(): - if k not in res.__dict__: - res.__dict__[k] = v - return res + return concrete_cls return classmethod(__class_getitem__) def _make_type_converter(type_annotation: TypeAnnotation, name: str) -> TypeConverter[_T]: - # TODO(egparedes): if a "typing tree" structure is implemented, refactor this code as a tree traversal. - # + # TODO(egparedes): if a "typing tree" structure is implemented, refactor this code + # as a tree traversal. if xtyping.is_actual_type(type_annotation) and not isinstance(None, type_annotation): assert not xtyping.get_args(type_annotation) assert isinstance(type_annotation, type) @@ -1316,11 +1298,7 @@ def _make_concrete_with_cache( # Replace field definitions with the new actual types for generic fields type_params_map = dict(zip(datamodel_cls.__parameters__, type_args)) model_fields = getattr(datamodel_cls, MODEL_FIELD_DEFINITIONS_ATTR) - new_annotations = { - # TODO(egparedes): ? - # noqa: e800 "__args__": "ClassVar[Tuple[Union[Type, TypeVar], ...]]", - # noqa: e800 "__parameters__": "ClassVar[Tuple[TypeVar, ...]]", - } + new_annotations = {} new_field_c_attrs = {} for field_name, field_type in xtyping.get_type_hints(datamodel_cls).items(): @@ -1353,8 +1331,16 @@ def _make_concrete_with_cache( "__module__": module if module else datamodel_cls.__module__, **new_field_c_attrs, } - concrete_cls = type(class_name, (datamodel_cls,), namespace) + + # Update the tuple of generic parameters in the new class, in case + # this is a partial concretization + assert hasattr(concrete_cls, "__parameters__") + concrete_cls.__parameters__ = tuple( + type_params_map[tp_var] + for tp_var in datamodel_cls.__parameters__ + if isinstance(type_params_map[tp_var], typing.TypeVar) + ) assert concrete_cls.__module__ == module or not module if MODEL_FIELD_DEFINITIONS_ATTR not in concrete_cls.__dict__: diff --git a/src/gt4py/eve/extended_typing.py b/src/gt4py/eve/extended_typing.py index 17462a37ff..3ee447ca6c 100644 --- a/src/gt4py/eve/extended_typing.py +++ b/src/gt4py/eve/extended_typing.py @@ -493,7 +493,7 @@ def _patched_proto_hook(other): # type: ignore[no-untyped-def] if isinstance(_typing.Any, type): # Python >= 3.11 _ArtefactTypes = (*_ArtefactTypes, _typing.Any) -# `Any` is a class since typing_extensions >= 4.4 +# `Any` is a class since typing_extensions >= 4.4 and Python 3.11 if (typing_exts_any := getattr(_typing_extensions, "Any", None)) is not _typing.Any and isinstance( typing_exts_any, type ): @@ -504,11 +504,13 @@ def is_actual_type(obj: Any) -> TypeGuard[Type]: """Check if an object has an actual type and instead of a typing artefact like ``GenericAlias`` or ``Any``. This is needed because since Python 3.9: - ``isinstance(types.GenericAlias(), type) is True`` + ``isinstance(types.GenericAlias(), type) is True`` and since Python 3.11: - ``isinstance(typing.Any, type) is True`` + ``isinstance(typing.Any, type) is True`` """ - return isinstance(obj, type) and type(obj) not in _ArtefactTypes + return ( + isinstance(obj, type) and (obj not in _ArtefactTypes) and (type(obj) not in _ArtefactTypes) + ) if hasattr(_typing_extensions, "Any") and _typing.Any is not _typing_extensions.Any: # type: ignore[attr-defined] # _typing_extensions.Any only from >= 4.4 @@ -641,9 +643,12 @@ def get_partial_type_hints( resolved_hints = get_type_hints( # type: ignore[call-arg] # Python 3.8 does not define `include-extras` obj, globalns=globalns, localns=localns, include_extras=include_extras ) - hints.update(resolved_hints) + hints[name] = resolved_hints[name] except NameError as error: if isinstance(hint, str): + # This conversion could be probably skipped in Python versions containing + # the fix applied in bpo-41370. Check: + # https://github.com/python/cpython/commit/b465b606049f6f7dd0711cb031fdaa251818741a#diff-ddb987fca5f5df0c9a2f5521ed687919d70bb3d64eaeb8021f98833a2a716887R344 hints[name] = ForwardRef(hint) elif isinstance(hint, (ForwardRef, _typing.ForwardRef)): hints[name] = hint diff --git a/src/gt4py/eve/traits.py b/src/gt4py/eve/traits.py index df556c9d7f..aacae804d8 100644 --- a/src/gt4py/eve/traits.py +++ b/src/gt4py/eve/traits.py @@ -172,3 +172,11 @@ def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: kwargs["symtable"] = kwargs["symtable"].parents return result + + +class PreserveLocationVisitor(visitors.NodeVisitor): + def visit(self, node: concepts.RootNode, **kwargs: Any) -> Any: + result = super().visit(node, **kwargs) + if hasattr(node, "location") and hasattr(result, "location"): + result.location = node.location + return result diff --git a/src/gt4py/eve/trees.py b/src/gt4py/eve/trees.py index cd7e71588f..74c5bd41bb 100644 --- a/src/gt4py/eve/trees.py +++ b/src/gt4py/eve/trees.py @@ -133,7 +133,7 @@ def _pre_walk_items( yield from _pre_walk_items(child, __key__=key) -def _pre_walk_values(node: TreeLike) -> Iterable[Tuple[Any]]: +def _pre_walk_values(node: TreeLike) -> Iterable: """Create a pre-order tree traversal iterator of values.""" yield node for child in iter_children_values(node): @@ -153,7 +153,7 @@ def _post_walk_items( yield __key__, node -def _post_walk_values(node: TreeLike) -> Iterable[Tuple[Any]]: +def _post_walk_values(node: TreeLike) -> Iterable: """Create a post-order tree traversal iterator of values.""" if (iter_children_values := getattr(node, "iter_children_values", None)) is not None: for child in iter_children_values(): diff --git a/src/gt4py/eve/utils.py b/src/gt4py/eve/utils.py index 7104f7658f..624407f319 100644 --- a/src/gt4py/eve/utils.py +++ b/src/gt4py/eve/utils.py @@ -1225,7 +1225,7 @@ def unzip(self) -> XIterable[Tuple[Any, ...]]: [('a', 'b', 'c'), (1, 2, 3)] """ - return XIterable(zip(*self.iterator)) # type: ignore # mypy gets confused with *args + return XIterable(zip(*self.iterator)) @typing.overload def islice(self, __stop: int) -> XIterable[T]: @@ -1536,7 +1536,7 @@ def reduceby( ) -> Dict[K, S]: ... - def reduceby( # type: ignore[misc] # signatures 2 and 4 are not satified due to inconsistencies with type variables + def reduceby( self, bin_op_func: Callable[[S, T], S], key: Union[str, List[K], Callable[[T], K]], diff --git a/src/gt4py/next/__init__.py b/src/gt4py/next/__init__.py index cbd5735949..1398af5f03 100644 --- a/src/gt4py/next/__init__.py +++ b/src/gt4py/next/__init__.py @@ -39,6 +39,11 @@ index_field, np_as_located_field, ) +from .program_processors.runners.gtfn import ( + run_gtfn_cached as gtfn_cpu, + run_gtfn_gpu_cached as gtfn_gpu, +) +from .program_processors.runners.roundtrip import backend as itir_python __all__ = [ @@ -74,5 +79,9 @@ "field_operator", "program", "scan_operator", + # from program_processor + "gtfn_cpu", + "gtfn_gpu", + "itir_python", *fbuiltins.__all__, ] diff --git a/src/gt4py/next/common.py b/src/gt4py/next/common.py index 6e0ca14cb0..840b0e8bbc 100644 --- a/src/gt4py/next/common.py +++ b/src/gt4py/next/common.py @@ -189,11 +189,12 @@ def __and__(self, other: UnitRange) -> UnitRange: return UnitRange(max(self.start, other.start), min(self.stop, other.stop)) def __contains__(self, value: Any) -> bool: - return ( - isinstance(value, core_defs.INTEGRAL_TYPES) - and value >= self.start - and value < self.stop - ) + # TODO(egparedes): use core_defs.IntegralScalar for `isinstance()` checks (see PEP 604) + # and remove int cast, once the related mypy bug (#16358) gets fixed + if isinstance(value, core_defs.INTEGRAL_TYPES): + return self.start <= cast(int, value) < self.stop + else: + return False def __le__(self, other: UnitRange) -> bool: return self.start >= other.start and self.stop <= other.stop @@ -574,38 +575,39 @@ def __call__(self, func: fbuiltins.BuiltInFunction[_R, _P], /) -> Callable[_P, _ ... -# TODO(havogt): replace this protocol with the new `GTFieldInterface` protocol -class NextGTDimsInterface(Protocol): +# TODO(havogt): we need to describe when this interface should be used instead of the `Field` protocol. +class GTFieldInterface(core_defs.GTDimsInterface, core_defs.GTOriginInterface, Protocol): """ - Protocol for objects providing the `__gt_dims__` property, naming :class:`Field` dimensions. + Protocol for object providing the `__gt_domain__` property, specifying the :class:`Domain` of a :class:`Field`. - The dimension names are objects of type :class:`Dimension`, in contrast to - :mod:`gt4py.cartesian`, where the labels are `str` s with implied semantics, - see :class:`~gt4py._core.definitions.GTDimsInterface` . + Note: + - A default implementation of the `__gt_dims__` interface from `gt4py.cartesian` is provided. + - No implementation of `__gt_origin__` is provided because of infinite fields. """ @property - def __gt_dims__(self) -> tuple[Dimension, ...]: + def __gt_domain__(self) -> Domain: + # TODO probably should be changed to `DomainLike` (with a new concept `DimensionLike`) + # to allow implementations without having to import gtx.Domain. ... - -# TODO(egparedes): add support for this new protocol in the cartesian module -class GTFieldInterface(Protocol): - """Protocol for object providing the `__gt_domain__` property, specifying the :class:`Domain` of a :class:`Field`.""" - @property - def __gt_domain__(self) -> Domain: - ... + def __gt_dims__(self) -> tuple[str, ...]: + return tuple(d.value for d in self.__gt_domain__.dims) @extended_runtime_checkable -class Field(NextGTDimsInterface, core_defs.GTOriginInterface, Protocol[DimsT, core_defs.ScalarT]): +class Field(GTFieldInterface, Protocol[DimsT, core_defs.ScalarT]): __gt_builtin_func__: ClassVar[GTBuiltInFuncDispatcher] @property def domain(self) -> Domain: ... + @property + def __gt_domain__(self) -> Domain: + return self.domain + @property def codomain(self) -> type[core_defs.ScalarT] | Dimension: ... @@ -923,10 +925,6 @@ def asnumpy(self) -> Never: def domain(self) -> Domain: return Domain(dims=(self.dimension,), ranges=(UnitRange.infinite(),)) - @property - def __gt_dims__(self) -> tuple[Dimension, ...]: - return self.domain.dims - @property def __gt_origin__(self) -> Never: raise TypeError("'CartesianConnectivity' does not support this operation.") diff --git a/src/gt4py/next/embedded/nd_array_field.py b/src/gt4py/next/embedded/nd_array_field.py index e8bdfc9d7f..52a61b40bb 100644 --- a/src/gt4py/next/embedded/nd_array_field.py +++ b/src/gt4py/next/embedded/nd_array_field.py @@ -105,10 +105,6 @@ def domain(self) -> common.Domain: def shape(self) -> tuple[int, ...]: return self._ndarray.shape - @property - def __gt_dims__(self) -> tuple[common.Dimension, ...]: - return self._domain.dims - @property def __gt_origin__(self) -> tuple[int, ...]: assert common.Domain.is_finite(self._domain) diff --git a/src/gt4py/next/ffront/decorator.py b/src/gt4py/next/ffront/decorator.py index 76a0ddcde0..9f8537f59b 100644 --- a/src/gt4py/next/ffront/decorator.py +++ b/src/gt4py/next/ffront/decorator.py @@ -344,7 +344,9 @@ def _validate_args(self, *args, **kwargs) -> None: raise_exception=True, ) except ValueError as err: - raise TypeError(f"Invalid argument types in call to '{self.past_node.id}'.") from err + raise errors.DSLError( + None, f"Invalid argument types in call to '{self.past_node.id}'.\n{err}" + ) from err def _process_args(self, args: tuple, kwargs: dict) -> tuple[tuple, tuple, dict[str, Any]]: self._validate_args(*args, **kwargs) @@ -453,27 +455,32 @@ def _process_args(self, args: tuple, kwargs: dict): ) from err full_args = [*args] + full_kwargs = {**kwargs} for index, param in enumerate(self.past_node.params): if param.id in self.bound_args.keys(): - full_args.insert(index, self.bound_args[param.id]) + if index < len(full_args): + full_args.insert(index, self.bound_args[param.id]) + else: + full_kwargs[str(param.id)] = self.bound_args[param.id] - return super()._process_args(tuple(full_args), kwargs) + return super()._process_args(tuple(full_args), full_kwargs) @functools.cached_property def itir(self): new_itir = super().itir for new_clos in new_itir.closures: - for key in self.bound_args.keys(): + new_args = [ref(inp.id) for inp in new_clos.inputs] + for key, value in self.bound_args.items(): index = next( index for index, closure_input in enumerate(new_clos.inputs) if closure_input.id == key ) + new_args[new_args.index(new_clos.inputs[index])] = promote_to_const_iterator( + literal_from_value(value) + ) new_clos.inputs.pop(index) - new_args = [ref(inp.id) for inp in new_clos.inputs] params = [sym(inp.id) for inp in new_clos.inputs] - for value in self.bound_args.values(): - new_args.append(promote_to_const_iterator(literal_from_value(value))) expr = itir.FunCall( fun=new_clos.stencil, args=new_args, @@ -847,6 +854,7 @@ def scan_operator_inner(definition: types.FunctionType) -> FieldOperator: return FieldOperator.from_function( definition, DEFAULT_BACKEND if backend is eve.NOTHING else backend, + grid_type, operator_node_cls=foast.ScanOperator, operator_attributes={"axis": axis, "forward": forward, "init": init}, ) diff --git a/src/gt4py/next/ffront/fbuiltins.py b/src/gt4py/next/ffront/fbuiltins.py index 9a6e321f9a..493493f697 100644 --- a/src/gt4py/next/ffront/fbuiltins.py +++ b/src/gt4py/next/ffront/fbuiltins.py @@ -15,7 +15,7 @@ import dataclasses import functools import inspect -from builtins import bool, float, int, tuple +from builtins import bool, float, int, tuple # noqa: A004 from typing import Any, Callable, Generic, ParamSpec, Tuple, TypeAlias, TypeVar, Union, cast import numpy as np diff --git a/src/gt4py/next/ffront/foast_passes/type_deduction.py b/src/gt4py/next/ffront/foast_passes/type_deduction.py index 639e5ff009..5e289af664 100644 --- a/src/gt4py/next/ffront/foast_passes/type_deduction.py +++ b/src/gt4py/next/ffront/foast_passes/type_deduction.py @@ -694,7 +694,7 @@ def visit_Call(self, node: foast.Call, **kwargs) -> foast.Call: ) except ValueError as err: raise errors.DSLError( - node.location, f"Invalid argument types in call to '{new_func}'." + node.location, f"Invalid argument types in call to '{new_func}'.\n{err}" ) from err return_type = type_info.return_type(func_type, with_args=arg_types, with_kwargs=kwarg_types) diff --git a/src/gt4py/next/ffront/foast_to_itir.py b/src/gt4py/next/ffront/foast_to_itir.py index c4d518d279..0c9ab4ab27 100644 --- a/src/gt4py/next/ffront/foast_to_itir.py +++ b/src/gt4py/next/ffront/foast_to_itir.py @@ -15,7 +15,7 @@ import dataclasses from typing import Any, Callable, Optional -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.eve.utils import UIDGenerator from gt4py.next.ffront import ( dialect_ast_enums, @@ -39,7 +39,7 @@ def promote_to_list( @dataclasses.dataclass -class FieldOperatorLowering(NodeTranslator): +class FieldOperatorLowering(PreserveLocationVisitor, NodeTranslator): """ Lower FieldOperator AST (FOAST) to Iterator IR (ITIR). @@ -61,7 +61,7 @@ class FieldOperatorLowering(NodeTranslator): >>> lowered.id SymbolName('fieldop') - >>> lowered.params + >>> lowered.params # doctest: +ELLIPSIS [Sym(id=SymbolName('inp'), kind='Iterator', dtype=('float64', False))] """ @@ -142,7 +142,7 @@ def visit_IfStmt( self, node: foast.IfStmt, *, inner_expr: Optional[itir.Expr], **kwargs ) -> itir.Expr: # the lowered if call doesn't need to be lifted as the condition can only originate - # from a scalar value (and not a field) + # from a scalar value (and not a field) assert ( isinstance(node.condition.type, ts.ScalarType) and node.condition.type.kind == ts.ScalarKind.BOOL diff --git a/src/gt4py/next/ffront/past_passes/type_deduction.py b/src/gt4py/next/ffront/past_passes/type_deduction.py index fc353d64e4..af8f5e8368 100644 --- a/src/gt4py/next/ffront/past_passes/type_deduction.py +++ b/src/gt4py/next/ffront/past_passes/type_deduction.py @@ -229,7 +229,7 @@ def visit_Call(self, node: past.Call, **kwargs): ) except ValueError as ex: - raise errors.DSLError(node.location, f"Invalid call to '{node.func.id}'.") from ex + raise errors.DSLError(node.location, f"Invalid call to '{node.func.id}'.\n{ex}") from ex return past.Call( func=new_func, diff --git a/src/gt4py/next/ffront/past_to_itir.py b/src/gt4py/next/ffront/past_to_itir.py index 709912077b..ed239e0436 100644 --- a/src/gt4py/next/ffront/past_to_itir.py +++ b/src/gt4py/next/ffront/past_to_itir.py @@ -40,7 +40,9 @@ def _flatten_tuple_expr( raise ValueError("Only 'past.Name', 'past.Subscript' or 'past.TupleExpr' thereof are allowed.") -class ProgramLowering(traits.VisitorWithSymbolTableTrait, NodeTranslator): +class ProgramLowering( + traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator +): """ Lower Program AST (PAST) to Iterator IR (ITIR). @@ -151,6 +153,7 @@ def _visit_stencil_call(self, node: past.Call, **kwargs) -> itir.StencilClosure: stencil=itir.SymRef(id=node.func.id), inputs=[*lowered_args, *lowered_kwargs.values()], output=output, + location=node.location, ) def _visit_slice_bound( @@ -175,17 +178,22 @@ def _visit_slice_bound( lowered_bound = self.visit(slice_bound, **kwargs) else: raise AssertionError("Expected 'None' or 'past.Constant'.") + if slice_bound: + lowered_bound.location = slice_bound.location return lowered_bound def _construct_itir_out_arg(self, node: past.Expr) -> itir.Expr: if isinstance(node, past.Name): - return itir.SymRef(id=node.id) + return itir.SymRef(id=node.id, location=node.location) elif isinstance(node, past.Subscript): - return self._construct_itir_out_arg(node.value) + itir_node = self._construct_itir_out_arg(node.value) + itir_node.location = node.location + return itir_node elif isinstance(node, past.TupleExpr): return itir.FunCall( fun=itir.SymRef(id="make_tuple"), args=[self._construct_itir_out_arg(el) for el in node.elts], + location=node.location, ) else: raise ValueError( @@ -247,7 +255,11 @@ def _construct_itir_domain_arg( else: raise AssertionError() - return itir.FunCall(fun=itir.SymRef(id=domain_builtin), args=domain_args) + return itir.FunCall( + fun=itir.SymRef(id=domain_builtin), + args=domain_args, + location=(node_domain or out_field).location, + ) def _construct_itir_initialized_domain_arg( self, diff --git a/src/gt4py/next/iterator/embedded.py b/src/gt4py/next/iterator/embedded.py index 13f8722494..6d610fd136 100644 --- a/src/gt4py/next/iterator/embedded.py +++ b/src/gt4py/next/iterator/embedded.py @@ -172,7 +172,7 @@ class LocatedField(Protocol): @property @abc.abstractmethod - def __gt_dims__(self) -> tuple[common.Dimension, ...]: + def __gt_domain__(self) -> common.Domain: ... # TODO(havogt): define generic Protocol to provide a concrete return type @@ -182,7 +182,7 @@ def field_getitem(self, indices: NamedFieldIndices) -> Any: @property def __gt_origin__(self) -> tuple[int, ...]: - return tuple([0] * len(self.__gt_dims__)) + return tuple([0] * len(self.__gt_domain__.dims)) @runtime_checkable @@ -675,12 +675,18 @@ def _is_concrete_position(pos: Position) -> TypeGuard[ConcretePosition]: def _get_axes( field_or_tuple: LocatedField | tuple, ) -> Sequence[common.Dimension]: # arbitrary nesting of tuples of LocatedField + return _get_domain(field_or_tuple).dims + + +def _get_domain( + field_or_tuple: LocatedField | tuple, +) -> common.Domain: # arbitrary nesting of tuples of LocatedField if isinstance(field_or_tuple, tuple): - first = _get_axes(field_or_tuple[0]) - assert all(first == _get_axes(f) for f in field_or_tuple) + first = _get_domain(field_or_tuple[0]) + assert all(first == _get_domain(f) for f in field_or_tuple) return first else: - return field_or_tuple.__gt_dims__ + return field_or_tuple.__gt_domain__ def _single_vertical_idx( @@ -894,14 +900,14 @@ class NDArrayLocatedFieldWrapper(MutableLocatedField): _ndarrayfield: common.Field @property - def __gt_dims__(self) -> tuple[common.Dimension, ...]: - return self._ndarrayfield.__gt_dims__ + def __gt_domain__(self) -> common.Domain: + return self._ndarrayfield.__gt_domain__ def _translate_named_indices( self, _named_indices: NamedFieldIndices ) -> common.AbsoluteIndexSequence: named_indices: Mapping[common.Dimension, FieldIndex | SparsePositionEntry] = { - d: _named_indices[d.value] for d in self._ndarrayfield.__gt_dims__ + d: _named_indices[d.value] for d in self._ndarrayfield.__gt_domain__.dims } domain_slice: list[common.NamedRange | common.NamedIndex] = [] for d, v in named_indices.items(): @@ -1046,8 +1052,8 @@ class IndexField(common.Field): _dimension: common.Dimension @property - def __gt_dims__(self) -> tuple[common.Dimension, ...]: - return (self._dimension,) + def __gt_domain__(self) -> common.Domain: + return self.domain @property def __gt_origin__(self) -> tuple[int, ...]: @@ -1165,8 +1171,8 @@ class ConstantField(common.Field[Any, core_defs.ScalarT]): _value: core_defs.ScalarT @property - def __gt_dims__(self) -> tuple[common.Dimension, ...]: - return tuple() + def __gt_domain__(self) -> common.Domain: + return self.domain @property def __gt_origin__(self) -> tuple[int, ...]: @@ -1452,7 +1458,7 @@ def _tuple_assign(field: tuple | MutableLocatedField, value: Any, named_indices: class TupleOfFields(TupleField): def __init__(self, data): self.data = data - self.__gt_dims__ = _get_axes(data) + self.__gt_domain__ = _get_domain(data) def field_getitem(self, named_indices: NamedFieldIndices) -> Any: return _build_tuple_result(self.data, named_indices) diff --git a/src/gt4py/next/iterator/ir.py b/src/gt4py/next/iterator/ir.py index e6ee20e227..37abbec9e7 100644 --- a/src/gt4py/next/iterator/ir.py +++ b/src/gt4py/next/iterator/ir.py @@ -17,12 +17,15 @@ import gt4py.eve as eve from gt4py.eve import Coerced, SymbolName, SymbolRef, datamodels +from gt4py.eve.concepts import SourceLocation from gt4py.eve.traits import SymbolTableTrait, ValidatedSymbolTableTrait from gt4py.eve.utils import noninstantiable @noninstantiable class Node(eve.Node): + location: Optional[SourceLocation] = eve.field(default=None, repr=False, compare=False) + def __str__(self) -> str: from gt4py.next.iterator.pretty_printer import pformat diff --git a/src/gt4py/next/iterator/tracing.py b/src/gt4py/next/iterator/tracing.py index 30fec1f9fd..05ebd02352 100644 --- a/src/gt4py/next/iterator/tracing.py +++ b/src/gt4py/next/iterator/tracing.py @@ -254,7 +254,7 @@ def _contains_tuple_dtype_field(arg): # other `np.int32`). We just ignore the error here and postpone fixing this to when # the new storages land (The implementation here works for LocatedFieldImpl). - return common.is_field(arg) and any(dim is None for dim in arg.__gt_dims__) + return common.is_field(arg) and any(dim is None for dim in arg.domain.dims) def _make_fencil_params(fun, args, *, use_arg_types: bool) -> list[Sym]: diff --git a/src/gt4py/next/iterator/transforms/collapse_list_get.py b/src/gt4py/next/iterator/transforms/collapse_list_get.py index 08cbd7313e..6acb8a79c4 100644 --- a/src/gt4py/next/iterator/transforms/collapse_list_get.py +++ b/src/gt4py/next/iterator/transforms/collapse_list_get.py @@ -16,7 +16,7 @@ from gt4py.next.iterator import ir -class CollapseListGet(eve.NodeTranslator): +class CollapseListGet(eve.PreserveLocationVisitor, eve.NodeTranslator): """Simplifies expressions containing `list_get`. Examples diff --git a/src/gt4py/next/iterator/transforms/collapse_tuple.py b/src/gt4py/next/iterator/transforms/collapse_tuple.py index 30457f2246..42bbf28909 100644 --- a/src/gt4py/next/iterator/transforms/collapse_tuple.py +++ b/src/gt4py/next/iterator/transforms/collapse_tuple.py @@ -48,7 +48,7 @@ def _get_tuple_size(elem: ir.Node, node_types: Optional[dict] = None) -> int | t @dataclass(frozen=True) -class CollapseTuple(eve.NodeTranslator): +class CollapseTuple(eve.PreserveLocationVisitor, eve.NodeTranslator): """ Simplifies `make_tuple`, `tuple_get` calls. @@ -88,13 +88,6 @@ def apply( node_types, ).visit(node) - return cls( - ignore_tuple_size, - collapse_make_tuple_tuple_get, - collapse_tuple_get_make_tuple, - use_global_type_inference, - ).visit(node) - def visit_FunCall(self, node: ir.FunCall, **kwargs) -> ir.Node: if ( self.collapse_make_tuple_tuple_get diff --git a/src/gt4py/next/iterator/transforms/constant_folding.py b/src/gt4py/next/iterator/transforms/constant_folding.py index fa326760b0..696a87a197 100644 --- a/src/gt4py/next/iterator/transforms/constant_folding.py +++ b/src/gt4py/next/iterator/transforms/constant_folding.py @@ -12,12 +12,12 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import embedded, ir from gt4py.next.iterator.ir_utils import ir_makers as im -class ConstantFolding(NodeTranslator): +class ConstantFolding(PreserveLocationVisitor, NodeTranslator): @classmethod def apply(cls, node: ir.Node) -> ir.Node: return cls().visit(node) diff --git a/src/gt4py/next/iterator/transforms/cse.py b/src/gt4py/next/iterator/transforms/cse.py index 034a39d68f..f9cf272c45 100644 --- a/src/gt4py/next/iterator/transforms/cse.py +++ b/src/gt4py/next/iterator/transforms/cse.py @@ -17,14 +17,20 @@ import operator import typing -from gt4py.eve import NodeTranslator, NodeVisitor, SymbolTableTrait, VisitorWithSymbolTableTrait +from gt4py.eve import ( + NodeTranslator, + NodeVisitor, + PreserveLocationVisitor, + SymbolTableTrait, + VisitorWithSymbolTableTrait, +) from gt4py.eve.utils import UIDGenerator from gt4py.next.iterator import ir from gt4py.next.iterator.transforms.inline_lambdas import inline_lambda @dataclasses.dataclass -class _NodeReplacer(NodeTranslator): +class _NodeReplacer(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type",) expr_map: dict[int, ir.SymRef] @@ -72,7 +78,7 @@ def _is_collectable_expr(node: ir.Node) -> bool: @dataclasses.dataclass -class CollectSubexpressions(VisitorWithSymbolTableTrait, NodeVisitor): +class CollectSubexpressions(PreserveLocationVisitor, VisitorWithSymbolTableTrait, NodeVisitor): @dataclasses.dataclass class SubexpressionData: #: A list of node ids with equal hash and a set of collected child subexpression ids @@ -341,7 +347,7 @@ def extract_subexpression( @dataclasses.dataclass(frozen=True) -class CommonSubexpressionElimination(NodeTranslator): +class CommonSubexpressionElimination(PreserveLocationVisitor, NodeTranslator): """ Perform common subexpression elimination. diff --git a/src/gt4py/next/iterator/transforms/eta_reduction.py b/src/gt4py/next/iterator/transforms/eta_reduction.py index 55b2141499..93702a6c96 100644 --- a/src/gt4py/next/iterator/transforms/eta_reduction.py +++ b/src/gt4py/next/iterator/transforms/eta_reduction.py @@ -12,11 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir -class EtaReduction(NodeTranslator): +class EtaReduction(PreserveLocationVisitor, NodeTranslator): """Eta reduction: simplifies `λ(args...) → f(args...)` to `f`.""" def visit_Lambda(self, node: ir.Lambda) -> ir.Node: diff --git a/src/gt4py/next/iterator/transforms/fuse_maps.py b/src/gt4py/next/iterator/transforms/fuse_maps.py index e9fbb0f81d..694dcd6a61 100644 --- a/src/gt4py/next/iterator/transforms/fuse_maps.py +++ b/src/gt4py/next/iterator/transforms/fuse_maps.py @@ -38,7 +38,7 @@ def _is_reduce(node: ir.Node) -> TypeGuard[ir.FunCall]: @dataclasses.dataclass(frozen=True) -class FuseMaps(traits.VisitorWithSymbolTableTrait, NodeTranslator): +class FuseMaps(traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator): """ Fuses nested `map_`s. @@ -66,6 +66,7 @@ def _as_lambda(self, fun: ir.SymRef | ir.Lambda, param_count: int) -> ir.Lambda: return ir.Lambda( params=params, expr=ir.FunCall(fun=fun, args=[ir.SymRef(id=p.id) for p in params]), + location=fun.location, ) def visit_FunCall(self, node: ir.FunCall, **kwargs): @@ -99,6 +100,7 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): ir.FunCall( fun=inner_op, args=[ir.SymRef(id=param.id) for param in inner_op.params], + location=node.location, ) ) ) diff --git a/src/gt4py/next/iterator/transforms/global_tmps.py b/src/gt4py/next/iterator/transforms/global_tmps.py index d9d3d18213..c423a3c277 100644 --- a/src/gt4py/next/iterator/transforms/global_tmps.py +++ b/src/gt4py/next/iterator/transforms/global_tmps.py @@ -19,9 +19,10 @@ import gt4py.eve as eve import gt4py.next as gtx -from gt4py.eve import Coerced, NodeTranslator +from gt4py.eve import Coerced, NodeTranslator, PreserveLocationVisitor from gt4py.eve.traits import SymbolTableTrait from gt4py.eve.utils import UIDGenerator +from gt4py.next import common from gt4py.next.iterator import ir, type_inference from gt4py.next.iterator.ir_utils import ir_makers as im from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift @@ -266,6 +267,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp stencil=stencil, output=im.ref(tmp_sym.id), inputs=[closure_param_arg_mapping[param.id] for param in lift_expr.args], # type: ignore[attr-defined] + location=current_closure.location, ) ) @@ -293,6 +295,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp output=current_closure.output, inputs=current_closure.inputs + [ir.SymRef(id=sym.id) for sym in extracted_lifts.keys()], + location=current_closure.location, ) ) else: @@ -306,6 +309,7 @@ def split_closures(node: ir.FencilDefinition, offset_provider) -> FencilWithTemp + [ir.Sym(id=tmp.id) for tmp in tmps] + [ir.Sym(id=AUTO_DOMAIN.fun.id)], # type: ignore[attr-defined] # value is a global constant closures=list(reversed(closures)), + location=node.location, ), params=node.params, tmps=[Temporary(id=tmp.id) for tmp in tmps], @@ -332,6 +336,7 @@ def prune_unused_temporaries(node: FencilWithTemporaries) -> FencilWithTemporari function_definitions=node.fencil.function_definitions, params=[p for p in node.fencil.params if p.id not in unused_tmps], closures=closures, + location=node.fencil.location, ), params=node.params, tmps=[tmp for tmp in node.tmps if tmp.id not in unused_tmps], @@ -437,9 +442,12 @@ def _group_offsets( return zip(tags, offsets, strict=True) # type: ignore[return-value] # mypy doesn't infer literal correctly -def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, Any]): +def update_domains( + node: FencilWithTemporaries, + offset_provider: Mapping[str, Any], + symbolic_sizes: Optional[dict[str, str]], +): horizontal_sizes = _max_domain_sizes_by_location_type(offset_provider) - closures: list[ir.StencilClosure] = [] domains = dict[str, ir.FunCall]() for closure in reversed(node.fencil.closures): @@ -452,6 +460,7 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An stencil=closure.stencil, output=closure.output, inputs=closure.inputs, + location=closure.location, ) else: domain = closure.domain @@ -479,16 +488,29 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An # cartesian shift dim = offset_provider[offset_name].value consumed_domain.ranges[dim] = consumed_domain.ranges[dim].translate(offset) - elif isinstance(offset_provider[offset_name], gtx.NeighborTableOffsetProvider): + elif isinstance(offset_provider[offset_name], common.Connectivity): # unstructured shift nbt_provider = offset_provider[offset_name] old_axis = nbt_provider.origin_axis.value new_axis = nbt_provider.neighbor_axis.value - consumed_domain.ranges.pop(old_axis) - assert new_axis not in consumed_domain.ranges - consumed_domain.ranges[new_axis] = SymbolicRange( - im.literal("0", ir.INTEGER_INDEX_BUILTIN), - im.literal(str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN), + + assert new_axis not in consumed_domain.ranges or old_axis == new_axis + + if symbolic_sizes is None: + new_range = SymbolicRange( + im.literal("0", ir.INTEGER_INDEX_BUILTIN), + im.literal( + str(horizontal_sizes[new_axis]), ir.INTEGER_INDEX_BUILTIN + ), + ) + else: + new_range = SymbolicRange( + im.literal("0", ir.INTEGER_INDEX_BUILTIN), + im.ref(symbolic_sizes[new_axis]), + ) + consumed_domain.ranges = dict( + (axis, range_) if axis != old_axis else (new_axis, new_range) + for axis, range_ in consumed_domain.ranges.items() ) else: raise NotImplementedError @@ -504,6 +526,7 @@ def update_domains(node: FencilWithTemporaries, offset_provider: Mapping[str, An function_definitions=node.fencil.function_definitions, params=node.fencil.params[:-1], # remove `_gtmp_auto_domain` param again closures=list(reversed(closures)), + location=node.fencil.location, ), params=node.params, tmps=node.tmps, @@ -563,14 +586,18 @@ def convert_type(dtype): # TODO(tehrengruber): Add support for dynamic shifts (e.g. the distance is a symbol). This can be # tricky: For every lift statement that is dynamically shifted we can not compute bounds anymore # and hence also not extract as a temporary. -class CreateGlobalTmps(NodeTranslator): +class CreateGlobalTmps(PreserveLocationVisitor, NodeTranslator): """Main entry point for introducing global temporaries. Transforms an existing iterator IR fencil into a fencil with global temporaries. """ def visit_FencilDefinition( - self, node: ir.FencilDefinition, *, offset_provider: Mapping[str, Any] + self, + node: ir.FencilDefinition, + *, + offset_provider: Mapping[str, Any], + symbolic_sizes: Optional[dict[str, str]], ) -> FencilWithTemporaries: # Split closures on lifted function calls and introduce temporaries res = split_closures(node, offset_provider=offset_provider) @@ -581,6 +608,6 @@ def visit_FencilDefinition( # Perform an eta-reduction which should put all calls at the highest level of a closure res = EtaReduction().visit(res) # Perform a naive extent analysis to compute domain sizes of closures and temporaries - res = update_domains(res, offset_provider) + res = update_domains(res, offset_provider, symbolic_sizes) # Use type inference to determine the data type of the temporaries return collect_tmps_info(res, offset_provider=offset_provider) diff --git a/src/gt4py/next/iterator/transforms/inline_fundefs.py b/src/gt4py/next/iterator/transforms/inline_fundefs.py index 6bf2b60592..a53232745f 100644 --- a/src/gt4py/next/iterator/transforms/inline_fundefs.py +++ b/src/gt4py/next/iterator/transforms/inline_fundefs.py @@ -14,11 +14,11 @@ from typing import Any, Dict, Set -from gt4py.eve import NOTHING, NodeTranslator +from gt4py.eve import NOTHING, NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir -class InlineFundefs(NodeTranslator): +class InlineFundefs(PreserveLocationVisitor, NodeTranslator): def visit_SymRef(self, node: ir.SymRef, *, symtable: Dict[str, Any]): if node.id in symtable and isinstance((symbol := symtable[node.id]), ir.FunctionDefinition): return ir.Lambda( @@ -31,7 +31,7 @@ def visit_FencilDefinition(self, node: ir.FencilDefinition): return self.generic_visit(node, symtable=node.annex.symtable) -class PruneUnreferencedFundefs(NodeTranslator): +class PruneUnreferencedFundefs(PreserveLocationVisitor, NodeTranslator): def visit_FunctionDefinition( self, node: ir.FunctionDefinition, *, referenced: Set[str], second_pass: bool ): diff --git a/src/gt4py/next/iterator/transforms/inline_into_scan.py b/src/gt4py/next/iterator/transforms/inline_into_scan.py index fe1eae6e07..a1c9a2eb5b 100644 --- a/src/gt4py/next/iterator/transforms/inline_into_scan.py +++ b/src/gt4py/next/iterator/transforms/inline_into_scan.py @@ -53,7 +53,9 @@ def _lambda_and_lift_inliner(node: ir.FunCall) -> ir.FunCall: return inlined -class InlineIntoScan(traits.VisitorWithSymbolTableTrait, NodeTranslator): +class InlineIntoScan( + traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator +): """ Inline non-SymRef arguments into the scan. @@ -100,6 +102,5 @@ def visit_FunCall(self, node: ir.FunCall, **kwargs): new_scan = ir.FunCall( fun=ir.SymRef(id="scan"), args=[new_scanpass, *original_scan_call.args[1:]] ) - result = ir.FunCall(fun=new_scan, args=[ir.SymRef(id=ref) for ref in refs_in_args]) - return result + return ir.FunCall(fun=new_scan, args=[ir.SymRef(id=ref) for ref in refs_in_args]) return self.generic_visit(node, **kwargs) diff --git a/src/gt4py/next/iterator/transforms/inline_lambdas.py b/src/gt4py/next/iterator/transforms/inline_lambdas.py index a56ad5cb10..0b89fe6d98 100644 --- a/src/gt4py/next/iterator/transforms/inline_lambdas.py +++ b/src/gt4py/next/iterator/transforms/inline_lambdas.py @@ -15,7 +15,7 @@ import dataclasses from typing import Optional -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir from gt4py.next.iterator.ir_utils.common_pattern_matcher import is_applied_lift from gt4py.next.iterator.transforms.remap_symbols import RemapSymbolRefs, RenameSymbols @@ -104,6 +104,7 @@ def new_name(name): new_expr = RemapSymbolRefs().visit(expr, symbol_map=symbol_map) if all(eligible_params): + new_expr.location = node.location return new_expr else: return ir.FunCall( @@ -116,11 +117,12 @@ def new_name(name): expr=new_expr, ), args=[arg for arg, eligible in zip(node.args, eligible_params) if not eligible], + location=node.location, ) @dataclasses.dataclass -class InlineLambdas(NodeTranslator): +class InlineLambdas(PreserveLocationVisitor, NodeTranslator): """Inline lambda calls by substituting every argument by its value.""" PRESERVED_ANNEX_ATTRS = ("type",) diff --git a/src/gt4py/next/iterator/transforms/inline_lifts.py b/src/gt4py/next/iterator/transforms/inline_lifts.py index d7d8e5e612..d6146d9fc8 100644 --- a/src/gt4py/next/iterator/transforms/inline_lifts.py +++ b/src/gt4py/next/iterator/transforms/inline_lifts.py @@ -103,14 +103,18 @@ def _transform_and_extract_lift_args( extracted_args[new_symbol] = arg new_args.append(ir.SymRef(id=new_symbol.id)) - return (im.lift(inner_stencil)(*new_args), extracted_args) + itir_node = im.lift(inner_stencil)(*new_args) + itir_node.location = node.location + return (itir_node, extracted_args) # TODO(tehrengruber): This pass has many different options that should be written as dedicated # passes. Due to a lack of infrastructure (e.g. no pass manager) to combine passes without # performance degradation we leave everything as one pass for now. @dataclasses.dataclass -class InlineLifts(traits.VisitorWithSymbolTableTrait, NodeTranslator): +class InlineLifts( + traits.PreserveLocationVisitor, traits.VisitorWithSymbolTableTrait, NodeTranslator +): """Inline lifted function calls. Optionally a predicate function can be passed which can enable or disable inlining of specific diff --git a/src/gt4py/next/iterator/transforms/merge_let.py b/src/gt4py/next/iterator/transforms/merge_let.py index 7426617ac8..bcfc6b2a17 100644 --- a/src/gt4py/next/iterator/transforms/merge_let.py +++ b/src/gt4py/next/iterator/transforms/merge_let.py @@ -17,7 +17,7 @@ from gt4py.next.iterator.transforms.symbol_ref_utils import CountSymbolRefs -class MergeLet(eve.NodeTranslator): +class MergeLet(eve.PreserveLocationVisitor, eve.NodeTranslator): """ Merge let-like statements. diff --git a/src/gt4py/next/iterator/transforms/normalize_shifts.py b/src/gt4py/next/iterator/transforms/normalize_shifts.py index efc9064612..c70dc1ccd1 100644 --- a/src/gt4py/next/iterator/transforms/normalize_shifts.py +++ b/src/gt4py/next/iterator/transforms/normalize_shifts.py @@ -12,11 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir -class NormalizeShifts(NodeTranslator): +class NormalizeShifts(PreserveLocationVisitor, NodeTranslator): def visit_FunCall(self, node: ir.FunCall): node = self.generic_visit(node) if ( diff --git a/src/gt4py/next/iterator/transforms/pass_manager.py b/src/gt4py/next/iterator/transforms/pass_manager.py index 2e05391634..08897861c2 100644 --- a/src/gt4py/next/iterator/transforms/pass_manager.py +++ b/src/gt4py/next/iterator/transforms/pass_manager.py @@ -13,6 +13,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import enum +from typing import Optional from gt4py.next.iterator import ir from gt4py.next.iterator.transforms import simple_inline_heuristic @@ -81,6 +82,7 @@ def apply_common_transforms( common_subexpression_elimination=True, force_inline_lambda_args=False, unconditionally_collapse_tuples=False, + symbolic_domain_sizes: Optional[dict[str, str]] = None, ): if lift_mode is None: lift_mode = LiftMode.FORCE_INLINE @@ -147,7 +149,9 @@ def apply_common_transforms( if lift_mode != LiftMode.FORCE_INLINE: assert offset_provider is not None - ir = CreateGlobalTmps().visit(ir, offset_provider=offset_provider) + ir = CreateGlobalTmps().visit( + ir, offset_provider=offset_provider, symbolic_sizes=symbolic_domain_sizes + ) ir = InlineLifts().visit(ir) # If after creating temporaries, the scan is not at the top, we inline. # The following example doesn't have a lift around the shift, i.e. temporary pass will not extract it. diff --git a/src/gt4py/next/iterator/transforms/power_unrolling.py b/src/gt4py/next/iterator/transforms/power_unrolling.py new file mode 100644 index 0000000000..ac71f2747d --- /dev/null +++ b/src/gt4py/next/iterator/transforms/power_unrolling.py @@ -0,0 +1,84 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later +import dataclasses +import math + +from gt4py.eve import NodeTranslator +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.inline_lambdas import InlineLambdas + + +def _is_power_call( + node: ir.FunCall, +) -> bool: + """Match expressions of the form `power(base, integral_literal)`.""" + return ( + isinstance(node.fun, ir.SymRef) + and node.fun.id == "power" + and isinstance(node.args[1], ir.Literal) + and float(node.args[1].value) == int(node.args[1].value) + and node.args[1].value >= im.literal_from_value(0).value + ) + + +def _compute_integer_power_of_two(exp: int) -> int: + return math.floor(math.log2(exp)) + + +@dataclasses.dataclass +class PowerUnrolling(NodeTranslator): + max_unroll: int + + @classmethod + def apply(cls, node: ir.Node, max_unroll: int = 5) -> ir.Node: + return cls(max_unroll=max_unroll).visit(node) + + def visit_FunCall(self, node: ir.FunCall): + new_node = self.generic_visit(node) + + if _is_power_call(new_node): + assert len(new_node.args) == 2 + # Check if unroll should be performed or if exponent is too large + base, exponent = new_node.args[0], int(new_node.args[1].value) + if 1 <= exponent <= self.max_unroll: + # Calculate and store powers of two of the base as long as they are smaller than the exponent. + # Do the same (using the stored values) with the remainder and multiply computed values. + pow_cur = _compute_integer_power_of_two(exponent) + pow_max = pow_cur + remainder = exponent + + # Build target expression + ret = im.ref(f"power_{2 ** pow_max}") + remainder -= 2**pow_cur + while remainder > 0: + pow_cur = _compute_integer_power_of_two(remainder) + remainder -= 2**pow_cur + + ret = im.multiplies_(ret, f"power_{2 ** pow_cur}") + + # Nest target expression to avoid multiple redundant evaluations + for i in range(pow_max, 0, -1): + ret = im.let( + f"power_{2 ** i}", + im.multiplies_(f"power_{2**(i-1)}", f"power_{2**(i-1)}"), + )(ret) + ret = im.let("power_1", base)(ret) + + # Simplify expression in case of SymRef by resolving let statements + if isinstance(base, ir.SymRef): + return InlineLambdas.apply(ret, opcount_preserving=True) + else: + return ret + return new_node diff --git a/src/gt4py/next/iterator/transforms/propagate_deref.py b/src/gt4py/next/iterator/transforms/propagate_deref.py index 54bdafcda8..783e54ede0 100644 --- a/src/gt4py/next/iterator/transforms/propagate_deref.py +++ b/src/gt4py/next/iterator/transforms/propagate_deref.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.eve.pattern_matching import ObjectPattern as P from gt4py.next.iterator import ir @@ -22,7 +22,7 @@ # `(λ(...) → plus(multiplies(...), ...))(...)`. -class PropagateDeref(NodeTranslator): +class PropagateDeref(PreserveLocationVisitor, NodeTranslator): @classmethod def apply(cls, node: ir.Node): """ diff --git a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py index 7fd3c50c6e..1e637a0bfb 100644 --- a/src/gt4py/next/iterator/transforms/prune_closure_inputs.py +++ b/src/gt4py/next/iterator/transforms/prune_closure_inputs.py @@ -12,11 +12,11 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir -class PruneClosureInputs(NodeTranslator): +class PruneClosureInputs(PreserveLocationVisitor, NodeTranslator): """Removes all unused input arguments from a stencil closure.""" def visit_StencilClosure(self, node: ir.StencilClosure) -> ir.StencilClosure: diff --git a/src/gt4py/next/iterator/transforms/remap_symbols.py b/src/gt4py/next/iterator/transforms/remap_symbols.py index cdf3d76173..431dd6cd7a 100644 --- a/src/gt4py/next/iterator/transforms/remap_symbols.py +++ b/src/gt4py/next/iterator/transforms/remap_symbols.py @@ -14,11 +14,11 @@ from typing import Any, Dict, Optional, Set -from gt4py.eve import NodeTranslator, SymbolTableTrait +from gt4py.eve import NodeTranslator, PreserveLocationVisitor, SymbolTableTrait from gt4py.next.iterator import ir -class RemapSymbolRefs(NodeTranslator): +class RemapSymbolRefs(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type",) def visit_SymRef(self, node: ir.SymRef, *, symbol_map: Dict[str, ir.Node]): @@ -39,7 +39,7 @@ def generic_visit(self, node: ir.Node, **kwargs: Any): # type: ignore[override] return super().generic_visit(node, **kwargs) -class RenameSymbols(NodeTranslator): +class RenameSymbols(PreserveLocationVisitor, NodeTranslator): PRESERVED_ANNEX_ATTRS = ("type",) def visit_Sym( diff --git a/src/gt4py/next/iterator/transforms/scan_eta_reduction.py b/src/gt4py/next/iterator/transforms/scan_eta_reduction.py index 3266c25c4b..d93b4242ab 100644 --- a/src/gt4py/next/iterator/transforms/scan_eta_reduction.py +++ b/src/gt4py/next/iterator/transforms/scan_eta_reduction.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir @@ -24,7 +24,7 @@ def _is_scan(node: ir.Node): ) -class ScanEtaReduction(NodeTranslator): +class ScanEtaReduction(PreserveLocationVisitor, NodeTranslator): """Applies eta-reduction-like transformation involving scans. Simplifies `λ(x, y) → scan(λ(state, param_y, param_x) → ..., ...)(y, x)` to `scan(λ(state, param_x, param_y) → ..., ...)`. @@ -55,9 +55,8 @@ def visit_Lambda(self, node: ir.Lambda) -> ir.Node: original_scanpass.params[i + 1] for i in new_scanpass_params_idx ] new_scanpass = ir.Lambda(params=new_scanpass_params, expr=original_scanpass.expr) - result = ir.FunCall( + return ir.FunCall( fun=ir.SymRef(id="scan"), args=[new_scanpass, *node.expr.fun.args[1:]] ) - return result return self.generic_visit(node) diff --git a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py index 1c587fb9d6..05d137e8c4 100644 --- a/src/gt4py/next/iterator/transforms/symbol_ref_utils.py +++ b/src/gt4py/next/iterator/transforms/symbol_ref_utils.py @@ -21,7 +21,7 @@ @dataclasses.dataclass -class CountSymbolRefs(eve.NodeVisitor): +class CountSymbolRefs(eve.PreserveLocationVisitor, eve.NodeVisitor): ref_counts: dict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int)) @classmethod diff --git a/src/gt4py/next/iterator/transforms/trace_shifts.py b/src/gt4py/next/iterator/transforms/trace_shifts.py index 5c607e7df1..082987ac96 100644 --- a/src/gt4py/next/iterator/transforms/trace_shifts.py +++ b/src/gt4py/next/iterator/transforms/trace_shifts.py @@ -16,7 +16,7 @@ from collections.abc import Callable from typing import Any, Final, Iterable, Literal -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.next.iterator import ir @@ -235,7 +235,7 @@ def _tuple_get(index, tuple_val): @dataclasses.dataclass(frozen=True) -class TraceShifts(NodeTranslator): +class TraceShifts(PreserveLocationVisitor, NodeTranslator): shift_recorder: ShiftRecorder = dataclasses.field(default_factory=ShiftRecorder) def visit_Literal(self, node: ir.SymRef, *, ctx: dict[str, Any]) -> Any: diff --git a/src/gt4py/next/iterator/transforms/unroll_reduce.py b/src/gt4py/next/iterator/transforms/unroll_reduce.py index 861052bb25..3c878b2b00 100644 --- a/src/gt4py/next/iterator/transforms/unroll_reduce.py +++ b/src/gt4py/next/iterator/transforms/unroll_reduce.py @@ -16,7 +16,7 @@ from collections.abc import Iterable, Iterator from typing import TypeGuard -from gt4py.eve import NodeTranslator +from gt4py.eve import NodeTranslator, PreserveLocationVisitor from gt4py.eve.utils import UIDGenerator from gt4py.next import common from gt4py.next.iterator import ir as itir @@ -100,31 +100,36 @@ def _get_connectivity( def _make_shift(offsets: list[itir.Expr], iterator: itir.Expr) -> itir.FunCall: return itir.FunCall( - fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=offsets), args=[iterator] + fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=offsets), + args=[iterator], + location=iterator.location, ) def _make_deref(iterator: itir.Expr) -> itir.FunCall: - return itir.FunCall(fun=itir.SymRef(id="deref"), args=[iterator]) + return itir.FunCall(fun=itir.SymRef(id="deref"), args=[iterator], location=iterator.location) def _make_can_deref(iterator: itir.Expr) -> itir.FunCall: - return itir.FunCall(fun=itir.SymRef(id="can_deref"), args=[iterator]) + return itir.FunCall( + fun=itir.SymRef(id="can_deref"), args=[iterator], location=iterator.location + ) def _make_if(cond: itir.Expr, true_expr: itir.Expr, false_expr: itir.Expr) -> itir.FunCall: return itir.FunCall( fun=itir.SymRef(id="if_"), args=[cond, true_expr, false_expr], + location=cond.location, ) def _make_list_get(offset: itir.Expr, expr: itir.Expr) -> itir.FunCall: - return itir.FunCall(fun=itir.SymRef(id="list_get"), args=[offset, expr]) + return itir.FunCall(fun=itir.SymRef(id="list_get"), args=[offset, expr], location=expr.location) @dataclasses.dataclass(frozen=True) -class UnrollReduce(NodeTranslator): +class UnrollReduce(PreserveLocationVisitor, NodeTranslator): # we use one UID generator per instance such that the generated ids are # stable across multiple runs (required for caching to properly work) uids: UIDGenerator = dataclasses.field(init=False, repr=False, default_factory=UIDGenerator) diff --git a/src/gt4py/next/iterator/type_inference.py b/src/gt4py/next/iterator/type_inference.py index 68627cfd89..d65f67b266 100644 --- a/src/gt4py/next/iterator/type_inference.py +++ b/src/gt4py/next/iterator/type_inference.py @@ -567,9 +567,7 @@ def _infer_shift_location_types(shift_args, offset_provider, constraints): axis = offset_provider[offset] if isinstance(axis, gtx.Dimension): continue # Cartesian shifts don’t change the location type - elif isinstance( - axis, (gtx.NeighborTableOffsetProvider, gtx.StridedNeighborOffsetProvider) - ): + elif isinstance(axis, Connectivity): assert ( axis.origin_axis.kind == axis.neighbor_axis.kind @@ -964,7 +962,7 @@ def visit_FencilDefinition( def _save_types_to_annex(node: ir.Node, types: dict[int, Type]) -> None: for child_node in node.pre_walk_values().if_isinstance(*TYPED_IR_NODES): try: - child_node.annex.type = types[id(child_node)] # type: ignore[attr-defined] + child_node.annex.type = types[id(child_node)] except KeyError: if not ( isinstance(child_node, ir.SymRef) diff --git a/src/gt4py/next/otf/workflow.py b/src/gt4py/next/otf/workflow.py index ed8b768972..3a82f9c738 100644 --- a/src/gt4py/next/otf/workflow.py +++ b/src/gt4py/next/otf/workflow.py @@ -82,7 +82,7 @@ def replace(self, **kwargs: Any) -> Self: if not dataclasses.is_dataclass(self): raise TypeError(f"'{self.__class__}' is not a dataclass.") assert not isinstance(self, type) - return dataclasses.replace(self, **kwargs) # type: ignore[misc] # `self` is guaranteed to be a dataclass (is_dataclass) should be a `TypeGuard`? + return dataclasses.replace(self, **kwargs) class ChainableWorkflowMixin(Workflow[StartT, EndT]): diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py deleted file mode 100644 index 4183f52550..0000000000 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_backend.py +++ /dev/null @@ -1,77 +0,0 @@ -# GT4Py - GridTools Framework -# -# Copyright (c) 2014-2023, ETH Zurich -# All rights reserved. -# -# This file is part of the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -from typing import Any - -import gt4py.next.iterator.ir as itir -from gt4py.eve import codegen -from gt4py.eve.exceptions import EveValueError -from gt4py.next.iterator.transforms.pass_manager import apply_common_transforms -from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen -from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering -from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import GTFN_lowering - - -def _lower( - program: itir.FencilDefinition, enable_itir_transforms: bool, do_unroll: bool, **kwargs: Any -): - offset_provider = kwargs.get("offset_provider") - assert isinstance(offset_provider, dict) - if enable_itir_transforms: - program = apply_common_transforms( - program, - lift_mode=kwargs.get("lift_mode"), - offset_provider=offset_provider, - unroll_reduce=do_unroll, - unconditionally_collapse_tuples=True, # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements - ) - gtfn_ir = GTFN_lowering.apply( - program, - offset_provider=offset_provider, - column_axis=kwargs.get("column_axis"), - ) - return gtfn_ir - - -def generate( - program: itir.FencilDefinition, enable_itir_transforms: bool = True, **kwargs: Any -) -> str: - if kwargs.get("imperative", False): - try: - gtfn_ir = _lower( - program=program, - enable_itir_transforms=enable_itir_transforms, - do_unroll=False, - **kwargs, - ) - except EveValueError: - # if we don't unroll, there may be lifts left in the itir which can't be lowered to - # gtfn. In this case, just retry with unrolled reductions. - gtfn_ir = _lower( - program=program, - enable_itir_transforms=enable_itir_transforms, - do_unroll=True, - **kwargs, - ) - gtfn_im_ir = GTFN_IM_lowering().visit(node=gtfn_ir, **kwargs) - generated_code = GTFNIMCodegen.apply(gtfn_im_ir, **kwargs) - else: - gtfn_ir = _lower( - program=program, - enable_itir_transforms=enable_itir_transforms, - do_unroll=True, - **kwargs, - ) - generated_code = GTFNCodegen.apply(gtfn_ir, **kwargs) - return codegen.format_source("cpp", generated_code, style="LLVM") diff --git a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py index 4abdaa6eea..718fef72af 100644 --- a/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py +++ b/src/gt4py/next/program_processors/codegens/gtfn/gtfn_module.py @@ -15,21 +15,24 @@ from __future__ import annotations import dataclasses +import functools import warnings from typing import Any, Final, Optional import numpy as np from gt4py._core import definitions as core_defs -from gt4py.eve import trees, utils +from gt4py.eve import codegen, trees, utils from gt4py.next import common from gt4py.next.common import Connectivity, Dimension from gt4py.next.ffront import fbuiltins from gt4py.next.iterator import ir as itir -from gt4py.next.iterator.transforms import LiftMode +from gt4py.next.iterator.transforms import LiftMode, pass_manager from gt4py.next.otf import languages, stages, step_types, workflow from gt4py.next.otf.binding import cpp_interface, interface -from gt4py.next.program_processors.codegens.gtfn import gtfn_backend +from gt4py.next.program_processors.codegens.gtfn.codegen import GTFNCodegen, GTFNIMCodegen +from gt4py.next.program_processors.codegens.gtfn.gtfn_ir_to_gtfn_im_ir import GTFN_IM_lowering +from gt4py.next.program_processors.codegens.gtfn.itir_to_gtfn_ir import GTFN_lowering from gt4py.next.type_system import type_specifications as ts, type_translation @@ -54,6 +57,7 @@ class GTFNTranslationStep( use_imperative_backend: bool = False lift_mode: Optional[LiftMode] = None device_type: core_defs.DeviceType = core_defs.DeviceType.CPU + symbolic_domain_sizes: Optional[dict[str, str]] = None def _default_language_settings(self) -> languages.LanguageWithHeaderFilesSettings: match self.device_type: @@ -171,6 +175,70 @@ def _process_connectivity_args( return parameters, arg_exprs + def _preprocess_program( + self, + program: itir.FencilDefinition, + offset_provider: dict[str, Connectivity | Dimension], + runtime_lift_mode: Optional[LiftMode] = None, + ) -> itir.FencilDefinition: + # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added + # to the interface of all (or at least all of concern) backends, but instead should be + # configured in the backend itself (like it is here), until then we respect the argument + # here and warn the user if it differs from the one configured. + lift_mode = runtime_lift_mode or self.lift_mode + if lift_mode != self.lift_mode: + warnings.warn( + f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but " + f"overriden to be {str(runtime_lift_mode)} at runtime." + ) + + if not self.enable_itir_transforms: + return program + + apply_common_transforms = functools.partial( + pass_manager.apply_common_transforms, + lift_mode=lift_mode, + offset_provider=offset_provider, + # sid::composite (via hymap) supports assigning from tuple with more elements to tuple with fewer elements + unconditionally_collapse_tuples=True, + symbolic_domain_sizes=self.symbolic_domain_sizes, + ) + + new_program = apply_common_transforms( + program, unroll_reduce=not self.use_imperative_backend + ) + + if self.use_imperative_backend and any( + node.id == "neighbors" + for node in new_program.pre_walk_values().if_isinstance(itir.SymRef) + ): + # if we don't unroll, there may be lifts left in the itir which can't be lowered to + # gtfn. In this case, just retry with unrolled reductions. + new_program = apply_common_transforms(program, unroll_reduce=True) + + return new_program + + def generate_stencil_source( + self, + program: itir.FencilDefinition, + offset_provider: dict[str, Connectivity | Dimension], + column_axis: Optional[common.Dimension], + runtime_lift_mode: Optional[LiftMode] = None, + ) -> str: + new_program = self._preprocess_program(program, offset_provider, runtime_lift_mode) + gtfn_ir = GTFN_lowering.apply( + new_program, + offset_provider=offset_provider, + column_axis=column_axis, + ) + + if self.use_imperative_backend: + gtfn_im_ir = GTFN_IM_lowering().visit(node=gtfn_ir) + generated_code = GTFNIMCodegen.apply(gtfn_im_ir) + else: + generated_code = GTFNCodegen.apply(gtfn_ir) + return codegen.format_source("cpp", generated_code, style="LLVM") + def __call__( self, inp: stages.ProgramCall, @@ -190,18 +258,6 @@ def __call__( inp.kwargs["offset_provider"] ) - # TODO(tehrengruber): Remove `lift_mode` from call interface. It has been implicitly added - # to the interface of all (or at least all of concern) backends, but instead should be - # configured in the backend itself (like it is here), until then we respect the argument - # here and warn the user if it differs from the one configured. - runtime_lift_mode = inp.kwargs.pop("lift_mode", None) - lift_mode = runtime_lift_mode or self.lift_mode - if runtime_lift_mode != self.lift_mode: - warnings.warn( - f"GTFN Backend was configured for LiftMode `{str(self.lift_mode)}`, but " - "overriden to be {str(runtime_lift_mode)} at runtime." - ) - # combine into a format that is aligned with what the backend expects parameters: list[interface.Parameter] = regular_parameters + connectivity_parameters backend_arg = self._backend_type() @@ -213,12 +269,11 @@ def __call__( f"{', '.join(connectivity_args_expr)})({', '.join(args_expr)});" ) decl_src = cpp_interface.render_function_declaration(function, body=decl_body) - stencil_src = gtfn_backend.generate( + stencil_src = self.generate_stencil_source( program, - enable_itir_transforms=self.enable_itir_transforms, - lift_mode=lift_mode, - imperative=self.use_imperative_backend, - **inp.kwargs, + inp.kwargs["offset_provider"], + inp.kwargs.get("column_axis", None), + inp.kwargs.get("lift_mode", None), ) source_code = interface.format_source( self._language_settings(), diff --git a/src/gt4py/next/program_processors/formatters/gtfn.py b/src/gt4py/next/program_processors/formatters/gtfn.py index f9fa154641..27dec77ed1 100644 --- a/src/gt4py/next/program_processors/formatters/gtfn.py +++ b/src/gt4py/next/program_processors/formatters/gtfn.py @@ -15,10 +15,19 @@ from typing import Any from gt4py.next.iterator import ir as itir -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.codegens.gtfn.gtfn_module import GTFNTranslationStep from gt4py.next.program_processors.processor_interface import program_formatter +from gt4py.next.program_processors.runners.gtfn import gtfn_executor @program_formatter def format_cpp(program: itir.FencilDefinition, *args: Any, **kwargs: Any) -> str: - return generate(program, **kwargs) + # TODO(tehrengruber): This is a little ugly. Revisit. + gtfn_translation = gtfn_executor.otf_workflow.translation + assert isinstance(gtfn_translation, GTFNTranslationStep) + return gtfn_translation.generate_stencil_source( + program, + offset_provider=kwargs.get("offset_provider", None), + column_axis=kwargs.get("column_axis", None), + runtime_lift_mode=kwargs.get("lift_mode", None), + ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py index 7fd4794e57..6a8b9bc9c6 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/__init__.py @@ -13,11 +13,14 @@ # SPDX-License-Identifier: GPL-3.0-or-later import hashlib import warnings +from inspect import currentframe, getframeinfo +from pathlib import Path from typing import Any, Mapping, Optional, Sequence import dace import numpy as np from dace.codegen.compiled_sdfg import CompiledSDFG +from dace.sdfg import utils as sdutils from dace.transformation.auto import auto_optimize as autoopt import gt4py.next.allocators as next_allocators @@ -25,7 +28,7 @@ import gt4py.next.program_processors.otf_compile_executor as otf_exec import gt4py.next.program_processors.processor_interface as ppi from gt4py.next import common -from gt4py.next.iterator import embedded as itir_embedded, transforms as itir_transforms +from gt4py.next.iterator import transforms as itir_transforms from gt4py.next.otf.compilation import cache as compilation_cache from gt4py.next.type_system import type_specifications as ts, type_translation @@ -108,23 +111,29 @@ def _ensure_is_on_device( def get_connectivity_args( - neighbor_tables: Sequence[tuple[str, itir_embedded.NeighborTableOffsetProvider]], + neighbor_tables: Mapping[str, common.NeighborTable], device: dace.dtypes.DeviceType, ) -> dict[str, Any]: return { - connectivity_identifier(offset): _ensure_is_on_device(table.table, device) - for offset, table in neighbor_tables + connectivity_identifier(offset): _ensure_is_on_device(offset_provider.table, device) + for offset, offset_provider in neighbor_tables.items() } def get_shape_args( arrays: Mapping[str, dace.data.Array], args: Mapping[str, Any] ) -> Mapping[str, int]: - return { - str(sym): size - for name, value in args.items() - for sym, size in zip(arrays[name].shape, value.shape) - } + shape_args: dict[str, int] = {} + for name, value in args.items(): + for sym, size in zip(arrays[name].shape, value.shape): + if isinstance(sym, dace.symbol): + assert sym.name not in shape_args + shape_args[sym.name] = size + elif sym != size: + raise RuntimeError( + f"Expected shape {arrays[name].shape} for arg {name}, got {value.shape}." + ) + return shape_args def get_offset_args( @@ -157,34 +166,41 @@ def get_stride_args( return stride_args -_build_cache_cpu: dict[str, CompiledSDFG] = {} -_build_cache_gpu: dict[str, CompiledSDFG] = {} +_build_cache: dict[str, CompiledSDFG] = {} def get_cache_id( + build_type: str, + build_for_gpu: bool, program: itir.FencilDefinition, arg_types: Sequence[ts.TypeSpec], column_axis: Optional[common.Dimension], offset_provider: Mapping[str, Any], ) -> str: - max_neighbors = [ - (k, v.max_neighbors) - for k, v in offset_provider.items() - if isinstance( - v, - ( - itir_embedded.NeighborTableOffsetProvider, - itir_embedded.StridedNeighborOffsetProvider, - ), - ) + def offset_invariants(offset): + if isinstance(offset, common.Connectivity): + return ( + offset.origin_axis, + offset.neighbor_axis, + offset.has_skip_values, + offset.max_neighbors, + ) + if isinstance(offset, common.Dimension): + return (offset,) + return tuple() + + offset_cache_keys = [ + (name, *offset_invariants(offset)) for name, offset in offset_provider.items() ] cache_id_args = [ str(arg) for arg in ( + build_type, + build_for_gpu, program, *arg_types, column_axis, - *max_neighbors, + *offset_cache_keys, ) ] m = hashlib.sha256() @@ -260,20 +276,41 @@ def build_sdfg_from_itir( # visit ITIR and generate SDFG program = preprocess_program(program, offset_provider, lift_mode) - # TODO: According to Lex one should build the SDFG first in a general mannor. - # Generalisation to a particular device should happen only at the end. - sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis, on_gpu) - sdfg = sdfg_genenerator.visit(program) + sdfg_genenerator = ItirToSDFG(arg_types, offset_provider, column_axis) + sdfg: dace.SDFG = sdfg_genenerator.visit(program) + if sdfg is None: + raise RuntimeError(f"Visit failed for program {program.id}.") + + for nested_sdfg in sdfg.all_sdfgs_recursive(): + if not nested_sdfg.debuginfo: + _, frameinfo = warnings.warn( + f"{nested_sdfg} does not have debuginfo. Consider adding them in the corresponding nested sdfg." + ), getframeinfo( + currentframe() # type: ignore + ) + nested_sdfg.debuginfo = dace.dtypes.DebugInfo( + start_line=frameinfo.lineno, + end_line=frameinfo.lineno, + filename=frameinfo.filename, + ) + + # TODO(edopao): remove `inline_loop_blocks` when DaCe transformations support LoopRegion construct + sdutils.inline_loop_blocks(sdfg) + + # run DaCe transformations to simplify the SDFG sdfg.simplify() # run DaCe auto-optimization heuristics if auto_optimize: - # TODO: Investigate how symbol definitions improve autoopt transformations, - # in which case the cache table should take the symbols map into account. + # TODO: Investigate performance improvement from SDFG specialization with constant symbols, + # for array shape and strides, although this would imply JIT compilation. symbols: dict[str, int] = {} device = dace.DeviceType.GPU if on_gpu else dace.DeviceType.CPU sdfg = autoopt.auto_optimize(sdfg, device, symbols=symbols, use_gpu_storage=on_gpu) + if on_gpu: + sdfg.apply_gpu_transformations() + return sdfg @@ -283,30 +320,36 @@ def run_dace_iterator(program: itir.FencilDefinition, *args, **kwargs): compiler_args = kwargs.get("compiler_args", None) # `None` will take default. build_type = kwargs.get("build_type", "RelWithDebInfo") on_gpu = kwargs.get("on_gpu", False) - auto_optimize = kwargs.get("auto_optimize", False) + auto_optimize = kwargs.get("auto_optimize", True) lift_mode = kwargs.get("lift_mode", itir_transforms.LiftMode.FORCE_INLINE) # ITIR parameters column_axis = kwargs.get("column_axis", None) offset_provider = kwargs["offset_provider"] + # debug option to store SDFGs on filesystem and skip lowering ITIR to SDFG at each run + skip_itir_lowering_to_sdfg = kwargs.get("skip_itir_lowering_to_sdfg", False) arg_types = [type_translation.from_value(arg) for arg in args] - cache_id = get_cache_id(program, arg_types, column_axis, offset_provider) + cache_id = get_cache_id(build_type, on_gpu, program, arg_types, column_axis, offset_provider) if build_cache is not None and cache_id in build_cache: # retrieve SDFG program from build cache sdfg_program = build_cache[cache_id] sdfg = sdfg_program.sdfg - else: - sdfg = build_sdfg_from_itir( - program, - *args, - offset_provider=offset_provider, - auto_optimize=auto_optimize, - on_gpu=on_gpu, - column_axis=column_axis, - lift_mode=lift_mode, - ) + sdfg_filename = f"_dacegraphs/gt4py/{cache_id}/{program.id}.sdfg" + if not (skip_itir_lowering_to_sdfg and Path(sdfg_filename).exists()): + sdfg = build_sdfg_from_itir( + program, + *args, + offset_provider=offset_provider, + auto_optimize=auto_optimize, + on_gpu=on_gpu, + column_axis=column_axis, + lift_mode=lift_mode, + ) + sdfg.save(sdfg_filename) + else: + sdfg = dace.SDFG.from_file(sdfg_filename) sdfg.build_folder = compilation_cache._session_cache_dir_path / ".dacecache" with dace.config.temporary_config(): @@ -342,7 +385,7 @@ def _run_dace_cpu(program: itir.FencilDefinition, *args, **kwargs) -> None: program, *args, **kwargs, - build_cache=_build_cache_cpu, + build_cache=_build_cache, build_type=_build_type, compiler_args=compiler_args, on_gpu=False, @@ -361,7 +404,7 @@ def _run_dace_gpu(program: itir.FencilDefinition, *args, **kwargs) -> None: program, *args, **kwargs, - build_cache=_build_cache_gpu, + build_cache=_build_cache, build_type=_build_type, on_gpu=True, ) diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py index e3b5ddf2ac..8a7826dae4 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_sdfg.py @@ -11,14 +11,15 @@ # distribution for a copy of the license or check . # # SPDX-License-Identifier: GPL-3.0-or-later -from typing import Any, Optional, cast +from typing import Any, Mapping, Optional, Sequence, cast import dace +from dace.sdfg.state import LoopRegion import gt4py.eve as eve from gt4py.next import Dimension, DimensionKind, type_inference as next_typing +from gt4py.next.common import NeighborTable from gt4py.next.iterator import ir as itir, type_inference as itir_typing -from gt4py.next.iterator.embedded import NeighborTableOffsetProvider from gt4py.next.iterator.ir import Expr, FunCall, Literal, SymRef from gt4py.next.type_system import type_specifications as ts, type_translation @@ -38,17 +39,17 @@ connectivity_identifier, create_memlet_at, create_memlet_full, + dace_debuginfo, filter_neighbor_tables, flatten_list, get_sorted_dims, map_nested_sdfg_symbols, - new_array_symbols, unique_name, unique_var_name, ) -def get_scan_args(stencil: Expr) -> tuple[bool, Literal]: +def _get_scan_args(stencil: Expr) -> tuple[bool, Literal]: """ Parse stencil expression to extract the scan arguments. @@ -67,7 +68,7 @@ def get_scan_args(stencil: Expr) -> tuple[bool, Literal]: return is_forward.value == "True", init_carry -def get_scan_dim( +def _get_scan_dim( column_axis: Dimension, storage_types: dict[str, ts.TypeSpec], output: SymRef, @@ -92,6 +93,35 @@ def get_scan_dim( ) +def _make_array_shape_and_strides( + name: str, + dims: Sequence[Dimension], + neighbor_tables: Mapping[str, NeighborTable], + sort_dims: bool, +) -> tuple[list[dace.symbol], list[dace.symbol]]: + """ + Parse field dimensions and allocate symbols for array shape and strides. + + For local dimensions, the size is known at compile-time and therefore + the corresponding array shape dimension is set to an integer literal value. + + Returns + ------- + tuple(shape, strides) + The output tuple fields are arrays of dace symbolic expressions. + """ + dtype = dace.int64 + sorted_dims = [dim for _, dim in get_sorted_dims(dims)] if sort_dims else dims + shape = [ + neighbor_tables[dim.value].max_neighbors + if dim.kind == DimensionKind.LOCAL + else dace.symbol(unique_name(f"{name}_shape{i}"), dtype) + for i, dim in enumerate(sorted_dims) + ] + strides = [dace.symbol(unique_name(f"{name}_stride{i}"), dtype) for i, _ in enumerate(shape)] + return shape, strides + + class ItirToSDFG(eve.NodeVisitor): param_types: list[ts.TypeSpec] storage_types: dict[str, ts.TypeSpec] @@ -99,38 +129,38 @@ class ItirToSDFG(eve.NodeVisitor): offset_provider: dict[str, Any] node_types: dict[int, next_typing.Type] unique_id: int - use_gpu_storage: bool def __init__( self, param_types: list[ts.TypeSpec], - offset_provider: dict[str, NeighborTableOffsetProvider], + offset_provider: dict[str, NeighborTable], column_axis: Optional[Dimension] = None, - use_gpu_storage: bool = False, ): self.param_types = param_types self.column_axis = column_axis self.offset_provider = offset_provider self.storage_types = {} - self.use_gpu_storage = use_gpu_storage - def add_storage(self, sdfg: dace.SDFG, name: str, type_: ts.TypeSpec, has_offset: bool = True): + def add_storage( + self, + sdfg: dace.SDFG, + name: str, + type_: ts.TypeSpec, + neighbor_tables: Mapping[str, NeighborTable], + has_offset: bool = True, + sort_dimensions: bool = True, + ): if isinstance(type_, ts.FieldType): - shape, strides = new_array_symbols(name, len(type_.dims)) + shape, strides = _make_array_shape_and_strides( + name, type_.dims, neighbor_tables, sort_dimensions + ) offset = ( [dace.symbol(unique_name(f"{name}_offset{i}_")) for i in range(len(type_.dims))] if has_offset else None ) dtype = as_dace_type(type_.dtype) - storage = ( - dace.dtypes.StorageType.GPU_Global - if self.use_gpu_storage - else dace.dtypes.StorageType.Default - ) - sdfg.add_array( - name, shape=shape, strides=strides, offset=offset, dtype=dtype, storage=storage - ) + sdfg.add_array(name, shape=shape, strides=strides, offset=offset, dtype=dtype) elif isinstance(type_, ts.ScalarType): sdfg.add_symbol(name, as_dace_type(type_)) @@ -153,6 +183,7 @@ def get_output_nodes( def visit_FencilDefinition(self, node: itir.FencilDefinition): program_sdfg = dace.SDFG(name=node.id) + program_sdfg.debuginfo = dace_debuginfo(node) last_state = program_sdfg.add_state("program_entry", True) self.node_types = itir_typing.infer_all(node) @@ -161,14 +192,23 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): # Add program parameters as SDFG storages. for param, type_ in zip(node.params, self.param_types): - self.add_storage(program_sdfg, str(param.id), type_) + self.add_storage(program_sdfg, str(param.id), type_, neighbor_tables) # Add connectivities as SDFG storages. - for offset, table in neighbor_tables: - scalar_kind = type_translation.get_scalar_kind(table.table.dtype) - local_dim = Dimension("ElementDim", kind=DimensionKind.LOCAL) - type_ = ts.FieldType([table.origin_axis, local_dim], ts.ScalarType(scalar_kind)) - self.add_storage(program_sdfg, connectivity_identifier(offset), type_, has_offset=False) + for offset, offset_provider in neighbor_tables.items(): + scalar_kind = type_translation.get_scalar_kind(offset_provider.table.dtype) + local_dim = Dimension(offset, kind=DimensionKind.LOCAL) + type_ = ts.FieldType( + [offset_provider.origin_axis, local_dim], ts.ScalarType(scalar_kind) + ) + self.add_storage( + program_sdfg, + connectivity_identifier(offset), + type_, + neighbor_tables, + has_offset=False, + sort_dimensions=False, + ) # Create a nested SDFG for all stencil closures. for closure in node.closures: @@ -197,15 +237,16 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition): inputs=set(input_names), outputs=set(output_names), symbol_mapping=symbol_mapping, + debuginfo=closure_sdfg.debuginfo, ) # Add access nodes for the program parameters and connect them to the nested SDFG's inputs via edges. for inner_name, memlet in input_mapping.items(): - access_node = last_state.add_access(inner_name) + access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) last_state.add_edge(access_node, None, nsdfg_node, inner_name, memlet) for inner_name, memlet in output_mapping.items(): - access_node = last_state.add_access(inner_name) + access_node = last_state.add_access(inner_name, debuginfo=nsdfg_node.debuginfo) last_state.add_edge(nsdfg_node, inner_name, access_node, None, memlet) # Create the call signature for the SDFG. @@ -223,12 +264,13 @@ def visit_StencilClosure( # Create the closure's nested SDFG and single state. closure_sdfg = dace.SDFG(name="closure") + closure_sdfg.debuginfo = dace_debuginfo(node) closure_state = closure_sdfg.add_state("closure_entry") closure_init_state = closure_sdfg.add_state_before(closure_state, "closure_init", True) input_names = [str(inp.id) for inp in node.inputs] neighbor_tables = filter_neighbor_tables(self.offset_provider) - connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] + connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()] output_nodes = self.get_output_nodes(node, closure_sdfg, closure_state) output_names = [k for k, _ in output_nodes.items()] @@ -246,12 +288,11 @@ def visit_StencilClosure( shape=array_table[name].shape, strides=array_table[name].strides, dtype=array_table[name].dtype, - storage=array_table[name].storage, transient=True, ) closure_init_state.add_nedge( - closure_init_state.add_access(name), - closure_init_state.add_access(transient_name), + closure_init_state.add_access(name, debuginfo=closure_sdfg.debuginfo), + closure_init_state.add_access(transient_name, debuginfo=closure_sdfg.debuginfo), create_memlet_full(name, closure_sdfg.arrays[name]), ) input_transients_mapping[name] = transient_name @@ -261,7 +302,6 @@ def visit_StencilClosure( shape=array_table[name].shape, strides=array_table[name].strides, dtype=array_table[name].dtype, - storage=array_table[name].storage, ) else: assert isinstance(self.storage_types[name], ts.ScalarType) @@ -288,9 +328,15 @@ def visit_StencilClosure( out_name = unique_var_name() closure_sdfg.add_scalar(out_name, dtype, transient=True) out_tasklet = closure_init_state.add_tasklet( - f"get_{name}", {}, {"__result"}, f"__result = {name}" + f"get_{name}", + {}, + {"__result"}, + f"__result = {name}", + debuginfo=closure_sdfg.debuginfo, + ) + access = closure_init_state.add_access( + out_name, debuginfo=closure_sdfg.debuginfo ) - access = closure_init_state.add_access(out_name) value = ValueExpr(access, dtype) memlet = dace.Memlet.simple(out_name, "0") closure_init_state.add_edge(out_tasklet, "__result", access, None, memlet) @@ -368,19 +414,20 @@ def visit_StencilClosure( outputs=output_mapping, symbol_mapping=symbol_mapping, output_nodes=output_nodes, + debuginfo=nsdfg.debuginfo, ) access_nodes = {edge.data.data: edge.dst for edge in closure_state.out_edges(map_exit)} for edge in closure_state.in_edges(map_exit): memlet = edge.data if memlet.data not in output_connectors_mapping: continue - transient_access = closure_state.add_access(memlet.data) + transient_access = closure_state.add_access(memlet.data, debuginfo=nsdfg.debuginfo) closure_state.add_edge( nsdfg_node, edge.src_conn, transient_access, None, - dace.Memlet.simple(memlet.data, output_subset), + dace.Memlet.simple(memlet.data, output_subset, debuginfo=nsdfg.debuginfo), ) inner_memlet = dace.Memlet.simple( memlet.data, output_subset, other_subset_str=memlet.subset @@ -401,11 +448,11 @@ def _visit_scan_stencil_closure( output_name: str, ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], int]: # extract scan arguments - is_forward, init_carry_value = get_scan_args(node.stencil) + is_forward, init_carry_value = _get_scan_args(node.stencil) # select the scan dimension based on program argument for column axis assert self.column_axis assert isinstance(node.output, SymRef) - scan_dim, scan_dim_index, scan_dtype = get_scan_dim( + scan_dim, scan_dim_index, scan_dtype = _get_scan_dim( self.column_axis, self.storage_types, node.output, @@ -429,41 +476,56 @@ def _visit_scan_stencil_closure( # the scan operator is implemented as an SDFG to be nested in the closure SDFG scan_sdfg = dace.SDFG(name="scan") - - # create a state machine for lambda call over the scan dimension - start_state = scan_sdfg.add_state("start", True) - lambda_state = scan_sdfg.add_state("lambda_compute") - end_state = scan_sdfg.add_state("end") + scan_sdfg.debuginfo = dace_debuginfo(node) # the carry value of the scan operator exists only in the scope of the scan sdfg scan_carry_name = unique_var_name() scan_sdfg.add_scalar(scan_carry_name, dtype=as_dace_type(scan_dtype), transient=True) + # create a loop region for lambda call over the scan dimension + scan_loop_var = f"i_{scan_dim}" + if is_forward: + scan_loop = LoopRegion( + label="scan", + condition_expr=f"{scan_loop_var} < {scan_ub_str}", + loop_var=scan_loop_var, + initialize_expr=f"{scan_loop_var} = {scan_lb_str}", + update_expr=f"{scan_loop_var} = {scan_loop_var} + 1", + inverted=False, + ) + else: + scan_loop = LoopRegion( + label="scan", + condition_expr=f"{scan_loop_var} >= {scan_lb_str}", + loop_var=scan_loop_var, + initialize_expr=f"{scan_loop_var} = {scan_ub_str} - 1", + update_expr=f"{scan_loop_var} = {scan_loop_var} - 1", + inverted=False, + ) + scan_sdfg.add_node(scan_loop) + compute_state = scan_loop.add_state("lambda_compute", is_start_block=True) + update_state = scan_loop.add_state("lambda_update") + scan_loop.add_edge(compute_state, update_state, dace.InterstateEdge()) + + start_state = scan_sdfg.add_state("start", is_start_block=True) + scan_sdfg.add_edge(start_state, scan_loop, dace.InterstateEdge()) + # tasklet for initialization of carry carry_init_tasklet = start_state.add_tasklet( - "get_carry_init_value", {}, {"__result"}, f"__result = {init_carry_value}" + "get_carry_init_value", + {}, + {"__result"}, + f"__result = {init_carry_value}", + debuginfo=scan_sdfg.debuginfo, ) start_state.add_edge( carry_init_tasklet, "__result", - start_state.add_access(scan_carry_name), + start_state.add_access(scan_carry_name, debuginfo=scan_sdfg.debuginfo), None, dace.Memlet.simple(scan_carry_name, "0"), ) - # TODO(edopao): replace state machine with dace loop construct - scan_sdfg.add_loop( - start_state, - lambda_state, - end_state, - loop_var=f"i_{scan_dim}", - initialize_expr=f"{scan_lb_str}" if is_forward else f"{scan_ub_str} - 1", - condition_expr=f"i_{scan_dim} < {scan_ub_str}" - if is_forward - else f"i_{scan_dim} >= {scan_lb_str}", - increment_expr=f"i_{scan_dim} + 1" if is_forward else f"i_{scan_dim} - 1", - ) - # add storage to scan SDFG for inputs for name in [*input_names, *connectivity_names]: assert name not in scan_sdfg.arrays @@ -518,37 +580,36 @@ def _visit_scan_stencil_closure( array_mapping = {**input_mapping, **connectivity_mapping} symbol_mapping = map_nested_sdfg_symbols(scan_sdfg, lambda_context.body, array_mapping) - scan_inner_node = lambda_state.add_nested_sdfg( + scan_inner_node = compute_state.add_nested_sdfg( lambda_context.body, parent=scan_sdfg, inputs=set(lambda_input_names) | set(connectivity_names), outputs=set(lambda_output_names), symbol_mapping=symbol_mapping, + debuginfo=lambda_context.body.debuginfo, ) # connect scan SDFG to lambda inputs for name, memlet in array_mapping.items(): - access_node = lambda_state.add_access(name) - lambda_state.add_edge(access_node, None, scan_inner_node, name, memlet) + access_node = compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo) + compute_state.add_edge(access_node, None, scan_inner_node, name, memlet) output_names = [output_name] assert len(lambda_output_names) == 1 # connect lambda output to scan SDFG for name, connector in zip(output_names, lambda_output_names): - lambda_state.add_edge( + compute_state.add_edge( scan_inner_node, connector, - lambda_state.add_access(name), + compute_state.add_access(name, debuginfo=lambda_context.body.debuginfo), None, - dace.Memlet.simple(name, f"i_{scan_dim}"), + dace.Memlet.simple(name, scan_loop_var), ) - # add state to scan SDFG to update the carry value at each loop iteration - lambda_update_state = scan_sdfg.add_state_after(lambda_state, "lambda_update") - lambda_update_state.add_memlet_path( - lambda_update_state.add_access(output_name), - lambda_update_state.add_access(scan_carry_name), - memlet=dace.Memlet.simple(output_names[0], f"i_{scan_dim}", other_subset_str="0"), + update_state.add_nedge( + update_state.add_access(output_name, debuginfo=lambda_context.body.debuginfo), + update_state.add_access(scan_carry_name, debuginfo=lambda_context.body.debuginfo), + dace.Memlet.simple(output_names[0], scan_loop_var, other_subset_str="0"), ) return scan_sdfg, map_ranges, scan_dim_index @@ -563,7 +624,7 @@ def _visit_parallel_stencil_closure( ) -> tuple[dace.SDFG, dict[str, str | dace.subsets.Subset], list[str]]: neighbor_tables = filter_neighbor_tables(self.offset_provider) input_names = [str(inp.id) for inp in node.inputs] - conn_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] + connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()] # find the scan dimension, same as output dimension, and exclude it from the map domain map_ranges = {} @@ -576,7 +637,7 @@ def _visit_parallel_stencil_closure( index_domain = {dim: f"i_{dim}" for dim, _ in closure_domain} input_arrays = [(name, self.storage_types[name]) for name in input_names] - connectivity_arrays = [(array_table[name], name) for name in conn_names] + connectivity_arrays = [(array_table[name], name) for name in connectivity_names] context, results = closure_to_tasklet_sdfg( node, diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py index 4c202b1fe8..322a147382 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/itir_to_tasklet.py @@ -35,6 +35,7 @@ connectivity_identifier, create_memlet_at, create_memlet_full, + dace_debuginfo, filter_neighbor_tables, flatten_list, map_nested_sdfg_symbols, @@ -183,6 +184,7 @@ def __init__( def builtin_neighbors( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) offset_literal, data = node_args assert isinstance(offset_literal, itir.OffsetLiteral) offset_dim = offset_literal.value @@ -214,13 +216,14 @@ def builtin_neighbors( sdfg.add_array( result_name, dtype=iterator.dtype, shape=(offset_provider.max_neighbors,), transient=True ) - result_access = state.add_access(result_name) + result_access = state.add_access(result_name, debuginfo=di) # generate unique map index name to avoid conflict with other maps inside same state neighbor_index = unique_name("neighbor_idx") me, mx = state.add_map( f"{offset_dim}_neighbors_map", ndrange={neighbor_index: f"0:{offset_provider.max_neighbors}"}, + debuginfo=di, ) table_name = connectivity_identifier(offset_dim) table_subset = (f"0:{sdfg.arrays[table_name].shape[0]}", neighbor_index) @@ -230,17 +233,24 @@ def builtin_neighbors( code="__result = __table[__idx]", inputs={"__table", "__idx"}, outputs={"__result"}, + debuginfo=di, ) data_access_tasklet = state.add_tasklet( "data_access", - code=f"__result = __field[{field_index}] if {neighbor_check} else {transformer.context.reduce_identity.value}", + code=f"__result = __field[{field_index}]" + + ( + f" if {neighbor_check} else {transformer.context.reduce_identity.value}" + if offset_provider.has_skip_values + else "" + ), inputs={"__field", field_index}, outputs={"__result"}, + debuginfo=di, ) idx_name = unique_var_name() sdfg.add_scalar(idx_name, _INDEX_DTYPE, transient=True) state.add_memlet_path( - state.add_access(table_name), + state.add_access(table_name, debuginfo=di), me, shift_tasklet, memlet=create_memlet_at(table_name, table_subset), @@ -250,7 +260,7 @@ def builtin_neighbors( iterator.indices[shifted_dim], me, shift_tasklet, - memlet=dace.Memlet.simple(iterator.indices[shifted_dim].data, "0"), + memlet=dace.Memlet.simple(iterator.indices[shifted_dim].data, "0", debuginfo=di), dst_conn="__idx", ) state.add_edge(shift_tasklet, "__result", data_access_tasklet, field_index, dace.Memlet()) @@ -270,7 +280,7 @@ def builtin_neighbors( data_access_tasklet, mx, result_access, - memlet=dace.Memlet.simple(result_name, neighbor_index), + memlet=dace.Memlet.simple(result_name, neighbor_index, debuginfo=di), src_conn="__result", ) @@ -280,6 +290,7 @@ def builtin_neighbors( def builtin_can_deref( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) # first visit shift, to get set of indices for deref can_deref_callable = node_args[0] assert isinstance(can_deref_callable, itir.FunCall) @@ -296,13 +307,15 @@ def builtin_can_deref( # Returning a SymbolExpr would be preferable, but it requires update to type-checking. result_name = unique_var_name() transformer.context.body.add_scalar(result_name, dace.dtypes.bool, transient=True) - result_node = transformer.context.state.add_access(result_name) + result_node = transformer.context.state.add_access(result_name, debuginfo=di) transformer.context.state.add_edge( - transformer.context.state.add_tasklet("can_always_deref", {}, {"_out"}, "_out = True"), + transformer.context.state.add_tasklet( + "can_always_deref", {}, {"_out"}, "_out = True", debuginfo=di + ), "_out", result_node, None, - dace.Memlet.simple(result_name, "0"), + dace.Memlet.simple(result_name, "0", debuginfo=di), ) return [ValueExpr(result_node, dace.dtypes.bool)] @@ -313,13 +326,18 @@ def builtin_can_deref( # TODO(edopao): select-memlet could maybe allow to efficiently translate can_deref to predicative execution return transformer.add_expr_tasklet( - list(zip(args, internals)), expr_code, dace.dtypes.bool, "can_deref" + list(zip(args, internals)), + expr_code, + dace.dtypes.bool, + "can_deref", + dace_debuginfo=di, ) def builtin_if( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) args = transformer.visit(node_args) assert len(args) == 3 if_node = args[0][0] if isinstance(args[0], list) else args[0] @@ -346,7 +364,7 @@ def builtin_if( for arg in (if_node, a, b) ] expr = "({1} if {0} else {2})".format(*internals) - if_expr = transformer.add_expr_tasklet(expr_args, expr, a.dtype, "if") + if_expr = transformer.add_expr_tasklet(expr_args, expr, a.dtype, "if", dace_debuginfo=di) if_expr_values.append(if_expr[0]) return if_expr_values @@ -355,26 +373,35 @@ def builtin_if( def builtin_list_get( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) args = list(itertools.chain(*transformer.visit(node_args))) assert len(args) == 2 # index node - assert isinstance(args[0], (SymbolExpr, ValueExpr)) - # 1D-array node - assert isinstance(args[1], ValueExpr) - # source node should be a 1D array - assert len(transformer.context.body.arrays[args[1].value.data].shape) == 1 - - expr_args = [(arg, f"{arg.value.data}_v") for arg in args if not isinstance(arg, SymbolExpr)] - internals = [ - arg.value if isinstance(arg, SymbolExpr) else f"{arg.value.data}_v" for arg in args - ] - expr = f"{internals[1]}[{internals[0]}]" - return transformer.add_expr_tasklet(expr_args, expr, args[1].dtype, "list_get") + if isinstance(args[0], SymbolExpr): + index_value = args[0].value + result_name = unique_var_name() + transformer.context.body.add_scalar(result_name, args[1].dtype, transient=True) + result_node = transformer.context.state.add_access(result_name) + transformer.context.state.add_nedge( + args[1].value, + result_node, + dace.Memlet.simple(args[1].value.data, index_value), + ) + return [ValueExpr(result_node, args[1].dtype)] + + else: + expr_args = [(arg, f"{arg.value.data}_v") for arg in args] + internals = [f"{arg.value.data}_v" for arg in args] + expr = f"{internals[1]}[{internals[0]}]" + return transformer.add_expr_tasklet( + expr_args, expr, args[1].dtype, "list_get", dace_debuginfo=di + ) def builtin_cast( transformer: "PythonTaskletCodegen", node: itir.Expr, node_args: list[itir.Expr] ) -> list[ValueExpr]: + di = dace_debuginfo(node, transformer.context.body.debuginfo) args = transformer.visit(node_args[0]) internals = [f"{arg.value.data}_v" for arg in args] target_type = node_args[1] @@ -383,7 +410,13 @@ def builtin_cast( node_type = transformer.node_types[id(node)] assert isinstance(node_type, itir_typing.Val) type_ = itir_type_as_dace_type(node_type.dtype) - return transformer.add_expr_tasklet(list(zip(args, internals)), expr, type_, "cast") + return transformer.add_expr_tasklet( + list(zip(args, internals)), + expr, + type_, + "cast", + dace_debuginfo=di, + ) def builtin_make_tuple( @@ -443,7 +476,9 @@ def _add_symbol(self, param, arg): # create storage in lambda sdfg self._sdfg.add_scalar(param, dtype=arg.dtype) # update table of lambda symbol - self._symbol_map[param] = ValueExpr(self._state.add_access(param), arg.dtype) + self._symbol_map[param] = ValueExpr( + self._state.add_access(param, debuginfo=self._sdfg.debuginfo), arg.dtype + ) elif isinstance(arg, IteratorExpr): # create storage in lambda sdfg ndims = len(arg.dimensions) @@ -453,9 +488,10 @@ def _add_symbol(self, param, arg): for _, index_name in index_names.items(): self._sdfg.add_scalar(index_name, dtype=_INDEX_DTYPE) # update table of lambda symbol - field = self._state.add_access(param) + field = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) indices = { - dim: self._state.add_access(index_arg) for dim, index_arg in index_names.items() + dim: self._state.add_access(index_arg, debuginfo=self._sdfg.debuginfo) + for dim, index_arg in index_names.items() } self._symbol_map[param] = IteratorExpr(field, indices, arg.dtype, arg.dimensions) else: @@ -503,7 +539,7 @@ def visit_SymRef(self, node: itir.SymRef): if param not in _GENERAL_BUILTIN_MAPPING and param not in self._symbol_map: node_type = self._node_types[id(node)] assert isinstance(node_type, Val) - access_node = self._state.add_access(param) + access_node = self._state.add_access(param, debuginfo=self._sdfg.debuginfo) self._symbol_map[param] = ValueExpr( access_node, dtype=itir_type_as_dace_type(node_type.dtype) ) @@ -536,12 +572,13 @@ def visit_Lambda( ]: func_name = f"lambda_{abs(hash(node)):x}" neighbor_tables = ( - filter_neighbor_tables(self.offset_provider) if use_neighbor_tables else [] + filter_neighbor_tables(self.offset_provider) if use_neighbor_tables else {} ) - connectivity_names = [connectivity_identifier(offset) for offset, _ in neighbor_tables] + connectivity_names = [connectivity_identifier(offset) for offset in neighbor_tables.keys()] # Create the SDFG for the lambda's body lambda_sdfg = dace.SDFG(func_name) + lambda_sdfg.debuginfo = dace_debuginfo(node) lambda_state = lambda_sdfg.add_state(f"{func_name}_entry", True) lambda_symbols_pass = GatherLambdaSymbolsPass( @@ -586,11 +623,14 @@ def visit_Lambda( results: list[ValueExpr] = [] # We are flattening the returned list of value expressions because the multiple outputs of a lambda # should be a list of nodes without tuple structure. Ideally, an ITIR transformation could do this. + node.expr.location = node.location for expr in flatten_list(lambda_taskgen.visit(node.expr)): if isinstance(expr, ValueExpr): result_name = unique_var_name() lambda_sdfg.add_scalar(result_name, expr.dtype, transient=True) - result_access = lambda_state.add_access(result_name) + result_access = lambda_state.add_access( + result_name, debuginfo=lambda_sdfg.debuginfo + ) lambda_state.add_nedge( expr.value, result_access, @@ -599,7 +639,9 @@ def visit_Lambda( result = ValueExpr(value=result_access, dtype=expr.dtype) else: # Forwarding result through a tasklet needed because empty SDFG states don't properly forward connectors - result = lambda_taskgen.add_expr_tasklet([], expr.value, expr.dtype, "forward")[0] + result = lambda_taskgen.add_expr_tasklet( + [], expr.value, expr.dtype, "forward", dace_debuginfo=lambda_sdfg.debuginfo + )[0] lambda_sdfg.arrays[result.value.data].transient = False results.append(result) @@ -624,6 +666,7 @@ def visit_Literal(self, node: itir.Literal) -> list[SymbolExpr]: return [SymbolExpr(node.value, itir_type_as_dace_type(node_type.dtype))] def visit_FunCall(self, node: itir.FunCall) -> list[ValueExpr] | IteratorExpr: + node.fun.location = node.location if isinstance(node.fun, itir.SymRef) and node.fun.id == "deref": return self._visit_deref(node) if isinstance(node.fun, itir.FunCall) and isinstance(node.fun.fun, itir.SymRef): @@ -646,7 +689,7 @@ def _visit_call(self, node: itir.FunCall): args = self.visit(node.args) args = [arg if isinstance(arg, Sequence) else [arg] for arg in args] args = list(itertools.chain(*args)) - + node.fun.location = node.location func_context, func_inputs, results = self.visit(node.fun, args=args) nsdfg_inputs = {} @@ -667,8 +710,8 @@ def _visit_call(self, node: itir.FunCall): nsdfg_inputs[var] = create_memlet_full(store, self.context.body.arrays[store]) neighbor_tables = filter_neighbor_tables(self.offset_provider) - for conn, _ in neighbor_tables: - var = connectivity_identifier(conn) + for offset in neighbor_tables.keys(): + var = connectivity_identifier(offset) nsdfg_inputs[var] = create_memlet_full(var, self.context.body.arrays[var]) symbol_mapping = map_nested_sdfg_symbols(self.context.body, func_context.body, nsdfg_inputs) @@ -679,6 +722,7 @@ def _visit_call(self, node: itir.FunCall): inputs=set(nsdfg_inputs.keys()), outputs=set(r.value.data for r in results), symbol_mapping=symbol_mapping, + debuginfo=dace_debuginfo(node, func_context.body.debuginfo), ) for name, value in func_inputs: @@ -695,17 +739,17 @@ def _visit_call(self, node: itir.FunCall): store = value.indices[dim] idx_memlet = nsdfg_inputs[var] self.context.state.add_edge(store, None, nsdfg_node, var, idx_memlet) - for conn, _ in neighbor_tables: - var = connectivity_identifier(conn) + for offset in neighbor_tables.keys(): + var = connectivity_identifier(offset) memlet = nsdfg_inputs[var] - access = self.context.state.add_access(var) + access = self.context.state.add_access(var, debuginfo=nsdfg_node.debuginfo) self.context.state.add_edge(access, None, nsdfg_node, var, memlet) result_exprs = [] for result in results: name = unique_var_name() self.context.body.add_scalar(name, result.dtype, transient=True) - result_access = self.context.state.add_access(name) + result_access = self.context.state.add_access(name, debuginfo=nsdfg_node.debuginfo) result_exprs.append(ValueExpr(result_access, result.dtype)) memlet = create_memlet_full(name, self.context.body.arrays[name]) self.context.state.add_edge(nsdfg_node, result.value.data, result_access, None, memlet) @@ -713,6 +757,7 @@ def _visit_call(self, node: itir.FunCall): return result_exprs def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: + di = dace_debuginfo(node, self.context.body.debuginfo) iterator = self.visit(node.args[0]) if not isinstance(iterator, IteratorExpr): # already a list of ValueExpr @@ -727,7 +772,13 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: ] internals = [f"{arg.value.data}_v" for arg in args] expr = f"{internals[0]}[{', '.join(internals[1:])}]" - return self.add_expr_tasklet(list(zip(args, internals)), expr, iterator.dtype, "deref") + return self.add_expr_tasklet( + list(zip(args, internals)), + expr, + iterator.dtype, + "deref", + dace_debuginfo=di, + ) else: # Not all dimensions are included in the deref index list: @@ -741,7 +792,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: result_name = unique_var_name() self.context.body.add_array(result_name, result_shape, iterator.dtype, transient=True) result_array = self.context.body.arrays[result_name] - result_node = self.context.state.add_access(result_name) + result_node = self.context.state.add_access(result_name, debuginfo=di) deref_connectors = ["_inp"] + [ f"_i_{dim}" for dim in sorted_dims if dim in iterator.indices @@ -776,6 +827,7 @@ def _visit_deref(self, node: itir.FunCall) -> list[ValueExpr]: output_nodes={ result_name: result_node, }, + debuginfo=di, ) return [ValueExpr(result_node, iterator.dtype)] @@ -789,10 +841,13 @@ def _split_shift_args( def _make_shift_for_rest(self, rest, iterator): return itir.FunCall( - fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), args=[iterator] + fun=itir.FunCall(fun=itir.SymRef(id="shift"), args=rest), + args=[iterator], + location=iterator.location, ) def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: + di = dace_debuginfo(node, self.context.body.debuginfo) shift = node.fun assert isinstance(shift, itir.FunCall) tail, rest = self._split_shift_args(shift.args) @@ -815,7 +870,9 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: if isinstance(self.offset_provider[offset_dim], NeighborTableOffsetProvider): offset_provider = self.offset_provider[offset_dim] - connectivity = self.context.state.add_access(connectivity_identifier(offset_dim)) + connectivity = self.context.state.add_access( + connectivity_identifier(offset_dim), debuginfo=di + ) shifted_dim = offset_provider.origin_axis.value target_dim = offset_provider.neighbor_axis.value @@ -850,7 +907,7 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: expr = f"{internals[0]} + {internals[1]}" shifted_value = self.add_expr_tasklet( - list(zip(args, internals)), expr, offset_node.dtype, "shift" + list(zip(args, internals)), expr, offset_node.dtype, "shift", dace_debuginfo=di )[0].value shifted_index = {dim: value for dim, value in iterator.indices.items()} @@ -860,13 +917,14 @@ def _visit_shift(self, node: itir.FunCall) -> IteratorExpr | list[ValueExpr]: return IteratorExpr(iterator.field, shifted_index, iterator.dtype, iterator.dimensions) def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: + di = dace_debuginfo(node, self.context.body.debuginfo) offset = node.value assert isinstance(offset, int) offset_var = unique_var_name() self.context.body.add_scalar(offset_var, _INDEX_DTYPE, transient=True) - offset_node = self.context.state.add_access(offset_var) + offset_node = self.context.state.add_access(offset_var, debuginfo=di) tasklet_node = self.context.state.add_tasklet( - "get_offset", {}, {"__out"}, f"__out = {offset}" + "get_offset", {}, {"__out"}, f"__out = {offset}", debuginfo=di ) self.context.state.add_edge( tasklet_node, "__out", offset_node, None, dace.Memlet.simple(offset_var, "0") @@ -874,6 +932,7 @@ def visit_OffsetLiteral(self, node: itir.OffsetLiteral) -> list[ValueExpr]: return [ValueExpr(offset_node, self.context.body.arrays[offset_var].dtype)] def _visit_reduce(self, node: itir.FunCall): + di = dace_debuginfo(node, self.context.body.debuginfo) node_type = self.node_types[id(node)] assert isinstance(node_type, itir_typing.Val) reduce_dtype = itir_type_as_dace_type(node_type.dtype) @@ -930,7 +989,9 @@ def _visit_reduce(self, node: itir.FunCall): reduce_input_name, nreduce_shape, reduce_dtype, transient=True ) - lambda_node = itir.Lambda(expr=fun_node.expr.args[1], params=fun_node.params[1:]) + lambda_node = itir.Lambda( + expr=fun_node.expr.args[1], params=fun_node.params[1:], location=node.location + ) lambda_context, inner_inputs, inner_outputs = self.visit( lambda_node, args=args, use_neighbor_tables=False ) @@ -946,7 +1007,7 @@ def _visit_reduce(self, node: itir.FunCall): self.context.body, lambda_context.body, input_mapping ) - reduce_input_node = self.context.state.add_access(reduce_input_name) + reduce_input_node = self.context.state.add_access(reduce_input_name, debuginfo=di) nsdfg_node, map_entry, _ = add_mapped_nested_sdfg( self.context.state, @@ -957,6 +1018,7 @@ def _visit_reduce(self, node: itir.FunCall): symbol_mapping=symbol_mapping, input_nodes={arg.value.data: arg.value for arg in args}, output_nodes={reduce_input_name: reduce_input_node}, + debuginfo=di, ) reduce_input_desc = reduce_input_node.desc(self.context.body) @@ -964,7 +1026,7 @@ def _visit_reduce(self, node: itir.FunCall): result_name = unique_var_name() # we allocate an array instead of a scalar because the reduce library node is generic and expects an array node self.context.body.add_array(result_name, (1,), reduce_dtype, transient=True) - result_access = self.context.state.add_access(result_name) + result_access = self.context.state.add_access(result_name, debuginfo=di) reduce_wcr = "lambda x, y: " + _MATH_BUILTINS_MAPPING[str(op_name)].format("x", "y") reduce_node = self.context.state.add_reduce(reduce_wcr, None, reduce_identity) @@ -997,7 +1059,13 @@ def _visit_numeric_builtin(self, node: itir.FunCall) -> list[ValueExpr]: node_type = self.node_types[id(node)] assert isinstance(node_type, itir_typing.Val) type_ = itir_type_as_dace_type(node_type.dtype) - return self.add_expr_tasklet(expr_args, expr, type_, "numeric") + return self.add_expr_tasklet( + expr_args, + expr, + type_, + "numeric", + dace_debuginfo=dace_debuginfo(node, self.context.body.debuginfo), + ) def _visit_general_builtin(self, node: itir.FunCall) -> list[ValueExpr]: assert isinstance(node.fun, itir.SymRef) @@ -1005,17 +1073,24 @@ def _visit_general_builtin(self, node: itir.FunCall) -> list[ValueExpr]: return expr_func(self, node, node.args) def add_expr_tasklet( - self, args: list[tuple[ValueExpr, str]], expr: str, result_type: Any, name: str + self, + args: list[tuple[ValueExpr, str]], + expr: str, + result_type: Any, + name: str, + dace_debuginfo: Optional[dace.dtypes.DebugInfo] = None, ) -> list[ValueExpr]: + di = dace_debuginfo if dace_debuginfo else self.context.body.debuginfo result_name = unique_var_name() self.context.body.add_scalar(result_name, result_type, transient=True) - result_access = self.context.state.add_access(result_name) + result_access = self.context.state.add_access(result_name, debuginfo=di) expr_tasklet = self.context.state.add_tasklet( name=name, inputs={internal for _, internal in args}, outputs={"__result"}, code=f"__result = {expr}", + debuginfo=di, ) for arg, internal in args: @@ -1033,7 +1108,7 @@ def add_expr_tasklet( ) self.context.state.add_edge(arg.value, None, expr_tasklet, internal, memlet) - memlet = dace.Memlet.simple(result_access.data, "0") + memlet = dace.Memlet.simple(result_access.data, "0", debuginfo=di) self.context.state.add_edge(expr_tasklet, "__result", result_access, None, memlet) return [ValueExpr(result_access, result_type)] @@ -1052,6 +1127,7 @@ def closure_to_tasklet_sdfg( node_types: dict[int, next_typing.Type], ) -> tuple[Context, Sequence[ValueExpr]]: body = dace.SDFG("tasklet_toplevel") + body.debuginfo = dace_debuginfo(node) state = body.add_state("tasklet_toplevel_entry", True) symbol_map: dict[str, TaskletExpr] = {} @@ -1059,8 +1135,10 @@ def closure_to_tasklet_sdfg( for dim, idx in domain.items(): name = f"{idx}_value" body.add_scalar(name, dtype=_INDEX_DTYPE, transient=True) - tasklet = state.add_tasklet(f"get_{dim}", set(), {"value"}, f"value = {idx}") - access = state.add_access(name) + tasklet = state.add_tasklet( + f"get_{dim}", set(), {"value"}, f"value = {idx}", debuginfo=body.debuginfo + ) + access = state.add_access(name, debuginfo=body.debuginfo) idx_accesses[dim] = access state.add_edge(tasklet, "value", access, None, dace.Memlet.simple(name, "0")) for name, ty in inputs: @@ -1070,14 +1148,14 @@ def closure_to_tasklet_sdfg( dims = [dim.value for dim in ty.dims] dtype = as_dace_type(ty.dtype) body.add_array(name, shape=shape, strides=strides, dtype=dtype) - field = state.add_access(name) + field = state.add_access(name, debuginfo=body.debuginfo) indices = {dim: idx_accesses[dim] for dim in domain.keys()} symbol_map[name] = IteratorExpr(field, indices, dtype, dims) else: assert isinstance(ty, ts.ScalarType) dtype = as_dace_type(ty) body.add_scalar(name, dtype=dtype) - symbol_map[name] = ValueExpr(state.add_access(name), dtype) + symbol_map[name] = ValueExpr(state.add_access(name, debuginfo=body.debuginfo), dtype) for arr, name in connectivities: shape, strides = new_array_symbols(name, ndim=2) body.add_array(name, shape=shape, strides=strides, dtype=arr.dtype) @@ -1089,10 +1167,12 @@ def closure_to_tasklet_sdfg( if is_scan(node.stencil): stencil = cast(FunCall, node.stencil) assert isinstance(stencil.args[0], Lambda) - lambda_node = itir.Lambda(expr=stencil.args[0].expr, params=stencil.args[0].params) - fun_node = itir.FunCall(fun=lambda_node, args=args) + lambda_node = itir.Lambda( + expr=stencil.args[0].expr, params=stencil.args[0].params, location=node.location + ) + fun_node = itir.FunCall(fun=lambda_node, args=args, location=node.location) else: - fun_node = itir.FunCall(fun=node.stencil, args=args) + fun_node = itir.FunCall(fun=node.stencil, args=args, location=node.location) results = translator.visit(fun_node) for r in results: diff --git a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py index 55717326a3..a66fc36b1b 100644 --- a/src/gt4py/next/program_processors/runners/dace_iterator/utility.py +++ b/src/gt4py/next/program_processors/runners/dace_iterator/utility.py @@ -12,15 +12,31 @@ # # SPDX-License-Identifier: GPL-3.0-or-later import itertools -from typing import Any, Sequence +from typing import Any, Optional, Sequence import dace from gt4py.next import Dimension -from gt4py.next.iterator.embedded import NeighborTableOffsetProvider +from gt4py.next.common import NeighborTable +from gt4py.next.iterator.ir import Node from gt4py.next.type_system import type_specifications as ts +def dace_debuginfo( + node: Node, debuginfo: Optional[dace.dtypes.DebugInfo] = None +) -> Optional[dace.dtypes.DebugInfo]: + location = node.location + if location: + return dace.dtypes.DebugInfo( + start_line=location.line, + start_column=location.column if location.column else 0, + end_line=location.end_line if location.end_line else -1, + end_column=location.end_column if location.end_column else 0, + filename=location.filename, + ) + return debuginfo + + def as_dace_type(type_: ts.ScalarType): if type_.kind == ts.ScalarKind.BOOL: return dace.bool_ @@ -36,11 +52,11 @@ def as_dace_type(type_: ts.ScalarType): def filter_neighbor_tables(offset_provider: dict[str, Any]): - return [ - (offset, table) + return { + offset: table for offset, table in offset_provider.items() - if isinstance(table, NeighborTableOffsetProvider) - ] + if isinstance(table, NeighborTable) + } def connectivity_identifier(name: str): @@ -119,11 +135,13 @@ def add_mapped_nested_sdfg( if input_nodes is None: input_nodes = { - memlet.data: state.add_access(memlet.data) for name, memlet in inputs.items() + memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) + for name, memlet in inputs.items() } if output_nodes is None: output_nodes = { - memlet.data: state.add_access(memlet.data) for name, memlet in outputs.items() + memlet.data: state.add_access(memlet.data, debuginfo=debuginfo) + for name, memlet in outputs.items() } if not inputs: state.add_edge(map_entry, None, nsdfg_node, None, dace.Memlet()) diff --git a/src/gt4py/next/type_system/type_translation.py b/src/gt4py/next/type_system/type_translation.py index 88a8347fe4..12649bf620 100644 --- a/src/gt4py/next/type_system/type_translation.py +++ b/src/gt4py/next/type_system/type_translation.py @@ -184,7 +184,7 @@ def from_value(value: Any) -> ts.TypeSpec: elif isinstance(value, common.Dimension): symbol_type = ts.DimensionType(dim=value) elif common.is_field(value): - dims = list(value.__gt_dims__) + dims = list(value.domain.dims) dtype = from_type_hint(value.dtype.scalar_type) symbol_type = ts.FieldType(dims=dims, dtype=dtype) elif isinstance(value, tuple): diff --git a/src/gt4py/storage/cartesian/utils.py b/src/gt4py/storage/cartesian/utils.py index 0f7cf5d0ab..4e7ebb0c21 100644 --- a/src/gt4py/storage/cartesian/utils.py +++ b/src/gt4py/storage/cartesian/utils.py @@ -192,6 +192,10 @@ def cpu_copy(array: Union[np.ndarray, "cp.ndarray"]) -> np.ndarray: def asarray( array: FieldLike, *, device: Literal["cpu", "gpu", None] = None ) -> np.ndarray | cp.ndarray: + if hasattr(array, "ndarray"): + # extract the buffer from a gt4py.next.Field + # TODO(havogt): probably `Field` should provide the array interface methods when applicable + array = array.ndarray if device == "gpu" or (not device and hasattr(array, "__cuda_array_interface__")): return cp.asarray(array) if device == "cpu" or ( diff --git a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py index e580333bc8..8cfff12df4 100644 --- a/tests/cartesian_tests/unit_tests/test_gtc/test_common.py +++ b/tests/cartesian_tests/unit_tests/test_gtc/test_common.py @@ -312,7 +312,7 @@ def test_symbolref_validation_for_valid_tree(): SymbolTableRootNode( nodes=[SymbolChildNode(name="foo"), SymbolRefChildNode(name="foo")], ) - SymbolTableRootNode( + SymbolTableRootNode( # noqa: B018 nodes=[ SymbolChildNode(name="foo"), SymbolRefChildNode(name="foo"), diff --git a/tests/eve_tests/unit_tests/test_datamodels.py b/tests/eve_tests/unit_tests/test_datamodels.py index 8fa9e02cb6..0abb893dd4 100644 --- a/tests/eve_tests/unit_tests/test_datamodels.py +++ b/tests/eve_tests/unit_tests/test_datamodels.py @@ -15,6 +15,7 @@ from __future__ import annotations import enum +import numbers import types import typing from typing import Set # noqa: F401 # imported but unused (used in exec() context) @@ -1150,66 +1151,80 @@ class PartialGenericModel(datamodels.GenericDataModel, Generic[T]): with pytest.raises(TypeError, match="'PartialGenericModel__int.value'"): PartialGenericModel__int(value=["1"]) - def test_partial_specialization(self): - class PartialGenericModel(datamodels.GenericDataModel, Generic[T, U]): + def test_partial_concretization(self): + class BaseGenericModel(datamodels.GenericDataModel, Generic[T, U]): value: List[Tuple[T, U]] - PartialGenericModel(value=[]) - PartialGenericModel(value=[("value", 3)]) - PartialGenericModel(value=[(1, "value")]) - PartialGenericModel(value=[(-1.0, "value")]) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=1) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=(1, 2)) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=[()]) - with pytest.raises(TypeError, match="'PartialGenericModel.value'"): - PartialGenericModel(value=[(1,)]) + assert len(BaseGenericModel.__parameters__) == 2 + + BaseGenericModel(value=[]) + BaseGenericModel(value=[("value", 3)]) + BaseGenericModel(value=[(1, "value")]) + BaseGenericModel(value=[(-1.0, "value")]) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=1) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=(1, 2)) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=[()]) + with pytest.raises(TypeError, match="'BaseGenericModel.value'"): + BaseGenericModel(value=[(1,)]) + + PartiallyConcretizedGenericModel = BaseGenericModel[int, U] + + assert len(PartiallyConcretizedGenericModel.__parameters__) == 1 + + PartiallyConcretizedGenericModel(value=[]) + PartiallyConcretizedGenericModel(value=[(1, 2)]) + PartiallyConcretizedGenericModel(value=[(1, "value")]) + PartiallyConcretizedGenericModel(value=[(1, (11, 12))]) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=1) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=(1, 2)) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=[1.0]) + with pytest.raises(TypeError, match=".value'"): + PartiallyConcretizedGenericModel(value=["1"]) - print(f"{PartialGenericModel.__parameters__=}") - print(f"{hasattr(PartialGenericModel ,'__args__')=}") + FullyConcretizedGenericModel = PartiallyConcretizedGenericModel[str] - PartiallySpecializedGenericModel = PartialGenericModel[int, U] - print(f"{PartiallySpecializedGenericModel.__datamodel_fields__=}") - print(f"{PartiallySpecializedGenericModel.__parameters__=}") - print(f"{PartiallySpecializedGenericModel.__args__=}") + assert len(FullyConcretizedGenericModel.__parameters__) == 0 - PartiallySpecializedGenericModel(value=[]) - PartiallySpecializedGenericModel(value=[(1, 2)]) - PartiallySpecializedGenericModel(value=[(1, "value")]) - PartiallySpecializedGenericModel(value=[(1, (11, 12))]) + FullyConcretizedGenericModel(value=[]) + FullyConcretizedGenericModel(value=[(1, "value")]) + with pytest.raises(TypeError, match=".value'"): + FullyConcretizedGenericModel(value=1) + with pytest.raises(TypeError, match=".value'"): + FullyConcretizedGenericModel(value=(1, 2)) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=1) + FullyConcretizedGenericModel(value=[1.0]) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=(1, 2)) + FullyConcretizedGenericModel(value=["1"]) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=[1.0]) + FullyConcretizedGenericModel(value=1) with pytest.raises(TypeError, match=".value'"): - PartiallySpecializedGenericModel(value=["1"]) - - # TODO(egparedes): after fixing partial nested datamodel specialization - # noqa: e800 FullySpecializedGenericModel = PartiallySpecializedGenericModel[str] - # noqa: e800 print(f"{FullySpecializedGenericModel.__datamodel_fields__=}") - # noqa: e800 print(f"{FullySpecializedGenericModel.__parameters__=}") - # noqa: e800 print(f"{FullySpecializedGenericModel.__args__=}") - - # noqa: e800 FullySpecializedGenericModel(value=[]) - # noqa: e800 FullySpecializedGenericModel(value=[(1, "value")]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=1) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=(1, 2)) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=[1.0]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=["1"]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=1) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=[(1, 2)]) - # noqa: e800 with pytest.raises(TypeError, match=".value'"): - # noqa: e800 FullySpecializedGenericModel(value=[(1, (11, 12))]) + FullyConcretizedGenericModel(value=[(1, 2)]) + with pytest.raises(TypeError, match=".value'"): + FullyConcretizedGenericModel(value=[(1, (11, 12))]) + + def test_partial_concretization_with_typevar(self): + class PartialGenericModel(datamodels.GenericDataModel, Generic[T]): + a: T + values: List[T] + + B = TypeVar("B", bound=numbers.Number) + PartiallyConcretizedGenericModel = PartialGenericModel[B] + + PartiallyConcretizedGenericModel(a=1, values=[2, 3]) + PartiallyConcretizedGenericModel(a=-1.32, values=[2.2, 3j]) + + with pytest.raises(TypeError, match=".a'"): + PartiallyConcretizedGenericModel(a="1", values=[2, 3]) + with pytest.raises(TypeError, match=".values'"): + PartiallyConcretizedGenericModel(a=1, values=[1, "2"]) + with pytest.raises(TypeError, match=".values'"): + PartiallyConcretizedGenericModel(a=1, values=(1, 2)) # Reuse sample_type_data from test_field_type_hint @pytest.mark.parametrize(["type_hint", "valid_values", "wrong_values"], SAMPLE_TYPE_DATA) diff --git a/tests/eve_tests/unit_tests/test_type_validation.py b/tests/eve_tests/unit_tests/test_type_validation.py index 70ef033ff0..d9977f0d3a 100644 --- a/tests/eve_tests/unit_tests/test_type_validation.py +++ b/tests/eve_tests/unit_tests/test_type_validation.py @@ -28,6 +28,7 @@ ) from gt4py.eve.extended_typing import ( Any, + Callable, Dict, Final, ForwardRef, @@ -41,8 +42,8 @@ ) -VALIDATORS: Final = [type_val.simple_type_validator] -FACTORIES: Final = [type_val.simple_type_validator_factory] +VALIDATORS: Final[list[Callable]] = [type_val.simple_type_validator] +FACTORIES: Final[list[Callable]] = [type_val.simple_type_validator_factory] class SampleEnum(enum.Enum): diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py new file mode 100644 index 0000000000..0de953d85f --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_bound_args.py @@ -0,0 +1,64 @@ +# -*- coding: utf-8 -*- +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import numpy as np + +import gt4py.next as gtx +from gt4py.next import int32 + +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import cartesian_case +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + fieldview_backend, + reduction_setup, +) + + +def test_with_bound_args(cartesian_case): + @gtx.field_operator + def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField: + if not condition: + scalar = 0 + return a + scalar + + @gtx.program + def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField): + fieldop_bound_args(a, scalar, condition, out=out) + + a = cases.allocate(cartesian_case, program_bound_args, "a")() + scalar = int32(1) + ref = a + scalar + out = cases.allocate(cartesian_case, program_bound_args, "out")() + + prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True) + cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref) + + +def test_with_bound_args_order_args(cartesian_case): + @gtx.field_operator + def fieldop_args(a: cases.IField, condition: bool, scalar: int32) -> cases.IField: + scalar = 0 if not condition else scalar + return a + scalar + + @gtx.program(backend=cartesian_case.backend) + def program_args(a: cases.IField, condition: bool, scalar: int32, out: cases.IField): + fieldop_args(a, condition, scalar, out=out) + + a = cases.allocate(cartesian_case, program_args, "a")() + out = cases.allocate(cartesian_case, program_args, "out")() + + prog_bounds = program_args.with_bound_args(condition=True) + prog_bounds(a=a, scalar=int32(1), out=out, offset_provider={}) + np.allclose(out.asnumpy(), a.asnumpy() + int32(1)) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py index 992f13ea67..ca2e7c2932 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_execution.py @@ -900,26 +900,6 @@ def test_docstring(a: cases.IField): cases.verify(cartesian_case, test_docstring, a, inout=a, ref=a) -def test_with_bound_args(cartesian_case): - @gtx.field_operator - def fieldop_bound_args(a: cases.IField, scalar: int32, condition: bool) -> cases.IField: - if not condition: - scalar = 0 - return a + a + scalar - - @gtx.program - def program_bound_args(a: cases.IField, scalar: int32, condition: bool, out: cases.IField): - fieldop_bound_args(a, scalar, condition, out=out) - - a = cases.allocate(cartesian_case, program_bound_args, "a")() - scalar = int32(1) - ref = a + a + 1 - out = cases.allocate(cartesian_case, program_bound_args, "out")() - - prog_bounds = program_bound_args.with_bound_args(scalar=scalar, condition=True) - cases.verify(cartesian_case, prog_bounds, a, out, inout=out, ref=ref) - - def test_domain(cartesian_case): @gtx.field_operator def fieldop_domain(a: cases.IField) -> cases.IField: diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py index 698dce2b5c..d100cd380c 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_external_local_field.py @@ -30,16 +30,6 @@ def test_external_local_field(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def testee( inp: gtx.Field[[Vertex, V2EDim], int32], ones: gtx.Field[[Edge], int32] diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py index e8d0c8b163..e2434d860a 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_gt4py_builtins.py @@ -46,16 +46,6 @@ ids=["positive_values", "negative_values"], ) def test_maxover_execution_(unstructured_case, strategy): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - if unstructured_case.backend in [ gtfn.run_gtfn, gtfn.run_gtfn_gpu, @@ -79,16 +69,6 @@ def testee(edge_f: cases.EField) -> cases.VField: @pytest.mark.uses_unstructured_shift def test_minover_execution(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def minover(edge_f: cases.EField) -> cases.VField: out = min_over(edge_f(V2E), axis=V2EDim) @@ -102,16 +82,6 @@ def minover(edge_f: cases.EField) -> cases.VField: @pytest.mark.uses_unstructured_shift def test_reduction_execution(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def reduction(edge_f: cases.EField) -> cases.VField: return neighbor_sum(edge_f(V2E), axis=V2EDim) @@ -150,16 +120,6 @@ def fencil(edge_f: cases.EField, out: cases.VField): @pytest.mark.uses_unstructured_shift def test_reduction_with_common_expression(unstructured_case): - # TODO(edopao): remove try/catch after uplift of dace module to version > 0.15 - try: - from gt4py.next.program_processors.runners.dace_iterator import run_dace_gpu - - if unstructured_case.backend == run_dace_gpu: - # see https://github.com/spcl/dace/pull/1442 - pytest.xfail("requires fix in dace module for cuda codegen") - except ImportError: - pass - @gtx.field_operator def testee(flux: cases.EField) -> cases.VField: return neighbor_sum(flux(V2E) + flux(V2E), axis=V2EDim) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py index c86881ab7c..938c69fb52 100644 --- a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_program.py @@ -20,6 +20,7 @@ import pytest import gt4py.next as gtx +from gt4py.next import errors from next_tests.integration_tests import cases from next_tests.integration_tests.cases import IDim, Ioff, JDim, cartesian_case, fieldview_backend @@ -222,7 +223,7 @@ def test_wrong_argument_type(cartesian_case, copy_program_def): inp = cartesian_case.as_field([JDim], np.ones((cartesian_case.default_sizes[JDim],))) out = cases.allocate(cartesian_case, copy_program, "out").strategy(cases.ConstInitializer(1))() - with pytest.raises(TypeError) as exc_info: + with pytest.raises(errors.DSLError) as exc_info: # program is defined on Field[[IDim], ...], but we call with # Field[[JDim], ...] copy_program(inp, out, offset_provider={}) diff --git a/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py new file mode 100644 index 0000000000..788081b81e --- /dev/null +++ b/tests/next_tests/integration_tests/feature_tests/ffront_tests/test_temporaries_with_sizes.py @@ -0,0 +1,119 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pytest +from numpy import int32, int64 + +from gt4py import next as gtx +from gt4py.next import common +from gt4py.next.iterator.transforms import LiftMode, apply_common_transforms +from gt4py.next.program_processors import otf_compile_executor +from gt4py.next.program_processors.runners.gtfn import run_gtfn_with_temporaries + +from next_tests.integration_tests import cases +from next_tests.integration_tests.cases import ( + E2V, + Case, + KDim, + Vertex, + cartesian_case, + unstructured_case, +) +from next_tests.integration_tests.feature_tests.ffront_tests.ffront_test_utils import ( + reduction_setup, +) +from next_tests.toy_connectivity import Cell, Edge + + +@pytest.fixture +def run_gtfn_with_temporaries_and_symbolic_sizes(): + return otf_compile_executor.OTFBackend( + executor=otf_compile_executor.OTFCompileExecutor( + name="run_gtfn_with_temporaries_and_sizes", + otf_workflow=run_gtfn_with_temporaries.executor.otf_workflow.replace( + translation=run_gtfn_with_temporaries.executor.otf_workflow.translation.replace( + symbolic_domain_sizes={ + "Cell": "num_cells", + "Edge": "num_edges", + "Vertex": "num_vertices", + }, + ), + ), + ), + allocator=run_gtfn_with_temporaries.allocator, + ) + + +@pytest.fixture +def testee(): + @gtx.field_operator + def testee_op(a: cases.VField) -> cases.EField: + amul = a * 2 + return amul(E2V[0]) + amul(E2V[1]) + + @gtx.program + def prog( + a: cases.VField, + out: cases.EField, + num_vertices: int32, + num_edges: int64, + num_cells: int32, + ): + testee_op(a, out=out) + + return prog + + +def test_verification(testee, run_gtfn_with_temporaries_and_symbolic_sizes, reduction_setup): + unstructured_case = Case( + run_gtfn_with_temporaries_and_symbolic_sizes, + offset_provider=reduction_setup.offset_provider, + default_sizes={ + Vertex: reduction_setup.num_vertices, + Edge: reduction_setup.num_edges, + Cell: reduction_setup.num_cells, + KDim: reduction_setup.k_levels, + }, + grid_type=common.GridType.UNSTRUCTURED, + ) + + a = cases.allocate(unstructured_case, testee, "a")() + out = cases.allocate(unstructured_case, testee, "out")() + + first_nbs, second_nbs = (reduction_setup.offset_provider["E2V"].table[:, i] for i in [0, 1]) + ref = (a.ndarray * 2)[first_nbs] + (a.ndarray * 2)[second_nbs] + + cases.verify( + unstructured_case, + testee, + a, + out, + reduction_setup.num_vertices, + reduction_setup.num_edges, + reduction_setup.num_cells, + inout=out, + ref=ref, + ) + + +def test_temporary_symbols(testee, reduction_setup): + itir_with_tmp = apply_common_transforms( + testee.itir, + lift_mode=LiftMode.FORCE_TEMPORARIES, + offset_provider=reduction_setup.offset_provider, + ) + + params = ["num_vertices", "num_edges", "num_cells"] + for param in params: + assert any([param == str(p) for p in itir_with_tmp.fencil.params]) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py index e851e7b130..5af4605988 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/anton_lap.py @@ -18,7 +18,7 @@ from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fundef, offset from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn @fundef @@ -69,7 +69,9 @@ def lap_fencil(i_size, j_size, k_size, i_off, j_off, k_off, out, inp): output_file = sys.argv[1] prog = trace_fencil_definition(lap_fencil, [None] * 8, use_arg_types=False) - generated_code = generate(prog, offset_provider={"i": IDim, "j": JDim}) + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider={"i": IDim, "j": JDim}, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py index 33c7d5baa7..3e8b88ac66 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil.py @@ -18,7 +18,7 @@ from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fundef from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn IDim = gtx.Dimension("IDim") @@ -48,7 +48,9 @@ def copy_fencil(isize, jsize, ksize, inp, out): output_file = sys.argv[1] prog = trace_fencil_definition(copy_fencil, [None] * 5, use_arg_types=False) - generated_code = generate(prog, offset_provider={}) + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider={}, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py index f7472d4ac3..fdc57449ee 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/copy_stencil_field_view.py @@ -18,7 +18,7 @@ import gt4py.next as gtx from gt4py.next import Field, field_operator, program -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn IDim = gtx.Dimension("IDim") @@ -47,7 +47,9 @@ def copy_program( output_file = sys.argv[1] prog = copy_program.itir - generated_code = generate(prog, offset_provider={}) + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider={}, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py index 1dfd74baca..abc3755dca 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/fvm_nabla.py @@ -19,7 +19,7 @@ from gt4py.next.iterator.builtins import * from gt4py.next.iterator.runtime import closure, fundef, offset from gt4py.next.iterator.tracing import trace_fencil_definition -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn, run_gtfn_imperative E2V = offset("E2V") @@ -92,13 +92,20 @@ def mapped_index(_, __) -> int: output_file = sys.argv[1] imperative = sys.argv[2].lower() == "true" + if imperative: + backend = run_gtfn_imperative + else: + backend = run_gtfn + # prog = trace(zavgS_fencil, [None] * 4) # TODO allow generating of 2 fencils prog = trace_fencil_definition(nabla_fencil, [None] * 7, use_arg_types=False) offset_provider = { "V2E": DummyConnectivity(max_neighbors=6, has_skip_values=True), "E2V": DummyConnectivity(max_neighbors=2, has_skip_values=False), } - generated_code = generate(prog, offset_provider=offset_provider, imperative=imperative) + generated_code = backend.executor.otf_workflow.translation.generate_stencil_source( + prog, offset_provider=offset_provider, column_axis=None + ) with open(output_file, "w+") as output: output.write(generated_code) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py index 578a19faab..9755774fd0 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/cpp_backend_tests/tridiagonal_solve.py @@ -19,7 +19,7 @@ from gt4py.next.iterator.runtime import closure, fundef from gt4py.next.iterator.tracing import trace_fencil_definition from gt4py.next.iterator.transforms import LiftMode -from gt4py.next.program_processors.codegens.gtfn.gtfn_backend import generate +from gt4py.next.program_processors.runners.gtfn import run_gtfn IDim = gtx.Dimension("IDim") @@ -67,10 +67,10 @@ def tridiagonal_solve_fencil(isize, jsize, ksize, a, b, c, d, x): prog = trace_fencil_definition(tridiagonal_solve_fencil, [None] * 8, use_arg_types=False) offset_provider = {"I": gtx.Dimension("IDim"), "J": gtx.Dimension("JDim")} - generated_code = generate( + generated_code = run_gtfn.executor.otf_workflow.translation.generate_stencil_source( prog, offset_provider=offset_provider, - lift_mode=LiftMode.SIMPLE_HEURISTIC, + runtime_lift_mode=LiftMode.SIMPLE_HEURISTIC, column_axis=KDim, ) diff --git a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py index 6eb65f4a68..39e6609879 100644 --- a/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py +++ b/tests/next_tests/integration_tests/multi_feature_tests/ffront_tests/test_icon_like_scan.py @@ -12,7 +12,7 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -from dataclasses import dataclass +import dataclasses import numpy as np import pytest @@ -203,22 +203,26 @@ def test_setup(fieldview_backend): allocator=fieldview_backend, ) - @dataclass(frozen=True) + @dataclasses.dataclass(frozen=True) class setup: - case: cases.Case = test_case - cell_size = case.default_sizes[Cell] - k_size = case.default_sizes[KDim] - z_alpha = case.as_field( + case: cases.Case = dataclasses.field(default_factory=lambda: test_case) + cell_size = test_case.default_sizes[Cell] + k_size = test_case.default_sizes[KDim] + z_alpha = test_case.as_field( [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size + 1)) ) - z_beta = case.as_field( + z_beta = test_case.as_field( + [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) + ) + z_q = test_case.as_field( + [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) + ) + w = test_case.as_field( [Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size)) ) - z_q = case.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) - w = case.as_field([Cell, KDim], np.random.default_rng().uniform(size=(cell_size, k_size))) z_q_ref, w_ref = reference(z_alpha.ndarray, z_beta.ndarray, z_q.ndarray, w.ndarray) - dummy = case.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) - z_q_out = case.as_field([Cell, KDim], np.zeros((cell_size, k_size))) + dummy = test_case.as_field([Cell, KDim], np.zeros((cell_size, k_size), dtype=bool)) + z_q_out = test_case.as_field([Cell, KDim], np.zeros((cell_size, k_size))) return setup() diff --git a/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py b/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py index 052f272d22..ea1cdb82a6 100644 --- a/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py +++ b/tests/next_tests/unit_tests/ffront_tests/ast_passes_tests/test_single_static_assign.py @@ -108,7 +108,10 @@ def test_unpacking_swap(): lines = ast.unparse(ssa_ast).split("\n") assert lines[0] == f"a{SEP}0 = 5" assert lines[1] == f"b{SEP}0 = 1" - assert lines[2] == f"(b{SEP}1, a{SEP}1) = (a{SEP}0, b{SEP}0)" + assert lines[2] in [ + f"(b{SEP}1, a{SEP}1) = (a{SEP}0, b{SEP}0)", + f"b{SEP}1, a{SEP}1 = (a{SEP}0, b{SEP}0)", + ] # unparse produces different parentheses in different Python versions def test_annotated_assign(): diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py index 86c3c98c62..5c2802f90c 100644 --- a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_global_tmps.py @@ -323,7 +323,7 @@ def test_update_cartesian_domains(): for a, s in (("JDim", "j"), ("KDim", "k")) ], ) - actual = update_domains(testee, {"I": gtx.Dimension("IDim")}) + actual = update_domains(testee, {"I": gtx.Dimension("IDim")}, symbolic_sizes=None) assert actual == expected diff --git a/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py new file mode 100644 index 0000000000..ae23becb4c --- /dev/null +++ b/tests/next_tests/unit_tests/iterator_tests/transforms_tests/test_power_unrolling.py @@ -0,0 +1,161 @@ +# GT4Py - GridTools Framework +# +# Copyright (c) 2014-2023, ETH Zurich +# All rights reserved. +# +# This file is part of the GT4Py project and the GridTools framework. +# GT4Py is free software: you can redistribute it and/or modify it under +# the terms of the GNU General Public License as published by the +# Free Software Foundation, either version 3 of the License, or any later +# version. See the LICENSE.txt file at the top-level directory of this +# distribution for a copy of the license or check . +# +# SPDX-License-Identifier: GPL-3.0-or-later + +import pytest + +from gt4py.eve import SymbolRef +from gt4py.next.iterator import ir +from gt4py.next.iterator.ir_utils import ir_makers as im +from gt4py.next.iterator.transforms.power_unrolling import PowerUnrolling + + +def test_power_unrolling_zero(): + pytest.xfail( + "Not implementeds we don't have an easy way to determine the type of the one literal (type inference is to expensive)." + ) + testee = im.call("power")("x", 0) + expected = im.literal_from_value(1) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_one(): + testee = im.call("power")("x", 1) + expected = ir.SymRef(id=SymbolRef("x")) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_two(): + testee = im.call("power")("x", 2) + expected = im.multiplies_("x", "x") + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_two_x_plus_two(): + testee = im.call("power")(im.plus("x", 2), 2) + expected = im.let("power_1", im.plus("x", 2))( + im.let("power_2", im.multiplies_("power_1", "power_1"))("power_2") + ) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_two_x_plus_one_times_three(): + testee = im.call("power")(im.multiplies_(im.plus("x", 1), 3), 2) + expected = im.let("power_1", im.multiplies_(im.plus("x", 1), 3))( + im.let("power_2", im.multiplies_("power_1", "power_1"))("power_2") + ) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_three(): + testee = im.call("power")("x", 3) + expected = im.multiplies_(im.multiplies_("x", "x"), "x") + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_four(): + testee = im.call("power")("x", 4) + expected = im.let("power_2", im.multiplies_("x", "x"))(im.multiplies_("power_2", "power_2")) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_five(): + testee = im.call("power")("x", 5) + tmp2 = im.multiplies_("x", "x") + expected = im.multiplies_(im.multiplies_(tmp2, tmp2), "x") + expected = im.let("power_2", im.multiplies_("x", "x"))( + im.multiplies_(im.multiplies_("power_2", "power_2"), "x") + ) + + actual = PowerUnrolling.apply(testee) + assert actual == expected + + +def test_power_unrolling_seven(): + testee = im.call("power")("x", 7) + expected = im.call("power")("x", 7) + + actual = PowerUnrolling.apply(testee, max_unroll=5) + assert actual == expected + + +def test_power_unrolling_seven_unrolled(): + testee = im.call("power")("x", 7) + expected = im.let("power_2", im.multiplies_("x", "x"))( + im.multiplies_(im.multiplies_(im.multiplies_("power_2", "power_2"), "power_2"), "x") + ) + + actual = PowerUnrolling.apply(testee, max_unroll=7) + assert actual == expected + + +def test_power_unrolling_seven_x_plus_one_unrolled(): + testee = im.call("power")(im.plus("x", 1), 7) + expected = im.let("power_1", im.plus("x", 1))( + im.let("power_2", im.multiplies_("power_1", "power_1"))( + im.let("power_4", im.multiplies_("power_2", "power_2"))( + im.multiplies_(im.multiplies_("power_4", "power_2"), "power_1") + ) + ) + ) + + actual = PowerUnrolling.apply(testee, max_unroll=7) + assert actual == expected + + +def test_power_unrolling_eight(): + testee = im.call("power")("x", 8) + expected = im.call("power")("x", 8) + + actual = PowerUnrolling.apply(testee, max_unroll=5) + assert actual == expected + + +def test_power_unrolling_eight_unrolled(): + testee = im.call("power")("x", 8) + expected = im.let("power_2", im.multiplies_("x", "x"))( + im.let("power_4", im.multiplies_("power_2", "power_2"))( + im.multiplies_("power_4", "power_4") + ) + ) + + actual = PowerUnrolling.apply(testee, max_unroll=8) + assert actual == expected + + +def test_power_unrolling_eight_x_plus_one_unrolled(): + testee = im.call("power")(im.plus("x", 1), 8) + expected = im.let("power_1", im.plus("x", 1))( + im.let("power_2", im.multiplies_("power_1", "power_1"))( + im.let("power_4", im.multiplies_("power_2", "power_2"))( + im.let("power_8", im.multiplies_("power_4", "power_4"))("power_8") + ) + ) + ) + + actual = PowerUnrolling.apply(testee, max_unroll=8) + assert actual == expected diff --git a/tox.ini b/tox.ini index 44dc912c8a..817f721f71 100644 --- a/tox.ini +++ b/tox.ini @@ -11,21 +11,24 @@ envlist = # docs labels = test-cartesian-cpu = cartesian-py38-internal-cpu, cartesian-py39-internal-cpu, cartesian-py310-internal-cpu, \ - cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu + cartesian-py311-internal-cpu, cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, \ + cartesian-py311-dace-cpu - test-eve-cpu = eve-py38, eve-py39, eve-py310 + test-eve-cpu = eve-py38, eve-py39, eve-py310, eve-py311 - test-next-cpu = next-py310-nomesh, next-py310-atlas + test-next-cpu = next-py310-nomesh, next-py311-nomesh, next-py310-atlas, next-py311-atlas test-storage-cpu = storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, \ - storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu + storage-py311-internal-cpu, storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu, \ + storage-py311-dace-cpu test-cpu = cartesian-py38-internal-cpu, cartesian-py39-internal-cpu, cartesian-py310-internal-cpu, \ - cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, \ - eve-py38, eve-py39, eve-py310, \ - next-py310-nomesh, next-py310-atlas, \ - storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, \ - storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu + cartesian-py311-internal-cpu, cartesian-py38-dace-cpu, cartesian-py39-dace-cpu, cartesian-py310-dace-cpu, \ + cartesian-py311-dace-cpu, \ + eve-py38, eve-py39, eve-py310, eve-py311, \ + next-py310-nomesh, next-py311-nomesh, next-py310-atlas, next-py311-atlas, \ + storage-py38-internal-cpu, storage-py39-internal-cpu, storage-py310-internal-cpu, storage-py311-internal-cpu, \ + storage-py38-dace-cpu, storage-py39-dace-cpu, storage-py310-dace-cpu, storage-py311-dace-cpu [testenv] deps = -r {tox_root}{/}{env:ENV_REQUIREMENTS_FILE:requirements-dev.txt} @@ -44,7 +47,7 @@ pass_env = NUM_PROCESSES set_env = PYTHONWARNINGS = {env:PYTHONWARNINGS:ignore:Support for `[tool.setuptools]` in `pyproject.toml` is still *beta*:UserWarning} -[testenv:cartesian-py{38,39,310}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] +[testenv:cartesian-py{38,39,310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] description = Run 'gt4py.cartesian' tests pass_env = {[testenv]pass_env}, BOOST_ROOT, BOOST_HOME, CUDA_HOME, CUDA_PATH, CXX, CC, OPENMP_CPPFLAGS, OPENMP_LDFLAGS, PIP_USER, PYTHONUSERBASE allowlist_externals = @@ -65,13 +68,13 @@ commands = ; coverage json --rcfile=setup.cfg ; coverage html --rcfile=setup.cfg --show-contexts -[testenv:eve-py{38,39,310}] +[testenv:eve-py{38,39,310,311}] description = Run 'gt4py.eve' tests commands = python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} {posargs} tests{/}eve_tests python -m pytest --doctest-modules src{/}gt4py{/}eve -[testenv:next-py{310}-{nomesh,atlas}-{cpu,cuda,cuda11x,cuda12x}] +[testenv:next-py{310,311}-{nomesh,atlas}-{cpu,cuda,cuda11x,cuda12x}] description = Run 'gt4py.next' tests pass_env = {[testenv]pass_env}, BOOST_ROOT, BOOST_HOME, CUDA_HOME, CUDA_PATH deps = @@ -87,14 +90,14 @@ commands = # atlas-{cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_atlas and requires_gpu" {posargs} tests{/}next_tests # TODO(ricoh): activate when such tests exist pytest --doctest-modules src{/}gt4py{/}next -[testenv:storage-py{38,39,310}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] +[testenv:storage-py{38,39,310,311}-{internal,dace}-{cpu,cuda,cuda11x,cuda12x}] description = Run 'gt4py.storage' tests commands = cpu: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "not requires_gpu" {posargs} tests{/}storage_tests {cuda,cuda11x,cuda12x}: python -m pytest --cache-clear -v -n {env:NUM_PROCESSES:1} -m "requires_gpu" {posargs} tests{/}storage_tests #pytest doctest-modules {posargs} src{/}gt4py{/}storage -[testenv:linters-py{38,39,310}] +[testenv:linters-py{38,39,310,311}] description = Run linters commands = flake8 .{/}src @@ -134,11 +137,13 @@ description = py38: Update requirements for testing a specific python version py39: Update requirements for testing a specific python version py310: Update requirements for testing a specific python version + py311: Update requirements for testing a specific python version base_python = common: py38 py38: py38 py39: py39 py310: py310 + py311: py311 deps = cogapp>=3.3 pip-tools>=6.10 @@ -178,7 +183,7 @@ commands = # Run cog to update .pre-commit-config.yaml with new versions common: cog -r -P .pre-commit-config.yaml -[testenv:dev-py{38,39,310}{-atlas,}] +[testenv:dev-py{38,39,310,311}{-atlas,}] description = Initialize development environment for gt4py deps = -r {tox_root}{/}requirements-dev.txt