Skip to content

Commit

Permalink
Update JAX AI Stack Diffusion tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Jan 21, 2025
1 parent e1f7e0a commit ab4e3eb
Showing 1 changed file with 84 additions and 42 deletions.
126 changes: 84 additions & 42 deletions docs/source/digits_diffusion_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ class UNet(nnx.Module):

### Defining the diffusion model

The final diffusion model will rely on the `UNet` class, and will include all the layers needed to perform the diffusion operations. The `DiffusionModel` class implements the diffusion process with:
Here, we'll define the diffusion model that encapsulates the previously components, such as the `UNet` class, and will include all the layers needed to perform the diffusion operations. The `DiffusionModel` class implements the diffusion process with:
- Forward diffusion (adding noise)
- Reverse diffusion (denoising)
- Custom noise scheduling
Expand All @@ -337,11 +337,18 @@ class DiffusionModel:
num_steps: int,
beta_start: float,
beta_end: float):
"""Initialize diffusion process parameters"""
"""Initialize diffusion process parameters.
Args:
model (UNet): The U-Net model for image generation.
num_steps (int): The number of diffusion steps in the process.
beta_start: The starting value for beta, controlling the noise level.
"""
self.model = model
self.num_steps = num_steps
# Noise schedule parameters
# Noise schedule parameters:
self.beta = self._cosine_beta_schedule(num_steps, beta_start, beta_end)
self.alpha = 1 - self.beta
self.alpha_cumulative = jnp.cumprod(self.alpha)
Expand All @@ -356,7 +363,7 @@ class DiffusionModel:
num_steps: int,
beta_start: float,
beta_end: float) -> jax.Array:
"""Cosine schedule for noise levels"""
"""Cosine schedule for noise levels."""
steps = jnp.linspace(0, num_steps, num_steps + 1)
x = steps / num_steps
alphas = jnp.cos(((x + 0.008) / 1.008) * jnp.pi * 0.5) ** 2
Expand All @@ -369,7 +376,12 @@ class DiffusionModel:
x: jax.Array,
t: jax.Array,
key: jax.Array) -> Tuple[jax.Array, jax.Array]:
"""Forward diffusion process - adds noise according to schedule"""
"""Forward diffusion process - adds noise according to schedule.
Args:
x (jax.Array): The input image.
t (jax.Array): The timestep(s) at which the noise is added.
"""
noise = jax.random.normal(key, x.shape)
noisy_x = (
jnp.sqrt(self.alpha_cumulative[t])[:, None, None, None] * x +
Expand All @@ -378,7 +390,11 @@ class DiffusionModel:
return noisy_x, noise
def reverse(self, x: jax.Array, key: jax.Array) -> jax.Array:
"""Reverse diffusion process - removes noise gradually"""
"""Reverse diffusion process - removes noise gradually.
Args:
x (jax.Array): The timestep(s).
"""
x_t = x
for t in reversed(range(self.num_steps)):
t_batch = jnp.array([t] * x.shape[0])
Expand All @@ -396,11 +412,12 @@ class DiffusionModel:

+++ {"id": "wKnYRqMAI06f"}

## Defining the loss function and the training step
## Defining the loss function and training step

In this section, we’ll define the training components for our model, including:
- A loss function (`loss_fn()`) with [SNR weighting](https://en.wikipedia.org/wiki/Signal-to-noise_ratio) and gradient penalty; and
- The training step (`train_step()`) with [gradient clipping](https://arxiv.org/pdf/1905.11881)
In this section, we’ll define the components for training our diffusion model, including:

- The loss function (`loss_fn()`), which incorporates [SNR weighting](https://en.wikipedia.org/wiki/Signal-to-noise_ratio) and a gradient penalty; and
- The training step (`train_step()`) with [gradient clipping](https://arxiv.org/pdf/1905.11881) for stability.

```{code-cell}
:id: rq9Ic8WYCCJI
Expand All @@ -411,29 +428,41 @@ def loss_fn(model: UNet,
noise: jax.Array,
sqrt_alpha_cumulative: jax.Array,
sqrt_one_minus_alpha_cumulative: jax.Array) -> jax.Array:
"""Loss function with SNR weighting and adaptive noise scaling"""
# Compute noisy images
"""The loss function with SNR weighting and adaptive noise scaling.
Args:
model(UNet): The U-Net model for image generation.
images (jax.Array): Images for training.
t (jax.Array): The timestep(s).
noise (jax.Array): The noise added to the images.
sqrt_alpha_cumulative (jax.Array): Square root of cumulative alpha values.
Returns:
The vallue of the loss function (jax.Array).
"""
# Compute the noisy images.
noisy_images = (
sqrt_alpha_cumulative[t][:, None, None, None] * images +
sqrt_one_minus_alpha_cumulative[t][:, None, None, None] * noise
)
predicted = model(noisy_images, t)
# SNR-weighted loss computation
# Compute the SNR-weighted loss.
snr = (sqrt_alpha_cumulative[t] / sqrt_one_minus_alpha_cumulative[t])[:, None, None, None]
loss_weights = snr / (1 + snr)
squared_error = (noise - predicted) ** 2
main_loss = jnp.mean(loss_weights * squared_error)
# Gradient penalty with reduced coefficient
# Gradient penalty with a reduced coefficient.
grad = jax.grad(lambda x: model(x, t).mean())(noisy_images)
grad_penalty = 0.02 * (jnp.square(grad).mean())
return main_loss + grad_penalty
# Flax NNX JIT-compilation for performance.
@nnx.jit
def train_step(model: UNet,
optimizer: nnx.Optimizer,
Expand All @@ -442,13 +471,25 @@ def train_step(model: UNet,
noise: jax.Array,
sqrt_alpha_cumulative: jax.Array,
sqrt_one_minus_alpha_cumulative: jax.Array) -> jax.Array:
"""Single training step with gradient clipping"""
"""Performs a single training step with gradient clipping.
Args:
model (UNet): The UNet model that's being trained.
optimizer (flax.nnx.Optimizer): The optimizer for parameter updates.
images (jax.Array): The training images.
t (jax.Array): The timestep(s).
noise (jax.Array): The noise added to the images during training.
sqrt_alpha_cumulative (jax.Array): Square root of cumulative alpha values.
Returns:
jax.Array: The loss value after a single training step.
"""
loss, grads = nnx.value_and_grad(loss_fn)(
model, images, t, noise,
sqrt_alpha_cumulative, sqrt_one_minus_alpha_cumulative
)
# Conservative gradient clipping
# Conservative gradient clipping.
clip_threshold = 0.3
grads = jax.tree_util.tree_map(
lambda g: jnp.clip(g, -clip_threshold, clip_threshold),
Expand All @@ -466,26 +507,27 @@ def train_step(model: UNet,
Next, we’ll define the model configuration and the training loop implementation.

We need to set up:
- Model hyperparameters;
- An optimizer with the learning rate schedule.

- Model hyperparameters
- An optimizer with the learning rate schedule

```{code-cell}
:id: w4CwR-6ivIjS
# Model and training hyperparameters
key = jax.random.PRNGKey(42)
# Set the model and training hyperparameters.
key = jax.random.PRNGKey(42) # PRNG seed for reproducibility.
in_channels = 1
out_channels = 1
features = 64
features = 64 # Number of features in the U-Net.
num_steps = 1000
num_epochs = 5000
batch_size = 64
learning_rate = 1e-4
beta_start = 1e-4
beta_end = 0.02
beta_start = 1e-4 # The starting value for beta (noise level schedule).
beta_end = 0.02 # The end value for beta (noise level schedule).
# Initialize model components
key, subkey = jax.random.split(key)
# Initialize model components.
key, subkey = jax.random.split(key) # Split the JAX PRNG key for initialization.
model = UNet(in_channels, out_channels, features, rngs=nnx.Rngs(default=subkey))
diffusion = DiffusionModel(
Expand All @@ -503,7 +545,7 @@ colab:
id: yLjb_t026uy3
outputId: 2cda0980-ac98-4fd7-ee3a-02728a64f1f7
---
# Learning rate schedule configuration
# Learning rate schedule configuration.
warmup_steps = 1000
total_steps = num_epochs
Expand All @@ -523,7 +565,7 @@ schedule_fn = optax.join_schedules(
boundaries=[warmup_steps]
)
# Optimizer configuration
# Optimizer configuration (AdamW).
optimizer = nnx.Optimizer(model, optax.chain(
optax.clip_by_global_norm(0.5),
optax.adamw(
Expand All @@ -535,7 +577,7 @@ optimizer = nnx.Optimizer(model, optax.chain(
)
))
# Model initialization with dummy input
# Model initialization with dummy input.
dummy_input = jnp.ones((1, 8, 8, 1))
dummy_t = jnp.zeros((1,), dtype=jnp.int32)
output = model(dummy_input, dummy_t)
Expand All @@ -550,6 +592,7 @@ print("\nModel initialized successfully")
### Implementing the training loop

Finally, we need to implement the main training loop for the diffusion model with:

- The progressive timestep sampling strategy
- [Exponential moving average (EMA)](https://en.wikipedia.org/wiki/Moving_average#Exponential_moving_average) loss tracking
- Adaptive noise generation
Expand All @@ -567,12 +610,12 @@ moving_avg_loss: Optional[float] = None # EMA of the loss value
beta: float = 0.99 # EMA decay factor for loss smoothing
for epoch in range(num_epochs + 1):
# Split PRNG key for independent random operations
# Split the JAX PRNG key for independent random operations.
key, subkey1 = jax.random.split(key)
key, subkey2 = jax.random.split(key)
# Progressive timestep sampling - weights early steps more heavily as training progresses
# This helps model focus on fine details in later epochs while maintaining stability
# Progressive timestep sampling - weights early steps more heavily as training progresses.
# This helps model focus on fine details in later epochs while maintaining stability.
progress = epoch / num_epochs
t_weights = jnp.linspace(1.0, 0.1 * (1.0 - progress), num_steps)
t = jax.random.choice(
Expand All @@ -582,24 +625,24 @@ for epoch in range(num_epochs + 1):
p=t_weights/t_weights.sum()
)
# Generate Gaussian noise for current batch
# Generate the Gaussian noise for the current batch.
noise = jax.random.normal(subkey2, images_train.shape)
# Execute training step with noise prediction and parameter updates
# Execute the training step with noise prediction and parameter updates.
loss = train_step(
model, optimizer, images_train, t, noise,
diffusion.sqrt_alpha_cumulative, diffusion.sqrt_one_minus_alpha_cumulative
)
# Update exponential moving average of loss for smoother tracking
# Update the exponential moving average of the loss for smoother tracking.
if moving_avg_loss is None:
moving_avg_loss = loss
else:
moving_avg_loss = beta * moving_avg_loss + (1 - beta) * loss
losses.append(moving_avg_loss)
# Log training progress at regular intervals
# Log the training progress at regular intervals.
if epoch % 100 == 0:
print(f"Epoch {epoch}, Loss: {moving_avg_loss:.4f}")
Expand Down Expand Up @@ -656,7 +699,7 @@ def reverse_diffusion_batch(model: UNet,
x: jax.Array,
key: jax.Array,
num_steps: int) -> jax.Array:
"""Efficient batched reverse diffusion using scan"""
"""Efficient batched reverse diffusion using `jax.lax.scan`."""
beta = jnp.linspace(1e-4, 0.02, num_steps)
alpha = 1 - beta
alpha_cumulative = jnp.cumprod(alpha)
Expand Down Expand Up @@ -726,15 +769,15 @@ def compute_forward_sequence(model: UNet,
image: jax.Array,
key: jax.Array,
num_vis_steps: int) -> jax.Array:
"""Compute forward diffusion sequence efficiently."""
# Prepare image sequence and noise parameters
"""Computes the forward diffusion sequence efficiently."""
# Prepare image sequence and noise parameters.
image_repeated = jnp.repeat(image[None], num_vis_steps, axis=0)
timesteps = jnp.linspace(0, 999, num_vis_steps).astype(jnp.int32) # Assuming 1000 steps
beta = jnp.linspace(1e-4, 0.02, 1000)
alpha = 1 - beta
alpha_cumulative = jnp.cumprod(alpha)
# Generate and apply noise progressively
# Generate and apply noise progressively.
noise = jax.random.normal(key, image_repeated.shape)
noisy_images = (
jnp.sqrt(alpha_cumulative[timesteps])[:, None, None, None] * image_repeated +
Expand Down Expand Up @@ -801,5 +844,4 @@ plot_forward_and_reverse(model, diffusion, images_test[0], subkey)

## Summary

In this tutorial, we implemented a simple diffusion model using JAX and Flax, and trained it with Optax and Flax. The model consisted of the U-Net model architecture with attention mechanisms, the training used Flax’s NNX JIT compilation, and we also learned how to visualize the diffusion process.

In this tutorial, we implemented a simple diffusion model using JAX and Flax, and trained it with Optax and Flax. The model consisted of the U-Net model architecture with attention mechanisms, the training used Flax’s NNX JIT compilation, and we also learned how to visualize the diffusion process.

0 comments on commit ab4e3eb

Please sign in to comment.