From 70d664ce354c5ed2ce952f830f805dd51d578409 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sat, 6 Jul 2024 13:53:31 +0200 Subject: [PATCH] Now defaulting to providing only one frame to the on-error debugger. This is because otherwise we pretty frequently bump into a JAX bug. Better to have folks control the bugginess explicitly... --- equinox/_config.py | 4 +++- equinox/_errors.py | 8 ++++---- equinox/_jit.py | 7 +++---- 3 files changed, 10 insertions(+), 9 deletions(-) diff --git a/equinox/_config.py b/equinox/_config.py index 0d4c5c6e..25aa019a 100644 --- a/equinox/_config.py +++ b/equinox/_config.py @@ -19,7 +19,9 @@ EQX_ON_ERROR_BREAKPOINT_FRAMES = os.environ.get("EQX_ON_ERROR_BREAKPOINT_FRAMES", None) -if EQX_ON_ERROR_BREAKPOINT_FRAMES is not None: +if EQX_ON_ERROR_BREAKPOINT_FRAMES is None: + EQX_ON_ERROR_BREAKPOINT_FRAMES = 1 +else: EQX_ON_ERROR_BREAKPOINT_FRAMES = int(EQX_ON_ERROR_BREAKPOINT_FRAMES) try: diff --git a/equinox/_errors.py b/equinox/_errors.py index 0ffb4cf0..be7debe5 100644 --- a/equinox/_errors.py +++ b/equinox/_errors.py @@ -199,10 +199,10 @@ def error_if( permanently fixing this value is not recommended. - You will need to also pass the `-s` flag to `pytest`, if you are also using that. - - This will sometimes raise a trace-time error due to JAX bug - [#16732](https://github.com/google/jax/issues/16732). (Bugs whilst debugging - bugs, eek!) If this happens, then it can be worked around by additionally - setting the `EQX_ON_ERROR_BREAKPOINT_FRAMES` variable to a small integer, + - By default this only allows you to see a single frame in the debugger. This is + to work around JAX bug [#16732](https://github.com/google/jax/issues/16732). + (Bugs whilst debugging bugs, eek!) In practice you may like to set the + `EQX_ON_ERROR_BREAKPOINT_FRAMES` environment variable to a small integer, which specifies how many frames upwards the debugger should capture. The JAX bug is triggered when taking too many frames. diff --git a/equinox/_jit.py b/equinox/_jit.py index 81a4899e..42a711ed 100644 --- a/equinox/_jit.py +++ b/equinox/_jit.py @@ -114,10 +114,9 @@ class XlaRuntimeError(Exception): ------- This error occurred during the runtime of your JAX program. Setting the environment variable `EQX_ON_ERROR=breakpoint` is usually the most useful way to debug such errors. -(This can be navigated using most of the usual commands for the Python debugger: -`u` and `d` to move through stack frames, the name of a variable to print its value, -etc.) See also `https://docs.kidger.site/equinox/api/errors/#equinox.error_if` for more -information. +(This can be interacted with using most of the usual commands for the Python debugger: +the name of a variable to print its value, etc.) It is recommended to read +`https://docs.kidger.site/equinox/api/errors/#equinox.error_if` for more information. """