Skip to content

Commit

Permalink
Tweaked documentation for new padding options to render correctly.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Feb 20, 2024
1 parent 23d983e commit 418777c
Showing 1 changed file with 75 additions and 44 deletions.
119 changes: 75 additions & 44 deletions equinox/nn/_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,29 +117,36 @@ def __init__(
`in_channels` must be divisible by `groups`.
- `use_bias`: Whether to add on a bias after the convolution.
- `padding_mode`: One of the following strings specifying the padding values.
`'ZEROS'` (default): pads with zeros, 1234 -> 00123400.
`'REFLECT'`: pads with the reflection on boundary, 1234 -> 32123432.
`'REPLICATE'`: pads with the replication of edge values, 1234 -> 11123444.
`'CIRCULAR'`: pads with circular values, 1234 -> 34123412.
- `'ZEROS'` (default): pads with zeros, `1234 -> 00123400`.
- `'REFLECT'`: pads with the reflection on boundary, `1234 -> 32123432`.
- `'REPLICATE'`: pads with the replication of edge values,
`1234 -> 11123444`.
- `'CIRCULAR'`: pads with circular values, `1234 -> 34123412`.
- `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
initialisation. (Keyword only argument.)
!!! info
All of `kernel_size`, `stride`, `padding`, `dilation` can be either an
integer or a sequence of integers. If they are a sequence then the sequence
should be of length equal to `num_spatial_dims`, and specify the value of
each property down each spatial dimension in turn.
integer or a sequence of integers.
If they are an integer then the same kernel size / stride / padding /
dilation will be used along every spatial dimension.
`padding` can alternatively be a sequence of 2-element tuples, each
representing the padding to apply before and after each spatial dimension.
`padding` can also be a string `'SAME'`, `'SAME_LOWER'`, or `'VALID'` (see
[jax.lax.conv_general_dilated](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html)
for details).
If they are a sequence then the sequence should be of length equal to
`num_spatial_dims`, and specify the value of each property down each spatial
dimension in turn.
In addition, `padding` can be:
- a sequence of 2-element tuples, each representing the padding to apply
before and after each spatial dimension.
- the string `'VALID'`, which is the same as zero padding.
- one of the strings `'SAME'` or `'SAME_LOWER'`. This will apply padding to
produce an output with the same size spatial dimensions as the input.
The padding is split between the two sides equally or almost equally. In
case the padding is an odd number, then the extra padding is added at
the end for `'SAME'` and at the beginning for `'SAME_LOWER'`.
"""
wkey, bkey = jrandom.split(key, 2)

Expand Down Expand Up @@ -189,7 +196,7 @@ def _nonzero_pad(self, x: Array) -> Array:
d * (k - 1) + 1 for k, d in zip(self.kernel_size, self.dilation)
)
padding = lax.padtype_to_pads(
x.shape[2:], rhs_shape, self.stride, self.padding
x.shape[1:], rhs_shape, self.stride, self.padding
)
else:
padding = list(self.padding)
Expand All @@ -203,7 +210,7 @@ def _nonzero_pad(self, x: Array) -> Array:
else:
raise ValueError("Invalid padding mode")

x = jnp.pad(x, [(0, 0), (0, 0)] + padding, mode)
x = jnp.pad(x, [(0, 0)] + padding, mode)
return x

@jax.named_scope("eqx.nn.Conv")
Expand All @@ -226,13 +233,15 @@ def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
f"Input to `Conv` needs to have rank {unbatched_rank},",
f" but input has shape {x.shape}.",
)
x = jnp.expand_dims(x, axis=0)
_, *input_shape = x.shape

if self.padding_mode != "ZEROS":
x = self._nonzero_pad(x)
padding = "VALID"
else:
padding = self.padding

x = jnp.expand_dims(x, axis=0)
x = lax.conv_general_dilated(
lhs=x,
rhs=self.weight,
Expand All @@ -242,6 +251,10 @@ def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
feature_group_count=self.groups,
)
x = jnp.squeeze(x, axis=0)

if self.padding in ("SAME", "SAME_LOWER"):
assert tuple(x.shape[1:]) == tuple(input_shape)

if self.use_bias:
x = x + self.bias
return x
Expand Down Expand Up @@ -397,28 +410,34 @@ def __init__(
- `use_bias`: Whether to add on a bias after the transposed convolution.
- `padding_mode`: One of the following strings specifying the padding values
used on the equivalent [`equinox.nn.Conv`][].
`'ZEROS'` (default): pads with zeros, no extra connectivity.
`'CIRCULAR'`: pads with circular values, extra connectivity (see Tip).
- `'ZEROS'` (default): pads with zeros, no extra connectivity.
- `'CIRCULAR'`: pads with circular values, extra connectivity (see the Tip
below).
- `key`: A `jax.random.PRNGKey` used to provide randomness for parameter
initialisation. (Keyword only argument.)
!!! info
All of `kernel_size`, `stride`, `padding`, `output_padding`, `dilation` can
be either an integer or a sequence of integers. If they are a sequence then
the sequence should be of length equal to `num_spatial_dims`, and specify
the value of each property down each spatial dimension in turn.
All of `kernel_size`, `stride`, `padding`, `dilation` can be either an
integer or a sequence of integers.
If they are an integer then the same kernel size / stride / padding /
dilation will be used along every spatial dimension.
`padding` can alternatively be a sequence of 2-element tuples, each
representing the padding to apply before and after each spatial dimension.
`padding` can also be a string `'SAME'`, `'SAME_LOWER'`, or `'VALID'` (see
[jax.lax.conv_general_dilated](https://jax.readthedocs.io/en/latest/_autosummary/jax.lax.conv_general_dilated.html)
for details).
If they are a sequence then the sequence should be of length equal to
`num_spatial_dims`, and specify the value of each property down each spatial
dimension in turn.
In addition, `padding` can be:
- a sequence of 2-element tuples, each representing the padding to apply
before and after each spatial dimension.
- the string `'VALID'`, which is the same as zero padding.
- one of the strings `'SAME'` or `'SAME_LOWER'`. This will apply padding to
produce an output with the same size spatial dimensions as the input.
The padding is split between the two sides equally or almost equally. In
case the padding is an odd number, then the extra padding is added at
the end for `'SAME'` and at the beginning for `'SAME_LOWER'`.
!!! tip
Expand All @@ -432,24 +451,32 @@ def __init__(
switching the values of `in_channels` and `out_channels`, whilst keeping
`kernel_size`, `stride`, `padding`, `dilation`, and `groups` the same.
When `stride > 1` then [`equinox.nn.Conv`][] maps multiple input shapes to the
same output shape. `output_padding` is provided to resolve this ambiguity.
For `'SAME'` or `'SAME_LOWER'` padding, it reduces the calculated input shape
In other cases, it adds a little extra padding to just the bottom/right
edges of the input. See [this discussion](https://github.com/patrick-kidger/equinox/issues/638)
for details.
`padding_mode = 'CIRCULAR'` is only implemented for `output_padding = 0` and
`padding = 'SAME'` or `'SMAE_LOWER'`. Extra connectivity created in 'CIRCULAR'
padding is taken into account. For instance, consider the equivalent
[`equinox.nn.Conv`][] with kernel size 3,\n
Input 1234 --(zero padding)--> 012340 --(conv)--> Output abcd \n
Input 1234 --(circular padding)--> 412341 --(conv)--> Output abcd \n
then a is connected with 1, 2 for zero padding, while connected with 1, 2, 4
for circular padding.
When `stride > 1` then [`equinox.nn.Conv`][] maps multiple input shapes to
the same output shape. `output_padding` is provided to resolve this
ambiguity.
- For `'SAME'` or `'SAME_LOWER'` padding, it reduces the calculated input
shape.
- For other cases, it adds a little extra padding to the bottom or right
edges of the input.
The extra connectivity created in 'CIRCULAR' padding is correctly taken into
account. For instance, consider the equivalent
[`equinox.nn.Conv`][] with kernel size 3. Then:
- `Input 1234 --(zero padding)--> 012340 --(conv)--> Output abcd`
- `Input 1234 --(circular padding)--> 412341 --(conv)--> Output abcd`
so that `a` is connected with `1, 2` for zero padding, while connected with
`1, 2, 4` for circular padding.
See [these animations](https://github.com/vdumoulin/conv_arithmetic/blob/af6f818b0bb396c26da79899554682a8a499101d/README.md#transposed-convolution-animations)
and [this report](https://arxiv.org/abs/1603.07285) for a nice reference.
!!! warning
`padding_mode='CIRCULAR'` is only implemented for `output_padding=0` and
`padding='SAME'` or `'SAME_LOWER'`.
""" # noqa: E501

wkey, bkey = jrandom.split(key, 2)
Expand Down Expand Up @@ -544,7 +571,7 @@ def _circular_pad(
self, x: Array, padding_t: tuple[tuple[int, int], ...]
) -> tuple[Array, tuple[tuple[int, int], ...]]:
stride = np.expand_dims(self.stride, axis=1)
pad_width = np.insert(padding_t // stride, (0, 0), 0, axis=0)
pad_width = np.insert(padding_t // stride, 0, 0, axis=0)
x = jnp.pad(x, pad_width, mode="wrap")
padding_t = tuple((p[0].item(), p[1].item()) for p in padding_t % stride)
return x, padding_t
Expand All @@ -568,11 +595,13 @@ def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
f"Input to `ConvTranspose` needs to have rank {unbatched_rank},",
f" but input has shape {x.shape}.",
)
x = jnp.expand_dims(x, axis=0)
_, *input_shape = x.shape

padding_t = self._padding_transpose()
if self.padding_mode == "CIRCULAR":
x, padding_t = self._circular_pad(x, padding_t)

x = jnp.expand_dims(x, axis=0)
x = lax.conv_general_dilated(
lhs=x,
rhs=self.weight,
Expand All @@ -582,8 +611,10 @@ def __call__(self, x: Array, *, key: Optional[PRNGKeyArray] = None) -> Array:
rhs_dilation=self.dilation,
feature_group_count=self.groups,
)

x = jnp.squeeze(x, axis=0)

if self.padding in ("SAME", "SAME_LOWER"):
assert x.shape[1:] == input_shape
if self.use_bias:
x = x + self.bias
return x
Expand Down

0 comments on commit 418777c

Please sign in to comment.