Skip to content

Commit

Permalink
Add chex to dependencies
Browse files Browse the repository at this point in the history
  • Loading branch information
jakevdp committed Jan 15, 2025
1 parent d27cae2 commit 9c6b94f
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 1 deletion.
27 changes: 26 additions & 1 deletion .github/workflows/nightly.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ jobs:
run: |
python -m pip install --upgrade pip
python -m pip install pytest pytest-xdist
python -m pip install -U jax flax grain ml_dtypes optax orbax-checkpoint orbax-export tensorflow tensorflow_datasets
python -m pip install -U chex jax flax grain ml_dtypes optax orbax-checkpoint orbax-export tensorflow tensorflow_datasets
- name: Run tests
run: |
pytest -n auto jax_ai_stack
Expand Down Expand Up @@ -147,3 +147,28 @@ jobs:
if: failure() && github.event.pull_request == null
with:
github-token: ${{ secrets.GITHUB_TOKEN }}

chex-nightly:
name: Test with chex nightly
runs-on: ubuntu-latest
timeout-minutes: 10
strategy:
fail-fast: false
steps:
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- uses: actions/setup-python@0b93645e9fea7318ecaed2b359559ac225c90a2b # v5.3.0
with:
python-version: 3.12
- name: Install dependencies with chex nightly
run: |
python -m pip install --upgrade pip
python -m pip install .[dev,tfds,grain]
python -m pip install --upgrade 'git+https://github.com/google-deepmind/chex/'
- name: Run tests
run: |
pytest -n auto jax_ai_stack
- name: Notify failed build
uses: jayqi/failed-build-issue-action@1a893bbf43ef1c2a8705e2b115cd4f0fe3c5649b # v1.2.0
if: failure() && github.event.pull_request == null
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ together via the integration tests in this repository. Packages include:
- [ml_dtypes](http://github.com/jax-ml/ml_dtypes): NumPy dtype extensions for machine learning.
- [optax](https://github.com/google-deepmind/optax): gradient processing and optimization in JAX.
- [orbax](https://github.com/google/orbax): checkpointing and persistence utilities for JAX.
- [chex](https://github.com/google-deepmind/chex): utilities for writing reliable JAX code.

### Optional packages

Expand Down
41 changes: 41 additions & 0 deletions jax_ai_stack/tests/test_chex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright 2024 The JAX Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import chex
import jax
import jax.numpy as jnp


class ChexTest(unittest.TestCase):

def test_chex_dataclass(self):
@chex.dataclass
class Params:
x: chex.ArrayDevice
y: chex.ArrayDevice

params = Params(
x=jnp.arange(4),
y=jnp.ones(10),
)

updated = jax.tree.map(lambda x: 2.0 * x, params)

chex.assert_trees_all_close(updated.x, jnp.arange(0, 8, 2))
chex.assert_trees_all_close(updated.y, jnp.full(10, fill_value=2.0))


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ keywords = []

# pip dependencies of the project
dependencies = [
"chex==0.1.88",
"jax==0.4.38",
"flax==0.10.2",
"ml_dtypes==0.4.0",
Expand Down

0 comments on commit 9c6b94f

Please sign in to comment.