Skip to content

Commit

Permalink
Update JAX AI Stack Diffusion tutorial
Browse files Browse the repository at this point in the history
  • Loading branch information
8bitmp3 committed Jan 22, 2025
1 parent ab4e3eb commit 23321e2
Showing 1 changed file with 22 additions and 21 deletions.
43 changes: 22 additions & 21 deletions docs/source/digits_diffusion_model.md
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,8 @@ class UNet(nnx.Module):

### Defining the diffusion model

Here, we'll define the diffusion model that encapsulates the previously components, such as the `UNet` class, and will include all the layers needed to perform the diffusion operations. The `DiffusionModel` class implements the diffusion process with:
Here, we will define the diffusion model that encapsulates the previously components, such as the `UNet` class, and include all the layers needed to perform the diffusion operations. The `DiffusionModel` class implements the diffusion process with:

- Forward diffusion (adding noise)
- Reverse diffusion (denoising)
- Custom noise scheduling
Expand All @@ -343,12 +344,12 @@ class DiffusionModel:
model (UNet): The U-Net model for image generation.
num_steps (int): The number of diffusion steps in the process.
beta_start: The starting value for beta, controlling the noise level.
beta_end: The end value for beta.
"""
self.model = model
self.num_steps = num_steps
# Noise schedule parameters:
# Noise schedule parameters.
self.beta = self._cosine_beta_schedule(num_steps, beta_start, beta_end)
self.alpha = 1 - self.beta
self.alpha_cumulative = jnp.cumprod(self.alpha)
Expand Down Expand Up @@ -482,7 +483,7 @@ def train_step(model: UNet,
sqrt_alpha_cumulative (jax.Array): Square root of cumulative alpha values.
Returns:
jax.Array: The loss value after a single training step.
The loss value after a single training step (jax.Array).
"""
loss, grads = nnx.value_and_grad(loss_fn)(
model, images, t, noise,
Expand Down Expand Up @@ -604,10 +605,10 @@ colab:
id: ZnQqHCAoVfi1
outputId: a105e2de-ba88-44d0-bad5-3a9a69e54826
---
# Initialize training metrics
losses: List[float] = [] # Store EMA loss history
moving_avg_loss: Optional[float] = None # EMA of the loss value
beta: float = 0.99 # EMA decay factor for loss smoothing
# Initialize training metrics.
losses: List[float] = [] # Store the EMA loss history.
moving_avg_loss: Optional[float] = None # The EMA of the loss value.
beta: float = 0.99 # The EMA decay factor for loss smoothing.
for epoch in range(num_epochs + 1):
# Split the JAX PRNG key for independent random operations.
Expand Down Expand Up @@ -665,14 +666,14 @@ colab:
id: 1bjvWNCcbN24
outputId: 457fd13f-377f-4021-ddc2-e36940b42550
---
# Plot the training loss history with logarithmic scaling
# Plot the training loss history with logarithmic scaling.
plt.figure(figsize=(10, 5)) # Create figure with wide aspect ratio for clarity
plt.plot(losses) # losses: List[float] - historical EMA loss values
plt.plot(losses) # losses: List[float] - historical EMA loss values.
plt.title('Training Loss Over Time')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.yscale('log') # Use log scale to better visualize exponential decay
plt.grid(True) # Add grid for easier value reading
plt.yscale('log') # Use the log scale to better visualize exponential decay.
plt.grid(True) # Add a grid for easier value reading.
plt.show()
```

Expand Down Expand Up @@ -728,7 +729,7 @@ def plot_samples(model: UNet,
images: jax.Array,
key: jax.Array,
num_samples: int = 9) -> None:
"""Visualize original vs reconstructed images"""
"""Visualize original vs reconstructed images."""
indices = jax.random.randint(key, (num_samples,), 0, len(images))
samples = images[indices]
Expand All @@ -751,7 +752,7 @@ def plot_samples(model: UNet,
plt.tight_layout()
plt.show()
# Generate visualization
# Create a plot of original vs reconstructed images.
key, subkey = jax.random.split(key)
plot_samples(model, diffusion, images_test, subkey)
```
Expand Down Expand Up @@ -791,7 +792,7 @@ def compute_reverse_sequence(model: UNet,
key: jax.Array,
num_vis_steps: int) -> jax.Array:
"""Compute reverse diffusion sequence efficiently."""
# Denoise completely and create interpolation sequence
# Denoise completely and create interpolation sequence.
final_image = reverse_diffusion_batch(model, noisy_image[None], key, 1000)[0]
alphas = jnp.linspace(0, 1, num_vis_steps)
reverse_sequence = (
Expand All @@ -811,20 +812,20 @@ def plot_forward_and_reverse(model: UNet,
forward_sequence = compute_forward_sequence(model, image, key1, num_steps)
reverse_sequence = compute_reverse_sequence(model, forward_sequence[-1], key2, num_steps)
# Setup visualization grid
# Plot the grid.
fig, (ax1, ax2) = plt.subplots(2, num_steps, figsize=(8, 2))
fig.suptitle('Forward and Reverse Diffusion Process', y=1.1)
fig.suptitle('Forward and reverse diffusion process', y=1.1)
timesteps = jnp.linspace(0, diffusion.num_steps-1, num_steps).astype(jnp.int32)
# Visualize forward diffusion
# Visualize forward diffusion.
for i in range(num_steps):
ax1[i].imshow(forward_sequence[i, ..., 0], cmap='binary', interpolation='gaussian')
ax1[i].axis('off')
ax1[i].set_title(f't={timesteps[i]}')
ax1[0].set_ylabel('Forward', rotation=90, labelpad=10)
# Visualize reverse diffusion
# Visualize reverse diffusion.
for i in range(num_steps):
ax2[i].imshow(reverse_sequence[i, ..., 0], cmap='binary', interpolation='gaussian')
ax2[i].axis('off')
Expand All @@ -834,9 +835,9 @@ def plot_forward_and_reverse(model: UNet,
plt.tight_layout()
plt.show()
# Generate visualization
# Create a plot.
key, subkey = jax.random.split(key)
print("\nFull Forward and Reverse Process:")
print("\nFull forward and reverse diffusion processes:")
plot_forward_and_reverse(model, diffusion, images_test[0], subkey)
```

Expand Down

0 comments on commit 23321e2

Please sign in to comment.