diff --git a/.github/workflows/nightly.yaml b/.github/workflows/nightly.yaml index 187843e..7900f21 100644 --- a/.github/workflows/nightly.yaml +++ b/.github/workflows/nightly.yaml @@ -38,7 +38,7 @@ jobs: run: | python -m pip install --upgrade pip python -m pip install pytest pytest-xdist - python -m pip install -U jax flax grain ml_dtypes optax orbax-checkpoint orbax-export tensorflow tensorflow_datasets + python -m pip install -U chex jax flax grain ml_dtypes optax orbax-checkpoint orbax-export tensorflow tensorflow_datasets - name: Run tests run: | pytest -n auto jax_ai_stack @@ -147,3 +147,28 @@ jobs: if: failure() && github.event.pull_request == null with: github-token: ${{ secrets.GITHUB_TOKEN }} + + chex-nightly: + name: Test with chex nightly + runs-on: ubuntu-latest + timeout-minutes: 10 + strategy: + fail-fast: false + steps: + - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + - uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0 + with: + python-version: 3.12 + - name: Install dependencies with chex nightly + run: | + python -m pip install --upgrade pip + python -m pip install .[dev,tfds,grain] + python -m pip install --upgrade 'git+https://github.com/google-deepmind/chex/' + - name: Run tests + run: | + pytest -n auto jax_ai_stack + - name: Notify failed build + uses: jayqi/failed-build-issue-action@1a893bbf43ef1c2a8705e2b115cd4f0fe3c5649b # v1.2.0 + if: failure() && github.event.pull_request == null + with: + github-token: ${{ secrets.GITHUB_TOKEN }} diff --git a/README.md b/README.md index c589572..a157c37 100644 --- a/README.md +++ b/README.md @@ -43,6 +43,7 @@ together via the integration tests in this repository. Packages include: - [ml_dtypes](http://github.com/jax-ml/ml_dtypes): NumPy dtype extensions for machine learning. - [optax](https://github.com/google-deepmind/optax): gradient processing and optimization in JAX. - [orbax](https://github.com/google/orbax): checkpointing and persistence utilities for JAX. +- [chex](https://github.com/google-deepmind/chex): utilities for writing reliable JAX code. ### Optional packages diff --git a/jax_ai_stack/tests/test_chex.py b/jax_ai_stack/tests/test_chex.py new file mode 100644 index 0000000..43fd31a --- /dev/null +++ b/jax_ai_stack/tests/test_chex.py @@ -0,0 +1,41 @@ +# Copyright 2024 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import chex +import jax +import jax.numpy as jnp + + +class ChexTest(unittest.TestCase): + + def test_chex_dataclass(self): + @chex.dataclass + class Params: + x: chex.ArrayDevice + y: chex.ArrayDevice + + params = Params( + x=jnp.arange(4), + y=jnp.ones(10), + ) + + updated = jax.tree.map(lambda x: 2.0 * x, params) + + chex.assert_trees_all_close(updated.x, jnp.arange(0, 8, 2)) + chex.assert_trees_all_close(updated.y, jnp.full(10, fill_value=2.0)) + + +if __name__ == '__main__': + unittest.main() diff --git a/pyproject.toml b/pyproject.toml index c0d2e85..9e6c61c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -15,6 +15,7 @@ keywords = [] # pip dependencies of the project dependencies = [ + "chex==0.1.88", "jax==0.4.38", "flax==0.10.2", "ml_dtypes==0.4.0",