diff --git a/scico/flax/autoencoders/autoencoders.py b/scico/flax/autoencoders/autoencoders.py index 590cc25f..65f5440d 100644 --- a/scico/flax/autoencoders/autoencoders.py +++ b/scico/flax/autoencoders/autoencoders.py @@ -34,15 +34,16 @@ class AE(nn.Module): Args: encoder: Encoder module in Flax. decoder: Decoder module in Flax. + latent_dim: Latent dimension of encoder. + channels: Number of channels of signal to decode. + dtype: Output dtype. Default: :attr:`~numpy.float32`. """ encoder: Callable decoder: Callable - - def setup(self): - """Setup of encoder and decoder modules for autoencoder (AE).""" - nn.share_scope(self, self.encoder) - nn.share_scope(self, self.decoder) + latent_dim: int + channels: int + dtype: Any = jnp.float32 def encode(self, x: ArrayLike) -> ArrayLike: """Apply encoder module. @@ -110,8 +111,7 @@ def __call__(self, x: ArrayLike) -> ArrayLike: Returns: The encoded array. """ - x = x.reshape((x.shape[0], -1)) - x = MLP(self.encoder_widths, self.activation_fn, activate_final=True)(x) + x = MLP(self.encoder_widths, self.activation_fn, activate_final=True, flatten_first=True)(x) x = nn.Dense(self.latent_dim)(x) return x @@ -159,16 +159,24 @@ def __call__(self, x: ArrayLike) -> ArrayLike: return x -class DenseAE(AE): - """Definition of autoencoder network using multi layer perceptron - (MLP), i.e. dense layers. +def DenseAE( + out_shape: Tuple[int], + channels: int, + encoder_widths: Tuple[int], + latent_dim: int, + decoder_widths: Tuple[int], + activation_fn: Callable = nn.leaky_relu, +): + """Function to construct autoencoder network using multi layer + perceptron (MLP), i.e. dense layers. Output is reshaped to given output shape via a properly sized layer added automatically to the tuple of the decoder widths. Args: - out_shape: Tuple (height, width, channels) of signal to decode + out_shape: Tuple (height, width) of signal to decode (if reshape requested). + channels: Number of channels of signal to decode. encoder_widths: List with number of neurons per layer in the MLP encoder. latent_dim: Latent dimension of encoder. @@ -176,40 +184,23 @@ class DenseAE(AE): MLP decoder. activation_fn: Flax function defining the activation operation to apply after each layer (except output layer). - dtype: Output dtype. Default: :attr:`~numpy.float32`. - """ - out_shape: Tuple[int] - encoder_widths: Tuple[int] - latent_dim: int - decoder_widths: Tuple[int] - activation_fn: Callable = nn.leaky_relu - dtype: Any = jnp.float32 - - @nn.compact - def __call__(self, x: ArrayLike) -> Tuple[ArrayLike, ArrayLike]: - """Apply sequence of encoder and decoder modules. - - Args: - x: The array to be autoencoded. - - Returns: - The output of the autoencoder module and the encoded - representation. - """ - encoder = DenseEncoder( - self.encoder_widths, - self.latent_dim, - self.activation_fn, - ) + Returns: + Autoencoder model with the specified architecture. + """ + encoder = DenseEncoder( + encoder_widths, + latent_dim, + activation_fn, + ) - decoder = DenseDecoder( - self.out_shape, - self.decoder_widths, - self.activation_fn, - reshape_final=True, - ) - return AE(encoder, decoder)(x) + decoder = DenseDecoder( + out_shape + (channels,), + decoder_widths, + activation_fn, + reshape_final=True, + ) + return AE(encoder, decoder, latent_dim, channels) class ConvEncoder(nn.Module): @@ -308,8 +299,20 @@ def __call__(self, x: ArrayLike) -> ArrayLike: return x -class ConvAE(AE): - """Definition of autoencoder network using convolutional layers. +def ConvAE( + out_shape: Tuple[int], + channels: int, + encoder_filters: Sequence[int], + latent_dim: int, + decoder_filters: Sequence[int], + encoder_kernel_size: Tuple[int, int] = (3, 3), + encoder_strides: Tuple[int, int] = (1, 1), + encoder_activation_fn: Callable = nn.leaky_relu, + decoder_kernel_size: Tuple[int, int] = (3, 3), + decoder_strides: Tuple[int, int] = (1, 1), + decoder_activation_fn: Callable = nn.leaky_relu, +): + """Function to construct autoencoder network using convolutional layers. Args: out_shape: Tuple (height, width) of signal to decode. @@ -333,47 +336,24 @@ class ConvAE(AE): decoder_activation_fn: Flax function defining the activation operation to apply after each layer in decoder (except output layer). - dtype: Output dtype. Default: :attr:`~numpy.float32`. - """ - out_shape: Tuple[int] - channels: int - encoder_filters: Sequence[int] - latent_dim: int - decoder_filters: Sequence[int] - encoder_kernel_size: Tuple[int, int] = (3, 3) - encoder_strides: Tuple[int, int] = (1, 1) - encoder_activation_fn: Callable = nn.leaky_relu - decoder_kernel_size: Tuple[int, int] = (3, 3) - decoder_strides: Tuple[int, int] = (1, 1) - decoder_activation_fn: Callable = nn.leaky_relu - dtype: Any = jnp.float32 - - @nn.compact - def __call__(self, x: ArrayLike) -> Tuple[ArrayLike, ArrayLike]: - """Apply sequence of encoder and decoder modules. - - Args: - x: The array to be autoencoded. - - Returns: - The output of the autoencoder module and the encoded - representation. - """ - encoder = ConvEncoder( - self.encoder_filters, - self.latent_dim, - self.encoder_kernel_size, - self.encoder_strides, - activation_fn=self.encoder_activation_fn, - ) - - decoder = ConvDecoder( - self.out_shape, - self.channels, - self.decoder_filters, - self.decoder_kernel_size, - self.decoder_strides, - activation_fn=self.decoder_activation_fn, - ) - return AE(encoder, decoder)(x) + Returns: + Autoencoder model with the specified architecture. + """ + encoder = ConvEncoder( + encoder_filters, + latent_dim, + encoder_kernel_size, + encoder_strides, + activation_fn=encoder_activation_fn, + ) + + decoder = ConvDecoder( + out_shape, + channels, + decoder_filters, + decoder_kernel_size, + decoder_strides, + activation_fn=decoder_activation_fn, + ) + return AE(encoder, decoder, latent_dim, channels)