Skip to content

Commit

Permalink
fix_broken_links
Browse files Browse the repository at this point in the history
  • Loading branch information
matteoguarrera committed Jun 13, 2024
1 parent fdf2b7c commit edc36ff
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/faq.md
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ eqx.filter_jit(loss_fn)(model, x, y) # ok

This error happens because a model, when treated as a PyTree, may have leaves that are not JAX types (such as functions). It only makes sense to trace arrays. Filtering is used to handle this.

Instead of [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html), use [`equinox.filter_jit`][]. Likewise for [other transformations](https://docs.kidger.site/equinox/api/filtering/transformations).
Instead of [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html), use [`equinox.filter_jit`][]. Likewise for [other transformations](https://docs.kidger.site/equinox/api/transformations/).

## How to mark arrays as non-trainable? (Like PyTorch's buffers?)

Expand Down
2 changes: 1 addition & 1 deletion equinox/_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def filter_shard(
A copy of `x` with the specified sharding constraints.
!!! Example
See also the [autoparallelism example](../../../examples/parallelism).
See also the [autoparallelism example](../../examples/parallelism).
"""
if isinstance(device_or_shardings, Device):
shardings = jax.sharding.SingleDeviceSharding(device_or_shardings)
Expand Down
2 changes: 1 addition & 1 deletion equinox/_update.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def apply_updates(model: PyTree, updates: PyTree) -> PyTree:
This is often useful when updating a model's parameters via stochastic gradient
descent. (This function is essentially the same as `optax.apply_updates`, except
that it understands `None`.) For example see the
[Train RNN example](../../../examples/train_rnn/).
[Train RNN example](../../examples/train_rnn/).
**Arguments:**
Expand Down
2 changes: 1 addition & 1 deletion examples/serialisation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
"source": [
"# Serialising both weights and hyperparameters\n",
"\n",
"Equinox has [facilities](/equinox/api/utilities/serialisation/) for the serialisation of the leaves of arbitrary PyTrees. The most basic use is to call `eqx.tree_serialise_leaves(filename, model)` to write all weights to a file. Deserialisation requires a PyTree of the correct shape to serve as a \"skeleton\" of sorts, whose weights are then read from the file with `model = eqx.tree_deserialise_leaves(filename, skeleton)`.\n",
"Equinox has [facilities](https://docs.kidger.site/equinox/api/serialisation/) for the serialisation of the leaves of arbitrary PyTrees. The most basic use is to call `eqx.tree_serialise_leaves(filename, model)` to write all weights to a file. Deserialisation requires a PyTree of the correct shape to serve as a \"skeleton\" of sorts, whose weights are then read from the file with `model = eqx.tree_deserialise_leaves(filename, skeleton)`.\n",
"\n",
"However, a typical model has both weights (arrays stored as leaves in the PyTree) and hyperparameters (the size of the network, etc.). When deserialising, we would like to read the hyperparameters as well as the weights. Ideally they should be stored in the same file. We can accomplish this as follows."
]
Expand Down

0 comments on commit edc36ff

Please sign in to comment.