Skip to content

Commit

Permalink
Now defaulting to providing only one frame to the on-error debugger.
Browse files Browse the repository at this point in the history
This is because otherwise we pretty frequently bump into a JAX bug. Better to have folks control the bugginess explicitly...
  • Loading branch information
patrick-kidger committed Jul 7, 2024
1 parent 8efb0fc commit 70d664c
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 9 deletions.
4 changes: 3 additions & 1 deletion equinox/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,9 @@


EQX_ON_ERROR_BREAKPOINT_FRAMES = os.environ.get("EQX_ON_ERROR_BREAKPOINT_FRAMES", None)
if EQX_ON_ERROR_BREAKPOINT_FRAMES is not None:
if EQX_ON_ERROR_BREAKPOINT_FRAMES is None:
EQX_ON_ERROR_BREAKPOINT_FRAMES = 1
else:
EQX_ON_ERROR_BREAKPOINT_FRAMES = int(EQX_ON_ERROR_BREAKPOINT_FRAMES)

try:
Expand Down
8 changes: 4 additions & 4 deletions equinox/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,10 +199,10 @@ def error_if(
permanently fixing this value is not recommended.
- You will need to also pass the `-s` flag to `pytest`, if you are
also using that.
- This will sometimes raise a trace-time error due to JAX bug
[#16732](https://github.com/google/jax/issues/16732). (Bugs whilst debugging
bugs, eek!) If this happens, then it can be worked around by additionally
setting the `EQX_ON_ERROR_BREAKPOINT_FRAMES` variable to a small integer,
- By default this only allows you to see a single frame in the debugger. This is
to work around JAX bug [#16732](https://github.com/google/jax/issues/16732).
(Bugs whilst debugging bugs, eek!) In practice you may like to set the
`EQX_ON_ERROR_BREAKPOINT_FRAMES` environment variable to a small integer,
which specifies how many frames upwards the debugger should capture. The
JAX bug is triggered when taking too many frames.
Expand Down
7 changes: 3 additions & 4 deletions equinox/_jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,9 @@ class XlaRuntimeError(Exception):
-------
This error occurred during the runtime of your JAX program. Setting the environment
variable `EQX_ON_ERROR=breakpoint` is usually the most useful way to debug such errors.
(This can be navigated using most of the usual commands for the Python debugger:
`u` and `d` to move through stack frames, the name of a variable to print its value,
etc.) See also `https://docs.kidger.site/equinox/api/errors/#equinox.error_if` for more
information.
(This can be interacted with using most of the usual commands for the Python debugger:
the name of a variable to print its value, etc.) It is recommended to read
`https://docs.kidger.site/equinox/api/errors/#equinox.error_if` for more information.
"""


Expand Down

0 comments on commit 70d664c

Please sign in to comment.