-
-
Notifications
You must be signed in to change notification settings - Fork 150
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' of https://github.com/lockwo/equinox into filter-…
…hessian
- Loading branch information
Showing
147 changed files
with
11,199 additions
and
2,445 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,23 +1,14 @@ | ||
repos: | ||
- repo: https://github.com/ambv/black | ||
rev: 22.3.0 | ||
hooks: | ||
- id: black | ||
- repo: https://github.com/charliermarsh/ruff-pre-commit | ||
rev: 'v0.0.255' | ||
hooks: | ||
- id: ruff | ||
args: ["--fix"] | ||
- repo: https://github.com/astral-sh/ruff-pre-commit | ||
rev: v0.1.7 | ||
hooks: | ||
- id: ruff # linter | ||
types_or: [ python, pyi, jupyter ] | ||
args: [ --fix ] | ||
- id: ruff-format # formatter | ||
types_or: [ python, pyi, jupyter ] | ||
- repo: https://github.com/RobertCraigie/pyright-python | ||
rev: v1.1.302 | ||
rev: v1.1.315 | ||
hooks: | ||
- id: pyright | ||
additional_dependencies: [beartype, einops, jax, jaxtyping, pytest, tensorflow, tf2onnx, typing_extensions] | ||
- repo: https://github.com/nbQA-dev/nbQA | ||
rev: 1.6.3 | ||
hooks: | ||
- id: nbqa-black | ||
additional_dependencies: [ipython==8.12, black] | ||
- id: nbqa-ruff | ||
args: ["--ignore=I001"] | ||
additional_dependencies: [ipython==8.12, ruff] | ||
additional_dependencies: [beartype, einops, jax, jaxtyping, optax, pytest, tensorflow, tf2onnx, typing_extensions] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
# Clear caches | ||
|
||
::: equinox.clear_caches |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
# Debugging tools | ||
|
||
Both Equinox and JAX provide a number of debugging tools. | ||
|
||
## Common sources of NaNs | ||
|
||
A common source of NaNs on the forward pass is calling `jnp.log` or `jnp.sqrt` on a negative number, or when dividing by zero. If you get a NaN whilst using these operations, check their inputs carefully (e.g. with `jax.debug.print`). | ||
|
||
A common source of NaNs when backpropagating is when using one of the above operations with a `jnp.where`, for example `y = jnp.where(x > 0, jnp.log(x), 0)`. In this case the NaN is created on the forward pass, but is then masked by the `jnp.where`. Unfortunately, when backpropagating, the order of the `log` and the `where` is flipped -- and the NaN is no longer masked! The solution is to use the "double where" trick: bracket your computation by a `where` on both sides. For this example, `safe_x = jnp.where(x > 0, x, 1); y = jnp.where(x > 0, jnp.log(safe_x), 0)`. This ensures that the NaN is never created in the first place at all. | ||
|
||
## Debugging runtime errors | ||
|
||
If you are getting a runtime error from [`equinox.error_if`][], then you can control the on-error behaviour via the environment variable `EQX_ON_ERROR`. In particular, setting `EQX_ON_ERROR=breakpoint` will open a `jax.debug.breakpoint` where the error arises. See the [runtime errors](./errors.md) for more information and for other values of this environment variable. | ||
|
||
If ran from `jax.jit`, then the [`equinox.error_if`][] error will be a long error message starting `INTERNAL: Generated function failed: CpuCallback error: RuntimeError: ...`. You may prefer to use `eqx.filter_jit`, which will remove some of the extra boilerplate from the error message. | ||
|
||
## JAX tools | ||
|
||
JAX itself provides the following tools: | ||
|
||
- the `jax.debug.print` function, for printing results under JIT. | ||
- the `jax.debug.breakpoint` function, for opening a debugger under JIT. | ||
- the `JAX_DEBUG_NANS=1` environment variable, for halting the computation once a NaN is encountered. This works best for NaNs encountered on the forward pass and outside of loops. If your NaN occurs on the backward pass only, then try [`equinox.debug.backward_nan`][] below. If the NaN occurs inside of a loop, then consider pairing this with `JAX_DISABLE_JIT=1`. (Many loops are implicitly jit'd.) | ||
- the `JAX_DISABLE_JIT=1` environment variable, for running the computation without JIT. This will be *much* slower, so this isn't always practical. | ||
- the `JAX_TRACEBACK_FILTERING=off` environment variable, which means errors and debuggers will include JAX and Equinox internals. (Which by default are filtered out.) | ||
|
||
## Equinox tools | ||
|
||
::: equinox.debug.announce_transform | ||
|
||
--- | ||
|
||
::: equinox.debug.backward_nan | ||
|
||
--- | ||
|
||
::: equinox.debug.breakpoint_if | ||
|
||
--- | ||
|
||
::: equinox.debug.store_dce | ||
|
||
::: equinox.debug.inspect_dce | ||
|
||
--- | ||
|
||
::: equinox.debug.assert_max_traces | ||
|
||
::: equinox.debug.get_num_traces |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
# Enumerations | ||
|
||
::: equinox.Enumeration | ||
selection: | ||
members: | ||
- where | ||
- promote |
Oops, something went wrong.