diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index ae987ce..142dc48 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -59,6 +59,7 @@ jobs: - numba - torch - tf + - jax steps: - uses: actions/checkout@v3 diff --git a/docs/api.rst b/docs/api.rst index f26ef23..b7fbd87 100644 --- a/docs/api.rst +++ b/docs/api.rst @@ -69,3 +69,4 @@ use these functions to obtain random number generators. .. autofunction:: numpy_rng .. autofunction:: numpy_random_state .. autofunction:: cupy_rng +.. autofunction:: jax_key diff --git a/docs/conf.py b/docs/conf.py index 12c2f8f..6d0d58a 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -32,6 +32,7 @@ "python": ("https://docs.python.org/3/", None), "numpy": ("https://docs.scipy.org/doc/numpy/", None), "sklearn": ("https://scikit-learn.org/stable/", None), + "jax": ("https://jax.readthedocs.io/en/latest/", None), } autodoc_default_options = {"members": True, "member-order": "bysource"} diff --git a/pyproject.toml b/pyproject.toml index dfa0475..572a590 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,6 +57,7 @@ torch = ["torch"] cupy = ["cupy"] cuda11x = ["cupy-cuda11x"] tf = ["tensorflow >=2,<3"] +jax = ["jax[cpu]>=0.4.16"] [project.urls] Homepage = "https://seedbank.lenksit.org" diff --git a/seedbank/__init__.py b/seedbank/__init__.py index 8b8c4e5..8ebca0e 100644 --- a/seedbank/__init__.py +++ b/seedbank/__init__.py @@ -35,6 +35,7 @@ "numpy_rng", "numpy_random_state", "cupy_rng", + "jax_key", "SeedLike", ] @@ -136,5 +137,6 @@ def int_seed( from seedbank._config import init_file # noqa: E402 from seedbank.cupy import cupy_rng # noqa: E402 +from seedbank.jax import jax_key # noqa: E402 from seedbank.numpy import numpy_random_state, numpy_rng # noqa: E402 from seedbank.stdlib import std_rng # noqa: E402 diff --git a/seedbank/jax.py b/seedbank/jax.py new file mode 100644 index 0000000..ba62487 --- /dev/null +++ b/seedbank/jax.py @@ -0,0 +1,48 @@ +""" +JAX support. + +Jax has no global random seeds, but we support making Jax keys. +""" + +# pyright: basic, reportAttributeAccessIssue=false +from __future__ import annotations + +from typing import Optional + +try: + import jax + + AVAILABLE = True +except ImportError: + AVAILABLE = False + +from . import derive_seed +from ._keys import SeedLike, make_seed + + +def jax_key( + spec: Optional[SeedLike] = None, +) -> jax.Array: + """ + Get a Jax random key (see :func:`jax.random.key`). Jax does not use global + state, instead relying on explicit random state management. This function + allows you to obtain an initial key for a set of random operations from the + Seedbank key. + + Args: + spec: + The spec from which to generate the key. The same spec will produce + the same key. + + Returns: + A random number generator. + """ + if not AVAILABLE: + raise RuntimeError("jax not importable") + + if spec is None: + seed = derive_seed() + else: + seed = make_seed(spec) + data = seed.generate_state(1, dtype="u8")[0] + return jax.random.key(data) diff --git a/tests/test_jax.py b/tests/test_jax.py new file mode 100644 index 0000000..82037e6 --- /dev/null +++ b/tests/test_jax.py @@ -0,0 +1,41 @@ +""" +stdlib python tests +""" + +import random + +from pytest import mark + +from seedbank import jax_key + +try: + import jax + import jax.numpy as jnp +except ImportError: + pytestmark = mark.skip("JAX not available") + + +def test_jax_key(): + """ + Make sure we get an stdlib RNG. + """ + key = jax_key() + assert isinstance(key, jax.Array) + + +def test_jax_newkey(): + """ + Test that two stdlib RNGs with fresh seeds return different numbers. + """ + k1 = jax_key() + k2 = jax_key() + assert not jnp.equal(k1, k2) + + +def test_jax_samekey(): + """ + Test that two stdlib RNGs with fresh seeds return different numbers. + """ + k1 = jax_key("foo") + k2 = jax_key("foo") + assert jnp.equal(k1, k2)