Skip to content

Commit

Permalink
Added EQX_GETKEY_SEED for use in reproducing tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Dec 10, 2023
1 parent 5926a4e commit 6841212
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 1 deletion.
5 changes: 5 additions & 0 deletions equinox/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,8 @@
EQX_ON_ERROR_BREAKPOINT_FRAMES = os.environ.get("EQX_ON_ERROR_BREAKPOINT_FRAMES", None)
if EQX_ON_ERROR_BREAKPOINT_FRAMES is not None:
EQX_ON_ERROR_BREAKPOINT_FRAMES = int(EQX_ON_ERROR_BREAKPOINT_FRAMES)

try:
EQX_GETKEY_SEED = int(os.environ["EQX_GETKEY_SEED"])
except KeyError:
EQX_GETKEY_SEED = None
4 changes: 3 additions & 1 deletion equinox/internal/_getkey.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import jax.random as jr
from jaxtyping import PRNGKeyArray

from .._config import EQX_GETKEY_SEED


# This offers reproducability -- the initial seed is printed in the repr so we can see
# it when a test fails.
Expand Down Expand Up @@ -32,7 +34,7 @@ def getkey():
call: int
key: PRNGKeyArray

def __init__(self, seed: Optional[int] = None):
def __init__(self, seed: Optional[int] = EQX_GETKEY_SEED):
if seed is None:
seed = random.randint(0, 2**31 - 1)
self.seed = seed
Expand Down

0 comments on commit 6841212

Please sign in to comment.