Skip to content

Commit

Permalink
Update JAX AI Stack Diffusion tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Jan 20, 2025
1 parent 949851f commit e1f7e0a
Showing 1 changed file with 42 additions and 37 deletions.
79 changes: 42 additions & 37 deletions docs/source/digits_diffusion_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,23 @@ kernelspec:

[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/digits_diffusion_model.ipynb)

In [Variational autoencoder (VAE) and debugging in JAX](https://jax-ai-stack.readthedocs.io/en/latest/digits_vae.html), we learned how to train a generative model called VAE on a simple digits dataset. In this tutorial, we will explore how to develop and train a simple diffusion model for image generation and perform inference using JAX, [Flax NNX](http://flax.readthedocs.io) and [Optax](http://optax.readthedocs.io). You will learn how to:
This tutorial will guide you through developing and training a simple diffusion model using the [U-Net architecture](https://en.wikipedia.org/wiki/U-Net) for image generation using JAX, [Flax NNX](http://flax.readthedocs.io) and [Optax](http://optax.readthedocs.io). The previous tutorial, [Variational autoencoder (VAE) and debugging in JAX](https://jax-ai-stack.readthedocs.io/en/latest/digits_vae.html), showed how to train a simpler generative model called VAE. Here, you will learn how to:

- Loading and preprocessing the dataset
- Defining the diffusion model
- Creating the loss and training functions
- Training the model with Google Colab’s Cloud TPU v2
- Visualizing and tracking the model’s progress.
- Load and preprocess the dataset
- Define the diffusion model with Flax
- Create the loss and training functions
- Train the model with Google Colab’s Cloud TPU v2
- Visualize and track the model’s progress

If you are new to JAX for AI, check out the [first tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with Flax, Optax and JAX.
If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with Flax, Optax and JAX.

+++ {"id": "gwaaMmjXt7n7"}

## Setup

JAX installation is covered in [this guide](https://jax.readthedocs.io/en/latest/installation.html) on the JAX documentation site.

Import JAX, JAX NumPy, Flax NNX, Optax, matplotlib and scikit-learn:
Start with importing JAX, JAX NumPy, Flax NNX, Optax, matplotlib and scikit-learn:

```{code-cell}
:id: dVVACvmDuDCM
Expand All @@ -51,7 +51,7 @@ from sklearn.model_selection import train_test_split

+++ {"id": "tQ5KGMyrYG2H"}

**Note:** If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v2 as the hardware accelerator. The output of the cell below will be a list of 8 (eight) devices:
**Note:** If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v2 as the hardware accelerator. The output of the cell below will show a list of 8 (eight) devices:

Check the available JAX devices, or [`jax.Device`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html):

Expand All @@ -62,12 +62,13 @@ colab:
id: ldmtemzPBO5z
outputId: d21720a2-65cd-4a5c-ef86-3a0912e36c34
---
# Check the available JAX devices.
jax.devices()
```

## Loading and preprocessing the data

As before, we will use the small, self-contained [scikit-learn `digits` dataset](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) for ease of experimentation. In addition, we will only use the '1' (one) digits from the dataset:
We'll use the small, self-contained [scikit-learn `digits` dataset](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) for ease of experimentation. This tutorial focuses on only generating the '1' (one) digits.

```{code-cell}
---
Expand All @@ -77,19 +78,23 @@ colab:
id: jNizSH6uuXY4
outputId: 112723a1-fd36-46b2-946d-6d789f5a33ed
---
# Data preprocessing:
# Load and preprocess the `digits` dataset.
digits = load_digits()
images = digits.images[digits.target == 1] # Filter images for only digits '1' (one).
images = images / 16.0 # Normalize pixel values into floating-point arrays in the `[0, 1]` interval.
images = jnp.asarray(images) # Convert to `jax.Array`s.
images = images.reshape(-1, 8, 8, 1) # Reshape to `(num_images, height, width, channels)`.
# Split the dataset:
# Filter for digit '1' (one) images.
images = digits.images[digits.target == 1]
# Normalize pixel values into floating-point arrays in the `[0, 1]` interval.
images = images / 16.0
# Convert to `jax.Array`s.
images = jnp.asarray(images)
# Reshape to `(num_images, height, width, channels)` for input to convolutional layers.
images = images.reshape(-1, 8, 8, 1)
# Split the dataset into training and test sets.
images_train, images_test = train_test_split(images, test_size=0.05, random_state=42)
print(f"Training set size: {images_train.shape[0]}")
print(f"Test set size: {images_test.shape[0]}")
# Visualize sample images:
# Visualize sample images.
fig, axes = plt.subplots(3, 3, figsize=(3, 3))
for i, ax in enumerate(axes.flat):
if i < len(images_train):
Expand All @@ -106,12 +111,12 @@ In this section, we’ll develop various parts of the [diffusion model](https://

### The U-Net architecture

For this example, we’ll use the [U-Net architecture](https://en.wikipedia.org/wiki/U-Net) as the backbone of the diffusion model. The U-Net consists of the following:
For this example, we’ll use the [U-Net architecture](https://en.wikipedia.org/wiki/U-Net), a convolutional neural network architecture, as the backbone of the diffusion model. The U-Net consists of the following:

- An [encoder](https://en.wikipedia.org/wiki/Autoencoder) path with [downsampling](https://en.wikipedia.org/wiki/Downsampling_(signal_processing))
- A bridge with a (self-)[attention mechanism](https://en.wikipedia.org/wiki/Attention_(machine_learning)
- A [decoder](https://en.wikipedia.org/wiki/Autoencoder) path with [upsampling](https://en.wikipedia.org/wiki/Upsampling)
- [Skip connections](https://en.wikipedia.org/wiki/Residual_neural_network#Residual_connection) between the encoder and the decoder
- An [encoder](https://en.wikipedia.org/wiki/Autoencoder) path that [downsamples](https://en.wikipedia.org/wiki/Downsampling_(signal_processing)) the input image, extracting features.
- A bridge with a (self-)[attention mechanism](https://en.wikipedia.org/wiki/Attention_(machine_learning) that connects the encoder with the decoder.
- A [decoder](https://en.wikipedia.org/wiki/Autoencoder) path that [upsamples](https://en.wikipedia.org/wiki/Upsampling) the feature representations learned by the encoder, reconstructing the output image.
- [Skip connections](https://en.wikipedia.org/wiki/Residual_neural_network#Residual_connection) between the encoder and the decoder.

Let's define a class called `UNet` by subclassing [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) and using, among other things, [`flax.nnx.Linear`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Linear) (linear or dense layers for time embedding and time projection layers, as well as the self-attention layers), [`flax.nnx.LayerNorm`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/normalization.html#flax.nnx.LayerNorm) (layer normalization), and [`flax.nnx.Conv`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Conv) (convolution layers for the output layer).

Expand All @@ -127,7 +132,7 @@ class UNet(nnx.Module):
*,
rngs: nnx.Rngs):
"""
Initialize UNet architecture with time embedding
Initialize the U-Net architecture with time embedding.
"""
self.features = features
Expand Down Expand Up @@ -172,7 +177,7 @@ class UNet(nnx.Module):
rngs=rngs)
def _create_attention_block(self, channels: int, rngs: nnx.Rngs) -> Callable:
"""Creates a self-attention block with learned query, key, value projections"""
"""Creates a self-attention block with learned query, key, value projections."""
query_proj = nnx.Linear(in_features=channels, out_features=channels, rngs=rngs)
key_proj = nnx.Linear(in_features=channels, out_features=channels, rngs=rngs)
value_proj = nnx.Linear(in_features=channels, out_features=channels, rngs=rngs)
Expand Down Expand Up @@ -206,7 +211,7 @@ class UNet(nnx.Module):
in_channels: int,
out_channels: int,
rngs: nnx.Rngs) -> Callable:
"""Creates a residual block with two convolutions and normalization"""
"""Creates a residual block with two convolutions and normalization."""
conv1 = nnx.Conv(in_features=in_channels,
out_features=out_channels,
kernel_size=(3, 3),
Expand All @@ -222,7 +227,7 @@ class UNet(nnx.Module):
rngs=rngs)
norm2 = nnx.LayerNorm(out_channels, rngs=rngs)
# Projection shortcut if dimensions change
# Projection shortcut if dimensions change:
shortcut = nnx.Conv(in_features=in_channels,
out_features=out_channels,
kernel_size=(1, 1),
Expand All @@ -245,7 +250,7 @@ class UNet(nnx.Module):
return forward
def _pos_encoding(self, t: jax.Array, dim: int) -> jax.Array:
"""Sinusoidal positional encoding for time embedding"""
"""Sinusoidal positional encoding for time embedding."""
half_dim = dim // 2
emb = jnp.log(10000.0) / (half_dim - 1)
emb = jnp.exp(jnp.arange(half_dim) * -emb)
Expand All @@ -254,30 +259,30 @@ class UNet(nnx.Module):
return emb
def _downsample(self, x: jax.Array) -> jax.Array:
"""Max pooling for downsampling"""
"""Max pooling for downsampling."""
return nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME')
def _upsample(self, x: jax.Array, target_size: int) -> jax.Array:
"""Nearest neighbor upsampling"""
"""Nearest neighbor upsampling."""
return jax.image.resize(x,
(x.shape[0], target_size, target_size, x.shape[3]),
method='nearest')
def __call__(self, x: jax.Array, t: jax.Array) -> jax.Array:
"""Forward pass through the UNet"""
"""The forward pass through the U-Net."""
# Time embedding and projection
t_emb = self._pos_encoding(t, 128)
t_emb = self.time_mlp_1(t_emb)
t_emb = nnx.gelu(t_emb)
t_emb = self.time_mlp_2(t_emb)
# Project time embeddings for each scale
# Project time embeddings for each scale:
t_emb1 = self.time_proj1(t_emb)[:, None, None, :]
t_emb2 = self.time_proj2(t_emb)[:, None, None, :]
t_emb3 = self.time_proj3(t_emb)[:, None, None, :]
t_emb4 = self.time_proj4(t_emb)[:, None, None, :]
# Encoder path with time injection
# Encoder path with time injection:
d1 = self.down_conv1(x)
t_emb1 = jnp.broadcast_to(t_emb1, d1.shape)
d1 = d1 + t_emb1
Expand All @@ -296,27 +301,27 @@ class UNet(nnx.Module):
t_emb4 = jnp.broadcast_to(t_emb4, d4.shape)
d4 = d4 + t_emb4
# Bridge
# The bridge:
b = self._downsample(d4)
b = self.bridge_down(b)
b = self.bridge_attention(b)
b = self.bridge_up(b)
# Decoder path with skip connections
# The decoder path with skip connections:
u4 = self.up_conv4(jnp.concatenate([self._upsample(b, d4.shape[1]), d4], axis=-1))
u3 = self.up_conv3(jnp.concatenate([self._upsample(u4, d3.shape[1]), d3], axis=-1))
u2 = self.up_conv2(jnp.concatenate([self._upsample(u3, d2.shape[1]), d2], axis=-1))
u1 = self.up_conv1(jnp.concatenate([self._upsample(u2, d1.shape[1]), d1], axis=-1))
# Final layers
# Final layers:
x = self.final_norm(u1)
x = nnx.gelu(x)
return self.final_conv(x)
```

+++ {"id": "XJaqiL07HD9D"}

### Creating the final diffusion model
### Defining the diffusion model

The final diffusion model will rely on the `UNet` class, and will include all the layers needed to perform the diffusion operations. The `DiffusionModel` class implements the diffusion process with:
- Forward diffusion (adding noise)
Expand Down

0 comments on commit e1f7e0a

Please sign in to comment.