Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Jax key support #30

Merged
merged 11 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ jobs:
- numba
- torch
- tf
- jax

steps:
- uses: actions/checkout@v3
Expand Down
1 change: 1 addition & 0 deletions docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,4 @@ use these functions to obtain random number generators.
.. autofunction:: numpy_rng
.. autofunction:: numpy_random_state
.. autofunction:: cupy_rng
.. autofunction:: jax_key
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"python": ("https://docs.python.org/3/", None),
"numpy": ("https://docs.scipy.org/doc/numpy/", None),
"sklearn": ("https://scikit-learn.org/stable/", None),
"jax": ("https://jax.readthedocs.io/en/latest/", None),
}

autodoc_default_options = {"members": True, "member-order": "bysource"}
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ torch = ["torch"]
cupy = ["cupy"]
cuda11x = ["cupy-cuda11x"]
tf = ["tensorflow >=2,<3"]
jax = ["jax[cpu]>=0.4.16"]

[project.urls]
Homepage = "https://seedbank.lenksit.org"
Expand Down
2 changes: 2 additions & 0 deletions seedbank/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
"numpy_rng",
"numpy_random_state",
"cupy_rng",
"jax_key",
"SeedLike",
]

Expand Down Expand Up @@ -136,5 +137,6 @@ def int_seed(

from seedbank._config import init_file # noqa: E402
from seedbank.cupy import cupy_rng # noqa: E402
from seedbank.jax import jax_key # noqa: E402
from seedbank.numpy import numpy_random_state, numpy_rng # noqa: E402
from seedbank.stdlib import std_rng # noqa: E402
48 changes: 48 additions & 0 deletions seedbank/jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""
JAX support.

Jax has no global random seeds, but we support making Jax keys.
"""

# pyright: basic, reportAttributeAccessIssue=false
from __future__ import annotations

from typing import Optional

try:
import jax

AVAILABLE = True
except ImportError:
AVAILABLE = False

from . import derive_seed
from ._keys import SeedLike, make_seed


def jax_key(
spec: Optional[SeedLike] = None,
) -> jax.Array:
"""
Get a Jax random key (see :func:`jax.random.key`). Jax does not use global
state, instead relying on explicit random state management. This function
allows you to obtain an initial key for a set of random operations from the
Seedbank key.

Args:
spec:
The spec from which to generate the key. The same spec will produce
the same key.

Returns:
A random number generator.
"""
if not AVAILABLE:
raise RuntimeError("jax not importable")

if spec is None:
seed = derive_seed()
else:
seed = make_seed(spec)
data = seed.generate_state(1, dtype="u8")[0]
return jax.random.key(data)
41 changes: 41 additions & 0 deletions tests/test_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""
stdlib python tests
"""

import random

from pytest import mark

from seedbank import jax_key

try:
import jax
import jax.numpy as jnp
except ImportError:
pytestmark = mark.skip("JAX not available")


def test_jax_key():
"""
Make sure we get an stdlib RNG.
"""
key = jax_key()
assert isinstance(key, jax.Array)


def test_jax_newkey():
"""
Test that two stdlib RNGs with fresh seeds return different numbers.
"""
k1 = jax_key()
k2 = jax_key()
assert not jnp.equal(k1, k2)


def test_jax_samekey():
"""
Test that two stdlib RNGs with fresh seeds return different numbers.
"""
k1 = jax_key("foo")
k2 = jax_key("foo")
assert jnp.equal(k1, k2)
Loading