Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ViT tutorial #116

Merged
merged 1 commit into from
Dec 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,339 changes: 411 additions & 928 deletions docs/JAX_Vision_transformer.ipynb

Large diffs are not rendered by default.

200 changes: 174 additions & 26 deletions docs/JAX_Vision_transformer.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ kernelspec:

# Vision Transformer with JAX & FLAX


In this tutorial we implement from scratch Vision Transformer (ViT) model based on the paper by Dosovitskiy et al: [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). We will train this model on [Food 101](https://huggingface.co/datasets/ethz/food101) dataset.
In this tutorial we implement from scratch the Vision Transformer (ViT) model based on the paper by Dosovitskiy et al: [An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale](https://arxiv.org/abs/2010.11929). We load the ImageNet pretrained weights and finetune this model on [Food 101](https://huggingface.co/datasets/ethz/food101) dataset.
This tutorial is originally inspired by [HuggingFace Image classification tutorial](https://huggingface.co/docs/transformers/tasks/image_classification).

+++
Expand All @@ -27,9 +26,10 @@ We will need to install the following Python packages:
- [TorchVision](https://pytorch.org/vision) will be used for image augmentations
- [grain](https://github.com/google/grain/) will be be used for efficient data loading
- [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress.
- [Matplotlib](https://matplotlib.org/stable/) will be used for visualization purposes

```{code-cell} ipython3
# !pip install -U datasets grain torchvision tqdm
# !pip install -U datasets grain torchvision tqdm matplotlib
# !pip install -U flax optax
```

Expand Down Expand Up @@ -98,7 +98,7 @@ class VisionTransformer(nnx.Module):
TransformerEncoder(hidden_size, mlp_dim, num_heads, dropout_rate, rngs=rngs)
for i in range(num_layers)
])
self.lnorm = nnx.LayerNorm(hidden_size, rngs=rngs)
self.final_norm = nnx.LayerNorm(hidden_size, rngs=rngs)

# Classification head
self.classifier = nnx.Linear(hidden_size, num_classes, rngs=rngs)
Expand All @@ -116,7 +116,7 @@ class VisionTransformer(nnx.Module):

# Encoder blocks
x = self.encoder(embeddings)
x = self.lnorm(x)
x = self.final_norm(x)

# fetch the first token
x = x[:, 0]
Expand Down Expand Up @@ -162,9 +162,155 @@ class TransformerEncoder(nnx.Module):
return x


# We use a configuration to make smaller model to reduce the training time
x = jnp.ones((4, 120, 120, 3))
model = VisionTransformer(num_classes=10, num_layers=4, num_heads=4, img_size=120, patch_size=8)
x = jnp.ones((4, 224, 224, 3))
model = VisionTransformer(num_classes=1000)
y = model(x)
print("Predictions shape: ", y.shape)
```

Let's now load the weights pretrained on the ImageNet dataset using HuggingFace Transformers. We load all weights and check whether we have consistent results with the reference model.

```{code-cell} ipython3
from transformers import FlaxViTForImageClassification

tf_model = FlaxViTForImageClassification.from_pretrained('google/vit-base-patch16-224')
```

```{code-cell} ipython3
def vit_inplace_copy_weights(*, src_model, dst_model):
assert isinstance(src_model, FlaxViTForImageClassification)
assert isinstance(dst_model, VisionTransformer)

tf_model_params = src_model.params
tf_model_params_fstate = nnx.traversals.flatten_mapping(tf_model_params)

flax_model_params = nnx.state(dst_model, nnx.Param)
flax_model_params_fstate = flax_model_params.flat_state()

params_name_mapping = {
("cls_token",): ("vit", "embeddings", "cls_token"),
("position_embeddings",): ("vit", "embeddings", "position_embeddings"),
**{
("patch_embeddings", x): ("vit", "embeddings", "patch_embeddings", "projection", x)
for x in ["kernel", "bias"]
},
**{
("encoder", "layers", i, "attn", y, x): (
"vit", "encoder", "layer", str(i), "attention", "attention", y, x
)
for x in ["kernel", "bias"]
for y in ["key", "value", "query"]
for i in range(12)
},
**{
("encoder", "layers", i, "attn", "out", x): (
"vit", "encoder", "layer", str(i), "attention", "output", "dense", x
)
for x in ["kernel", "bias"]
for i in range(12)
},
**{
("encoder", "layers", i, "mlp", "layers", y1, x): (
"vit", "encoder", "layer", str(i), y2, "dense", x
)
for x in ["kernel", "bias"]
for y1, y2 in [(0, "intermediate"), (3, "output")]
for i in range(12)
},
**{
("encoder", "layers", i, y1, x): (
"vit", "encoder", "layer", str(i), y2, x
)
for x in ["scale", "bias"]
for y1, y2 in [("norm1", "layernorm_before"), ("norm2", "layernorm_after")]
for i in range(12)
},
**{
("final_norm", x): ("vit", "layernorm", x)
for x in ["scale", "bias"]
},
**{
("classifier", x): ("classifier", x)
for x in ["kernel", "bias"]
}
}

nonvisited = set(flax_model_params_fstate.keys())

for key1, key2 in params_name_mapping.items():
assert key1 in flax_model_params_fstate, key1
assert key2 in tf_model_params_fstate, (key1, key2)

nonvisited.remove(key1)

src_value = tf_model_params_fstate[key2]
if key2[-1] == "kernel" and key2[-2] in ("key", "value", "query"):
shape = src_value.shape
src_value = src_value.reshape((shape[0], 12, 64))

if key2[-1] == "bias" and key2[-2] in ("key", "value", "query"):
src_value = src_value.reshape((12, 64))

if key2[-4:] == ("attention", "output", "dense", "kernel"):
shape = src_value.shape
src_value = src_value.reshape((12, 64, shape[-1]))

dst_value = flax_model_params_fstate[key1]
assert src_value.shape == dst_value.value.shape, (key2, src_value.shape, key1, dst_value.value.shape)
dst_value.value = src_value.copy()
assert dst_value.value.mean() == src_value.mean(), (dst_value.value, src_value.mean())

assert len(nonvisited) == 0, nonvisited
nnx.update(dst_model, nnx.State.from_flat_path(flax_model_params_fstate))


vit_inplace_copy_weights(src_model=tf_model, dst_model=model)
```

Let's check the pretrained weights of our model and compare with the reference model results:

```{code-cell} ipython3
import matplotlib.pyplot as plt
from transformers import ViTImageProcessor
from PIL import Image
import requests

url = "https://farm2.staticflickr.com/1152/1151216944_1525126615_z.jpg"
image = Image.open(requests.get(url, stream=True).raw)

processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')

inputs = processor(images=image, return_tensors="np")
outputs = tf_model(**inputs)
logits = outputs.logits


model.eval()
x = jnp.transpose(inputs["pixel_values"], axes=(0, 2, 3, 1))
output = model(x)

# model predicts one of the 1000 ImageNet classes
ref_class_idx = logits.argmax(-1).item()
pred_class_idx = output.argmax(-1).item()
assert jnp.abs(logits[0, :] - output[0, :]).max() < 0.1

fig, axs = plt.subplots(1, 2, figsize=(12, 8))
axs[0].set_title(
f"Reference model:\n{tf_model.config.id2label[ref_class_idx]}\nP={nnx.softmax(logits, axis=-1)[0, ref_class_idx]:.3f}"
)
axs[0].imshow(image)
axs[1].set_title(
f"Our model:\n{tf_model.config.id2label[pred_class_idx]}\nP={nnx.softmax(output, axis=-1)[0, pred_class_idx]:.3f}"
)
axs[1].imshow(image)
```

Now let's replace the classifier with a smaller fully-connected layer returning 20 classes instead of 1000:

```{code-cell} ipython3
model.classifier = nnx.Linear(model.classifier.in_features, 20, rngs=nnx.Rngs(0))

x = jnp.ones((4, 224, 224, 3))
y = model(x)
print("Predictions shape: ", y.shape)
```
Expand All @@ -177,19 +323,19 @@ In the following sections we set up a image classification dataset and train thi

In the this tutorial we use [Food 101](https://huggingface.co/datasets/ethz/food101) dataset which consists of 101 food categories, with 101,000 images. For each class, 250 manually reviewed test images are provided as well as 750 training images. On purpose, the training images were not cleaned, and thus still contain some amount of noise. This comes mostly in the form of intense colors and sometimes wrong labels. All images were rescaled to have a maximum side length of 512 pixels.

We will download the data using [HuggingFace Datasets](https://huggingface.co/docs/datasets/) and select 10 classes to reduce the dataset size and the model training time. We will be using [TorchVision](https://pytorch.org/vision) to transform input images and [`grain`](https://github.com/google/grain/) for efficient data loading.
We will download the data using [HuggingFace Datasets](https://huggingface.co/docs/datasets/) and select 20 classes to reduce the dataset size and the model training time. We will be using [TorchVision](https://pytorch.org/vision) to transform input images and [`grain`](https://github.com/google/grain/) for efficient data loading.

```{code-cell} ipython3
from datasets import load_dataset

# Select first 10 classes to reduce the dataset size and the training time.
train_size = 10 * 750
val_size = 10 * 250
# Select first 20 classes to reduce the dataset size and the training time.
train_size = 20 * 750
val_size = 20 * 250

train_dataset = load_dataset("food101", split=f"train[:{train_size}]")
val_dataset = load_dataset("food101", split=f"validation[:{val_size}]")

# Let's create labels mapping where we map current labels between 0 and 9
# Let's create labels mapping where we map current labels between 0 and 19
labels_mapping = {}
index = 0
for i in range(0, len(val_dataset), 250):
Expand All @@ -198,6 +344,7 @@ for i in range(0, len(val_dataset), 250):
labels_mapping[label] = index
index += 1

inv_labels_mapping = {v: k for k, v in labels_mapping.items()}

print("Training dataset size:", len(train_dataset))
print("Validation dataset size:", len(val_dataset))
Expand Down Expand Up @@ -248,18 +395,19 @@ import numpy as np
from torchvision.transforms import v2 as T


img_size = 120
img_size = 224


def to_np_array(pil_image):
return np.asarray(pil_image.convert("RGB"))


def normalize(image):
mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
std = np.array([0.229, 0.224, 0.225], dtype=np.float32)
image = image.astype(np.float32) / 255.0
return (image - mean) / std
# Image preprocessing matches the one of pretrained ViT
mean = np.array([0.5, 0.5, 0.5], dtype=np.float32)
std = np.array([0.5, 0.5, 0.5], dtype=np.float32)
image = image.astype(np.float32) / 255.0
return (image - mean) / std


tv_train_transforms = T.Compose([
Expand All @@ -283,7 +431,7 @@ def get_transform(fn):
batch["image"] = [
fn(pil_image) for pil_image in batch["image"]
]
# map label index between 0 - 9
# map label index between 0 - 19
batch["label"] = [
labels_mapping[label] for label in batch["label"]
]
Expand All @@ -303,7 +451,7 @@ import grain.python as grain


seed = 12
train_batch_size = 64
train_batch_size = 32
val_batch_size = 2 * train_batch_size


Expand Down Expand Up @@ -363,15 +511,15 @@ print("Validation batch info:", val_batch["image"].shape, val_batch["image"].dty
display_datapoints(
*[(train_batch["image"][i], train_batch["label"][i]) for i in range(5)],
tag="(Training) ",
names_map=train_dataset.features["label"].names
names_map={k: train_dataset.features["label"].names[v] for k, v in inv_labels_mapping.items()}
)
```

```{code-cell} ipython3
display_datapoints(
*[(val_batch["image"][i], val_batch["label"][i]) for i in range(5)],
tag="(Validation) ",
names_map=val_dataset.features["label"].names
names_map={k: val_dataset.features["label"].names[v] for k, v in inv_labels_mapping.items()}
)
```

Expand All @@ -382,8 +530,8 @@ We defined training and validation datasets and the model. In this section we wi
```{code-cell} ipython3
import optax

num_epochs = 50
learning_rate = 0.005
num_epochs = 3
learning_rate = 0.001
momentum = 0.8
total_steps = len(train_dataset) // train_batch_size

Expand Down Expand Up @@ -544,7 +692,6 @@ preds = model(test_images)
```{code-cell} ipython3
num_samples = len(test_indices)
names_map = train_dataset.features["label"].names
inv_labels_mapping = {v: k for k, v in labels_mapping.items()}

probas = nnx.softmax(preds, axis=1)
pred_labels = probas.argmax(axis=1)
Expand All @@ -567,10 +714,11 @@ for i in range(num_samples):

## Further reading

In this tutorial we implemented from scratch Vision Transformer model and trained it on a subset of Food 101 dataset. Trained model shows 67% classification accuracy. Next steps could be to finetune hyperparameters like the learning rate and train for more epochs.
In this tutorial we implemented from scratch the Vision Transformer model and finetuned it on a subset of Food 101 dataset. The trained model shows almost perfect classification accuracy: 95%.

- Model checkpointing and exporting using [Orbax](https://orbax.readthedocs.io/en/latest/).
- Optimizers and the learning rate scheduling using [Optax](https://optax.readthedocs.io/en/latest/).
- Freezing model's parameters using trainable parameters filtering: [example 1](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/optimizer.html#flax.nnx.optimizer.Optimizer.update) and [example 2](https://github.com/google/flax/issues/4167#issuecomment-2324245208).
- Other Computer Vision tutorials in [jax-ai-stack](https://jax-ai-stack.readthedocs.io/en/latest/tutorials.html).

```{code-cell} ipython3
Expand Down
Loading