Skip to content

Commit

Permalink
Restructure of AE models
Browse files Browse the repository at this point in the history
  • Loading branch information
crstngc committed Jan 26, 2025
1 parent e2973d1 commit 0003bbd
Showing 1 changed file with 68 additions and 88 deletions.
156 changes: 68 additions & 88 deletions scico/flax/autoencoders/autoencoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,16 @@ class AE(nn.Module):
Args:
encoder: Encoder module in Flax.
decoder: Decoder module in Flax.
latent_dim: Latent dimension of encoder.
channels: Number of channels of signal to decode.
dtype: Output dtype. Default: :attr:`~numpy.float32`.
"""

encoder: Callable
decoder: Callable

def setup(self):
"""Setup of encoder and decoder modules for autoencoder (AE)."""
nn.share_scope(self, self.encoder)
nn.share_scope(self, self.decoder)
latent_dim: int
channels: int
dtype: Any = jnp.float32

def encode(self, x: ArrayLike) -> ArrayLike:
"""Apply encoder module.
Expand Down Expand Up @@ -110,8 +111,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike:
Returns:
The encoded array.
"""
x = x.reshape((x.shape[0], -1))
x = MLP(self.encoder_widths, self.activation_fn, activate_final=True)(x)
x = MLP(self.encoder_widths, self.activation_fn, activate_final=True, flatten_first=True)(x)
x = nn.Dense(self.latent_dim)(x)
return x

Expand Down Expand Up @@ -159,57 +159,48 @@ def __call__(self, x: ArrayLike) -> ArrayLike:
return x


class DenseAE(AE):
"""Definition of autoencoder network using multi layer perceptron
(MLP), i.e. dense layers.
def DenseAE(
out_shape: Tuple[int],
channels: int,
encoder_widths: Tuple[int],
latent_dim: int,
decoder_widths: Tuple[int],
activation_fn: Callable = nn.leaky_relu,
):
"""Function to construct autoencoder network using multi layer
perceptron (MLP), i.e. dense layers.
Output is reshaped to given output shape via a properly sized layer
added automatically to the tuple of the decoder widths.
Args:
out_shape: Tuple (height, width, channels) of signal to decode
out_shape: Tuple (height, width) of signal to decode
(if reshape requested).
channels: Number of channels of signal to decode.
encoder_widths: List with number of neurons per layer in the
MLP encoder.
latent_dim: Latent dimension of encoder.
decoder_widths: List with number of neurons per layer in the
MLP decoder.
activation_fn: Flax function defining the activation operation
to apply after each layer (except output layer).
dtype: Output dtype. Default: :attr:`~numpy.float32`.
"""
out_shape: Tuple[int]
encoder_widths: Tuple[int]
latent_dim: int
decoder_widths: Tuple[int]
activation_fn: Callable = nn.leaky_relu
dtype: Any = jnp.float32

@nn.compact
def __call__(self, x: ArrayLike) -> Tuple[ArrayLike, ArrayLike]:
"""Apply sequence of encoder and decoder modules.
Args:
x: The array to be autoencoded.
Returns:
The output of the autoencoder module and the encoded
representation.
"""
encoder = DenseEncoder(
self.encoder_widths,
self.latent_dim,
self.activation_fn,
)
Returns:
Autoencoder model with the specified architecture.
"""
encoder = DenseEncoder(
encoder_widths,
latent_dim,
activation_fn,
)

decoder = DenseDecoder(
self.out_shape,
self.decoder_widths,
self.activation_fn,
reshape_final=True,
)
return AE(encoder, decoder)(x)
decoder = DenseDecoder(
out_shape + (channels,),
decoder_widths,
activation_fn,
reshape_final=True,
)
return AE(encoder, decoder, latent_dim, channels)


class ConvEncoder(nn.Module):
Expand Down Expand Up @@ -308,8 +299,20 @@ def __call__(self, x: ArrayLike) -> ArrayLike:
return x


class ConvAE(AE):
"""Definition of autoencoder network using convolutional layers.
def ConvAE(
out_shape: Tuple[int],
channels: int,
encoder_filters: Sequence[int],
latent_dim: int,
decoder_filters: Sequence[int],
encoder_kernel_size: Tuple[int, int] = (3, 3),
encoder_strides: Tuple[int, int] = (1, 1),
encoder_activation_fn: Callable = nn.leaky_relu,
decoder_kernel_size: Tuple[int, int] = (3, 3),
decoder_strides: Tuple[int, int] = (1, 1),
decoder_activation_fn: Callable = nn.leaky_relu,
):
"""Function to construct autoencoder network using convolutional layers.
Args:
out_shape: Tuple (height, width) of signal to decode.
Expand All @@ -333,47 +336,24 @@ class ConvAE(AE):
decoder_activation_fn: Flax function defining the activation
operation to apply after each layer in decoder (except
output layer).
dtype: Output dtype. Default: :attr:`~numpy.float32`.
"""
out_shape: Tuple[int]
channels: int
encoder_filters: Sequence[int]
latent_dim: int
decoder_filters: Sequence[int]
encoder_kernel_size: Tuple[int, int] = (3, 3)
encoder_strides: Tuple[int, int] = (1, 1)
encoder_activation_fn: Callable = nn.leaky_relu
decoder_kernel_size: Tuple[int, int] = (3, 3)
decoder_strides: Tuple[int, int] = (1, 1)
decoder_activation_fn: Callable = nn.leaky_relu
dtype: Any = jnp.float32

@nn.compact
def __call__(self, x: ArrayLike) -> Tuple[ArrayLike, ArrayLike]:
"""Apply sequence of encoder and decoder modules.
Args:
x: The array to be autoencoded.
Returns:
The output of the autoencoder module and the encoded
representation.
"""
encoder = ConvEncoder(
self.encoder_filters,
self.latent_dim,
self.encoder_kernel_size,
self.encoder_strides,
activation_fn=self.encoder_activation_fn,
)

decoder = ConvDecoder(
self.out_shape,
self.channels,
self.decoder_filters,
self.decoder_kernel_size,
self.decoder_strides,
activation_fn=self.decoder_activation_fn,
)
return AE(encoder, decoder)(x)
Returns:
Autoencoder model with the specified architecture.
"""
encoder = ConvEncoder(
encoder_filters,
latent_dim,
encoder_kernel_size,
encoder_strides,
activation_fn=encoder_activation_fn,
)

decoder = ConvDecoder(
out_shape,
channels,
decoder_filters,
decoder_kernel_size,
decoder_strides,
activation_fn=decoder_activation_fn,
)
return AE(encoder, decoder, latent_dim, channels)

0 comments on commit 0003bbd

Please sign in to comment.