diff --git a/equinox/_config.py b/equinox/_config.py index 0d4c5c6e..25aa019a 100644 --- a/equinox/_config.py +++ b/equinox/_config.py @@ -19,7 +19,9 @@ EQX_ON_ERROR_BREAKPOINT_FRAMES = os.environ.get("EQX_ON_ERROR_BREAKPOINT_FRAMES", None) -if EQX_ON_ERROR_BREAKPOINT_FRAMES is not None: +if EQX_ON_ERROR_BREAKPOINT_FRAMES is None: + EQX_ON_ERROR_BREAKPOINT_FRAMES = 1 +else: EQX_ON_ERROR_BREAKPOINT_FRAMES = int(EQX_ON_ERROR_BREAKPOINT_FRAMES) try: diff --git a/equinox/_errors.py b/equinox/_errors.py index 0ffb4cf0..be7debe5 100644 --- a/equinox/_errors.py +++ b/equinox/_errors.py @@ -199,10 +199,10 @@ def error_if( permanently fixing this value is not recommended. - You will need to also pass the `-s` flag to `pytest`, if you are also using that. - - This will sometimes raise a trace-time error due to JAX bug - [#16732](https://github.com/google/jax/issues/16732). (Bugs whilst debugging - bugs, eek!) If this happens, then it can be worked around by additionally - setting the `EQX_ON_ERROR_BREAKPOINT_FRAMES` variable to a small integer, + - By default this only allows you to see a single frame in the debugger. This is + to work around JAX bug [#16732](https://github.com/google/jax/issues/16732). + (Bugs whilst debugging bugs, eek!) In practice you may like to set the + `EQX_ON_ERROR_BREAKPOINT_FRAMES` environment variable to a small integer, which specifies how many frames upwards the debugger should capture. The JAX bug is triggered when taking too many frames. diff --git a/equinox/_jit.py b/equinox/_jit.py index 81a4899e..42a711ed 100644 --- a/equinox/_jit.py +++ b/equinox/_jit.py @@ -114,10 +114,9 @@ class XlaRuntimeError(Exception): ------- This error occurred during the runtime of your JAX program. Setting the environment variable `EQX_ON_ERROR=breakpoint` is usually the most useful way to debug such errors. -(This can be navigated using most of the usual commands for the Python debugger: -`u` and `d` to move through stack frames, the name of a variable to print its value, -etc.) See also `https://docs.kidger.site/equinox/api/errors/#equinox.error_if` for more -information. +(This can be interacted with using most of the usual commands for the Python debugger: +the name of a variable to print its value, etc.) It is recommended to read +`https://docs.kidger.site/equinox/api/errors/#equinox.error_if` for more information. """