diff --git a/docs/source/JAX_for_LLM_pretraining.md b/docs/source/JAX_for_LLM_pretraining.md index 75ad425..bb1d249 100644 --- a/docs/source/JAX_for_LLM_pretraining.md +++ b/docs/source/JAX_for_LLM_pretraining.md @@ -13,19 +13,27 @@ kernelspec: +++ {"id": "NIOXoY1xgiww"} -# Pre-training an LLM (miniGPT) +# Train a miniGPT language model with JAX for AI [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_LLM_pretraining.ipynb) -This tutorial demonstrates how to use JAX/Flax for LLM pretraining via data and tensor parallelism. It is originally inspired by this [Keras miniGPT tutorial](https://keras.io/examples/generative/text_generation_with_miniature_gpt/). +This tutorial will demonstrate how to use JAX, [Flax NNX](http://flax.readthedocs.io) and [Optax](http://optax.readthedocs.io) for language model training using data and tensor [parallelism](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization) for [Single-Program Multi-Data](https://en.wikipedia.org/wiki/Single_program,_multiple_data)). It was originally inspired by the [Keras miniGPT tutorial](https://keras.io/examples/generative/text_generation_with_miniature_gpt/). -We will use Google TPUs and [SPMD](https://en.wikipedia.org/wiki/Single_program,_multiple_data) to train a language model `miniGPT`. Instead of using a GPU, you should use the free TPU on Colab or Kaggle for this tutorial. +Here, you will learn how to: + +- Define the miniGPT model with Flax and JAX automatic parallelism +- Load and preprocess the dataset +- Create the loss and training step functions +- Train the model on Google Colab’s Cloud TPU v2 +- Profile for hyperparameter tuning + +If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html). +++ {"id": "hTmz5Cbco7n_"} ## Setup -Install JAX and Flax first. We will install Tiktoken for tokenization and Grain for data loading as well. +JAX installation is covered in [this guide](https://jax.readthedocs.io/en/latest/installation.html) on the JAX documentation site. We will use [Tiktoken](https://github.com/openai/tiktoken) for tokenization and [Grain](https://google-grain.readthedocs.io/en/latest/index.html) for data loading. ```{code-cell} --- @@ -34,13 +42,14 @@ colab: id: 6zMsOIc7ouCO outputId: 037d56a9-b18f-4504-f80a-3a4fa2945068 --- -!pip install -q jax-ai-stack !pip install -Uq tiktoken grain matplotlib ``` +++ {"id": "Rcji_799n4eA"} -Confirm we have TPUs set up. +**Note:** If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v2 as the hardware accelerator. + +Check the available JAX devices, or [`jax.Device`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html). The output of the cell below will show a list of 8 (eight) devices. ```{code-cell} --- @@ -69,19 +78,22 @@ outputId: e6eff24e-5578-4277-a0f9-24e27bd91ee0 +++ {"id": "sKE2uUafLobI"} -Take care of the imports. +Import the necessary modules, including JAX NumPy, Flax NNX, Optax, Grain, pandas, and Tiktoken: ```{code-cell} :id: MKYFNOhdLq98 import jax import jax.numpy as jnp + +from jax.sharding import Mesh, PartitionSpec as P, NamedSharding # For data and model parallelism (explained in more detail later) +from jax.experimental import mesh_utils + import flax.nnx as nnx import optax + from dataclasses import dataclass import grain.python as pygrain -from jax.experimental import mesh_utils -from jax.sharding import Mesh, PartitionSpec as P, NamedSharding import pandas as pd import tiktoken import time @@ -89,20 +101,38 @@ import time +++ {"id": "rPyt7MV6prz1"} -## Build the model +## Define the miniGPT model with Flax and JAX automatic parallelism + +### Leveraging JAX parallelism + +One of the most powerful features of JAX is [device parallelism](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization) for SPMD. + +- The data parallelism technique enables, for example, the training data to run via multiple parts (this is called sharding) - batches - in parallel and simultaneously across different devices, such as GPUs and Google TPUs. This allows to use larger batch sizes to speed up training +- Tensor parallelism allows us to split the model parameter tensors across several devices (sharding model tensors). +- You can learn more about the basics of JAX parallelism in more detail in the [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html) on the JAX documentation site. + +In this example, we'll utilize a 4-way data parallel and 2-way tensor parallel setup. The free Google Cloud TPU v2 on Google Colab offers 4 chips, each with 2 TPU cores. The TPU v2 architeture aligns with the proposed setup. -One of the biggest advantages of JAX is how easy it is to enable parallelism. To demonstrate this, we are going to use 4-way data parallel and 2-way tensor parallel. Tensor parallelism is one kind of model parallelism, which shards model tensors; there are other kinds of model parallelism, which we won't cover in this tutorial. +### jax.sharding.Mesh -As a background, data parallel means splitting a batch of training data into multiple parts (this is called sharding); this way you can use bigger batch sizes to accelerate training, if you have multiple devices that can run in parallel. On the other hand, you can shard not just the training data. Sometimes your model is so big that the model parameters don't fit on a single accelerator. In this case, tensor parallel helps splitting the parameter tensors within a model onto multiple accelerators so that the model can actually run. Both approaches can take advantage of modern accelerators. For example, TPU v2 on the free Colab tier offers 4 chips, each of which has 2 TPU cores. So this architeture works well with 4-way data parallel and 2-way tensor parallel. +Earlier, we imported [`jax.sharding.Mesh`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh) - is a multidimensional NumPy array of JAX devices, where each axis of the mesh has a name, such as `'x'` or `'y'`. This will help encapsulate the information about the TPU resource organization for distributing computations across the devices. -To get a detailed understanding of how JAX automatic parallelism works, please refer to this [JAX tutorial](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html#way-batch-data-parallelism-and-2-way-model-tensor-parallelism). In our case to leverage parallelism, we first need to define a `Mesh`, which declares the TPU resources with 2 axes: `batch` axis as 4 and `model` axis as 2, which maps to the TPU v2 cores. Here, the `model` axis enables the tensor parallel for us. +Our `Mesh` will have two arguments: +- `devices`: This will take the value of [`jax.experimental.mesh_utils((4, 2))`](https://jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html), enabling us to build a device mesh. It is a NumPy ndarray with JAX devices (a list of devices from the JAX backend as obtained from [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices)).. +- `axis_names`, where: + - `batch`: 4 devices along the first axis - i.e. sharded into 4 - for data parallelism; and + - `model`: 2 devices along the second axis - i.e. sharded into 2 - for tensor paralleism, mapping to the TPU v2 cores. + +This matches the `(4, 2)` structure in the Colab's TPU v2 setup. + +Let's instantiate `Mesh` as `mesh` and declare the TPU configuration to define how data and model parameters are distributed across the devices: ```{code-cell} :id: xuMlCK3Q8WJD mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model')) -### Alternative 8-way data parallel with only one line of code change. +### Alternatively, we could use the 8-way data parallelism with only one line of code change. ### JAX enables quick experimentation with different partitioning strategies ### like this. We will come back to this point at the end of this tutorial. # mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model')) @@ -110,7 +140,7 @@ mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model')) +++ {"id": "_ZKdhNo98NgG"} -We are going to use the GPT-2 tokenizer via [Tiktoken](https://github.com/openai/tiktoken). +We will use the GPT-2 tokenizer from the [Tiktoken](https://github.com/openai/tiktoken) library: ```{code-cell} :id: iWbkk1V7-Isg @@ -120,40 +150,55 @@ tokenizer = tiktoken.get_encoding("gpt2") +++ {"id": "0XHQ0BQ9-KIj"} -To use model parallel, we need to tell JAX compiler how to shard the model tensors. We first use `PartitionSpec` (shorted to `P` in the code) to describe how to shard a tensor: in our case a tensor could be either sharded along the `model` axis or be replicated on other dimensions (which is denoted by `None`). [`NamedSharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.NamedSharding) can then specify how a model tensor is sharded across the devices mesh using a pair of `Mesh` and `PartitionSpec`. +To leverage model parallelism, we need to instruct the JAX compiler how to shard the model tensors across the TPU devices. Earlier, we also imported [`jax.sharding.PartitionSpec`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.PartitionSpec) and [`jax.sharding.NamedSharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.NamedSharding): +- [`PartitionSpec`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.PartitionSpec) (using alias `P`) defines how tensors are sharded across the devices in our `Mesh`. Its elements describe how an input dimension is partitioned across mesh dimensions. For example, in `PartitionSpec('x', 'y')` the first dimension of data is sharded across `x` axis of the mesh, and the second one - across the `y` axis. + - We'll use `PartitionSpec` to describe how to shard a tensor across, for example, the `model` axis or be replicated on other dimensions (which is denoted by `None`). +- [`NamedSharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.NamedSharding) is a (`Mesh`, `PartitionSpec`) pair that describes how to shard a model tensor across our `mesh`. +- We combine `Mesh` (the TPU resources) with `PartitionSpec` and create a `NamedSharding`, which instructs how to shard each model tensor across the TPU devices. -Finally, we use `nnx.with_partitioning` to let the layers know that their tensors need to be shared/replicated according to our spec. You need to do this for every tensor/layer in your model. +Additionally, we'll use Flax NNX's [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to let each model layer know that the model weights or tensors need to be sharded according to our specification. We need to do this for every tensor/layer in the model. +- `nnx.with_partitioning` will take two arguments, such as the `initializer` (such as [`flax.nnx.initializers.xavier_uniform`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/initializers.html#flax.nnx.initializers.xavier_uniform) and [`flax.nnx.initializers.zeros_init`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/initializers.html#flax.nnx.initializers.zeros_init)) and `sharding` (e.g. `NamedSharding(Mesh, PartitionSpec)` or `NamedSharding(mesh, P('model')` in our case). ```{code-cell} :id: z0p-IHurrB9i +# Define a triangular mask for causal attention with `jax.numpy.tril` and `jax.numpy.ones`. def causal_attention_mask(seq_len): return jnp.tril(jnp.ones((seq_len, seq_len))) +# Define a single Transformer block. class TransformerBlock(nnx.Module): + # Initialize layers of the Transformer block. def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, *, rngs: nnx.Rngs, rate: float = 0.1): + # Multi-Head Attention (MHA) with `flax.nnx.MultiHeadAttention`. self.mha = nnx.MultiHeadAttention(num_heads=num_heads, in_features=embed_dim, - kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))), - bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))), + kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))), # Specify tensor sharding. + bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))), # Specify tensor sharding. rngs=rngs) + # The first dropout with `flax.nnx.Dropout`. self.dropout1 = nnx.Dropout(rate=rate) + # First layer normalization with `flax.nnx.LayerNorm`. self.layer_norm1 = nnx.LayerNorm(epsilon=1e-6, num_features=embed_dim, scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P('model'))), bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))), rngs=rngs) + # The first linear transformation for the feed-forward network with `flax.nnx.Linear`. self.linear1 = nnx.Linear(in_features=embed_dim, out_features=ff_dim, kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))), bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))), rngs=rngs) + # The second linear transformation for the feed-forward network with `flax.nnx.Linear`. self.linear2 = nnx.Linear(in_features=ff_dim, out_features=embed_dim, kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), NamedSharding(mesh, P(None, 'model'))), bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), NamedSharding(mesh, P('model'))), rngs=rngs) + # The second dropout with `flax.nnx.Dropout`. self.dropout2 = nnx.Dropout(rate=rate) + # Second layer normalization with `flax.nnx.LayerNorm`. self.layer_norm2 = nnx.LayerNorm(epsilon=1e-6, num_features=embed_dim, scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), NamedSharding(mesh, P(None, 'model'))), @@ -161,28 +206,36 @@ class TransformerBlock(nnx.Module): rngs=rngs) + # Apply the Transformer block to the input sequence. def __call__(self, inputs, training: bool = False): input_shape = inputs.shape _, seq_len, _ = input_shape - # Create causal mask + # Instantiate the causal attention mask. mask = causal_attention_mask(seq_len) - # Apply MultiHeadAttention with causal mask + # Apply Multi-Head Attention with the causal attention mask. attention_output = self.mha( inputs_q=inputs, mask=mask, decode=False ) + # Apply the first dropout. attention_output = self.dropout1(attention_output, deterministic=not training) + # Apply the first layer normalization. out1 = self.layer_norm1(inputs + attention_output) - # Feed-forward network + # Feed-forward network. + # Apply the first linear transformation. ffn_output = self.linear1(out1) + # Apply the ReLU activation with `flax.nnx.relu`. ffn_output = nnx.relu(ffn_output) + # Apply the second linear transformation. ffn_output = self.linear2(ffn_output) + # Apply the second dropout. ffn_output = self.dropout2(ffn_output, deterministic=not training) + # Apply the second layer normalization and return the output of the Transformer block. return self.layer_norm2(out1 + ffn_output) @@ -275,7 +328,7 @@ num_epochs = 1 +++ {"id": "mI1ci-HyMspJ"} -## Prepare data +## Loading and preprocessing the data Data loading and preprocessing with [Grain](https://github.com/google/grain). @@ -327,9 +380,7 @@ text_dl = load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen) +++ {"id": "BKVSD8KSM1um"} -## Train the model - -Define loss function and training step function. +## Defining the loss function and training step function ```{code-cell} :id: 8rRuTmABNV4b @@ -349,6 +400,8 @@ def train_step(model: MiniGPT, optimizer: nnx.Optimizer, metrics: nnx.MultiMetri +++ {"id": "5um2vkeUNckm"} +## Training the model + Start training. It takes ~50 minutes on Colab. Note that for data parallel, we are sharding the training data along the `batch` axis using `jax.device_put` with `NamedeSharding`. @@ -441,7 +494,8 @@ As you can see, the model goes from generating completely random words at the be +++ {"id": "soPqiR1JNmjf"} -## Saving +## Saving the checkpoint + Save the model checkpoint. ```{code-cell} @@ -462,7 +516,7 @@ checkpointer.save('/content/save', state) !ls /content/save/ ``` -## Profiling for Hyperparameter Tuning +## Profiling for hyperparameter tuning ```{code-cell} !pip install -Uq tensorboard-plugin-profile tensorflow tensorboard @@ -550,14 +604,3 @@ By looking at the Trace Viewer tool and looking under each TPU's ops, we can see ``` By changing hyperparameters and comparing profiles, we're able to gain significant insights into our bottlenecks and limitations. These are just two examples of hyperparameters to tune, but plenty more of them will have significant effects on training speed and resource utilization. - -+++ {"id": "jCApVd7671c1"} - -## Disconnect the Colab runtime - -```{code-cell} -:id: NsqYdbrDVKSq - -from google.colab import runtime -runtime.unassign() -```