Skip to content

Commit

Permalink
Switch to ruff-format and ruff for ipynb
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Dec 10, 2023
1 parent 8434401 commit 8465ec2
Show file tree
Hide file tree
Showing 58 changed files with 135 additions and 176 deletions.
25 changes: 8 additions & 17 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,23 +1,14 @@
repos:
- repo: https://github.com/ambv/black
rev: 22.3.0
hooks:
- id: black
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: 'v0.0.255'
hooks:
- id: ruff
args: ["--fix"]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.7
hooks:
- id: ruff # linter
types_or: [ python, pyi, jupyter ]
args: [ --fix ]
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.315
hooks:
- id: pyright
additional_dependencies: [beartype, einops, jax, jaxtyping, optax, pytest, tensorflow, tf2onnx, typing_extensions]
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.6.3
hooks:
- id: nbqa-black
additional_dependencies: [ipython==8.12, black]
- id: nbqa-ruff
args: ["--ignore=I001"]
additional_dependencies: [ipython==8.12, ruff]
12 changes: 8 additions & 4 deletions equinox/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,8 @@ def __get__(self, instance, owner):

@overload
def filter_value_and_grad(
*, has_aux: Literal[False] = False
*,
has_aux: Literal[False] = False,
) -> Callable[[Callable[_P, _ScalarTy]], Callable[_P, tuple[_ScalarTy, PyTree]]]:
...

Expand All @@ -126,7 +127,8 @@ def filter_value_and_grad(

@overload
def filter_value_and_grad(
*, has_aux: Literal[True] = True
*,
has_aux: Literal[True] = True,
) -> Callable[
[Callable[_P, tuple[_ScalarTy, _T]]],
Callable[_P, tuple[tuple[_ScalarTy, _T], PyTree]],
Expand Down Expand Up @@ -193,7 +195,8 @@ def filter_value_and_grad(

@overload
def filter_grad(
*, has_aux: Literal[False] = False
*,
has_aux: Literal[False] = False,
) -> Callable[[Callable[_P, _Scalar]], Callable[_P, PyTree[Float[Array, "..."]]]]:
...

Expand All @@ -207,7 +210,8 @@ def filter_grad(

@overload
def filter_grad(
*, has_aux: Literal[True] = True
*,
has_aux: Literal[True] = True,
) -> Callable[
[Callable[_P, tuple[_Scalar, _T]]],
Callable[_P, tuple[PyTree[Float[Array, "..."]], _T]],
Expand Down
6 changes: 2 additions & 4 deletions equinox/_make_jaxpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,16 +40,14 @@ def _fn(*_dynamic_flat):
_out_dynamic, _out_static = partition(_out, is_array)
return _out_dynamic, Static(_out_static)

jaxpr, out_struct = jax.make_jaxpr(_fn, return_shape=True)(
*dynamic_flat
) # pyright: ignore
jaxpr, out_struct = jax.make_jaxpr(_fn, return_shape=True)(*dynamic_flat) # pyright: ignore
dynamic_out_struct, static_out = out_struct
static_out = static_out.value
return jaxpr, dynamic_out_struct, static_out


def filter_make_jaxpr(
fun: Callable[_P, Any]
fun: Callable[_P, Any],
) -> Callable[
_P, tuple[jax.core.ClosedJaxpr, PyTree[jax.ShapeDtypeStruct], PyTree[Any]]
]:
Expand Down
12 changes: 3 additions & 9 deletions equinox/_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,12 +421,8 @@ def __init__(self, *args, **kwargs):
# [Step 6] Register as a pytree.
jtu.register_pytree_with_keys(
cls,
flatten_with_keys=ft.partial(
_flatten_module, with_keys=True
), # pyright: ignore
flatten_func=ft.partial(
_flatten_module, with_keys=False
), # pyright: ignore
flatten_with_keys=ft.partial(_flatten_module, with_keys=True), # pyright: ignore
flatten_func=ft.partial(_flatten_module, with_keys=False), # pyright: ignore
unflatten_func=ft.partial(_unflatten_module, cls), # pyright: ignore
)
# Done!
Expand Down Expand Up @@ -584,9 +580,7 @@ def _make_initable(cls: _ModuleMeta, init, post_init, wraps: bool) -> _ModuleMet
if wraps:
field_names = _wrapper_field_names
else:
field_names = {
field.name for field in dataclasses.fields(cls) # pyright: ignore
}
field_names = {field.name for field in dataclasses.fields(cls)} # pyright: ignore

class _InitableModule(cls, _Initable): # pyright: ignore
pass
Expand Down
8 changes: 2 additions & 6 deletions equinox/_unvmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,7 @@ def _unvmap_all_impl(x):


def _unvmap_all_abstract_eval(x):
return jax.core.ShapedArray(
shape=(), dtype=jax.numpy.bool_.dtype # pyright: ignore
)
return jax.core.ShapedArray(shape=(), dtype=jax.numpy.bool_.dtype) # pyright: ignore


def _unvmap_all_batch(x, batch_axes):
Expand Down Expand Up @@ -56,9 +54,7 @@ def _unvmap_any_impl(x):


def _unvmap_any_abstract_eval(x):
return jax.core.ShapedArray(
shape=(), dtype=jax.numpy.bool_.dtype # pyright: ignore
)
return jax.core.ShapedArray(shape=(), dtype=jax.numpy.bool_.dtype) # pyright: ignore


def _unvmap_any_batch(x, batch_axes):
Expand Down
3 changes: 2 additions & 1 deletion equinox/debug/_max_traces.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def assert_max_traces(

@overload
def assert_max_traces(
*, max_traces: Optional[int]
*,
max_traces: Optional[int],
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
...

Expand Down
4 changes: 2 additions & 2 deletions equinox/internal/_loop/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def _maybe_set_transpose(
i_static,
i_treedef,
kwargs,
makes_false_steps
makes_false_steps,
):
assert not ad.is_undefined_primal(pred)
for z in i_dynamic_leaves:
Expand Down Expand Up @@ -242,7 +242,7 @@ def _maybe_set(pred, xs, x, i, *, kwargs, makes_false_steps):
i_static=i_static,
i_treedef=i_treedef,
kwargs=kwargs,
makes_false_steps=makes_false_steps
makes_false_steps=makes_false_steps,
)
return out

Expand Down
4 changes: 1 addition & 3 deletions equinox/internal/_noinline.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,9 +403,7 @@ def __call__(self, *args, **kwargs):
)


def noinline(
fn: Callable, abstract_fn: Optional[Callable] = None # pyright: ignore
) -> Callable:
def noinline(fn: Callable, abstract_fn: Optional[Callable] = None) -> Callable: # pyright: ignore
"""Marks a function as not being inlined into a larger computation.
This can help to reduce compile time at the expense of increased runtime.
Expand Down
6 changes: 3 additions & 3 deletions equinox/internal/_omega.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def __rev(x, y):
return __rev


for (name, op) in [
for name, op in [
("__add__", operator.add),
("__sub__", operator.sub),
("__mul__", operator.mul),
Expand Down Expand Up @@ -169,7 +169,7 @@ def __rev(x, y):
_set_binary(ω, name, op)


for (name, op) in [
for name, op in [
("__neg__", operator.neg),
("__pos__", operator.pos),
("__abs__", operator.abs),
Expand Down Expand Up @@ -225,7 +225,7 @@ def fn(self, other):
setattr(base, name, fn)


for (name, op) in [
for name, op in [
("set", lambda x, y, z, **kwargs: x.at[y].set(z, **kwargs)),
("add", lambda x, y, z, **kwargs: x.at[y].add(z, **kwargs)),
("multiply", lambda x, y, z, **kwargs: x.at[y].multiply(z, **kwargs)),
Expand Down
2 changes: 1 addition & 1 deletion equinox/internal/_primitive.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,7 @@ def batch_rule(axis_size, axis_name, trace_type, inputs, batch_axes, **params):
__axis_size=axis_size,
__axis_name=axis_name,
__batch_axes=batch_axes,
params=params
params=params,
)
batch_axes_out = jtu.tree_map(lambda _: 0, out)
return out, batch_axes_out
Expand Down
4 changes: 2 additions & 2 deletions equinox/nn/_dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(
p: float = 0.5,
inference: bool = False,
*,
deterministic: Optional[bool] = None
deterministic: Optional[bool] = None,
):
"""**Arguments:**
Expand Down Expand Up @@ -61,7 +61,7 @@ def __call__(
*,
key: Optional[PRNGKeyArray] = None,
inference: Optional[bool] = None,
deterministic: Optional[bool] = None
deterministic: Optional[bool] = None,
) -> Array:
"""**Arguments:**
Expand Down
2 changes: 1 addition & 1 deletion equinox/nn/_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def __init__(
out_features: Union[int, Literal["scalar"]],
use_bias: bool = True,
*,
key: PRNGKeyArray
key: PRNGKeyArray,
):
"""**Arguments:**
Expand Down
2 changes: 1 addition & 1 deletion equinox/nn/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# the `else` branch as well:
# https://github.com/microsoft/pyright/issues/3450
def all_sequences(
x: Union[Sequence[Any], Sequence[_T]]
x: Union[Sequence[Any], Sequence[_T]],
) -> "te.StrictTypeGuard[Sequence[_T]]":
...

Expand Down
4 changes: 2 additions & 2 deletions equinox/nn/_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
use_bias: bool = True,
*,
key: PRNGKeyArray,
**kwargs
**kwargs,
):
"""**Arguments:**
Expand Down Expand Up @@ -151,7 +151,7 @@ def __init__(
use_bias: bool = True,
*,
key: PRNGKeyArray,
**kwargs
**kwargs,
):
"""**Arguments:**
Expand Down
4 changes: 2 additions & 2 deletions equinox/nn/_spectral_norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def __init__(
inference: bool = False,
*,
key: PRNGKeyArray,
**kwargs
**kwargs,
):
"""**Arguments:**
Expand Down Expand Up @@ -112,7 +112,7 @@ def __call__(
state: State,
*,
key: Optional[PRNGKeyArray] = None,
inference: Optional[bool] = None
inference: Optional[bool] = None,
) -> tuple[Array, State]:
"""**Arguments:**
Expand Down
5 changes: 2 additions & 3 deletions examples/bert.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,15 @@
"from typing import Dict, List, Mapping, Optional\n",
"\n",
"import einops # https://github.com/arogozhnikov/einops\n",
"import equinox as eqx\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import numpy as np\n",
"import optax # https://github.com/deepmind/optax\n",
"from datasets import load_dataset # https://github.com/huggingface/datasets\n",
"from jaxtyping import Array, Float, Int # https://github.com/google/jaxtyping\n",
"from tqdm import notebook as tqdm # https://github.com/tqdm/tqdm\n",
"from transformers import AutoTokenizer # https://github.com/huggingface/transformers\n",
"\n",
"import equinox as eqx"
"from transformers import AutoTokenizer # https://github.com/huggingface/transformers"
]
},
{
Expand Down
25 changes: 15 additions & 10 deletions examples/deep_convolutional_gan.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -40,18 +40,17 @@
"from collections.abc import Callable\n",
"from typing import Union\n",
"\n",
"import equinox as eqx\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax.random as jr\n",
"import matplotlib.pyplot as plt\n",
"import optax\n",
"\n",
"# We'll use PyTorch to load the dataset.\n",
"import torch\n",
"import torchvision\n",
"import torchvision.transforms as transforms\n",
"import matplotlib.pyplot as plt\n",
"\n",
"import equinox as eqx"
"import torchvision.transforms as transforms"
]
},
{
Expand Down Expand Up @@ -416,9 +415,12 @@
" real_labels = jnp.ones(batch_size)\n",
"\n",
" (\n",
" loss,\n",
" (discriminator_state, generator_state, key),\n",
" ), grads = compute_grads_discriminator(\n",
" (\n",
" loss,\n",
" (discriminator_state, generator_state, key),\n",
" ),\n",
" grads,\n",
" ) = compute_grads_discriminator(\n",
" discriminator,\n",
" generator,\n",
" fake_labels,\n",
Expand Down Expand Up @@ -450,9 +452,12 @@
" real_labels = jnp.ones(batch_size)\n",
"\n",
" (\n",
" loss,\n",
" (discriminator_state, generator_state, key),\n",
" ), grads = compute_grads_generator(\n",
" (\n",
" loss,\n",
" (discriminator_state, generator_state, key),\n",
" ),\n",
" grads,\n",
" ) = compute_grads_generator(\n",
" generator, discriminator, real_labels, discriminator_state, generator_state, key\n",
" )\n",
"\n",
Expand Down
5 changes: 2 additions & 3 deletions examples/frozen_layer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,12 @@
"metadata": {},
"outputs": [],
"source": [
"import equinox as eqx\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import jax.random as jrandom\n",
"import jax.tree_util as jtu\n",
"import optax # https://github.com/deepmind/optax\n",
"\n",
"import equinox as eqx"
"import optax # https://github.com/deepmind/optax"
]
},
{
Expand Down
5 changes: 2 additions & 3 deletions examples/mnist.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,13 @@
},
"outputs": [],
"source": [
"import equinox as eqx\n",
"import jax\n",
"import jax.numpy as jnp\n",
"import optax # https://github.com/deepmind/optax\n",
"import torch # https://pytorch.org\n",
"import torchvision # https://pytorch.org\n",
"from jaxtyping import Array, Float, Int, PyTree # https://github.com/google/jaxtyping\n",
"\n",
"import equinox as eqx"
"from jaxtyping import Array, Float, Int, PyTree # https://github.com/google/jaxtyping"
]
},
{
Expand Down
1 change: 1 addition & 0 deletions examples/parallelism.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"import numpy as np\n",
"import optax # https://github.com/deepmind/optax\n",
"\n",
"\n",
"# Hyperparameters\n",
"dataset_size = 64\n",
"channel_size = 4\n",
Expand Down
Loading

0 comments on commit 8465ec2

Please sign in to comment.