From 1e5e2c8a599b97021d7f42a4bc8ec05313c7e6cf Mon Sep 17 00:00:00 2001 From: selamw1 Date: Mon, 2 Dec 2024 15:34:19 -0800 Subject: [PATCH 1/7] md_and_ipynb_files_paired --- docs/data_loaders_on_cpu_with_jax.ipynb | 3570 +++++++++++++++++++++++ docs/data_loaders_on_cpu_with_jax.md | 685 +++++ 2 files changed, 4255 insertions(+) create mode 100644 docs/data_loaders_on_cpu_with_jax.ipynb create mode 100644 docs/data_loaders_on_cpu_with_jax.md diff --git a/docs/data_loaders_on_cpu_with_jax.ipynb b/docs/data_loaders_on_cpu_with_jax.ipynb new file mode 100644 index 0000000..21bd599 --- /dev/null +++ b/docs/data_loaders_on_cpu_with_jax.ipynb @@ -0,0 +1,3570 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "PUFGZggH49zp" + }, + "source": [ + "# Introduction to Data Loaders on CPU with JAX" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3ia4PKEV5Dr8" + }, + "source": [ + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/data_loaders_on_cpu_with_jax.ipynb)\n", + "\n", + "This tutorial explores different data loading strategies for using **JAX** on a single [**CPU**](https://jax.readthedocs.io/en/latest/glossary.html#term-CPU). While JAX doesn't include a built-in data loader, it seamlessly integrates with popular data loading libraries, including:\n", + "\n", + "- [**PyTorch DataLoader**](https://github.com/pytorch/data)\n", + "- [**TensorFlow Datasets (TFDS)**](https://github.com/tensorflow/datasets)\n", + "- [**Grain**](https://github.com/google/grain)\n", + "- [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)\n", + "\n", + "You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "pEsb135zE-Jo" + }, + "source": [ + "## Setting JAX to Use CPU Only\n", + "\n", + "First, you'll restrict JAX to use only the CPU, even if a GPU is available. This ensures consistency and allows you to focus on CPU-based data loading." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": { + "id": "vqP6xyObC0_9" + }, + "outputs": [], + "source": [ + "import os\n", + "os.environ['JAX_PLATFORM_NAME'] = 'cpu'" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-rsMgVtO6asW" + }, + "source": [ + "Import JAX API" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "tDJNQ6V-Dg5g" + }, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from jax import random, grad, jit, vmap" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TsFdlkSZKp9S" + }, + "source": [ + "### CPU Setup Verification" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "N3sqvaF3KJw1", + "outputId": "449c83d9-d050-4b15-9a8d-f71e340501f2" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[CpuDevice(id=0)]" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.devices()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qyJ_WTghDnIc" + }, + "source": [ + "## Setting Hyperparameters and Initializing Parameters\n", + "\n", + "You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": { + "id": "qLNOSloFDka_" + }, + "outputs": [], + "source": [ + "# A helper function to randomly initialize weights and biases\n", + "# for a dense neural network layer\n", + "def random_layer_params(m, n, key, scale=1e-2):\n", + " w_key, b_key = random.split(key)\n", + " return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))\n", + "\n", + "# Function to initialize network parameters for all layers based on defined sizes\n", + "def init_network_params(sizes, key):\n", + " keys = random.split(key, len(sizes))\n", + " return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]\n", + "\n", + "layer_sizes = [784, 512, 512, 10] # Layers of the network\n", + "step_size = 0.01 # Learning rate for optimization\n", + "num_epochs = 8 # Number of training epochs\n", + "batch_size = 128 # Batch size for training\n", + "n_targets = 10 # Number of classes (digits 0-9)\n", + "num_pixels = 28 * 28 # Input size (MNIST images are 28x28 pixels)\n", + "data_dir = '/tmp/mnist_dataset' # Directory for storing the dataset\n", + "\n", + "# Initialize network parameters using the defined layer sizes and a random seed\n", + "params = init_network_params(layer_sizes, random.PRNGKey(0))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "6Ci_CqW7q6XM" + }, + "source": [ + "## Model Prediction with Auto-Batching\n", + "\n", + "In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.\n", + "\n", + "To efficiently process multiple images simultaneously, you'll use [`vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap), which allows you to vectorize the `predict` function and apply it across a batch of inputs. This technique, called auto-batching, improves computational efficiency by leveraging hardware acceleration." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": { + "id": "bKIYPSkvD1QV" + }, + "outputs": [], + "source": [ + "from jax.scipy.special import logsumexp\n", + "\n", + "def relu(x):\n", + " return jnp.maximum(0, x)\n", + "\n", + "def predict(params, image):\n", + " # per-example prediction\n", + " activations = image\n", + " for w, b in params[:-1]:\n", + " outputs = jnp.dot(w, activations) + b\n", + " activations = relu(outputs)\n", + "\n", + " final_w, final_b = params[-1]\n", + " logits = jnp.dot(final_w, activations) + final_b\n", + " return logits - logsumexp(logits)\n", + "\n", + "# Make a batched version of the `predict` function\n", + "batched_predict = vmap(predict, in_axes=(None, 0))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "niTSr34_sDZi" + }, + "source": [ + "## Utility and Loss Functions\n", + "\n", + "You'll now define utility functions for:\n", + "\n", + "- One-hot encoding: Converts class indices to binary vectors.\n", + "- Accuracy calculation: Measures the performance of the model on the dataset.\n", + "- Loss computation: Calculates the difference between predictions and targets.\n", + "\n", + "To optimize performance:\n", + "\n", + "- [`grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) is used to compute gradients of the loss function with respect to network parameters.\n", + "- [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) compiles the update function, enabling faster execution by leveraging JAX's [XLA](https://openxla.org/xla) compilation." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": { + "id": "sA0a06raEQfS" + }, + "outputs": [], + "source": [ + "import time\n", + "\n", + "def one_hot(x, k, dtype=jnp.float32):\n", + " \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n", + " return jnp.array(x[:, None] == jnp.arange(k), dtype)\n", + "\n", + "def accuracy(params, images, targets):\n", + " \"\"\"Calculate the accuracy of predictions.\"\"\"\n", + " target_class = jnp.argmax(targets, axis=1)\n", + " predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n", + " return jnp.mean(predicted_class == target_class)\n", + "\n", + "def loss(params, images, targets):\n", + " \"\"\"Calculate the loss between predictions and targets.\"\"\"\n", + " preds = batched_predict(params, images)\n", + " return -jnp.mean(preds * targets)\n", + "\n", + "@jit\n", + "def update(params, x, y):\n", + " \"\"\"Update the network parameters using gradient descent.\"\"\"\n", + " grads = grad(loss)(params, x, y)\n", + " return [(w - step_size * dw, b - step_size * db)\n", + " for (w, b), (dw, db) in zip(params, grads)]\n", + "\n", + "def reshape_and_one_hot(x, y):\n", + " \"\"\"Reshape and one-hot encode the inputs.\"\"\"\n", + " x = jnp.reshape(x, (len(x), num_pixels))\n", + " y = one_hot(y, n_targets)\n", + " return x, y\n", + "\n", + "def train_model(num_epochs, params, training_generator, data_loader_type='streamed'):\n", + " \"\"\"Train the model for a given number of epochs.\"\"\"\n", + " for epoch in range(num_epochs):\n", + " start_time = time.time()\n", + " for x, y in training_generator() if data_loader_type == 'streamed' else training_generator:\n", + " x, y = reshape_and_one_hot(x, y)\n", + " params = update(params, x, y)\n", + "\n", + " print(f\"Epoch {epoch + 1} in {time.time() - start_time:.2f} sec: \"\n", + " f\"Train Accuracy: {accuracy(params, train_images, train_labels):.4f}, \"\n", + " f\"Test Accuracy: {accuracy(params, test_images, test_labels):.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Hsionp5IYsQ9" + }, + "source": [ + "## Loading Data with PyTorch DataLoader\n", + "\n", + "This section shows how to load the MNIST dataset using PyTorch's DataLoader, convert the data to NumPy arrays, and apply transformations to flatten and cast images." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "jmsfrWrHxIhC", + "outputId": "33dfeada-a763-4d26-f778-a27966e34d55" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.5.1+cu121)\n", + "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.20.1+cu121)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.4.2)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.10.0)\n", + "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.26.4)\n", + "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (11.0.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (3.0.2)\n" + ] + } + ], + "source": [ + "!pip install torch torchvision" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "id": "kO5_WzwY59gE" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from jax.tree_util import tree_map\n", + "from torch.utils import data\n", + "from torchvision.datasets import MNIST" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "id": "6f6qU8PCc143" + }, + "outputs": [], + "source": [ + "def numpy_collate(batch):\n", + " \"\"\"Convert a batch of PyTorch data to NumPy arrays.\"\"\"\n", + " return tree_map(np.asarray, data.default_collate(batch))\n", + "\n", + "class NumpyLoader(data.DataLoader):\n", + " \"\"\"Custom DataLoader to return NumPy arrays from a PyTorch Dataset.\"\"\"\n", + " def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs):\n", + " super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=numpy_collate, **kwargs)\n", + "\n", + "class FlattenAndCast(object):\n", + " \"\"\"Transform class to flatten and cast images to float32.\"\"\"\n", + " def __call__(self, pic):\n", + " return np.ravel(np.array(pic, dtype=jnp.float32))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mfSnfJND6I8G" + }, + "source": [ + "### Load Dataset with Transformations\n", + "\n", + "Standardize the data by flattening the images, casting them to `float32`, and ensuring consistent data types." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Kxbl6bcx6crv", + "outputId": "372bbf4c-3ad5-4fd8-cc5d-27b50f5e4f38" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", + "Failed to download (trying next):\n", + "HTTP Error 403: Forbidden\n", + "\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /tmp/mnist_dataset/MNIST/raw/train-images-idx3-ubyte.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 9.91M/9.91M [00:00<00:00, 49.4MB/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting /tmp/mnist_dataset/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist_dataset/MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", + "Failed to download (trying next):\n", + "HTTP Error 403: Forbidden\n", + "\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist_dataset/MNIST/raw/train-labels-idx1-ubyte.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 28.9k/28.9k [00:00<00:00, 2.09MB/s]" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting /tmp/mnist_dataset/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist_dataset/MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Failed to download (trying next):\n", + "HTTP Error 403: Forbidden\n", + "\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist_dataset/MNIST/raw/t10k-images-idx3-ubyte.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 1.65M/1.65M [00:00<00:00, 13.3MB/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting /tmp/mnist_dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist_dataset/MNIST/raw\n", + "\n", + "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", + "Failed to download (trying next):\n", + "HTTP Error 403: Forbidden\n", + "\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz\n", + "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist_dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|██████████| 4.54k/4.54k [00:00<00:00, 8.81MB/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Extracting /tmp/mnist_dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist_dataset/MNIST/raw\n", + "\n" + ] + } + ], + "source": [ + "mnist_dataset = MNIST(data_dir, download=True, transform=FlattenAndCast())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kbdsqvPZGrsa" + }, + "source": [ + "### Full Training Dataset for Accuracy Checks\n", + "\n", + "Convert the entire training dataset to JAX arrays." + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "id": "c9ZCJq_rzPck" + }, + "outputs": [], + "source": [ + "train_images = jnp.array(mnist_dataset.data.numpy().reshape(len(mnist_dataset.data), -1), dtype=jnp.float32)\n", + "train_labels = one_hot(np.array(mnist_dataset.targets), n_targets)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WXUh0BwvG8Ko" + }, + "source": [ + "### Get Full Test Dataset\n", + "\n", + "Load and process the full test dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": { + "id": "brlLG4SqGphm" + }, + "outputs": [], + "source": [ + "mnist_dataset_test = MNIST(data_dir, download=True, train=False)\n", + "test_images = jnp.array(mnist_dataset_test.data.numpy().reshape(len(mnist_dataset_test.data), -1), dtype=jnp.float32)\n", + "test_labels = one_hot(np.array(mnist_dataset_test.targets), n_targets)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Oz-UVnCxG5E8", + "outputId": "abbaa26d-491a-4e63-e8c9-d3c571f53a28" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train: (60000, 784) (60000, 10)\n", + "Test: (10000, 784) (10000, 10)\n" + ] + } + ], + "source": [ + "print('Train:', train_images.shape, train_labels.shape)\n", + "print('Test:', test_images.shape, test_labels.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "m3zfxqnMiCbm" + }, + "source": [ + "### Training Data Generator\n", + "\n", + "Define a generator function using PyTorch's DataLoader for batch training. Setting `num_workers > 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload.\n", + "\n", + "Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.` This warning can be safely ignored since data loaders do not use JAX within the forked processes." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": { + "id": "B-fES82EiL6Z" + }, + "outputs": [], + "source": [ + "def pytorch_training_generator(mnist_dataset):\n", + " return NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Xzt2x9S1HC3T" + }, + "source": [ + "### Training Loop (PyTorch DataLoader)\n", + "\n", + "The training loop uses the PyTorch DataLoader to iterate through batches and update model parameters." + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "vtUjHsh-rJs8", + "outputId": "4766333e-4366-493b-995a-102778d1345a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 in 28.93 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9196\n", + "Epoch 2 in 8.33 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9384\n", + "Epoch 3 in 6.99 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9468\n", + "Epoch 4 in 7.01 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532\n", + "Epoch 5 in 8.17 sec: Train Accuracy: 0.9630, Test Accuracy: 0.9579\n", + "Epoch 6 in 8.27 sec: Train Accuracy: 0.9674, Test Accuracy: 0.9615\n", + "Epoch 7 in 8.32 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9650\n", + "Epoch 8 in 8.07 sec: Train Accuracy: 0.9737, Test Accuracy: 0.9671\n" + ] + } + ], + "source": [ + "train_model(num_epochs, params, pytorch_training_generator(mnist_dataset), data_loader_type='iterable')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Nm45ZTo6yrf5" + }, + "source": [ + "## Loading Data with TensorFlow Datasets (TFDS)\n", + "\n", + "This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlow." + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "id": "sGaQAk1DHMUx" + }, + "outputs": [], + "source": [ + "import tensorflow_datasets as tfds\n", + "import tensorflow as tf\n", + "\n", + "# Ensuring CPU-Only Execution, disable any GPU usage(if applicable) for TF\n", + "tf.config.set_visible_devices([], device_type='GPU')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3xdQY7H6wr3n" + }, + "source": [ + "### Fetch Full Dataset for Evaluation\n", + "\n", + "Load the dataset with `tfds.load`, convert it to NumPy arrays, and process it for evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 104, + "referenced_widgets": [ + "b8cdabf5c05848f38f03850cab08b56f", + "a8b76d5f93004c089676e5a2a9b3336c", + "119ac8428f9441e7a25eb0afef2fbb2a", + "76a9815e5c2b4764a13409cebaf66821", + "45ce8dd5c4b949afa957ec8ffb926060", + "05b7145fd62d4581b2123c7680f11cdd", + "b96267f014814ec5b96ad7e6165104b1", + "bce34bdbfbd64f1f8353a4e8515cee0b", + "93b8206f8c5841a692cdce985ae301d8", + "c95f592620c64da595cc787567b2c4db", + "8a97071f862c4ec3b4b4140d2e34eda2" + ] + }, + "id": "1hOamw_7C8Pb", + "outputId": "ca166490-22db-4732-b29f-866b7593e489" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /tmp/mnist_dataset/mnist/3.0.1...\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b8cdabf5c05848f38f03850cab08b56f", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Dl Completed...: 0%| | 0/5 [00:00=9.1.0 in /usr/local/lib/python3.10/dist-packages (from grain) (10.5.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from grain) (1.26.4)\n", + "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (4.12.2)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (2024.10.0)\n", + "Requirement already satisfied: importlib_resources in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (6.4.5)\n", + "Requirement already satisfied: zipp in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (3.21.0)\n", + "Downloading grain-0.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (418 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m419.0/419.0 kB\u001b[0m \u001b[31m7.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading jaxtyping-0.2.36-py3-none-any.whl (55 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m55.8/55.8 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: jaxtyping, grain\n", + "Successfully installed grain-0.2.2 jaxtyping-0.2.36\n" + ] + } + ], + "source": [ + "!pip install grain" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "66bH3ZDJ7Iat" + }, + "source": [ + "Import Required Libraries (import MNIST dataset from torchvision)" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": { + "id": "mS62eVL9Ifmz" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import grain.python as pygrain\n", + "from torchvision.datasets import MNIST" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0h6mwVrspPA-" + }, + "source": [ + "### Define Dataset Class\n", + "\n", + "Create a custom dataset class to load MNIST data for Grain." + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": { + "id": "bnrhac5Hh7y1" + }, + "outputs": [], + "source": [ + "class Dataset:\n", + " def __init__(self, data_dir, train=True):\n", + " self.data_dir = data_dir\n", + " self.train = train\n", + " self.load_data()\n", + "\n", + " def load_data(self):\n", + " self.dataset = MNIST(self.data_dir, download=True, train=self.train)\n", + "\n", + " def __len__(self):\n", + " return len(self.dataset)\n", + "\n", + " def __getitem__(self, index):\n", + " img, label = self.dataset[index]\n", + " return np.ravel(np.array(img, dtype=np.float32)), label" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "53mf8bWEsyTr" + }, + "source": [ + "### Initialize the Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "id": "pN3oF7-ostGE" + }, + "outputs": [], + "source": [ + "mnist_dataset = Dataset(data_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GqD-ycgBuwv9" + }, + "source": [ + "### Get the full train and test dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": { + "id": "f1VnTuX3u_kL" + }, + "outputs": [], + "source": [ + "# Convert training data to JAX arrays and encode labels as one-hot vectors\n", + "train_images = jnp.array([mnist_dataset[i][0] for i in range(len(mnist_dataset))], dtype=jnp.float32)\n", + "train_labels = one_hot(np.array([mnist_dataset[i][1] for i in range(len(mnist_dataset))]), n_targets)\n", + "\n", + "# Load test dataset and process it\n", + "mnist_dataset_test = MNIST(data_dir, download=True, train=False)\n", + "test_images = jnp.array([np.ravel(np.array(mnist_dataset_test[i][0], dtype=np.float32)) for i in range(len(mnist_dataset_test))], dtype=jnp.float32)\n", + "test_labels = one_hot(np.array([mnist_dataset_test[i][1] for i in range(len(mnist_dataset_test))]), n_targets)" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "a2NHlp9klrQL", + "outputId": "14be58c0-851e-4a44-dfcc-d02f0718dab5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train: (60000, 784) (60000, 10)\n", + "Test: (10000, 784) (10000, 10)\n" + ] + } + ], + "source": [ + "print(\"Train:\", train_images.shape, train_labels.shape)\n", + "print(\"Test:\", test_images.shape, test_labels.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "fETnWRo2crhf" + }, + "source": [ + "### Initialize PyGrain DataLoader\n", + "\n", + "Set up a PyGrain DataLoader for sequential batch sampling." + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": { + "id": "9RuFTcsCs2Ac" + }, + "outputs": [], + "source": [ + "sampler = pygrain.SequentialSampler(\n", + " num_records=len(mnist_dataset),\n", + " shard_options=pygrain.NoSharding()) # Single-device, no sharding\n", + "\n", + "def pygrain_training_generator():\n", + " \"\"\"Grain DataLoader generator for training.\"\"\"\n", + " return pygrain.DataLoader(\n", + " data_source=mnist_dataset,\n", + " sampler=sampler,\n", + " operations=[pygrain.Batch(batch_size)],\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GvpJPHAbeuHW" + }, + "source": [ + "### Training Loop (Grain)\n", + "\n", + "Run the training loop using the Grain DataLoader." + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cjxJRtiTadEI", + "outputId": "3f624366-b683-4d20-9d0a-777d345b0e21" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 in 15.39 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9196\n", + "Epoch 2 in 15.27 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9384\n", + "Epoch 3 in 12.61 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9468\n", + "Epoch 4 in 12.62 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532\n", + "Epoch 5 in 12.39 sec: Train Accuracy: 0.9630, Test Accuracy: 0.9579\n", + "Epoch 6 in 12.19 sec: Train Accuracy: 0.9674, Test Accuracy: 0.9615\n", + "Epoch 7 in 12.56 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9650\n", + "Epoch 8 in 13.04 sec: Train Accuracy: 0.9737, Test Accuracy: 0.9671\n" + ] + } + ], + "source": [ + "train_model(num_epochs, params, pygrain_training_generator)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oixvOI816qUn" + }, + "source": [ + "## Loading Data with Hugging Face\n", + "\n", + "This section demonstrates loading MNIST data using the Hugging Face `datasets` library. You'll format the dataset for JAX compatibility, prepare flattened images and one-hot-encoded labels, and define a training generator." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "o51P6lr86wz-" + }, + "source": [ + "Install the Hugging Face `datasets` library." + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "19ipxPhI6oSN", + "outputId": "684e445f-d23e-4924-9e76-2c2c9359f0be" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Collecting datasets\n", + " Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.1)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n", + "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n", + "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n", + " Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n", + "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n", + "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.6)\n", + "Collecting xxhash (from datasets)\n", + " Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n", + "Collecting multiprocess<0.70.17 (from datasets)\n", + " Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n", + "Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)\n", + " Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.11.2)\n", + "Requirement already satisfied: huggingface-hub>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.26.2)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.2)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.3)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.5.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (0.2.0)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.17.2)\n", + "Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.23.0->datasets) (4.12.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.4.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.2.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.8.30)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", + "Downloading datasets-3.1.0-py3-none-any.whl (480 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m480.6/480.6 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m9.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (179 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.3/179.3 kB\u001b[0m \u001b[31m13.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hDownloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m15.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hInstalling collected packages: xxhash, fsspec, dill, multiprocess, datasets\n", + " Attempting uninstall: fsspec\n", + " Found existing installation: fsspec 2024.10.0\n", + " Uninstalling fsspec-2024.10.0:\n", + " Successfully uninstalled fsspec-2024.10.0\n", + "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", + "gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which is incompatible.\u001b[0m\u001b[31m\n", + "\u001b[0mSuccessfully installed datasets-3.1.0 dill-0.3.8 fsspec-2024.9.0 multiprocess-0.70.16 xxhash-3.5.0\n" + ] + } + ], + "source": [ + "!pip install datasets" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "be0h_dZv0593" + }, + "source": [ + "Import Library" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": { + "id": "8v1N59p76zn0" + }, + "outputs": [], + "source": [ + "from datasets import load_dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8Gaj11tO7C86" + }, + "source": [ + "### Load and Format MNIST Dataset\n", + "\n", + "Load the MNIST dataset from Hugging Face and format it as `numpy` arrays for quick access or `jax` to get JAX arrays." + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 301, + "referenced_widgets": [ + "32f6132a31aa4c508d3c3c5ef70348bb", + "d7c2ffa6b143463c91cbf8befca6ca01", + "fd964ecd3926419d92927c67f955d5d0", + "60feca3fde7c4447ad8393b0542eb999", + "3354a0baeca94d18bc6b2a8b8b465b58", + "a0d0d052772b46deac7657ad052991a4", + "fb34783b9cba462e9b690e0979c4b07a", + "8d8170c1ed99490589969cd753c40748", + "f1ecb6db00a54e088f1e09164222d637", + "3cf5dd8d29aa4619b39dc2542df7e42e", + "2e5d42ca710441b389895f2d3b611d0a", + "5d8202da24244dc896e9a8cba6a4ed4f", + "a6d64c953631412b8bd8f0ba53ae4d32", + "69240c5cbfbb4e91961f5b49812a26f0", + "865f38532b784a7c971f5d33b87b443e", + "ceb1c004191947cdaa10af9b9c03c80d", + "64c6041037914779b5e8e9cf5a80ad04", + "562fa6a0e7b846a180ac4b423c5511c5", + "b3b922288f9c4df2a4088279ff6d1531", + "75a1a8ffda554318890cf74c345ed9a9", + "3bae06cacf394a5998c2326199da94f5", + "ff6428a3daa5496c81d5e664aba01f97", + "1ba3f86870724f55b94a35cb6b4173af", + "b3e163fd8b8a4f289d5a25611cb66d23", + "abd2daba215e4f7c9ddabde04d6eb382", + "e22ee019049144d5aba573cdf4dbe4fc", + "6ac765dac67841a69218140785f024c6", + "7b057411a54e434fb74804b90daa8d44", + "563f71b3c67d47c3ab1100f5dc1b98f3", + "d81a657361ab4bba8bcc0cf309d2ff64", + "20316312ab88471ba90cbb954be3e964", + "698fda742f834473a23fb7e5e4cf239c", + "289b52c5a38146b8b467a5f4678f6271", + "d07c2f37cf914894b1551a8104e6cb70", + "5b55c73d551d483baaa6a1411c2597b1", + "2308f77723f54ac898588f48d1853b65", + "54d2589714d04b2e928b816258cb0df4", + "f84b795348c04c7a950165301a643671", + "bc853a4a8d3c4dbda23d183f0a3b4f27", + "1012ddc0343842d8b913a7d85df8ab8f", + "771a73a8f5084a57afc5654d72e022f0", + "311a43449f074841b6df4130b0871ac9", + "cd4d29cb01134469b52d6936c35eb943", + "013cf89ee6174d29bb3f4fdff7b36049", + "9237d877d84e4b3ab69698ecf56915bb", + "337ef4d37e6b4ff6bf6e8bd4ca93383f", + "b4096d3837b84ccdb8f1186435c87281", + "7259d3b7e11b4736b4d2aa8e9c55e994", + "1ad1f8e99a864fc4a2bc532d9a4ff110", + "b2b50451eabd40978ef46db5e7dd08c4", + "2dad5c5541e243128e23c3dd3e420ac2", + "a3de458b61e5493081d6bb9cf7e923db", + "37760f8a7b164e6f9c1a23d621e9fe6b", + "745a2aedcfab491fb9cffba19958b0c5", + "2f6c670640d048d2af453638cfde3a1e" + ] + }, + "id": "a22kTvgk6_fJ", + "outputId": "35fc38b9-a6ab-4b02-ffa4-ab27fac69df4" + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", + "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", + "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", + "You will be able to reuse this secret in all of your notebooks.\n", + "Please note that authentication is recommended but still optional to access public models or datasets.\n", + " warnings.warn(\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "32f6132a31aa4c508d3c3c5ef70348bb", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "README.md: 0%| | 0.00/6.97k [00:00 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload. + +Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.` This warning can be safely ignored since data loaders do not use JAX within the forked processes. + +```{code-cell} +:id: B-fES82EiL6Z + +def pytorch_training_generator(mnist_dataset): + return NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0) +``` + ++++ {"id": "Xzt2x9S1HC3T"} + +### Training Loop (PyTorch DataLoader) + +The training loop uses the PyTorch DataLoader to iterate through batches and update model parameters. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: vtUjHsh-rJs8 +outputId: 4766333e-4366-493b-995a-102778d1345a +--- +train_model(num_epochs, params, pytorch_training_generator(mnist_dataset), data_loader_type='iterable') +``` + ++++ {"id": "Nm45ZTo6yrf5"} + +## Loading Data with TensorFlow Datasets (TFDS) + +This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlow. + +```{code-cell} +:id: sGaQAk1DHMUx + +import tensorflow_datasets as tfds +import tensorflow as tf + +# Ensuring CPU-Only Execution, disable any GPU usage(if applicable) for TF +tf.config.set_visible_devices([], device_type='GPU') +``` + ++++ {"id": "3xdQY7H6wr3n"} + +### Fetch Full Dataset for Evaluation + +Load the dataset with `tfds.load`, convert it to NumPy arrays, and process it for evaluation. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ + height: 104 + referenced_widgets: [b8cdabf5c05848f38f03850cab08b56f, a8b76d5f93004c089676e5a2a9b3336c, + 119ac8428f9441e7a25eb0afef2fbb2a, 76a9815e5c2b4764a13409cebaf66821, 45ce8dd5c4b949afa957ec8ffb926060, + 05b7145fd62d4581b2123c7680f11cdd, b96267f014814ec5b96ad7e6165104b1, bce34bdbfbd64f1f8353a4e8515cee0b, + 93b8206f8c5841a692cdce985ae301d8, c95f592620c64da595cc787567b2c4db, 8a97071f862c4ec3b4b4140d2e34eda2] +id: 1hOamw_7C8Pb +outputId: ca166490-22db-4732-b29f-866b7593e489 +--- +# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1) +mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True) +mnist_data = tfds.as_numpy(mnist_data) +train_data, test_data = mnist_data['train'], mnist_data['test'] + +# Full train set +train_images, train_labels = train_data['image'], train_data['label'] +train_images = jnp.reshape(train_images, (len(train_images), num_pixels)) +train_labels = one_hot(train_labels, n_targets) + +# Full test set +test_images, test_labels = test_data['image'], test_data['label'] +test_images = jnp.reshape(test_images, (len(test_images), num_pixels)) +test_labels = one_hot(test_labels, n_targets) +``` + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: Td3PiLdmEf7z +outputId: 96403b0f-6079-43ce-df16-d4583f09906b +--- +print('Train:', train_images.shape, train_labels.shape) +print('Test:', test_images.shape, test_labels.shape) +``` + ++++ {"id": "UWRSaalfdyDX"} + +### Define the Training Generator + +Create a generator function to yield batches of data for training. + +```{code-cell} +:id: vX59u8CqEf4J + +def training_generator(): + # as_supervised=True gives us the (image, label) as a tuple instead of a dict + ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir) + # You can build up an arbitrary tf.data input pipeline + ds = ds.batch(batch_size).prefetch(1) + # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays + return tfds.as_numpy(ds) +``` + ++++ {"id": "EAWeUdnuFNBY"} + +### Training Loop (TFDS) + +Use the training generator in a custom training loop. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: h2sO13XDGvq1 +outputId: a150246e-ceb5-46ac-db71-2a8177a9d04d +--- +train_model(num_epochs, params, training_generator) +``` + ++++ {"id": "-ryVkrAITS9Z"} + +## Loading Data with Grain + +This section demonstrates how to load MNIST data using Grain, a data-loading library. You'll define a custom dataset class for Grain and set up a Grain DataLoader for efficient training. + ++++ {"id": "waYhUMUGmhH-"} + +Install Grain + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: L78o7eeyGvn5 +outputId: 76d16565-0d9e-4f5f-c6b1-4cf4a683d0e7 +--- +!pip install grain +``` + ++++ {"id": "66bH3ZDJ7Iat"} + +Import Required Libraries (import MNIST dataset from torchvision) + +```{code-cell} +:id: mS62eVL9Ifmz + +import numpy as np +import grain.python as pygrain +from torchvision.datasets import MNIST +``` + ++++ {"id": "0h6mwVrspPA-"} + +### Define Dataset Class + +Create a custom dataset class to load MNIST data for Grain. + +```{code-cell} +:id: bnrhac5Hh7y1 + +class Dataset: + def __init__(self, data_dir, train=True): + self.data_dir = data_dir + self.train = train + self.load_data() + + def load_data(self): + self.dataset = MNIST(self.data_dir, download=True, train=self.train) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + img, label = self.dataset[index] + return np.ravel(np.array(img, dtype=np.float32)), label +``` + ++++ {"id": "53mf8bWEsyTr"} + +### Initialize the Dataset + +```{code-cell} +:id: pN3oF7-ostGE + +mnist_dataset = Dataset(data_dir) +``` + ++++ {"id": "GqD-ycgBuwv9"} + +### Get the full train and test dataset + +```{code-cell} +:id: f1VnTuX3u_kL + +# Convert training data to JAX arrays and encode labels as one-hot vectors +train_images = jnp.array([mnist_dataset[i][0] for i in range(len(mnist_dataset))], dtype=jnp.float32) +train_labels = one_hot(np.array([mnist_dataset[i][1] for i in range(len(mnist_dataset))]), n_targets) + +# Load test dataset and process it +mnist_dataset_test = MNIST(data_dir, download=True, train=False) +test_images = jnp.array([np.ravel(np.array(mnist_dataset_test[i][0], dtype=np.float32)) for i in range(len(mnist_dataset_test))], dtype=jnp.float32) +test_labels = one_hot(np.array([mnist_dataset_test[i][1] for i in range(len(mnist_dataset_test))]), n_targets) +``` + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: a2NHlp9klrQL +outputId: 14be58c0-851e-4a44-dfcc-d02f0718dab5 +--- +print("Train:", train_images.shape, train_labels.shape) +print("Test:", test_images.shape, test_labels.shape) +``` + ++++ {"id": "fETnWRo2crhf"} + +### Initialize PyGrain DataLoader + +Set up a PyGrain DataLoader for sequential batch sampling. + +```{code-cell} +:id: 9RuFTcsCs2Ac + +sampler = pygrain.SequentialSampler( + num_records=len(mnist_dataset), + shard_options=pygrain.NoSharding()) # Single-device, no sharding + +def pygrain_training_generator(): + """Grain DataLoader generator for training.""" + return pygrain.DataLoader( + data_source=mnist_dataset, + sampler=sampler, + operations=[pygrain.Batch(batch_size)], + ) +``` + ++++ {"id": "GvpJPHAbeuHW"} + +### Training Loop (Grain) + +Run the training loop using the Grain DataLoader. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: cjxJRtiTadEI +outputId: 3f624366-b683-4d20-9d0a-777d345b0e21 +--- +train_model(num_epochs, params, pygrain_training_generator) +``` + ++++ {"id": "oixvOI816qUn"} + +## Loading Data with Hugging Face + +This section demonstrates loading MNIST data using the Hugging Face `datasets` library. You'll format the dataset for JAX compatibility, prepare flattened images and one-hot-encoded labels, and define a training generator. + ++++ {"id": "o51P6lr86wz-"} + +Install the Hugging Face `datasets` library. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: 19ipxPhI6oSN +outputId: 684e445f-d23e-4924-9e76-2c2c9359f0be +--- +!pip install datasets +``` + ++++ {"id": "be0h_dZv0593"} + +Import Library + +```{code-cell} +:id: 8v1N59p76zn0 + +from datasets import load_dataset +``` + ++++ {"id": "8Gaj11tO7C86"} + +### Load and Format MNIST Dataset + +Load the MNIST dataset from Hugging Face and format it as `numpy` arrays for quick access or `jax` to get JAX arrays. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ + height: 301 + referenced_widgets: [32f6132a31aa4c508d3c3c5ef70348bb, d7c2ffa6b143463c91cbf8befca6ca01, + fd964ecd3926419d92927c67f955d5d0, 60feca3fde7c4447ad8393b0542eb999, 3354a0baeca94d18bc6b2a8b8b465b58, + a0d0d052772b46deac7657ad052991a4, fb34783b9cba462e9b690e0979c4b07a, 8d8170c1ed99490589969cd753c40748, + f1ecb6db00a54e088f1e09164222d637, 3cf5dd8d29aa4619b39dc2542df7e42e, 2e5d42ca710441b389895f2d3b611d0a, + 5d8202da24244dc896e9a8cba6a4ed4f, a6d64c953631412b8bd8f0ba53ae4d32, 69240c5cbfbb4e91961f5b49812a26f0, + 865f38532b784a7c971f5d33b87b443e, ceb1c004191947cdaa10af9b9c03c80d, 64c6041037914779b5e8e9cf5a80ad04, + 562fa6a0e7b846a180ac4b423c5511c5, b3b922288f9c4df2a4088279ff6d1531, 75a1a8ffda554318890cf74c345ed9a9, + 3bae06cacf394a5998c2326199da94f5, ff6428a3daa5496c81d5e664aba01f97, 1ba3f86870724f55b94a35cb6b4173af, + b3e163fd8b8a4f289d5a25611cb66d23, abd2daba215e4f7c9ddabde04d6eb382, e22ee019049144d5aba573cdf4dbe4fc, + 6ac765dac67841a69218140785f024c6, 7b057411a54e434fb74804b90daa8d44, 563f71b3c67d47c3ab1100f5dc1b98f3, + d81a657361ab4bba8bcc0cf309d2ff64, 20316312ab88471ba90cbb954be3e964, 698fda742f834473a23fb7e5e4cf239c, + 289b52c5a38146b8b467a5f4678f6271, d07c2f37cf914894b1551a8104e6cb70, 5b55c73d551d483baaa6a1411c2597b1, + 2308f77723f54ac898588f48d1853b65, 54d2589714d04b2e928b816258cb0df4, f84b795348c04c7a950165301a643671, + bc853a4a8d3c4dbda23d183f0a3b4f27, 1012ddc0343842d8b913a7d85df8ab8f, 771a73a8f5084a57afc5654d72e022f0, + 311a43449f074841b6df4130b0871ac9, cd4d29cb01134469b52d6936c35eb943, 013cf89ee6174d29bb3f4fdff7b36049, + 9237d877d84e4b3ab69698ecf56915bb, 337ef4d37e6b4ff6bf6e8bd4ca93383f, b4096d3837b84ccdb8f1186435c87281, + 7259d3b7e11b4736b4d2aa8e9c55e994, 1ad1f8e99a864fc4a2bc532d9a4ff110, b2b50451eabd40978ef46db5e7dd08c4, + 2dad5c5541e243128e23c3dd3e420ac2, a3de458b61e5493081d6bb9cf7e923db, 37760f8a7b164e6f9c1a23d621e9fe6b, + 745a2aedcfab491fb9cffba19958b0c5, 2f6c670640d048d2af453638cfde3a1e] +id: a22kTvgk6_fJ +outputId: 35fc38b9-a6ab-4b02-ffa4-ab27fac69df4 +--- +mnist_dataset = load_dataset("mnist").with_format("numpy") +``` + ++++ {"id": "IFjTyGxY19b0"} + +### Extract images and labels + +Get image shape and flatten for model input + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: NHrKatD_7HbH +outputId: deec1739-2fc0-4e71-8567-f2e0c9db198b +--- +train_images = mnist_dataset["train"]["image"] +train_labels = mnist_dataset["train"]["label"] +test_images = mnist_dataset["test"]["image"] +test_labels = mnist_dataset["test"]["label"] + +# Flatten images and one-hot encode labels +image_shape = train_images.shape[1:] +num_features = image_shape[0] * image_shape[1] + +train_images = train_images.reshape(-1, num_features) +test_images = test_images.reshape(-1, num_features) + +train_labels = one_hot(train_labels, n_targets) +test_labels = one_hot(test_labels, n_targets) + +print('Train:', train_images.shape, train_labels.shape) +print('Test:', test_images.shape, test_labels.shape) +``` + ++++ {"id": "kk_4zJlz7T1E"} + +### Define Training Generator + +Set up a generator to yield batches of images and labels for training. + +```{code-cell} +:id: -zLJhogj7RL- + +def hf_training_generator(): + """Yield batches for training.""" + for batch in mnist_dataset["train"].iter(batch_size): + x, y = batch["image"], batch["label"] + yield x, y +``` + ++++ {"id": "HIsGfkLI7dvZ"} + +### Training Loop (Hugging Face Datasets) + +Run the training loop using the Hugging Face training generator. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: RhloYGsw6nPf +outputId: d49c1cd2-a546-46a6-84fb-d9507c38f4ca +--- +train_model(num_epochs, params, hf_training_generator) +``` + ++++ {"id": "qXylIOwidWI3"} + +## Summary + +This notebook has guided you through efficient methods for loading data on a CPU when using JAX. You’ve learned how to leverage popular libraries such as PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to streamline the data loading process for your machine learning tasks. Each of these methods offers unique advantages and considerations, allowing you to choose the best approach based on the specific needs of your project. From ccac6b5d030faf1881f7bf64fa22743598e4b4f2 Mon Sep 17 00:00:00 2001 From: selamw1 Date: Tue, 3 Dec 2024 14:53:11 -0800 Subject: [PATCH 2/7] =?UTF-8?q?=E2=80=9Creferece=5Ftutorial=5Flinks=5Fadde?= =?UTF-8?q?d=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/data_loaders_on_cpu_with_jax.ipynb | 79 +++++++++++++------------ docs/data_loaders_on_cpu_with_jax.md | 10 +++- 2 files changed, 50 insertions(+), 39 deletions(-) diff --git a/docs/data_loaders_on_cpu_with_jax.ipynb b/docs/data_loaders_on_cpu_with_jax.ipynb index 21bd599..0ba897e 100644 --- a/docs/data_loaders_on_cpu_with_jax.ipynb +++ b/docs/data_loaders_on_cpu_with_jax.ipynb @@ -24,7 +24,13 @@ "- [**Grain**](https://github.com/google/grain)\n", "- [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)\n", "\n", - "You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset." + "In this tutorial, you'll learn how to efficiently load data using these libraries for a simple image classification task based on the MNIST dataset.\n", + "\n", + "Compared to GPU or multi-device setups, CPU-based data loading is straightforward as it avoids challenges like GPU memory management and data synchronization across devices. This makes it ideal for smaller-scale tasks or scenarios where data resides exclusively on the CPU.\n", + "\n", + "If you're looking for GPU-specific data loading advice, see [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html).\n", + "\n", + "If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html)." ] }, { @@ -40,7 +46,7 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "id": "vqP6xyObC0_9" }, @@ -61,7 +67,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "id": "tDJNQ6V-Dg5g" }, @@ -83,7 +89,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -120,7 +126,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { "id": "qLNOSloFDka_" }, @@ -164,7 +170,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { "id": "bKIYPSkvD1QV" }, @@ -212,7 +218,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "id": "sA0a06raEQfS" }, @@ -274,7 +280,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -308,7 +314,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "id": "kO5_WzwY59gE" }, @@ -322,7 +328,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": { "id": "6f6qU8PCc143" }, @@ -356,7 +362,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -486,7 +492,7 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": { "id": "c9ZCJq_rzPck" }, @@ -509,7 +515,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": { "id": "brlLG4SqGphm" }, @@ -522,7 +528,7 @@ }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -560,7 +566,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": { "id": "B-fES82EiL6Z" }, @@ -583,7 +589,7 @@ }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -624,7 +630,7 @@ }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": { "id": "sGaQAk1DHMUx" }, @@ -650,7 +656,7 @@ }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -721,7 +727,7 @@ }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -757,7 +763,7 @@ }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": { "id": "vX59u8CqEf4J" }, @@ -785,7 +791,7 @@ }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -835,7 +841,7 @@ }, { "cell_type": "code", - "execution_count": 21, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -887,7 +893,7 @@ }, { "cell_type": "code", - "execution_count": 22, + "execution_count": null, "metadata": { "id": "mS62eVL9Ifmz" }, @@ -911,7 +917,7 @@ }, { "cell_type": "code", - "execution_count": 23, + "execution_count": null, "metadata": { "id": "bnrhac5Hh7y1" }, @@ -945,7 +951,7 @@ }, { "cell_type": "code", - "execution_count": 24, + "execution_count": null, "metadata": { "id": "pN3oF7-ostGE" }, @@ -965,7 +971,7 @@ }, { "cell_type": "code", - "execution_count": 25, + "execution_count": null, "metadata": { "id": "f1VnTuX3u_kL" }, @@ -983,7 +989,7 @@ }, { "cell_type": "code", - "execution_count": 26, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -1019,7 +1025,7 @@ }, { "cell_type": "code", - "execution_count": 27, + "execution_count": null, "metadata": { "id": "9RuFTcsCs2Ac" }, @@ -1051,7 +1057,7 @@ }, { "cell_type": "code", - "execution_count": 28, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -1101,7 +1107,7 @@ }, { "cell_type": "code", - "execution_count": 29, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -1187,7 +1193,7 @@ }, { "cell_type": "code", - "execution_count": 30, + "execution_count": null, "metadata": { "id": "8v1N59p76zn0" }, @@ -1209,7 +1215,7 @@ }, { "cell_type": "code", - "execution_count": 31, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", @@ -1376,7 +1382,7 @@ }, { "cell_type": "code", - "execution_count": 32, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -1427,7 +1433,7 @@ }, { "cell_type": "code", - "execution_count": 33, + "execution_count": null, "metadata": { "id": "-zLJhogj7RL-" }, @@ -1453,7 +1459,7 @@ }, { "cell_type": "code", - "execution_count": 34, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -1489,13 +1495,12 @@ "source": [ "## Summary\n", "\n", - "This notebook has guided you through efficient methods for loading data on a CPU when using JAX. You’ve learned how to leverage popular libraries such as PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to streamline the data loading process for your machine learning tasks. Each of these methods offers unique advantages and considerations, allowing you to choose the best approach based on the specific needs of your project." + "This notebook has introduced efficient strategies for data loading on a CPU with JAX, demonstrating how to integrate popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. Each library offers distinct advantages, enabling you to streamline the data loading process for machine learning tasks. By understanding the strengths of these methods, you can select the approach that best suits your project's specific requirements." ] } ], "metadata": { "colab": { - "name": "data_loaders_on_cpu_with_jax.ipynb", "provenance": [] }, "jupytext": { diff --git a/docs/data_loaders_on_cpu_with_jax.md b/docs/data_loaders_on_cpu_with_jax.md index f565d1d..d26c687 100644 --- a/docs/data_loaders_on_cpu_with_jax.md +++ b/docs/data_loaders_on_cpu_with_jax.md @@ -26,7 +26,13 @@ This tutorial explores different data loading strategies for using **JAX** on a - [**Grain**](https://github.com/google/grain) - [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading) -You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset. +In this tutorial, you'll learn how to efficiently load data using these libraries for a simple image classification task based on the MNIST dataset. + +Compared to GPU or multi-device setups, CPU-based data loading is straightforward as it avoids challenges like GPU memory management and data synchronization across devices. This makes it ideal for smaller-scale tasks or scenarios where data resides exclusively on the CPU. + +If you're looking for GPU-specific data loading advice, see [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html). + +If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html). +++ {"id": "pEsb135zE-Jo"} @@ -682,4 +688,4 @@ train_model(num_epochs, params, hf_training_generator) ## Summary -This notebook has guided you through efficient methods for loading data on a CPU when using JAX. You’ve learned how to leverage popular libraries such as PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to streamline the data loading process for your machine learning tasks. Each of these methods offers unique advantages and considerations, allowing you to choose the best approach based on the specific needs of your project. +This notebook has introduced efficient strategies for data loading on a CPU with JAX, demonstrating how to integrate popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. Each library offers distinct advantages, enabling you to streamline the data loading process for machine learning tasks. By understanding the strengths of these methods, you can select the approach that best suits your project's specific requirements. From c6fcdc937bcd154b1a816339d4f09a3fe839a641 Mon Sep 17 00:00:00 2001 From: selamw1 Date: Tue, 26 Nov 2024 14:14:34 -0800 Subject: [PATCH 3/7] file_conflict_resolved --- docs/source/conf.py | 2 ++ docs/source/tutorials.md | 2 ++ 2 files changed, 4 insertions(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index c97b04a..aad2b2f 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -66,6 +66,7 @@ 'JAX_time_series_classification.md', 'JAX_transformer_text_classification.md', 'data_loaders_on_cpu_with_jax.md', + 'data_loaders_on_gpu_with_jax.md', ] suppress_warnings = [ @@ -102,4 +103,5 @@ 'JAX_time_series_classification.ipynb', 'JAX_transformer_text_classification.ipynb', 'data_loaders_on_cpu_with_jax.ipynb', + 'data_loaders_on_gpu_with_jax.ipynb', ] diff --git a/docs/source/tutorials.md b/docs/source/tutorials.md index f26a7dc..dc82a13 100644 --- a/docs/source/tutorials.md +++ b/docs/source/tutorials.md @@ -19,11 +19,13 @@ JAX_basic_text_classification JAX_examples_image_segmentation JAX_Vision_transformer JAX_machine_translation +<<<<<<< HEAD JAX_visualizing_models_metrics JAX_image_captioning JAX_time_series_classification JAX_transformer_text_classification data_loaders_on_cpu_with_jax +data_loaders_on_gpu_with_jax ``` Once you've gone through this content, you can refer to package-specific From 3d8bf534a54a8ab5dd926da95cb158ebc9fa01f6 Mon Sep 17 00:00:00 2001 From: selamw1 Date: Wed, 27 Nov 2024 10:17:46 -0800 Subject: [PATCH 4/7] missed_notebook_files_added --- docs/data_loaders_on_gpu_with_jax.ipynb | 1172 +++++++++++++++++++++++ docs/data_loaders_on_gpu_with_jax.md | 645 +++++++++++++ 2 files changed, 1817 insertions(+) create mode 100644 docs/data_loaders_on_gpu_with_jax.ipynb create mode 100644 docs/data_loaders_on_gpu_with_jax.md diff --git a/docs/data_loaders_on_gpu_with_jax.ipynb b/docs/data_loaders_on_gpu_with_jax.ipynb new file mode 100644 index 0000000..f726297 --- /dev/null +++ b/docs/data_loaders_on_gpu_with_jax.ipynb @@ -0,0 +1,1172 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "PUFGZggH49zp" + }, + "source": [ + "# Introduction to Data Loaders on GPU with JAX" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3ia4PKEV5Dr8" + }, + "source": [ + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/data_loaders_on_gpu_with_jax.ipynb)\n", + "\n", + "This tutorial explores different data loading strategies for using **JAX** on a single [**GPU**](https://jax.readthedocs.io/en/latest/glossary.html#term-GPU). While JAX doesn't include a built-in data loader, it seamlessly integrates with popular data loading libraries, including:\n", + "* [**PyTorch DataLoader**](https://github.com/pytorch/data)\n", + "* [**TensorFlow Datasets (TFDS)**](https://github.com/tensorflow/datasets)\n", + "* [**Grain**](https://github.com/google/grain)\n", + "* [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)\n", + "\n", + "You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset.\n", + "\n", + "Compared to the [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html), working with GPUs introduces opportunities for further optimization, such as transferring data to the GPU using `device_put`, leveraging larger batch sizes for faster processing, and addressing considerations like memory management." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-rsMgVtO6asW" + }, + "source": [ + "### Import JAX API" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": { + "id": "tDJNQ6V-Dg5g" + }, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from jax import grad, jit, vmap, random, device_put" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TsFdlkSZKp9S" + }, + "source": [ + "### Checking GPU Availability for JAX" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "N3sqvaF3KJw1", + "outputId": "ab40f542-b8c0-422c-ca68-4ce292817889" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[CudaDevice(id=0)]" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.devices()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qyJ_WTghDnIc" + }, + "source": [ + "### Setting Hyperparameters and Initializing Parameters\n", + "\n", + "You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network." + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": { + "id": "qLNOSloFDka_" + }, + "outputs": [], + "source": [ + "# A helper function to randomly initialize weights and biases\n", + "# for a dense neural network layer\n", + "def random_layer_params(m, n, key, scale=1e-2):\n", + " w_key, b_key = random.split(key)\n", + " return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))\n", + "\n", + "# Function to initialize network parameters for all layers based on defined sizes\n", + "def init_network_params(sizes, key):\n", + " keys = random.split(key, len(sizes))\n", + " return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]\n", + "\n", + "layer_sizes = [784, 512, 512, 10] # Layers of the network\n", + "step_size = 0.01 # Learning rate\n", + "num_epochs = 8 # Number of training epochs\n", + "batch_size = 128 # Batch size for training\n", + "n_targets = 10 # Number of classes (digits 0-9)\n", + "num_pixels = 28 * 28 # Each MNIST image is 28x28 pixels\n", + "data_dir = '/tmp/mnist_dataset' # Directory for storing the dataset\n", + "\n", + "# Initialize network parameters using the defined layer sizes and a random seed\n", + "params = init_network_params(layer_sizes, random.PRNGKey(0))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rHLdqeI7D2WZ" + }, + "source": [ + "### Model Prediction with Auto-Batching\n", + "\n", + "In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.\n", + "\n", + "To efficiently process multiple images simultaneously, you'll use [`vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap), which allows you to vectorize the `predict` function and apply it across a batch of inputs. This technique, called auto-batching, improves computational efficiency by leveraging hardware acceleration." + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": { + "id": "bKIYPSkvD1QV" + }, + "outputs": [], + "source": [ + "from jax.scipy.special import logsumexp\n", + "\n", + "def relu(x):\n", + " return jnp.maximum(0, x)\n", + "\n", + "def predict(params, image):\n", + " # per-example predictions\n", + " activations = image\n", + " for w, b in params[:-1]:\n", + " outputs = jnp.dot(w, activations) + b\n", + " activations = relu(outputs)\n", + "\n", + " final_w, final_b = params[-1]\n", + " logits = jnp.dot(final_w, activations) + final_b\n", + " return logits - logsumexp(logits)\n", + "\n", + "# Make a batched version of the `predict` function\n", + "batched_predict = vmap(predict, in_axes=(None, 0))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rLqfeORsERek" + }, + "source": [ + "### Utility and Loss Functions\n", + "\n", + "You'll now define utility functions for:\n", + "- One-hot encoding: Converts class indices to binary vectors.\n", + "- Accuracy calculation: Measures the performance of the model on the dataset.\n", + "- Loss computation: Calculates the difference between predictions and targets.\n", + "\n", + "To optimize performance:\n", + "- [`grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) is used to compute gradients of the loss function with respect to network parameters.\n", + "- [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) compiles the update function, enabling faster execution by leveraging JAX's [XLA](https://openxla.org/xla) compilation.\n", + "\n", + "- [`device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html) to transfer the dataset to the GPU." + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": { + "id": "sA0a06raEQfS" + }, + "outputs": [], + "source": [ + "import time\n", + "\n", + "def one_hot(x, k, dtype=jnp.float32):\n", + " \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n", + " return jnp.array(x[:, None] == jnp.arange(k), dtype)\n", + "\n", + "def accuracy(params, images, targets):\n", + " \"\"\"Calculate the accuracy of predictions.\"\"\"\n", + " target_class = jnp.argmax(targets, axis=1)\n", + " predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n", + " return jnp.mean(predicted_class == target_class)\n", + "\n", + "def loss(params, images, targets):\n", + " \"\"\"Calculate the loss between predictions and targets.\"\"\"\n", + " preds = batched_predict(params, images)\n", + " return -jnp.mean(preds * targets)\n", + "\n", + "@jit\n", + "def update(params, x, y):\n", + " \"\"\"Update the network parameters using gradient descent.\"\"\"\n", + " grads = grad(loss)(params, x, y)\n", + " return [(w - step_size * dw, b - step_size * db)\n", + " for (w, b), (dw, db) in zip(params, grads)]\n", + "\n", + "def reshape_and_one_hot(x, y):\n", + " \"\"\"Reshape and one-hot encode the inputs.\"\"\"\n", + " x = jnp.reshape(x, (len(x), num_pixels))\n", + " y = one_hot(y, n_targets)\n", + " return x, y\n", + "\n", + "def train_model(num_epochs, params, training_generator, data_loader_type='streamed'):\n", + " \"\"\"Train the model for a given number of epochs and device_put for GPU transfer.\"\"\"\n", + " for epoch in range(num_epochs):\n", + " start_time = time.time()\n", + " for x, y in training_generator() if data_loader_type == 'streamed' else training_generator:\n", + " x, y = reshape_and_one_hot(x, y)\n", + " x, y = device_put(x), device_put(y)\n", + " params = update(params, x, y)\n", + "\n", + " print(f\"Epoch {epoch + 1} in {time.time() - start_time:.2f} sec: \"\n", + " f\"Train Accuracy: {accuracy(params, train_images, train_labels):.4f}, \"\n", + " f\"Test Accuracy: {accuracy(params, test_images, test_labels):.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Hsionp5IYsQ9" + }, + "source": [ + "## Loading Data with PyTorch DataLoader\n", + "\n", + "This section shows how to load the MNIST dataset using PyTorch's DataLoader, convert the data to NumPy arrays, and apply transformations to flatten and cast images." + ] + }, + { + "cell_type": "code", + "execution_count": 40, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "uA7XY0OezHse", + "outputId": "4c86f455-ff1d-474e-f8e3-7111d9b56996" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.5.1+cu121)\n", + "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.20.1+cu121)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.4.2)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.9.0)\n", + "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.26.4)\n", + "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (11.0.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (3.0.2)\n" + ] + } + ], + "source": [ + "!pip install torch torchvision" + ] + }, + { + "cell_type": "code", + "execution_count": 41, + "metadata": { + "id": "kO5_WzwY59gE" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from jax.tree_util import tree_map\n", + "from torch.utils import data\n", + "from torchvision.datasets import MNIST" + ] + }, + { + "cell_type": "code", + "execution_count": 42, + "metadata": { + "id": "6f6qU8PCc143" + }, + "outputs": [], + "source": [ + "def numpy_collate(batch):\n", + " \"\"\"Collate function to convert a batch of PyTorch data into NumPy arrays.\"\"\"\n", + " return tree_map(np.asarray, data.default_collate(batch))\n", + "\n", + "class NumpyLoader(data.DataLoader):\n", + " \"\"\"Custom DataLoader to return NumPy arrays from a PyTorch Dataset.\"\"\"\n", + " def __init__(self, dataset, batch_size=1,\n", + " shuffle=False, sampler=None,\n", + " batch_sampler=None, num_workers=0,\n", + " pin_memory=False, drop_last=False,\n", + " timeout=0, worker_init_fn=None):\n", + " super(self.__class__, self).__init__(dataset,\n", + " batch_size=batch_size,\n", + " shuffle=shuffle,\n", + " sampler=sampler,\n", + " batch_sampler=batch_sampler,\n", + " num_workers=num_workers,\n", + " collate_fn=numpy_collate,\n", + " pin_memory=pin_memory,\n", + " drop_last=drop_last,\n", + " timeout=timeout,\n", + " worker_init_fn=worker_init_fn)\n", + "class FlattenAndCast(object):\n", + " \"\"\"Transform class to flatten and cast images to float32.\"\"\"\n", + " def __call__(self, pic):\n", + " return np.ravel(np.array(pic, dtype=jnp.float32))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mfSnfJND6I8G" + }, + "source": [ + "### Load Dataset with Transformations\n", + "\n", + "Standardize the data by flattening the images, casting them to `float32`, and ensuring consistent data types." + ] + }, + { + "cell_type": "code", + "execution_count": 43, + "metadata": { + "id": "Kxbl6bcx6crv" + }, + "outputs": [], + "source": [ + "mnist_dataset = MNIST(data_dir, download=True, transform=FlattenAndCast())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kbdsqvPZGrsa" + }, + "source": [ + "### Full Training Dataset for Accuracy Checks\n", + "\n", + "Convert the entire training dataset to JAX arrays." + ] + }, + { + "cell_type": "code", + "execution_count": 44, + "metadata": { + "id": "c9ZCJq_rzPck" + }, + "outputs": [], + "source": [ + "train_images = np.array(mnist_dataset.data).reshape(len(mnist_dataset.data), -1)\n", + "train_labels = one_hot(np.array(mnist_dataset.targets), n_targets)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WXUh0BwvG8Ko" + }, + "source": [ + "### Get Full Test Dataset\n", + "\n", + "Load and process the full test dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 45, + "metadata": { + "id": "brlLG4SqGphm" + }, + "outputs": [], + "source": [ + "mnist_dataset_test = MNIST(data_dir, download=True, train=False)\n", + "test_images = jnp.array(mnist_dataset_test.data.numpy().reshape(len(mnist_dataset_test.data), -1), dtype=jnp.float32)\n", + "test_labels = one_hot(np.array(mnist_dataset_test.targets), n_targets)" + ] + }, + { + "cell_type": "code", + "execution_count": 46, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Oz-UVnCxG5E8", + "outputId": "53f3fb32-5096-4862-e022-3c3a1d82137a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train: (60000, 784) (60000, 10)\n", + "Test: (10000, 784) (10000, 10)\n" + ] + } + ], + "source": [ + "print('Train:', train_images.shape, train_labels.shape)\n", + "print('Test:', test_images.shape, test_labels.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mNjn9dMPitKL" + }, + "source": [ + "### Training Data Generator\n", + "\n", + "Define a generator function using PyTorch's DataLoader for batch training.\n", + "Setting `num_workers > 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload.\n", + "\n", + "Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.`\n", + "This warning can be safely ignored since data loaders do not use JAX within the forked processes." + ] + }, + { + "cell_type": "code", + "execution_count": 47, + "metadata": { + "id": "0LdT8P8aisWF" + }, + "outputs": [], + "source": [ + "def pytorch_training_generator(mnist_dataset):\n", + " return NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Xzt2x9S1HC3T" + }, + "source": [ + "### Training Loop (PyTorch DataLoader)\n", + "\n", + "The training loop uses the PyTorch DataLoader to iterate through batches and update model parameters." + ] + }, + { + "cell_type": "code", + "execution_count": 48, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "SqweRz_98sN8", + "outputId": "bdd45256-3f5a-48f7-e45c-378078ac4279" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 in 20.23 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9195\n", + "Epoch 2 in 14.64 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9385\n", + "Epoch 3 in 3.91 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9467\n", + "Epoch 4 in 3.85 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532\n", + "Epoch 5 in 4.48 sec: Train Accuracy: 0.9631, Test Accuracy: 0.9577\n", + "Epoch 6 in 4.03 sec: Train Accuracy: 0.9675, Test Accuracy: 0.9617\n", + "Epoch 7 in 3.86 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9652\n", + "Epoch 8 in 4.57 sec: Train Accuracy: 0.9736, Test Accuracy: 0.9671\n" + ] + } + ], + "source": [ + "train_model(num_epochs, params, pytorch_training_generator(mnist_dataset), data_loader_type='iterable')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Nm45ZTo6yrf5" + }, + "source": [ + "## Loading Data with TensorFlow Datasets (TFDS)\n", + "\n", + "This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlow." + ] + }, + { + "cell_type": "code", + "execution_count": 49, + "metadata": { + "id": "sGaQAk1DHMUx" + }, + "outputs": [], + "source": [ + "import tensorflow_datasets as tfds" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZSc5K0Eiwm4L" + }, + "source": [ + "### Fetch Full Dataset for Evaluation\n", + "\n", + "Load the dataset with `tfds.load`, convert it to NumPy arrays, and process it for evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": 50, + "metadata": { + "id": "1hOamw_7C8Pb" + }, + "outputs": [], + "source": [ + "# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)\n", + "mnist_data, info = tfds.load(name=\"mnist\", batch_size=-1, data_dir=data_dir, with_info=True)\n", + "mnist_data = tfds.as_numpy(mnist_data)\n", + "train_data, test_data = mnist_data['train'], mnist_data['test']\n", + "\n", + "# Full train set\n", + "train_images, train_labels = train_data['image'], train_data['label']\n", + "train_images = jnp.reshape(train_images, (len(train_images), num_pixels))\n", + "train_labels = one_hot(train_labels, n_targets)\n", + "\n", + "# Full test set\n", + "test_images, test_labels = test_data['image'], test_data['label']\n", + "test_images = jnp.reshape(test_images, (len(test_images), num_pixels))\n", + "test_labels = one_hot(test_labels, n_targets)" + ] + }, + { + "cell_type": "code", + "execution_count": 51, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Td3PiLdmEf7z", + "outputId": "b8c9a32a-9cf0-4dc3-cb51-db21d32c6545" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train: (60000, 784) (60000, 10)\n", + "Test: (10000, 784) (10000, 10)\n" + ] + } + ], + "source": [ + "print('Train:', train_images.shape, train_labels.shape)\n", + "print('Test:', test_images.shape, test_labels.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dXMvgk6sdq4j" + }, + "source": [ + "### Define the Training Generator\n", + "\n", + "Create a generator function to yield batches of data for training." + ] + }, + { + "cell_type": "code", + "execution_count": 52, + "metadata": { + "id": "vX59u8CqEf4J" + }, + "outputs": [], + "source": [ + "def training_generator():\n", + " # as_supervised=True gives us the (image, label) as a tuple instead of a dict\n", + " ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)\n", + " # You can build up an arbitrary tf.data input pipeline\n", + " ds = ds.batch(batch_size).prefetch(1)\n", + " # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays\n", + " return tfds.as_numpy(ds)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EAWeUdnuFNBY" + }, + "source": [ + "### Training Loop (TFDS)\n", + "\n", + "Use the training generator in a custom training loop." + ] + }, + { + "cell_type": "code", + "execution_count": 53, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "h2sO13XDGvq1", + "outputId": "f30805bb-e725-46ee-e053-6e97f2af81c5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 in 20.86 sec: Train Accuracy: 0.9253, Test Accuracy: 0.9268\n", + "Epoch 2 in 8.56 sec: Train Accuracy: 0.9428, Test Accuracy: 0.9413\n", + "Epoch 3 in 5.40 sec: Train Accuracy: 0.9532, Test Accuracy: 0.9511\n", + "Epoch 4 in 3.86 sec: Train Accuracy: 0.9598, Test Accuracy: 0.9555\n", + "Epoch 5 in 3.88 sec: Train Accuracy: 0.9652, Test Accuracy: 0.9601\n", + "Epoch 6 in 10.35 sec: Train Accuracy: 0.9692, Test Accuracy: 0.9631\n", + "Epoch 7 in 4.39 sec: Train Accuracy: 0.9726, Test Accuracy: 0.9650\n", + "Epoch 8 in 4.77 sec: Train Accuracy: 0.9753, Test Accuracy: 0.9669\n" + ] + } + ], + "source": [ + "train_model(num_epochs, params, training_generator)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-ryVkrAITS9Z" + }, + "source": [ + "## Loading Data with Grain\n", + "\n", + "This section demonstrates how to load MNIST data using Grain, a data-loading library. You'll define a custom dataset class for Grain and set up a Grain DataLoader for efficient training." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "waYhUMUGmhH-" + }, + "source": [ + "Install Grain" + ] + }, + { + "cell_type": "code", + "execution_count": 54, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "L78o7eeyGvn5", + "outputId": "cb0ce6cf-243b-4183-8f63-646e00232caa" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: grain in /usr/local/lib/python3.10/dist-packages (0.2.2)\n", + "Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from grain) (1.4.0)\n", + "Requirement already satisfied: array-record in /usr/local/lib/python3.10/dist-packages (from grain) (0.5.1)\n", + "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.10/dist-packages (from grain) (3.1.0)\n", + "Requirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from grain) (0.1.8)\n", + "Requirement already satisfied: etils[epath,epy] in /usr/local/lib/python3.10/dist-packages (from grain) (1.10.0)\n", + "Requirement already satisfied: jaxtyping in /usr/local/lib/python3.10/dist-packages (from grain) (0.2.36)\n", + "Requirement already satisfied: more-itertools>=9.1.0 in /usr/local/lib/python3.10/dist-packages (from grain) (10.5.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from grain) (1.26.4)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (2024.9.0)\n", + "Requirement already satisfied: importlib_resources in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (6.4.5)\n", + "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (4.12.2)\n", + "Requirement already satisfied: zipp in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (3.21.0)\n" + ] + } + ], + "source": [ + "!pip install grain" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "66bH3ZDJ7Iat" + }, + "source": [ + "Import Required Libraries (import MNIST dataset from torchvision)" + ] + }, + { + "cell_type": "code", + "execution_count": 55, + "metadata": { + "id": "mS62eVL9Ifmz" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import grain.python as pygrain\n", + "from torchvision.datasets import MNIST" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0h6mwVrspPA-" + }, + "source": [ + "### Define Dataset Class\n", + "\n", + "Create a custom dataset class to load MNIST data for Grain." + ] + }, + { + "cell_type": "code", + "execution_count": 56, + "metadata": { + "id": "bnrhac5Hh7y1" + }, + "outputs": [], + "source": [ + "class Dataset:\n", + " def __init__(self, data_dir, train=True):\n", + " self.data_dir = data_dir\n", + " self.train = train\n", + " self.load_data()\n", + "\n", + " def load_data(self):\n", + " self.dataset = MNIST(self.data_dir, download=True, train=self.train)\n", + "\n", + " def __len__(self):\n", + " return len(self.dataset)\n", + "\n", + " def __getitem__(self, index):\n", + " img, label = self.dataset[index]\n", + " return np.ravel(np.array(img, dtype=np.float32)), label" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "53mf8bWEsyTr" + }, + "source": [ + "### Initialize the Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 57, + "metadata": { + "id": "pN3oF7-ostGE" + }, + "outputs": [], + "source": [ + "mnist_dataset = Dataset(data_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GqD-ycgBuwv9" + }, + "source": [ + "### Get the full train and test dataset" + ] + }, + { + "cell_type": "code", + "execution_count": 58, + "metadata": { + "id": "f1VnTuX3u_kL" + }, + "outputs": [], + "source": [ + "# Convert training data to JAX arrays and encode labels as one-hot vectors\n", + "train_images = jnp.array([mnist_dataset[i][0] for i in range(len(mnist_dataset))], dtype=jnp.float32)\n", + "train_labels = one_hot(np.array([mnist_dataset[i][1] for i in range(len(mnist_dataset))]), n_targets)\n", + "\n", + "# Load test dataset and process it\n", + "mnist_dataset_test = MNIST(data_dir, download=True, train=False)\n", + "test_images = jnp.array([np.ravel(np.array(mnist_dataset_test[i][0], dtype=np.float32)) for i in range(len(mnist_dataset_test))], dtype=jnp.float32)\n", + "test_labels = one_hot(np.array([mnist_dataset_test[i][1] for i in range(len(mnist_dataset_test))]), n_targets)" + ] + }, + { + "cell_type": "code", + "execution_count": 59, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "a2NHlp9klrQL", + "outputId": "c9422190-55e9-400b-bd4e-0e7bf23dc6a1" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train: (60000, 784) (60000, 10)\n", + "Test: (10000, 784) (10000, 10)\n" + ] + } + ], + "source": [ + "print(\"Train:\", train_images.shape, train_labels.shape)\n", + "print(\"Test:\", test_images.shape, test_labels.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1QPbXt7O0JN-" + }, + "source": [ + "### Initialize PyGrain DataLoader" + ] + }, + { + "cell_type": "code", + "execution_count": 60, + "metadata": { + "id": "2jqd1jJt25Bj" + }, + "outputs": [], + "source": [ + "sampler = pygrain.SequentialSampler(\n", + " num_records=len(mnist_dataset),\n", + " shard_options=pygrain.NoSharding()) # Single-device, no sharding\n", + "\n", + "def pygrain_training_generator():\n", + " return pygrain.DataLoader(\n", + " data_source=mnist_dataset,\n", + " sampler=sampler,\n", + " operations=[pygrain.Batch(batch_size)],\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mV5z4GLCGKlx" + }, + "source": [ + "### Training Loop (Grain)\n", + "\n", + "Run the training loop using the Grain DataLoader." + ] + }, + { + "cell_type": "code", + "execution_count": 61, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9-iANQ-9CcW_", + "outputId": "b0e19da2-9e34-4183-c5d8-af66de5efa5c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 in 15.65 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9195\n", + "Epoch 2 in 15.03 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9385\n", + "Epoch 3 in 14.93 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9467\n", + "Epoch 4 in 11.56 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532\n", + "Epoch 5 in 9.33 sec: Train Accuracy: 0.9631, Test Accuracy: 0.9577\n", + "Epoch 6 in 9.31 sec: Train Accuracy: 0.9675, Test Accuracy: 0.9617\n", + "Epoch 7 in 9.78 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9652\n", + "Epoch 8 in 9.80 sec: Train Accuracy: 0.9736, Test Accuracy: 0.9671\n" + ] + } + ], + "source": [ + "train_model(num_epochs, params, pygrain_training_generator)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "o51P6lr86wz-" + }, + "source": [ + "## Loading Data with Hugging Face\n", + "\n", + "This section demonstrates loading MNIST data using the Hugging Face `datasets` library. You'll format the dataset for JAX compatibility, prepare flattened images and one-hot-encoded labels, and define a training generator." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "69vrihaOi4Oz" + }, + "source": [ + "Install the Hugging Face `datasets` library." + ] + }, + { + "cell_type": "code", + "execution_count": 62, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "19ipxPhI6oSN", + "outputId": "b80b80cd-fc14-4a43-f8a8-2802de4faade" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (3.1.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.1)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n", + "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n", + "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n", + "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.6)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.5.0)\n", + "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n", + "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.9.0)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.11.2)\n", + "Requirement already satisfied: huggingface-hub>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.26.2)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.2)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.3)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.5.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (0.2.0)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.17.2)\n", + "Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.23.0->datasets) (4.12.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.4.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.2.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.8.30)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n" + ] + } + ], + "source": [ + "!pip install datasets" + ] + }, + { + "cell_type": "code", + "execution_count": 63, + "metadata": { + "id": "8v1N59p76zn0" + }, + "outputs": [], + "source": [ + "from datasets import load_dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8Gaj11tO7C86" + }, + "source": [ + "Load the MNIST dataset from Hugging Face and format it as `numpy` arrays for quick access or `jax` to get JAX arrays." + ] + }, + { + "cell_type": "code", + "execution_count": 64, + "metadata": { + "id": "a22kTvgk6_fJ" + }, + "outputs": [], + "source": [ + "mnist_dataset = load_dataset(\"mnist\", cache_dir=data_dir).with_format(\"numpy\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tgI7dIaX7JzM" + }, + "source": [ + "### Extract images and labels\n", + "\n", + "Get image shape and flatten for model input." + ] + }, + { + "cell_type": "code", + "execution_count": 65, + "metadata": { + "id": "NHrKatD_7HbH" + }, + "outputs": [], + "source": [ + "train_images = mnist_dataset[\"train\"][\"image\"]\n", + "train_labels = mnist_dataset[\"train\"][\"label\"]\n", + "test_images = mnist_dataset[\"test\"][\"image\"]\n", + "test_labels = mnist_dataset[\"test\"][\"label\"]\n", + "\n", + "# Extract image shape\n", + "image_shape = train_images.shape[1:]\n", + "num_features = image_shape[0] * image_shape[1]\n", + "\n", + "# Flatten the images\n", + "train_images = train_images.reshape(-1, num_features)\n", + "test_images = test_images.reshape(-1, num_features)\n", + "\n", + "# One-hot encode the labels\n", + "train_labels = one_hot(train_labels, n_targets)\n", + "test_labels = one_hot(test_labels, n_targets)" + ] + }, + { + "cell_type": "code", + "execution_count": 66, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dITh435Z7Nwb", + "outputId": "cc89c1ec-6987-4f1c-90a4-c3b355ea7225" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train: (60000, 784) (60000, 10)\n", + "Test: (10000, 784) (10000, 10)\n" + ] + } + ], + "source": [ + "print('Train:', train_images.shape, train_labels.shape)\n", + "print('Test:', test_images.shape, test_labels.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kk_4zJlz7T1E" + }, + "source": [ + "### Define Training Generator\n", + "\n", + "Set up a generator to yield batches of images and labels for training." + ] + }, + { + "cell_type": "code", + "execution_count": 67, + "metadata": { + "id": "-zLJhogj7RL-" + }, + "outputs": [], + "source": [ + "def hf_training_generator():\n", + " \"\"\"Yield batches for training.\"\"\"\n", + " for batch in mnist_dataset[\"train\"].iter(batch_size):\n", + " x, y = batch[\"image\"], batch[\"label\"]\n", + " yield x, y" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HIsGfkLI7dvZ" + }, + "source": [ + "### Training Loop (Hugging Face Datasets)\n", + "\n", + "Run the training loop using the Hugging Face training generator." + ] + }, + { + "cell_type": "code", + "execution_count": 68, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Ui6aLiZP7aLe", + "outputId": "c51529e0-563f-4af0-9793-76b5e6f323db" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 in 19.06 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9195\n", + "Epoch 2 in 8.94 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9385\n", + "Epoch 3 in 5.43 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9467\n", + "Epoch 4 in 6.41 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532\n", + "Epoch 5 in 5.80 sec: Train Accuracy: 0.9631, Test Accuracy: 0.9577\n", + "Epoch 6 in 6.61 sec: Train Accuracy: 0.9675, Test Accuracy: 0.9617\n", + "Epoch 7 in 5.49 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9652\n", + "Epoch 8 in 6.64 sec: Train Accuracy: 0.9736, Test Accuracy: 0.9671\n" + ] + } + ], + "source": [ + "train_model(num_epochs, params, hf_training_generator)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rCJq2rvKlKWX" + }, + "source": [ + "## Summary\n", + "\n", + "This notebook explored efficient methods for loading data on a GPU with JAX, using libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. You also learned GPU-specific optimizations, such as `device_put` for data transfer and memory management, to enhance training efficiency. Each methods offers unique benefits, helping you choose the best fit for your project needs." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "name": "data_loaders_on_gpu_with_jax.ipynb", + "provenance": [] + }, + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/data_loaders_on_gpu_with_jax.md b/docs/data_loaders_on_gpu_with_jax.md new file mode 100644 index 0000000..4ec7487 --- /dev/null +++ b/docs/data_loaders_on_gpu_with_jax.md @@ -0,0 +1,645 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.15.2 +kernelspec: + display_name: Python 3 + name: python3 +--- + ++++ {"id": "PUFGZggH49zp"} + +# Introduction to Data Loaders on GPU with JAX + ++++ {"id": "3ia4PKEV5Dr8"} + +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/data_loaders_on_gpu_with_jax.ipynb) + +This tutorial explores different data loading strategies for using **JAX** on a single [**GPU**](https://jax.readthedocs.io/en/latest/glossary.html#term-GPU). While JAX doesn't include a built-in data loader, it seamlessly integrates with popular data loading libraries, including: +* [**PyTorch DataLoader**](https://github.com/pytorch/data) +* [**TensorFlow Datasets (TFDS)**](https://github.com/tensorflow/datasets) +* [**Grain**](https://github.com/google/grain) +* [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading) + +You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset. + +Compared to the [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html), working with GPUs introduces opportunities for further optimization, such as transferring data to the GPU using `device_put`, leveraging larger batch sizes for faster processing, and addressing considerations like memory management. + ++++ {"id": "-rsMgVtO6asW"} + +### Import JAX API + +```{code-cell} +:id: tDJNQ6V-Dg5g + +import jax +import jax.numpy as jnp +from jax import grad, jit, vmap, random, device_put +``` + ++++ {"id": "TsFdlkSZKp9S"} + +### Checking GPU Availability for JAX + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: N3sqvaF3KJw1 +outputId: ab40f542-b8c0-422c-ca68-4ce292817889 +--- +jax.devices() +``` + ++++ {"id": "qyJ_WTghDnIc"} + +### Setting Hyperparameters and Initializing Parameters + +You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network. + +```{code-cell} +:id: qLNOSloFDka_ + +# A helper function to randomly initialize weights and biases +# for a dense neural network layer +def random_layer_params(m, n, key, scale=1e-2): + w_key, b_key = random.split(key) + return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,)) + +# Function to initialize network parameters for all layers based on defined sizes +def init_network_params(sizes, key): + keys = random.split(key, len(sizes)) + return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)] + +layer_sizes = [784, 512, 512, 10] # Layers of the network +step_size = 0.01 # Learning rate +num_epochs = 8 # Number of training epochs +batch_size = 128 # Batch size for training +n_targets = 10 # Number of classes (digits 0-9) +num_pixels = 28 * 28 # Each MNIST image is 28x28 pixels +data_dir = '/tmp/mnist_dataset' # Directory for storing the dataset + +# Initialize network parameters using the defined layer sizes and a random seed +params = init_network_params(layer_sizes, random.PRNGKey(0)) +``` + ++++ {"id": "rHLdqeI7D2WZ"} + +### Model Prediction with Auto-Batching + +In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image. + +To efficiently process multiple images simultaneously, you'll use [`vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap), which allows you to vectorize the `predict` function and apply it across a batch of inputs. This technique, called auto-batching, improves computational efficiency by leveraging hardware acceleration. + +```{code-cell} +:id: bKIYPSkvD1QV + +from jax.scipy.special import logsumexp + +def relu(x): + return jnp.maximum(0, x) + +def predict(params, image): + # per-example predictions + activations = image + for w, b in params[:-1]: + outputs = jnp.dot(w, activations) + b + activations = relu(outputs) + + final_w, final_b = params[-1] + logits = jnp.dot(final_w, activations) + final_b + return logits - logsumexp(logits) + +# Make a batched version of the `predict` function +batched_predict = vmap(predict, in_axes=(None, 0)) +``` + ++++ {"id": "rLqfeORsERek"} + +### Utility and Loss Functions + +You'll now define utility functions for: +- One-hot encoding: Converts class indices to binary vectors. +- Accuracy calculation: Measures the performance of the model on the dataset. +- Loss computation: Calculates the difference between predictions and targets. + +To optimize performance: +- [`grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) is used to compute gradients of the loss function with respect to network parameters. +- [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) compiles the update function, enabling faster execution by leveraging JAX's [XLA](https://openxla.org/xla) compilation. + +- [`device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html) to transfer the dataset to the GPU. + +```{code-cell} +:id: sA0a06raEQfS + +import time + +def one_hot(x, k, dtype=jnp.float32): + """Create a one-hot encoding of x of size k.""" + return jnp.array(x[:, None] == jnp.arange(k), dtype) + +def accuracy(params, images, targets): + """Calculate the accuracy of predictions.""" + target_class = jnp.argmax(targets, axis=1) + predicted_class = jnp.argmax(batched_predict(params, images), axis=1) + return jnp.mean(predicted_class == target_class) + +def loss(params, images, targets): + """Calculate the loss between predictions and targets.""" + preds = batched_predict(params, images) + return -jnp.mean(preds * targets) + +@jit +def update(params, x, y): + """Update the network parameters using gradient descent.""" + grads = grad(loss)(params, x, y) + return [(w - step_size * dw, b - step_size * db) + for (w, b), (dw, db) in zip(params, grads)] + +def reshape_and_one_hot(x, y): + """Reshape and one-hot encode the inputs.""" + x = jnp.reshape(x, (len(x), num_pixels)) + y = one_hot(y, n_targets) + return x, y + +def train_model(num_epochs, params, training_generator, data_loader_type='streamed'): + """Train the model for a given number of epochs and device_put for GPU transfer.""" + for epoch in range(num_epochs): + start_time = time.time() + for x, y in training_generator() if data_loader_type == 'streamed' else training_generator: + x, y = reshape_and_one_hot(x, y) + x, y = device_put(x), device_put(y) + params = update(params, x, y) + + print(f"Epoch {epoch + 1} in {time.time() - start_time:.2f} sec: " + f"Train Accuracy: {accuracy(params, train_images, train_labels):.4f}, " + f"Test Accuracy: {accuracy(params, test_images, test_labels):.4f}") +``` + ++++ {"id": "Hsionp5IYsQ9"} + +## Loading Data with PyTorch DataLoader + +This section shows how to load the MNIST dataset using PyTorch's DataLoader, convert the data to NumPy arrays, and apply transformations to flatten and cast images. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: uA7XY0OezHse +outputId: 4c86f455-ff1d-474e-f8e3-7111d9b56996 +--- +!pip install torch torchvision +``` + +```{code-cell} +:id: kO5_WzwY59gE + +import numpy as np +from jax.tree_util import tree_map +from torch.utils import data +from torchvision.datasets import MNIST +``` + +```{code-cell} +:id: 6f6qU8PCc143 + +def numpy_collate(batch): + """Collate function to convert a batch of PyTorch data into NumPy arrays.""" + return tree_map(np.asarray, data.default_collate(batch)) + +class NumpyLoader(data.DataLoader): + """Custom DataLoader to return NumPy arrays from a PyTorch Dataset.""" + def __init__(self, dataset, batch_size=1, + shuffle=False, sampler=None, + batch_sampler=None, num_workers=0, + pin_memory=False, drop_last=False, + timeout=0, worker_init_fn=None): + super(self.__class__, self).__init__(dataset, + batch_size=batch_size, + shuffle=shuffle, + sampler=sampler, + batch_sampler=batch_sampler, + num_workers=num_workers, + collate_fn=numpy_collate, + pin_memory=pin_memory, + drop_last=drop_last, + timeout=timeout, + worker_init_fn=worker_init_fn) +class FlattenAndCast(object): + """Transform class to flatten and cast images to float32.""" + def __call__(self, pic): + return np.ravel(np.array(pic, dtype=jnp.float32)) +``` + ++++ {"id": "mfSnfJND6I8G"} + +### Load Dataset with Transformations + +Standardize the data by flattening the images, casting them to `float32`, and ensuring consistent data types. + +```{code-cell} +:id: Kxbl6bcx6crv + +mnist_dataset = MNIST(data_dir, download=True, transform=FlattenAndCast()) +``` + ++++ {"id": "kbdsqvPZGrsa"} + +### Full Training Dataset for Accuracy Checks + +Convert the entire training dataset to JAX arrays. + +```{code-cell} +:id: c9ZCJq_rzPck + +train_images = np.array(mnist_dataset.data).reshape(len(mnist_dataset.data), -1) +train_labels = one_hot(np.array(mnist_dataset.targets), n_targets) +``` + ++++ {"id": "WXUh0BwvG8Ko"} + +### Get Full Test Dataset + +Load and process the full test dataset. + +```{code-cell} +:id: brlLG4SqGphm + +mnist_dataset_test = MNIST(data_dir, download=True, train=False) +test_images = jnp.array(mnist_dataset_test.data.numpy().reshape(len(mnist_dataset_test.data), -1), dtype=jnp.float32) +test_labels = one_hot(np.array(mnist_dataset_test.targets), n_targets) +``` + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: Oz-UVnCxG5E8 +outputId: 53f3fb32-5096-4862-e022-3c3a1d82137a +--- +print('Train:', train_images.shape, train_labels.shape) +print('Test:', test_images.shape, test_labels.shape) +``` + ++++ {"id": "mNjn9dMPitKL"} + +### Training Data Generator + +Define a generator function using PyTorch's DataLoader for batch training. +Setting `num_workers > 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload. + +Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.` +This warning can be safely ignored since data loaders do not use JAX within the forked processes. + +```{code-cell} +:id: 0LdT8P8aisWF + +def pytorch_training_generator(mnist_dataset): + return NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0) +``` + ++++ {"id": "Xzt2x9S1HC3T"} + +### Training Loop (PyTorch DataLoader) + +The training loop uses the PyTorch DataLoader to iterate through batches and update model parameters. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: SqweRz_98sN8 +outputId: bdd45256-3f5a-48f7-e45c-378078ac4279 +--- +train_model(num_epochs, params, pytorch_training_generator(mnist_dataset), data_loader_type='iterable') +``` + ++++ {"id": "Nm45ZTo6yrf5"} + +## Loading Data with TensorFlow Datasets (TFDS) + +This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlow. + +```{code-cell} +:id: sGaQAk1DHMUx + +import tensorflow_datasets as tfds +``` + ++++ {"id": "ZSc5K0Eiwm4L"} + +### Fetch Full Dataset for Evaluation + +Load the dataset with `tfds.load`, convert it to NumPy arrays, and process it for evaluation. + +```{code-cell} +:id: 1hOamw_7C8Pb + +# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1) +mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True) +mnist_data = tfds.as_numpy(mnist_data) +train_data, test_data = mnist_data['train'], mnist_data['test'] + +# Full train set +train_images, train_labels = train_data['image'], train_data['label'] +train_images = jnp.reshape(train_images, (len(train_images), num_pixels)) +train_labels = one_hot(train_labels, n_targets) + +# Full test set +test_images, test_labels = test_data['image'], test_data['label'] +test_images = jnp.reshape(test_images, (len(test_images), num_pixels)) +test_labels = one_hot(test_labels, n_targets) +``` + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: Td3PiLdmEf7z +outputId: b8c9a32a-9cf0-4dc3-cb51-db21d32c6545 +--- +print('Train:', train_images.shape, train_labels.shape) +print('Test:', test_images.shape, test_labels.shape) +``` + ++++ {"id": "dXMvgk6sdq4j"} + +### Define the Training Generator + +Create a generator function to yield batches of data for training. + +```{code-cell} +:id: vX59u8CqEf4J + +def training_generator(): + # as_supervised=True gives us the (image, label) as a tuple instead of a dict + ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir) + # You can build up an arbitrary tf.data input pipeline + ds = ds.batch(batch_size).prefetch(1) + # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays + return tfds.as_numpy(ds) +``` + ++++ {"id": "EAWeUdnuFNBY"} + +### Training Loop (TFDS) + +Use the training generator in a custom training loop. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: h2sO13XDGvq1 +outputId: f30805bb-e725-46ee-e053-6e97f2af81c5 +--- +train_model(num_epochs, params, training_generator) +``` + ++++ {"id": "-ryVkrAITS9Z"} + +## Loading Data with Grain + +This section demonstrates how to load MNIST data using Grain, a data-loading library. You'll define a custom dataset class for Grain and set up a Grain DataLoader for efficient training. + ++++ {"id": "waYhUMUGmhH-"} + +Install Grain + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: L78o7eeyGvn5 +outputId: cb0ce6cf-243b-4183-8f63-646e00232caa +--- +!pip install grain +``` + ++++ {"id": "66bH3ZDJ7Iat"} + +Import Required Libraries (import MNIST dataset from torchvision) + +```{code-cell} +:id: mS62eVL9Ifmz + +import numpy as np +import grain.python as pygrain +from torchvision.datasets import MNIST +``` + ++++ {"id": "0h6mwVrspPA-"} + +### Define Dataset Class + +Create a custom dataset class to load MNIST data for Grain. + +```{code-cell} +:id: bnrhac5Hh7y1 + +class Dataset: + def __init__(self, data_dir, train=True): + self.data_dir = data_dir + self.train = train + self.load_data() + + def load_data(self): + self.dataset = MNIST(self.data_dir, download=True, train=self.train) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + img, label = self.dataset[index] + return np.ravel(np.array(img, dtype=np.float32)), label +``` + ++++ {"id": "53mf8bWEsyTr"} + +### Initialize the Dataset + +```{code-cell} +:id: pN3oF7-ostGE + +mnist_dataset = Dataset(data_dir) +``` + ++++ {"id": "GqD-ycgBuwv9"} + +### Get the full train and test dataset + +```{code-cell} +:id: f1VnTuX3u_kL + +# Convert training data to JAX arrays and encode labels as one-hot vectors +train_images = jnp.array([mnist_dataset[i][0] for i in range(len(mnist_dataset))], dtype=jnp.float32) +train_labels = one_hot(np.array([mnist_dataset[i][1] for i in range(len(mnist_dataset))]), n_targets) + +# Load test dataset and process it +mnist_dataset_test = MNIST(data_dir, download=True, train=False) +test_images = jnp.array([np.ravel(np.array(mnist_dataset_test[i][0], dtype=np.float32)) for i in range(len(mnist_dataset_test))], dtype=jnp.float32) +test_labels = one_hot(np.array([mnist_dataset_test[i][1] for i in range(len(mnist_dataset_test))]), n_targets) +``` + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: a2NHlp9klrQL +outputId: c9422190-55e9-400b-bd4e-0e7bf23dc6a1 +--- +print("Train:", train_images.shape, train_labels.shape) +print("Test:", test_images.shape, test_labels.shape) +``` + ++++ {"id": "1QPbXt7O0JN-"} + +### Initialize PyGrain DataLoader + +```{code-cell} +:id: 2jqd1jJt25Bj + +sampler = pygrain.SequentialSampler( + num_records=len(mnist_dataset), + shard_options=pygrain.NoSharding()) # Single-device, no sharding + +def pygrain_training_generator(): + return pygrain.DataLoader( + data_source=mnist_dataset, + sampler=sampler, + operations=[pygrain.Batch(batch_size)], + ) +``` + ++++ {"id": "mV5z4GLCGKlx"} + +### Training Loop (Grain) + +Run the training loop using the Grain DataLoader. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: 9-iANQ-9CcW_ +outputId: b0e19da2-9e34-4183-c5d8-af66de5efa5c +--- +train_model(num_epochs, params, pygrain_training_generator) +``` + ++++ {"id": "o51P6lr86wz-"} + +## Loading Data with Hugging Face + +This section demonstrates loading MNIST data using the Hugging Face `datasets` library. You'll format the dataset for JAX compatibility, prepare flattened images and one-hot-encoded labels, and define a training generator. + ++++ {"id": "69vrihaOi4Oz"} + +Install the Hugging Face `datasets` library. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: 19ipxPhI6oSN +outputId: b80b80cd-fc14-4a43-f8a8-2802de4faade +--- +!pip install datasets +``` + +```{code-cell} +:id: 8v1N59p76zn0 + +from datasets import load_dataset +``` + ++++ {"id": "8Gaj11tO7C86"} + +Load the MNIST dataset from Hugging Face and format it as `numpy` arrays for quick access or `jax` to get JAX arrays. + +```{code-cell} +:id: a22kTvgk6_fJ + +mnist_dataset = load_dataset("mnist", cache_dir=data_dir).with_format("numpy") +``` + ++++ {"id": "tgI7dIaX7JzM"} + +### Extract images and labels + +Get image shape and flatten for model input. + +```{code-cell} +:id: NHrKatD_7HbH + +train_images = mnist_dataset["train"]["image"] +train_labels = mnist_dataset["train"]["label"] +test_images = mnist_dataset["test"]["image"] +test_labels = mnist_dataset["test"]["label"] + +# Extract image shape +image_shape = train_images.shape[1:] +num_features = image_shape[0] * image_shape[1] + +# Flatten the images +train_images = train_images.reshape(-1, num_features) +test_images = test_images.reshape(-1, num_features) + +# One-hot encode the labels +train_labels = one_hot(train_labels, n_targets) +test_labels = one_hot(test_labels, n_targets) +``` + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: dITh435Z7Nwb +outputId: cc89c1ec-6987-4f1c-90a4-c3b355ea7225 +--- +print('Train:', train_images.shape, train_labels.shape) +print('Test:', test_images.shape, test_labels.shape) +``` + ++++ {"id": "kk_4zJlz7T1E"} + +### Define Training Generator + +Set up a generator to yield batches of images and labels for training. + +```{code-cell} +:id: -zLJhogj7RL- + +def hf_training_generator(): + """Yield batches for training.""" + for batch in mnist_dataset["train"].iter(batch_size): + x, y = batch["image"], batch["label"] + yield x, y +``` + ++++ {"id": "HIsGfkLI7dvZ"} + +### Training Loop (Hugging Face Datasets) + +Run the training loop using the Hugging Face training generator. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: Ui6aLiZP7aLe +outputId: c51529e0-563f-4af0-9793-76b5e6f323db +--- +train_model(num_epochs, params, hf_training_generator) +``` + ++++ {"id": "rCJq2rvKlKWX"} + +## Summary + +This notebook explored efficient methods for loading data on a GPU with JAX, using libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. You also learned GPU-specific optimizations, such as `device_put` for data transfer and memory management, to enhance training efficiency. Each methods offers unique benefits, helping you choose the best fit for your project needs. From 53701d4700a58a36efc8b67c5f3b44a9238f9471 Mon Sep 17 00:00:00 2001 From: selamw1 Date: Tue, 3 Dec 2024 15:22:02 -0800 Subject: [PATCH 5/7] =?UTF-8?q?=E2=80=9Creferece=5Ftutorial=5Flinks=5Fadde?= =?UTF-8?q?d=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/data_loaders_on_gpu_with_jax.ipynb | 88 +++++++++++++------------ docs/data_loaders_on_gpu_with_jax.md | 19 ++++-- 2 files changed, 58 insertions(+), 49 deletions(-) diff --git a/docs/data_loaders_on_gpu_with_jax.ipynb b/docs/data_loaders_on_gpu_with_jax.ipynb index f726297..40c8ddc 100644 --- a/docs/data_loaders_on_gpu_with_jax.ipynb +++ b/docs/data_loaders_on_gpu_with_jax.ipynb @@ -25,7 +25,12 @@ "\n", "You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset.\n", "\n", - "Compared to the [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html), working with GPUs introduces opportunities for further optimization, such as transferring data to the GPU using `device_put`, leveraging larger batch sizes for faster processing, and addressing considerations like memory management." + "Compared to [CPU-based loading](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html), working with a GPU introduces specific considerations like transferring data to the GPU using `device_put`, managing larger batch sizes for faster processing, and efficiently utilizing GPU memory. Unlike multi-device setups, this guide focuses on optimizing data handling for a single GPU.\n", + "\n", + "\n", + "If you're looking for CPU-specific data loading advice, see [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html).\n", + "\n", + "If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html)." ] }, { @@ -34,12 +39,12 @@ "id": "-rsMgVtO6asW" }, "source": [ - "### Import JAX API" + "## Import JAX API" ] }, { "cell_type": "code", - "execution_count": 35, + "execution_count": null, "metadata": { "id": "tDJNQ6V-Dg5g" }, @@ -56,12 +61,12 @@ "id": "TsFdlkSZKp9S" }, "source": [ - "### Checking GPU Availability for JAX" + "## Checking GPU Availability for JAX" ] }, { "cell_type": "code", - "execution_count": 36, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -91,14 +96,14 @@ "id": "qyJ_WTghDnIc" }, "source": [ - "### Setting Hyperparameters and Initializing Parameters\n", + "## Setting Hyperparameters and Initializing Parameters\n", "\n", "You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network." ] }, { "cell_type": "code", - "execution_count": 37, + "execution_count": null, "metadata": { "id": "qLNOSloFDka_" }, @@ -133,7 +138,7 @@ "id": "rHLdqeI7D2WZ" }, "source": [ - "### Model Prediction with Auto-Batching\n", + "## Model Prediction with Auto-Batching\n", "\n", "In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.\n", "\n", @@ -142,7 +147,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": null, "metadata": { "id": "bKIYPSkvD1QV" }, @@ -174,7 +179,7 @@ "id": "rLqfeORsERek" }, "source": [ - "### Utility and Loss Functions\n", + "## Utility and Loss Functions\n", "\n", "You'll now define utility functions for:\n", "- One-hot encoding: Converts class indices to binary vectors.\n", @@ -190,7 +195,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": null, "metadata": { "id": "sA0a06raEQfS" }, @@ -253,7 +258,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -287,7 +292,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": null, "metadata": { "id": "kO5_WzwY59gE" }, @@ -301,7 +306,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": null, "metadata": { "id": "6f6qU8PCc143" }, @@ -348,7 +353,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": null, "metadata": { "id": "Kxbl6bcx6crv" }, @@ -370,7 +375,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": null, "metadata": { "id": "c9ZCJq_rzPck" }, @@ -393,7 +398,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": null, "metadata": { "id": "brlLG4SqGphm" }, @@ -406,7 +411,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -446,7 +451,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": null, "metadata": { "id": "0LdT8P8aisWF" }, @@ -469,7 +474,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -510,7 +515,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": null, "metadata": { "id": "sGaQAk1DHMUx" }, @@ -532,7 +537,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": null, "metadata": { "id": "1hOamw_7C8Pb" }, @@ -556,7 +561,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -592,7 +597,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": null, "metadata": { "id": "vX59u8CqEf4J" }, @@ -620,7 +625,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -670,7 +675,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -714,7 +719,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": null, "metadata": { "id": "mS62eVL9Ifmz" }, @@ -738,7 +743,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": null, "metadata": { "id": "bnrhac5Hh7y1" }, @@ -772,7 +777,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": null, "metadata": { "id": "pN3oF7-ostGE" }, @@ -792,7 +797,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": null, "metadata": { "id": "f1VnTuX3u_kL" }, @@ -810,7 +815,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -844,7 +849,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": null, "metadata": { "id": "2jqd1jJt25Bj" }, @@ -875,7 +880,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -925,7 +930,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -979,7 +984,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": null, "metadata": { "id": "8v1N59p76zn0" }, @@ -999,7 +1004,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": null, "metadata": { "id": "a22kTvgk6_fJ" }, @@ -1021,7 +1026,7 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": null, "metadata": { "id": "NHrKatD_7HbH" }, @@ -1047,7 +1052,7 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -1083,7 +1088,7 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": null, "metadata": { "id": "-zLJhogj7RL-" }, @@ -1109,7 +1114,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -1145,7 +1150,7 @@ "source": [ "## Summary\n", "\n", - "This notebook explored efficient methods for loading data on a GPU with JAX, using libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. You also learned GPU-specific optimizations, such as `device_put` for data transfer and memory management, to enhance training efficiency. Each methods offers unique benefits, helping you choose the best fit for your project needs." + "This notebook explored efficient methods for loading data on a GPU with JAX, using libraries such as PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. You also learned GPU-specific optimizations, including using `device_put` for data transfer and managing GPU memory, to enhance training efficiency. Each method offers unique benefits, allowing you to choose the best approach based on your project requirements." ] } ], @@ -1153,7 +1158,6 @@ "accelerator": "GPU", "colab": { "gpuType": "T4", - "name": "data_loaders_on_gpu_with_jax.ipynb", "provenance": [] }, "jupytext": { diff --git a/docs/data_loaders_on_gpu_with_jax.md b/docs/data_loaders_on_gpu_with_jax.md index 4ec7487..a83ec4c 100644 --- a/docs/data_loaders_on_gpu_with_jax.md +++ b/docs/data_loaders_on_gpu_with_jax.md @@ -27,11 +27,16 @@ This tutorial explores different data loading strategies for using **JAX** on a You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset. -Compared to the [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html), working with GPUs introduces opportunities for further optimization, such as transferring data to the GPU using `device_put`, leveraging larger batch sizes for faster processing, and addressing considerations like memory management. +Compared to [CPU-based loading](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html), working with a GPU introduces specific considerations like transferring data to the GPU using `device_put`, managing larger batch sizes for faster processing, and efficiently utilizing GPU memory. Unlike multi-device setups, this guide focuses on optimizing data handling for a single GPU. + + +If you're looking for CPU-specific data loading advice, see [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html). + +If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html). +++ {"id": "-rsMgVtO6asW"} -### Import JAX API +## Import JAX API ```{code-cell} :id: tDJNQ6V-Dg5g @@ -43,7 +48,7 @@ from jax import grad, jit, vmap, random, device_put +++ {"id": "TsFdlkSZKp9S"} -### Checking GPU Availability for JAX +## Checking GPU Availability for JAX ```{code-cell} --- @@ -57,7 +62,7 @@ jax.devices() +++ {"id": "qyJ_WTghDnIc"} -### Setting Hyperparameters and Initializing Parameters +## Setting Hyperparameters and Initializing Parameters You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network. @@ -89,7 +94,7 @@ params = init_network_params(layer_sizes, random.PRNGKey(0)) +++ {"id": "rHLdqeI7D2WZ"} -### Model Prediction with Auto-Batching +## Model Prediction with Auto-Batching In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image. @@ -120,7 +125,7 @@ batched_predict = vmap(predict, in_axes=(None, 0)) +++ {"id": "rLqfeORsERek"} -### Utility and Loss Functions +## Utility and Loss Functions You'll now define utility functions for: - One-hot encoding: Converts class indices to binary vectors. @@ -642,4 +647,4 @@ train_model(num_epochs, params, hf_training_generator) ## Summary -This notebook explored efficient methods for loading data on a GPU with JAX, using libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. You also learned GPU-specific optimizations, such as `device_put` for data transfer and memory management, to enhance training efficiency. Each methods offers unique benefits, helping you choose the best fit for your project needs. +This notebook explored efficient methods for loading data on a GPU with JAX, using libraries such as PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. You also learned GPU-specific optimizations, including using `device_put` for data transfer and managing GPU memory, to enhance training efficiency. Each method offers unique benefits, allowing you to choose the best approach based on your project requirements. From e861c180c6cfa4494d5b5b8fdc6339734064fbb3 Mon Sep 17 00:00:00 2001 From: selamw1 Date: Wed, 4 Dec 2024 13:41:36 -0800 Subject: [PATCH 6/7] rebased_to_doc_source --- .../source/data_loaders_on_cpu_with_jax.ipynb | 10 +- docs/source/data_loaders_on_cpu_with_jax.md | 10 +- .../source/data_loaders_on_gpu_with_jax.ipynb | 1176 +++++++++++++++++ docs/source/data_loaders_on_gpu_with_jax.md | 650 +++++++++ 4 files changed, 1842 insertions(+), 4 deletions(-) create mode 100644 docs/source/data_loaders_on_gpu_with_jax.ipynb create mode 100644 docs/source/data_loaders_on_gpu_with_jax.md diff --git a/docs/source/data_loaders_on_cpu_with_jax.ipynb b/docs/source/data_loaders_on_cpu_with_jax.ipynb index 21bd599..34a8445 100644 --- a/docs/source/data_loaders_on_cpu_with_jax.ipynb +++ b/docs/source/data_loaders_on_cpu_with_jax.ipynb @@ -24,7 +24,13 @@ "- [**Grain**](https://github.com/google/grain)\n", "- [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)\n", "\n", - "You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset." + "In this tutorial, you'll learn how to efficiently load data using these libraries for a simple image classification task based on the MNIST dataset.\n", + "\n", + "Compared to GPU or multi-device setups, CPU-based data loading is straightforward as it avoids challenges like GPU memory management and data synchronization across devices. This makes it ideal for smaller-scale tasks or scenarios where data resides exclusively on the CPU.\n", + "\n", + "If you're looking for GPU-specific data loading advice, see [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html).\n", + "\n", + "If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html)." ] }, { @@ -1489,7 +1495,7 @@ "source": [ "## Summary\n", "\n", - "This notebook has guided you through efficient methods for loading data on a CPU when using JAX. You’ve learned how to leverage popular libraries such as PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to streamline the data loading process for your machine learning tasks. Each of these methods offers unique advantages and considerations, allowing you to choose the best approach based on the specific needs of your project." + "This notebook has introduced efficient strategies for data loading on a CPU with JAX, demonstrating how to integrate popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. Each library offers distinct advantages, enabling you to streamline the data loading process for machine learning tasks. By understanding the strengths of these methods, you can select the approach that best suits your project's specific requirements." ] } ], diff --git a/docs/source/data_loaders_on_cpu_with_jax.md b/docs/source/data_loaders_on_cpu_with_jax.md index f565d1d..d26c687 100644 --- a/docs/source/data_loaders_on_cpu_with_jax.md +++ b/docs/source/data_loaders_on_cpu_with_jax.md @@ -26,7 +26,13 @@ This tutorial explores different data loading strategies for using **JAX** on a - [**Grain**](https://github.com/google/grain) - [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading) -You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset. +In this tutorial, you'll learn how to efficiently load data using these libraries for a simple image classification task based on the MNIST dataset. + +Compared to GPU or multi-device setups, CPU-based data loading is straightforward as it avoids challenges like GPU memory management and data synchronization across devices. This makes it ideal for smaller-scale tasks or scenarios where data resides exclusively on the CPU. + +If you're looking for GPU-specific data loading advice, see [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html). + +If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html). +++ {"id": "pEsb135zE-Jo"} @@ -682,4 +688,4 @@ train_model(num_epochs, params, hf_training_generator) ## Summary -This notebook has guided you through efficient methods for loading data on a CPU when using JAX. You’ve learned how to leverage popular libraries such as PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to streamline the data loading process for your machine learning tasks. Each of these methods offers unique advantages and considerations, allowing you to choose the best approach based on the specific needs of your project. +This notebook has introduced efficient strategies for data loading on a CPU with JAX, demonstrating how to integrate popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. Each library offers distinct advantages, enabling you to streamline the data loading process for machine learning tasks. By understanding the strengths of these methods, you can select the approach that best suits your project's specific requirements. diff --git a/docs/source/data_loaders_on_gpu_with_jax.ipynb b/docs/source/data_loaders_on_gpu_with_jax.ipynb new file mode 100644 index 0000000..40c8ddc --- /dev/null +++ b/docs/source/data_loaders_on_gpu_with_jax.ipynb @@ -0,0 +1,1176 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "PUFGZggH49zp" + }, + "source": [ + "# Introduction to Data Loaders on GPU with JAX" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "3ia4PKEV5Dr8" + }, + "source": [ + "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/data_loaders_on_gpu_with_jax.ipynb)\n", + "\n", + "This tutorial explores different data loading strategies for using **JAX** on a single [**GPU**](https://jax.readthedocs.io/en/latest/glossary.html#term-GPU). While JAX doesn't include a built-in data loader, it seamlessly integrates with popular data loading libraries, including:\n", + "* [**PyTorch DataLoader**](https://github.com/pytorch/data)\n", + "* [**TensorFlow Datasets (TFDS)**](https://github.com/tensorflow/datasets)\n", + "* [**Grain**](https://github.com/google/grain)\n", + "* [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)\n", + "\n", + "You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset.\n", + "\n", + "Compared to [CPU-based loading](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html), working with a GPU introduces specific considerations like transferring data to the GPU using `device_put`, managing larger batch sizes for faster processing, and efficiently utilizing GPU memory. Unlike multi-device setups, this guide focuses on optimizing data handling for a single GPU.\n", + "\n", + "\n", + "If you're looking for CPU-specific data loading advice, see [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html).\n", + "\n", + "If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html)." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-rsMgVtO6asW" + }, + "source": [ + "## Import JAX API" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "tDJNQ6V-Dg5g" + }, + "outputs": [], + "source": [ + "import jax\n", + "import jax.numpy as jnp\n", + "from jax import grad, jit, vmap, random, device_put" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "TsFdlkSZKp9S" + }, + "source": [ + "## Checking GPU Availability for JAX" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "N3sqvaF3KJw1", + "outputId": "ab40f542-b8c0-422c-ca68-4ce292817889" + }, + "outputs": [ + { + "data": { + "text/plain": [ + "[CudaDevice(id=0)]" + ] + }, + "execution_count": 36, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "jax.devices()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qyJ_WTghDnIc" + }, + "source": [ + "## Setting Hyperparameters and Initializing Parameters\n", + "\n", + "You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qLNOSloFDka_" + }, + "outputs": [], + "source": [ + "# A helper function to randomly initialize weights and biases\n", + "# for a dense neural network layer\n", + "def random_layer_params(m, n, key, scale=1e-2):\n", + " w_key, b_key = random.split(key)\n", + " return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))\n", + "\n", + "# Function to initialize network parameters for all layers based on defined sizes\n", + "def init_network_params(sizes, key):\n", + " keys = random.split(key, len(sizes))\n", + " return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]\n", + "\n", + "layer_sizes = [784, 512, 512, 10] # Layers of the network\n", + "step_size = 0.01 # Learning rate\n", + "num_epochs = 8 # Number of training epochs\n", + "batch_size = 128 # Batch size for training\n", + "n_targets = 10 # Number of classes (digits 0-9)\n", + "num_pixels = 28 * 28 # Each MNIST image is 28x28 pixels\n", + "data_dir = '/tmp/mnist_dataset' # Directory for storing the dataset\n", + "\n", + "# Initialize network parameters using the defined layer sizes and a random seed\n", + "params = init_network_params(layer_sizes, random.PRNGKey(0))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rHLdqeI7D2WZ" + }, + "source": [ + "## Model Prediction with Auto-Batching\n", + "\n", + "In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.\n", + "\n", + "To efficiently process multiple images simultaneously, you'll use [`vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap), which allows you to vectorize the `predict` function and apply it across a batch of inputs. This technique, called auto-batching, improves computational efficiency by leveraging hardware acceleration." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bKIYPSkvD1QV" + }, + "outputs": [], + "source": [ + "from jax.scipy.special import logsumexp\n", + "\n", + "def relu(x):\n", + " return jnp.maximum(0, x)\n", + "\n", + "def predict(params, image):\n", + " # per-example predictions\n", + " activations = image\n", + " for w, b in params[:-1]:\n", + " outputs = jnp.dot(w, activations) + b\n", + " activations = relu(outputs)\n", + "\n", + " final_w, final_b = params[-1]\n", + " logits = jnp.dot(final_w, activations) + final_b\n", + " return logits - logsumexp(logits)\n", + "\n", + "# Make a batched version of the `predict` function\n", + "batched_predict = vmap(predict, in_axes=(None, 0))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rLqfeORsERek" + }, + "source": [ + "## Utility and Loss Functions\n", + "\n", + "You'll now define utility functions for:\n", + "- One-hot encoding: Converts class indices to binary vectors.\n", + "- Accuracy calculation: Measures the performance of the model on the dataset.\n", + "- Loss computation: Calculates the difference between predictions and targets.\n", + "\n", + "To optimize performance:\n", + "- [`grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) is used to compute gradients of the loss function with respect to network parameters.\n", + "- [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) compiles the update function, enabling faster execution by leveraging JAX's [XLA](https://openxla.org/xla) compilation.\n", + "\n", + "- [`device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html) to transfer the dataset to the GPU." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sA0a06raEQfS" + }, + "outputs": [], + "source": [ + "import time\n", + "\n", + "def one_hot(x, k, dtype=jnp.float32):\n", + " \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n", + " return jnp.array(x[:, None] == jnp.arange(k), dtype)\n", + "\n", + "def accuracy(params, images, targets):\n", + " \"\"\"Calculate the accuracy of predictions.\"\"\"\n", + " target_class = jnp.argmax(targets, axis=1)\n", + " predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n", + " return jnp.mean(predicted_class == target_class)\n", + "\n", + "def loss(params, images, targets):\n", + " \"\"\"Calculate the loss between predictions and targets.\"\"\"\n", + " preds = batched_predict(params, images)\n", + " return -jnp.mean(preds * targets)\n", + "\n", + "@jit\n", + "def update(params, x, y):\n", + " \"\"\"Update the network parameters using gradient descent.\"\"\"\n", + " grads = grad(loss)(params, x, y)\n", + " return [(w - step_size * dw, b - step_size * db)\n", + " for (w, b), (dw, db) in zip(params, grads)]\n", + "\n", + "def reshape_and_one_hot(x, y):\n", + " \"\"\"Reshape and one-hot encode the inputs.\"\"\"\n", + " x = jnp.reshape(x, (len(x), num_pixels))\n", + " y = one_hot(y, n_targets)\n", + " return x, y\n", + "\n", + "def train_model(num_epochs, params, training_generator, data_loader_type='streamed'):\n", + " \"\"\"Train the model for a given number of epochs and device_put for GPU transfer.\"\"\"\n", + " for epoch in range(num_epochs):\n", + " start_time = time.time()\n", + " for x, y in training_generator() if data_loader_type == 'streamed' else training_generator:\n", + " x, y = reshape_and_one_hot(x, y)\n", + " x, y = device_put(x), device_put(y)\n", + " params = update(params, x, y)\n", + "\n", + " print(f\"Epoch {epoch + 1} in {time.time() - start_time:.2f} sec: \"\n", + " f\"Train Accuracy: {accuracy(params, train_images, train_labels):.4f}, \"\n", + " f\"Test Accuracy: {accuracy(params, test_images, test_labels):.4f}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Hsionp5IYsQ9" + }, + "source": [ + "## Loading Data with PyTorch DataLoader\n", + "\n", + "This section shows how to load the MNIST dataset using PyTorch's DataLoader, convert the data to NumPy arrays, and apply transformations to flatten and cast images." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "uA7XY0OezHse", + "outputId": "4c86f455-ff1d-474e-f8e3-7111d9b56996" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.5.1+cu121)\n", + "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.20.1+cu121)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1)\n", + "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", + "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.4.2)\n", + "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.9.0)\n", + "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n", + "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.26.4)\n", + "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (11.0.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (3.0.2)\n" + ] + } + ], + "source": [ + "!pip install torch torchvision" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kO5_WzwY59gE" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "from jax.tree_util import tree_map\n", + "from torch.utils import data\n", + "from torchvision.datasets import MNIST" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6f6qU8PCc143" + }, + "outputs": [], + "source": [ + "def numpy_collate(batch):\n", + " \"\"\"Collate function to convert a batch of PyTorch data into NumPy arrays.\"\"\"\n", + " return tree_map(np.asarray, data.default_collate(batch))\n", + "\n", + "class NumpyLoader(data.DataLoader):\n", + " \"\"\"Custom DataLoader to return NumPy arrays from a PyTorch Dataset.\"\"\"\n", + " def __init__(self, dataset, batch_size=1,\n", + " shuffle=False, sampler=None,\n", + " batch_sampler=None, num_workers=0,\n", + " pin_memory=False, drop_last=False,\n", + " timeout=0, worker_init_fn=None):\n", + " super(self.__class__, self).__init__(dataset,\n", + " batch_size=batch_size,\n", + " shuffle=shuffle,\n", + " sampler=sampler,\n", + " batch_sampler=batch_sampler,\n", + " num_workers=num_workers,\n", + " collate_fn=numpy_collate,\n", + " pin_memory=pin_memory,\n", + " drop_last=drop_last,\n", + " timeout=timeout,\n", + " worker_init_fn=worker_init_fn)\n", + "class FlattenAndCast(object):\n", + " \"\"\"Transform class to flatten and cast images to float32.\"\"\"\n", + " def __call__(self, pic):\n", + " return np.ravel(np.array(pic, dtype=jnp.float32))" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mfSnfJND6I8G" + }, + "source": [ + "### Load Dataset with Transformations\n", + "\n", + "Standardize the data by flattening the images, casting them to `float32`, and ensuring consistent data types." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Kxbl6bcx6crv" + }, + "outputs": [], + "source": [ + "mnist_dataset = MNIST(data_dir, download=True, transform=FlattenAndCast())" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kbdsqvPZGrsa" + }, + "source": [ + "### Full Training Dataset for Accuracy Checks\n", + "\n", + "Convert the entire training dataset to JAX arrays." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "c9ZCJq_rzPck" + }, + "outputs": [], + "source": [ + "train_images = np.array(mnist_dataset.data).reshape(len(mnist_dataset.data), -1)\n", + "train_labels = one_hot(np.array(mnist_dataset.targets), n_targets)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WXUh0BwvG8Ko" + }, + "source": [ + "### Get Full Test Dataset\n", + "\n", + "Load and process the full test dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "brlLG4SqGphm" + }, + "outputs": [], + "source": [ + "mnist_dataset_test = MNIST(data_dir, download=True, train=False)\n", + "test_images = jnp.array(mnist_dataset_test.data.numpy().reshape(len(mnist_dataset_test.data), -1), dtype=jnp.float32)\n", + "test_labels = one_hot(np.array(mnist_dataset_test.targets), n_targets)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Oz-UVnCxG5E8", + "outputId": "53f3fb32-5096-4862-e022-3c3a1d82137a" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train: (60000, 784) (60000, 10)\n", + "Test: (10000, 784) (10000, 10)\n" + ] + } + ], + "source": [ + "print('Train:', train_images.shape, train_labels.shape)\n", + "print('Test:', test_images.shape, test_labels.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mNjn9dMPitKL" + }, + "source": [ + "### Training Data Generator\n", + "\n", + "Define a generator function using PyTorch's DataLoader for batch training.\n", + "Setting `num_workers > 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload.\n", + "\n", + "Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.`\n", + "This warning can be safely ignored since data loaders do not use JAX within the forked processes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "0LdT8P8aisWF" + }, + "outputs": [], + "source": [ + "def pytorch_training_generator(mnist_dataset):\n", + " return NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Xzt2x9S1HC3T" + }, + "source": [ + "### Training Loop (PyTorch DataLoader)\n", + "\n", + "The training loop uses the PyTorch DataLoader to iterate through batches and update model parameters." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "SqweRz_98sN8", + "outputId": "bdd45256-3f5a-48f7-e45c-378078ac4279" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 in 20.23 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9195\n", + "Epoch 2 in 14.64 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9385\n", + "Epoch 3 in 3.91 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9467\n", + "Epoch 4 in 3.85 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532\n", + "Epoch 5 in 4.48 sec: Train Accuracy: 0.9631, Test Accuracy: 0.9577\n", + "Epoch 6 in 4.03 sec: Train Accuracy: 0.9675, Test Accuracy: 0.9617\n", + "Epoch 7 in 3.86 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9652\n", + "Epoch 8 in 4.57 sec: Train Accuracy: 0.9736, Test Accuracy: 0.9671\n" + ] + } + ], + "source": [ + "train_model(num_epochs, params, pytorch_training_generator(mnist_dataset), data_loader_type='iterable')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Nm45ZTo6yrf5" + }, + "source": [ + "## Loading Data with TensorFlow Datasets (TFDS)\n", + "\n", + "This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlow." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "sGaQAk1DHMUx" + }, + "outputs": [], + "source": [ + "import tensorflow_datasets as tfds" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ZSc5K0Eiwm4L" + }, + "source": [ + "### Fetch Full Dataset for Evaluation\n", + "\n", + "Load the dataset with `tfds.load`, convert it to NumPy arrays, and process it for evaluation." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "1hOamw_7C8Pb" + }, + "outputs": [], + "source": [ + "# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)\n", + "mnist_data, info = tfds.load(name=\"mnist\", batch_size=-1, data_dir=data_dir, with_info=True)\n", + "mnist_data = tfds.as_numpy(mnist_data)\n", + "train_data, test_data = mnist_data['train'], mnist_data['test']\n", + "\n", + "# Full train set\n", + "train_images, train_labels = train_data['image'], train_data['label']\n", + "train_images = jnp.reshape(train_images, (len(train_images), num_pixels))\n", + "train_labels = one_hot(train_labels, n_targets)\n", + "\n", + "# Full test set\n", + "test_images, test_labels = test_data['image'], test_data['label']\n", + "test_images = jnp.reshape(test_images, (len(test_images), num_pixels))\n", + "test_labels = one_hot(test_labels, n_targets)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Td3PiLdmEf7z", + "outputId": "b8c9a32a-9cf0-4dc3-cb51-db21d32c6545" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train: (60000, 784) (60000, 10)\n", + "Test: (10000, 784) (10000, 10)\n" + ] + } + ], + "source": [ + "print('Train:', train_images.shape, train_labels.shape)\n", + "print('Test:', test_images.shape, test_labels.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "dXMvgk6sdq4j" + }, + "source": [ + "### Define the Training Generator\n", + "\n", + "Create a generator function to yield batches of data for training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "vX59u8CqEf4J" + }, + "outputs": [], + "source": [ + "def training_generator():\n", + " # as_supervised=True gives us the (image, label) as a tuple instead of a dict\n", + " ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)\n", + " # You can build up an arbitrary tf.data input pipeline\n", + " ds = ds.batch(batch_size).prefetch(1)\n", + " # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays\n", + " return tfds.as_numpy(ds)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EAWeUdnuFNBY" + }, + "source": [ + "### Training Loop (TFDS)\n", + "\n", + "Use the training generator in a custom training loop." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "h2sO13XDGvq1", + "outputId": "f30805bb-e725-46ee-e053-6e97f2af81c5" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 in 20.86 sec: Train Accuracy: 0.9253, Test Accuracy: 0.9268\n", + "Epoch 2 in 8.56 sec: Train Accuracy: 0.9428, Test Accuracy: 0.9413\n", + "Epoch 3 in 5.40 sec: Train Accuracy: 0.9532, Test Accuracy: 0.9511\n", + "Epoch 4 in 3.86 sec: Train Accuracy: 0.9598, Test Accuracy: 0.9555\n", + "Epoch 5 in 3.88 sec: Train Accuracy: 0.9652, Test Accuracy: 0.9601\n", + "Epoch 6 in 10.35 sec: Train Accuracy: 0.9692, Test Accuracy: 0.9631\n", + "Epoch 7 in 4.39 sec: Train Accuracy: 0.9726, Test Accuracy: 0.9650\n", + "Epoch 8 in 4.77 sec: Train Accuracy: 0.9753, Test Accuracy: 0.9669\n" + ] + } + ], + "source": [ + "train_model(num_epochs, params, training_generator)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "-ryVkrAITS9Z" + }, + "source": [ + "## Loading Data with Grain\n", + "\n", + "This section demonstrates how to load MNIST data using Grain, a data-loading library. You'll define a custom dataset class for Grain and set up a Grain DataLoader for efficient training." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "waYhUMUGmhH-" + }, + "source": [ + "Install Grain" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "L78o7eeyGvn5", + "outputId": "cb0ce6cf-243b-4183-8f63-646e00232caa" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: grain in /usr/local/lib/python3.10/dist-packages (0.2.2)\n", + "Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from grain) (1.4.0)\n", + "Requirement already satisfied: array-record in /usr/local/lib/python3.10/dist-packages (from grain) (0.5.1)\n", + "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.10/dist-packages (from grain) (3.1.0)\n", + "Requirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from grain) (0.1.8)\n", + "Requirement already satisfied: etils[epath,epy] in /usr/local/lib/python3.10/dist-packages (from grain) (1.10.0)\n", + "Requirement already satisfied: jaxtyping in /usr/local/lib/python3.10/dist-packages (from grain) (0.2.36)\n", + "Requirement already satisfied: more-itertools>=9.1.0 in /usr/local/lib/python3.10/dist-packages (from grain) (10.5.0)\n", + "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from grain) (1.26.4)\n", + "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (2024.9.0)\n", + "Requirement already satisfied: importlib_resources in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (6.4.5)\n", + "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (4.12.2)\n", + "Requirement already satisfied: zipp in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (3.21.0)\n" + ] + } + ], + "source": [ + "!pip install grain" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "66bH3ZDJ7Iat" + }, + "source": [ + "Import Required Libraries (import MNIST dataset from torchvision)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mS62eVL9Ifmz" + }, + "outputs": [], + "source": [ + "import numpy as np\n", + "import grain.python as pygrain\n", + "from torchvision.datasets import MNIST" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "0h6mwVrspPA-" + }, + "source": [ + "### Define Dataset Class\n", + "\n", + "Create a custom dataset class to load MNIST data for Grain." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "bnrhac5Hh7y1" + }, + "outputs": [], + "source": [ + "class Dataset:\n", + " def __init__(self, data_dir, train=True):\n", + " self.data_dir = data_dir\n", + " self.train = train\n", + " self.load_data()\n", + "\n", + " def load_data(self):\n", + " self.dataset = MNIST(self.data_dir, download=True, train=self.train)\n", + "\n", + " def __len__(self):\n", + " return len(self.dataset)\n", + "\n", + " def __getitem__(self, index):\n", + " img, label = self.dataset[index]\n", + " return np.ravel(np.array(img, dtype=np.float32)), label" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "53mf8bWEsyTr" + }, + "source": [ + "### Initialize the Dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "pN3oF7-ostGE" + }, + "outputs": [], + "source": [ + "mnist_dataset = Dataset(data_dir)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "GqD-ycgBuwv9" + }, + "source": [ + "### Get the full train and test dataset" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "f1VnTuX3u_kL" + }, + "outputs": [], + "source": [ + "# Convert training data to JAX arrays and encode labels as one-hot vectors\n", + "train_images = jnp.array([mnist_dataset[i][0] for i in range(len(mnist_dataset))], dtype=jnp.float32)\n", + "train_labels = one_hot(np.array([mnist_dataset[i][1] for i in range(len(mnist_dataset))]), n_targets)\n", + "\n", + "# Load test dataset and process it\n", + "mnist_dataset_test = MNIST(data_dir, download=True, train=False)\n", + "test_images = jnp.array([np.ravel(np.array(mnist_dataset_test[i][0], dtype=np.float32)) for i in range(len(mnist_dataset_test))], dtype=jnp.float32)\n", + "test_labels = one_hot(np.array([mnist_dataset_test[i][1] for i in range(len(mnist_dataset_test))]), n_targets)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "a2NHlp9klrQL", + "outputId": "c9422190-55e9-400b-bd4e-0e7bf23dc6a1" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train: (60000, 784) (60000, 10)\n", + "Test: (10000, 784) (10000, 10)\n" + ] + } + ], + "source": [ + "print(\"Train:\", train_images.shape, train_labels.shape)\n", + "print(\"Test:\", test_images.shape, test_labels.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1QPbXt7O0JN-" + }, + "source": [ + "### Initialize PyGrain DataLoader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2jqd1jJt25Bj" + }, + "outputs": [], + "source": [ + "sampler = pygrain.SequentialSampler(\n", + " num_records=len(mnist_dataset),\n", + " shard_options=pygrain.NoSharding()) # Single-device, no sharding\n", + "\n", + "def pygrain_training_generator():\n", + " return pygrain.DataLoader(\n", + " data_source=mnist_dataset,\n", + " sampler=sampler,\n", + " operations=[pygrain.Batch(batch_size)],\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "mV5z4GLCGKlx" + }, + "source": [ + "### Training Loop (Grain)\n", + "\n", + "Run the training loop using the Grain DataLoader." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "9-iANQ-9CcW_", + "outputId": "b0e19da2-9e34-4183-c5d8-af66de5efa5c" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 in 15.65 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9195\n", + "Epoch 2 in 15.03 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9385\n", + "Epoch 3 in 14.93 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9467\n", + "Epoch 4 in 11.56 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532\n", + "Epoch 5 in 9.33 sec: Train Accuracy: 0.9631, Test Accuracy: 0.9577\n", + "Epoch 6 in 9.31 sec: Train Accuracy: 0.9675, Test Accuracy: 0.9617\n", + "Epoch 7 in 9.78 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9652\n", + "Epoch 8 in 9.80 sec: Train Accuracy: 0.9736, Test Accuracy: 0.9671\n" + ] + } + ], + "source": [ + "train_model(num_epochs, params, pygrain_training_generator)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "o51P6lr86wz-" + }, + "source": [ + "## Loading Data with Hugging Face\n", + "\n", + "This section demonstrates loading MNIST data using the Hugging Face `datasets` library. You'll format the dataset for JAX compatibility, prepare flattened images and one-hot-encoded labels, and define a training generator." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "69vrihaOi4Oz" + }, + "source": [ + "Install the Hugging Face `datasets` library." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "19ipxPhI6oSN", + "outputId": "b80b80cd-fc14-4a43-f8a8-2802de4faade" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (3.1.0)\n", + "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.1)\n", + "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n", + "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n", + "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n", + "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n", + "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n", + "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.6)\n", + "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.5.0)\n", + "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n", + "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.9.0)\n", + "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.11.2)\n", + "Requirement already satisfied: huggingface-hub>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.26.2)\n", + "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.2)\n", + "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n", + "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.3)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", + "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.5.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n", + "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (0.2.0)\n", + "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.17.2)\n", + "Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.23.0->datasets) (4.12.2)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.4.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.10)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.2.3)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.8.30)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", + "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", + "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n" + ] + } + ], + "source": [ + "!pip install datasets" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "8v1N59p76zn0" + }, + "outputs": [], + "source": [ + "from datasets import load_dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "8Gaj11tO7C86" + }, + "source": [ + "Load the MNIST dataset from Hugging Face and format it as `numpy` arrays for quick access or `jax` to get JAX arrays." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "a22kTvgk6_fJ" + }, + "outputs": [], + "source": [ + "mnist_dataset = load_dataset(\"mnist\", cache_dir=data_dir).with_format(\"numpy\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "tgI7dIaX7JzM" + }, + "source": [ + "### Extract images and labels\n", + "\n", + "Get image shape and flatten for model input." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "NHrKatD_7HbH" + }, + "outputs": [], + "source": [ + "train_images = mnist_dataset[\"train\"][\"image\"]\n", + "train_labels = mnist_dataset[\"train\"][\"label\"]\n", + "test_images = mnist_dataset[\"test\"][\"image\"]\n", + "test_labels = mnist_dataset[\"test\"][\"label\"]\n", + "\n", + "# Extract image shape\n", + "image_shape = train_images.shape[1:]\n", + "num_features = image_shape[0] * image_shape[1]\n", + "\n", + "# Flatten the images\n", + "train_images = train_images.reshape(-1, num_features)\n", + "test_images = test_images.reshape(-1, num_features)\n", + "\n", + "# One-hot encode the labels\n", + "train_labels = one_hot(train_labels, n_targets)\n", + "test_labels = one_hot(test_labels, n_targets)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "dITh435Z7Nwb", + "outputId": "cc89c1ec-6987-4f1c-90a4-c3b355ea7225" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train: (60000, 784) (60000, 10)\n", + "Test: (10000, 784) (10000, 10)\n" + ] + } + ], + "source": [ + "print('Train:', train_images.shape, train_labels.shape)\n", + "print('Test:', test_images.shape, test_labels.shape)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "kk_4zJlz7T1E" + }, + "source": [ + "### Define Training Generator\n", + "\n", + "Set up a generator to yield batches of images and labels for training." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "-zLJhogj7RL-" + }, + "outputs": [], + "source": [ + "def hf_training_generator():\n", + " \"\"\"Yield batches for training.\"\"\"\n", + " for batch in mnist_dataset[\"train\"].iter(batch_size):\n", + " x, y = batch[\"image\"], batch[\"label\"]\n", + " yield x, y" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "HIsGfkLI7dvZ" + }, + "source": [ + "### Training Loop (Hugging Face Datasets)\n", + "\n", + "Run the training loop using the Hugging Face training generator." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "Ui6aLiZP7aLe", + "outputId": "c51529e0-563f-4af0-9793-76b5e6f323db" + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Epoch 1 in 19.06 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9195\n", + "Epoch 2 in 8.94 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9385\n", + "Epoch 3 in 5.43 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9467\n", + "Epoch 4 in 6.41 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532\n", + "Epoch 5 in 5.80 sec: Train Accuracy: 0.9631, Test Accuracy: 0.9577\n", + "Epoch 6 in 6.61 sec: Train Accuracy: 0.9675, Test Accuracy: 0.9617\n", + "Epoch 7 in 5.49 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9652\n", + "Epoch 8 in 6.64 sec: Train Accuracy: 0.9736, Test Accuracy: 0.9671\n" + ] + } + ], + "source": [ + "train_model(num_epochs, params, hf_training_generator)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rCJq2rvKlKWX" + }, + "source": [ + "## Summary\n", + "\n", + "This notebook explored efficient methods for loading data on a GPU with JAX, using libraries such as PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. You also learned GPU-specific optimizations, including using `device_put` for data transfer and managing GPU memory, to enhance training efficiency. Each method offers unique benefits, allowing you to choose the best approach based on your project requirements." + ] + } + ], + "metadata": { + "accelerator": "GPU", + "colab": { + "gpuType": "T4", + "provenance": [] + }, + "jupytext": { + "formats": "ipynb,md:myst" + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/source/data_loaders_on_gpu_with_jax.md b/docs/source/data_loaders_on_gpu_with_jax.md new file mode 100644 index 0000000..a83ec4c --- /dev/null +++ b/docs/source/data_loaders_on_gpu_with_jax.md @@ -0,0 +1,650 @@ +--- +jupytext: + formats: ipynb,md:myst + text_representation: + extension: .md + format_name: myst + format_version: 0.13 + jupytext_version: 1.15.2 +kernelspec: + display_name: Python 3 + name: python3 +--- + ++++ {"id": "PUFGZggH49zp"} + +# Introduction to Data Loaders on GPU with JAX + ++++ {"id": "3ia4PKEV5Dr8"} + +[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/data_loaders_on_gpu_with_jax.ipynb) + +This tutorial explores different data loading strategies for using **JAX** on a single [**GPU**](https://jax.readthedocs.io/en/latest/glossary.html#term-GPU). While JAX doesn't include a built-in data loader, it seamlessly integrates with popular data loading libraries, including: +* [**PyTorch DataLoader**](https://github.com/pytorch/data) +* [**TensorFlow Datasets (TFDS)**](https://github.com/tensorflow/datasets) +* [**Grain**](https://github.com/google/grain) +* [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading) + +You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset. + +Compared to [CPU-based loading](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html), working with a GPU introduces specific considerations like transferring data to the GPU using `device_put`, managing larger batch sizes for faster processing, and efficiently utilizing GPU memory. Unlike multi-device setups, this guide focuses on optimizing data handling for a single GPU. + + +If you're looking for CPU-specific data loading advice, see [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html). + +If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html). + ++++ {"id": "-rsMgVtO6asW"} + +## Import JAX API + +```{code-cell} +:id: tDJNQ6V-Dg5g + +import jax +import jax.numpy as jnp +from jax import grad, jit, vmap, random, device_put +``` + ++++ {"id": "TsFdlkSZKp9S"} + +## Checking GPU Availability for JAX + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: N3sqvaF3KJw1 +outputId: ab40f542-b8c0-422c-ca68-4ce292817889 +--- +jax.devices() +``` + ++++ {"id": "qyJ_WTghDnIc"} + +## Setting Hyperparameters and Initializing Parameters + +You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network. + +```{code-cell} +:id: qLNOSloFDka_ + +# A helper function to randomly initialize weights and biases +# for a dense neural network layer +def random_layer_params(m, n, key, scale=1e-2): + w_key, b_key = random.split(key) + return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,)) + +# Function to initialize network parameters for all layers based on defined sizes +def init_network_params(sizes, key): + keys = random.split(key, len(sizes)) + return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)] + +layer_sizes = [784, 512, 512, 10] # Layers of the network +step_size = 0.01 # Learning rate +num_epochs = 8 # Number of training epochs +batch_size = 128 # Batch size for training +n_targets = 10 # Number of classes (digits 0-9) +num_pixels = 28 * 28 # Each MNIST image is 28x28 pixels +data_dir = '/tmp/mnist_dataset' # Directory for storing the dataset + +# Initialize network parameters using the defined layer sizes and a random seed +params = init_network_params(layer_sizes, random.PRNGKey(0)) +``` + ++++ {"id": "rHLdqeI7D2WZ"} + +## Model Prediction with Auto-Batching + +In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image. + +To efficiently process multiple images simultaneously, you'll use [`vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap), which allows you to vectorize the `predict` function and apply it across a batch of inputs. This technique, called auto-batching, improves computational efficiency by leveraging hardware acceleration. + +```{code-cell} +:id: bKIYPSkvD1QV + +from jax.scipy.special import logsumexp + +def relu(x): + return jnp.maximum(0, x) + +def predict(params, image): + # per-example predictions + activations = image + for w, b in params[:-1]: + outputs = jnp.dot(w, activations) + b + activations = relu(outputs) + + final_w, final_b = params[-1] + logits = jnp.dot(final_w, activations) + final_b + return logits - logsumexp(logits) + +# Make a batched version of the `predict` function +batched_predict = vmap(predict, in_axes=(None, 0)) +``` + ++++ {"id": "rLqfeORsERek"} + +## Utility and Loss Functions + +You'll now define utility functions for: +- One-hot encoding: Converts class indices to binary vectors. +- Accuracy calculation: Measures the performance of the model on the dataset. +- Loss computation: Calculates the difference between predictions and targets. + +To optimize performance: +- [`grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) is used to compute gradients of the loss function with respect to network parameters. +- [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) compiles the update function, enabling faster execution by leveraging JAX's [XLA](https://openxla.org/xla) compilation. + +- [`device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html) to transfer the dataset to the GPU. + +```{code-cell} +:id: sA0a06raEQfS + +import time + +def one_hot(x, k, dtype=jnp.float32): + """Create a one-hot encoding of x of size k.""" + return jnp.array(x[:, None] == jnp.arange(k), dtype) + +def accuracy(params, images, targets): + """Calculate the accuracy of predictions.""" + target_class = jnp.argmax(targets, axis=1) + predicted_class = jnp.argmax(batched_predict(params, images), axis=1) + return jnp.mean(predicted_class == target_class) + +def loss(params, images, targets): + """Calculate the loss between predictions and targets.""" + preds = batched_predict(params, images) + return -jnp.mean(preds * targets) + +@jit +def update(params, x, y): + """Update the network parameters using gradient descent.""" + grads = grad(loss)(params, x, y) + return [(w - step_size * dw, b - step_size * db) + for (w, b), (dw, db) in zip(params, grads)] + +def reshape_and_one_hot(x, y): + """Reshape and one-hot encode the inputs.""" + x = jnp.reshape(x, (len(x), num_pixels)) + y = one_hot(y, n_targets) + return x, y + +def train_model(num_epochs, params, training_generator, data_loader_type='streamed'): + """Train the model for a given number of epochs and device_put for GPU transfer.""" + for epoch in range(num_epochs): + start_time = time.time() + for x, y in training_generator() if data_loader_type == 'streamed' else training_generator: + x, y = reshape_and_one_hot(x, y) + x, y = device_put(x), device_put(y) + params = update(params, x, y) + + print(f"Epoch {epoch + 1} in {time.time() - start_time:.2f} sec: " + f"Train Accuracy: {accuracy(params, train_images, train_labels):.4f}, " + f"Test Accuracy: {accuracy(params, test_images, test_labels):.4f}") +``` + ++++ {"id": "Hsionp5IYsQ9"} + +## Loading Data with PyTorch DataLoader + +This section shows how to load the MNIST dataset using PyTorch's DataLoader, convert the data to NumPy arrays, and apply transformations to flatten and cast images. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: uA7XY0OezHse +outputId: 4c86f455-ff1d-474e-f8e3-7111d9b56996 +--- +!pip install torch torchvision +``` + +```{code-cell} +:id: kO5_WzwY59gE + +import numpy as np +from jax.tree_util import tree_map +from torch.utils import data +from torchvision.datasets import MNIST +``` + +```{code-cell} +:id: 6f6qU8PCc143 + +def numpy_collate(batch): + """Collate function to convert a batch of PyTorch data into NumPy arrays.""" + return tree_map(np.asarray, data.default_collate(batch)) + +class NumpyLoader(data.DataLoader): + """Custom DataLoader to return NumPy arrays from a PyTorch Dataset.""" + def __init__(self, dataset, batch_size=1, + shuffle=False, sampler=None, + batch_sampler=None, num_workers=0, + pin_memory=False, drop_last=False, + timeout=0, worker_init_fn=None): + super(self.__class__, self).__init__(dataset, + batch_size=batch_size, + shuffle=shuffle, + sampler=sampler, + batch_sampler=batch_sampler, + num_workers=num_workers, + collate_fn=numpy_collate, + pin_memory=pin_memory, + drop_last=drop_last, + timeout=timeout, + worker_init_fn=worker_init_fn) +class FlattenAndCast(object): + """Transform class to flatten and cast images to float32.""" + def __call__(self, pic): + return np.ravel(np.array(pic, dtype=jnp.float32)) +``` + ++++ {"id": "mfSnfJND6I8G"} + +### Load Dataset with Transformations + +Standardize the data by flattening the images, casting them to `float32`, and ensuring consistent data types. + +```{code-cell} +:id: Kxbl6bcx6crv + +mnist_dataset = MNIST(data_dir, download=True, transform=FlattenAndCast()) +``` + ++++ {"id": "kbdsqvPZGrsa"} + +### Full Training Dataset for Accuracy Checks + +Convert the entire training dataset to JAX arrays. + +```{code-cell} +:id: c9ZCJq_rzPck + +train_images = np.array(mnist_dataset.data).reshape(len(mnist_dataset.data), -1) +train_labels = one_hot(np.array(mnist_dataset.targets), n_targets) +``` + ++++ {"id": "WXUh0BwvG8Ko"} + +### Get Full Test Dataset + +Load and process the full test dataset. + +```{code-cell} +:id: brlLG4SqGphm + +mnist_dataset_test = MNIST(data_dir, download=True, train=False) +test_images = jnp.array(mnist_dataset_test.data.numpy().reshape(len(mnist_dataset_test.data), -1), dtype=jnp.float32) +test_labels = one_hot(np.array(mnist_dataset_test.targets), n_targets) +``` + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: Oz-UVnCxG5E8 +outputId: 53f3fb32-5096-4862-e022-3c3a1d82137a +--- +print('Train:', train_images.shape, train_labels.shape) +print('Test:', test_images.shape, test_labels.shape) +``` + ++++ {"id": "mNjn9dMPitKL"} + +### Training Data Generator + +Define a generator function using PyTorch's DataLoader for batch training. +Setting `num_workers > 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload. + +Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.` +This warning can be safely ignored since data loaders do not use JAX within the forked processes. + +```{code-cell} +:id: 0LdT8P8aisWF + +def pytorch_training_generator(mnist_dataset): + return NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0) +``` + ++++ {"id": "Xzt2x9S1HC3T"} + +### Training Loop (PyTorch DataLoader) + +The training loop uses the PyTorch DataLoader to iterate through batches and update model parameters. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: SqweRz_98sN8 +outputId: bdd45256-3f5a-48f7-e45c-378078ac4279 +--- +train_model(num_epochs, params, pytorch_training_generator(mnist_dataset), data_loader_type='iterable') +``` + ++++ {"id": "Nm45ZTo6yrf5"} + +## Loading Data with TensorFlow Datasets (TFDS) + +This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlow. + +```{code-cell} +:id: sGaQAk1DHMUx + +import tensorflow_datasets as tfds +``` + ++++ {"id": "ZSc5K0Eiwm4L"} + +### Fetch Full Dataset for Evaluation + +Load the dataset with `tfds.load`, convert it to NumPy arrays, and process it for evaluation. + +```{code-cell} +:id: 1hOamw_7C8Pb + +# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1) +mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True) +mnist_data = tfds.as_numpy(mnist_data) +train_data, test_data = mnist_data['train'], mnist_data['test'] + +# Full train set +train_images, train_labels = train_data['image'], train_data['label'] +train_images = jnp.reshape(train_images, (len(train_images), num_pixels)) +train_labels = one_hot(train_labels, n_targets) + +# Full test set +test_images, test_labels = test_data['image'], test_data['label'] +test_images = jnp.reshape(test_images, (len(test_images), num_pixels)) +test_labels = one_hot(test_labels, n_targets) +``` + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: Td3PiLdmEf7z +outputId: b8c9a32a-9cf0-4dc3-cb51-db21d32c6545 +--- +print('Train:', train_images.shape, train_labels.shape) +print('Test:', test_images.shape, test_labels.shape) +``` + ++++ {"id": "dXMvgk6sdq4j"} + +### Define the Training Generator + +Create a generator function to yield batches of data for training. + +```{code-cell} +:id: vX59u8CqEf4J + +def training_generator(): + # as_supervised=True gives us the (image, label) as a tuple instead of a dict + ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir) + # You can build up an arbitrary tf.data input pipeline + ds = ds.batch(batch_size).prefetch(1) + # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays + return tfds.as_numpy(ds) +``` + ++++ {"id": "EAWeUdnuFNBY"} + +### Training Loop (TFDS) + +Use the training generator in a custom training loop. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: h2sO13XDGvq1 +outputId: f30805bb-e725-46ee-e053-6e97f2af81c5 +--- +train_model(num_epochs, params, training_generator) +``` + ++++ {"id": "-ryVkrAITS9Z"} + +## Loading Data with Grain + +This section demonstrates how to load MNIST data using Grain, a data-loading library. You'll define a custom dataset class for Grain and set up a Grain DataLoader for efficient training. + ++++ {"id": "waYhUMUGmhH-"} + +Install Grain + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: L78o7eeyGvn5 +outputId: cb0ce6cf-243b-4183-8f63-646e00232caa +--- +!pip install grain +``` + ++++ {"id": "66bH3ZDJ7Iat"} + +Import Required Libraries (import MNIST dataset from torchvision) + +```{code-cell} +:id: mS62eVL9Ifmz + +import numpy as np +import grain.python as pygrain +from torchvision.datasets import MNIST +``` + ++++ {"id": "0h6mwVrspPA-"} + +### Define Dataset Class + +Create a custom dataset class to load MNIST data for Grain. + +```{code-cell} +:id: bnrhac5Hh7y1 + +class Dataset: + def __init__(self, data_dir, train=True): + self.data_dir = data_dir + self.train = train + self.load_data() + + def load_data(self): + self.dataset = MNIST(self.data_dir, download=True, train=self.train) + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, index): + img, label = self.dataset[index] + return np.ravel(np.array(img, dtype=np.float32)), label +``` + ++++ {"id": "53mf8bWEsyTr"} + +### Initialize the Dataset + +```{code-cell} +:id: pN3oF7-ostGE + +mnist_dataset = Dataset(data_dir) +``` + ++++ {"id": "GqD-ycgBuwv9"} + +### Get the full train and test dataset + +```{code-cell} +:id: f1VnTuX3u_kL + +# Convert training data to JAX arrays and encode labels as one-hot vectors +train_images = jnp.array([mnist_dataset[i][0] for i in range(len(mnist_dataset))], dtype=jnp.float32) +train_labels = one_hot(np.array([mnist_dataset[i][1] for i in range(len(mnist_dataset))]), n_targets) + +# Load test dataset and process it +mnist_dataset_test = MNIST(data_dir, download=True, train=False) +test_images = jnp.array([np.ravel(np.array(mnist_dataset_test[i][0], dtype=np.float32)) for i in range(len(mnist_dataset_test))], dtype=jnp.float32) +test_labels = one_hot(np.array([mnist_dataset_test[i][1] for i in range(len(mnist_dataset_test))]), n_targets) +``` + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: a2NHlp9klrQL +outputId: c9422190-55e9-400b-bd4e-0e7bf23dc6a1 +--- +print("Train:", train_images.shape, train_labels.shape) +print("Test:", test_images.shape, test_labels.shape) +``` + ++++ {"id": "1QPbXt7O0JN-"} + +### Initialize PyGrain DataLoader + +```{code-cell} +:id: 2jqd1jJt25Bj + +sampler = pygrain.SequentialSampler( + num_records=len(mnist_dataset), + shard_options=pygrain.NoSharding()) # Single-device, no sharding + +def pygrain_training_generator(): + return pygrain.DataLoader( + data_source=mnist_dataset, + sampler=sampler, + operations=[pygrain.Batch(batch_size)], + ) +``` + ++++ {"id": "mV5z4GLCGKlx"} + +### Training Loop (Grain) + +Run the training loop using the Grain DataLoader. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: 9-iANQ-9CcW_ +outputId: b0e19da2-9e34-4183-c5d8-af66de5efa5c +--- +train_model(num_epochs, params, pygrain_training_generator) +``` + ++++ {"id": "o51P6lr86wz-"} + +## Loading Data with Hugging Face + +This section demonstrates loading MNIST data using the Hugging Face `datasets` library. You'll format the dataset for JAX compatibility, prepare flattened images and one-hot-encoded labels, and define a training generator. + ++++ {"id": "69vrihaOi4Oz"} + +Install the Hugging Face `datasets` library. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: 19ipxPhI6oSN +outputId: b80b80cd-fc14-4a43-f8a8-2802de4faade +--- +!pip install datasets +``` + +```{code-cell} +:id: 8v1N59p76zn0 + +from datasets import load_dataset +``` + ++++ {"id": "8Gaj11tO7C86"} + +Load the MNIST dataset from Hugging Face and format it as `numpy` arrays for quick access or `jax` to get JAX arrays. + +```{code-cell} +:id: a22kTvgk6_fJ + +mnist_dataset = load_dataset("mnist", cache_dir=data_dir).with_format("numpy") +``` + ++++ {"id": "tgI7dIaX7JzM"} + +### Extract images and labels + +Get image shape and flatten for model input. + +```{code-cell} +:id: NHrKatD_7HbH + +train_images = mnist_dataset["train"]["image"] +train_labels = mnist_dataset["train"]["label"] +test_images = mnist_dataset["test"]["image"] +test_labels = mnist_dataset["test"]["label"] + +# Extract image shape +image_shape = train_images.shape[1:] +num_features = image_shape[0] * image_shape[1] + +# Flatten the images +train_images = train_images.reshape(-1, num_features) +test_images = test_images.reshape(-1, num_features) + +# One-hot encode the labels +train_labels = one_hot(train_labels, n_targets) +test_labels = one_hot(test_labels, n_targets) +``` + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: dITh435Z7Nwb +outputId: cc89c1ec-6987-4f1c-90a4-c3b355ea7225 +--- +print('Train:', train_images.shape, train_labels.shape) +print('Test:', test_images.shape, test_labels.shape) +``` + ++++ {"id": "kk_4zJlz7T1E"} + +### Define Training Generator + +Set up a generator to yield batches of images and labels for training. + +```{code-cell} +:id: -zLJhogj7RL- + +def hf_training_generator(): + """Yield batches for training.""" + for batch in mnist_dataset["train"].iter(batch_size): + x, y = batch["image"], batch["label"] + yield x, y +``` + ++++ {"id": "HIsGfkLI7dvZ"} + +### Training Loop (Hugging Face Datasets) + +Run the training loop using the Hugging Face training generator. + +```{code-cell} +--- +colab: + base_uri: https://localhost:8080/ +id: Ui6aLiZP7aLe +outputId: c51529e0-563f-4af0-9793-76b5e6f323db +--- +train_model(num_epochs, params, hf_training_generator) +``` + ++++ {"id": "rCJq2rvKlKWX"} + +## Summary + +This notebook explored efficient methods for loading data on a GPU with JAX, using libraries such as PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. You also learned GPU-specific optimizations, including using `device_put` for data transfer and managing GPU memory, to enhance training efficiency. Each method offers unique benefits, allowing you to choose the best approach based on your project requirements. From 36a0d3e2d82b408bb4aba52623984f9fcd6bb60e Mon Sep 17 00:00:00 2001 From: selamw1 Date: Wed, 4 Dec 2024 14:00:46 -0800 Subject: [PATCH 7/7] old_files_removed_tutorial_updated --- docs/data_loaders_on_cpu_with_jax.ipynb | 3575 ----------------------- docs/data_loaders_on_cpu_with_jax.md | 691 ----- docs/data_loaders_on_gpu_with_jax.ipynb | 1176 -------- docs/data_loaders_on_gpu_with_jax.md | 650 ----- docs/source/tutorials.md | 1 - 5 files changed, 6093 deletions(-) delete mode 100644 docs/data_loaders_on_cpu_with_jax.ipynb delete mode 100644 docs/data_loaders_on_cpu_with_jax.md delete mode 100644 docs/data_loaders_on_gpu_with_jax.ipynb delete mode 100644 docs/data_loaders_on_gpu_with_jax.md diff --git a/docs/data_loaders_on_cpu_with_jax.ipynb b/docs/data_loaders_on_cpu_with_jax.ipynb deleted file mode 100644 index 0ba897e..0000000 --- a/docs/data_loaders_on_cpu_with_jax.ipynb +++ /dev/null @@ -1,3575 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "PUFGZggH49zp" - }, - "source": [ - "# Introduction to Data Loaders on CPU with JAX" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3ia4PKEV5Dr8" - }, - "source": [ - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/data_loaders_on_cpu_with_jax.ipynb)\n", - "\n", - "This tutorial explores different data loading strategies for using **JAX** on a single [**CPU**](https://jax.readthedocs.io/en/latest/glossary.html#term-CPU). While JAX doesn't include a built-in data loader, it seamlessly integrates with popular data loading libraries, including:\n", - "\n", - "- [**PyTorch DataLoader**](https://github.com/pytorch/data)\n", - "- [**TensorFlow Datasets (TFDS)**](https://github.com/tensorflow/datasets)\n", - "- [**Grain**](https://github.com/google/grain)\n", - "- [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)\n", - "\n", - "In this tutorial, you'll learn how to efficiently load data using these libraries for a simple image classification task based on the MNIST dataset.\n", - "\n", - "Compared to GPU or multi-device setups, CPU-based data loading is straightforward as it avoids challenges like GPU memory management and data synchronization across devices. This makes it ideal for smaller-scale tasks or scenarios where data resides exclusively on the CPU.\n", - "\n", - "If you're looking for GPU-specific data loading advice, see [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html).\n", - "\n", - "If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "pEsb135zE-Jo" - }, - "source": [ - "## Setting JAX to Use CPU Only\n", - "\n", - "First, you'll restrict JAX to use only the CPU, even if a GPU is available. This ensures consistency and allows you to focus on CPU-based data loading." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "vqP6xyObC0_9" - }, - "outputs": [], - "source": [ - "import os\n", - "os.environ['JAX_PLATFORM_NAME'] = 'cpu'" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-rsMgVtO6asW" - }, - "source": [ - "Import JAX API" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "tDJNQ6V-Dg5g" - }, - "outputs": [], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "from jax import random, grad, jit, vmap" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TsFdlkSZKp9S" - }, - "source": [ - "### CPU Setup Verification" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "N3sqvaF3KJw1", - "outputId": "449c83d9-d050-4b15-9a8d-f71e340501f2" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[CpuDevice(id=0)]" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.devices()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qyJ_WTghDnIc" - }, - "source": [ - "## Setting Hyperparameters and Initializing Parameters\n", - "\n", - "You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "qLNOSloFDka_" - }, - "outputs": [], - "source": [ - "# A helper function to randomly initialize weights and biases\n", - "# for a dense neural network layer\n", - "def random_layer_params(m, n, key, scale=1e-2):\n", - " w_key, b_key = random.split(key)\n", - " return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))\n", - "\n", - "# Function to initialize network parameters for all layers based on defined sizes\n", - "def init_network_params(sizes, key):\n", - " keys = random.split(key, len(sizes))\n", - " return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]\n", - "\n", - "layer_sizes = [784, 512, 512, 10] # Layers of the network\n", - "step_size = 0.01 # Learning rate for optimization\n", - "num_epochs = 8 # Number of training epochs\n", - "batch_size = 128 # Batch size for training\n", - "n_targets = 10 # Number of classes (digits 0-9)\n", - "num_pixels = 28 * 28 # Input size (MNIST images are 28x28 pixels)\n", - "data_dir = '/tmp/mnist_dataset' # Directory for storing the dataset\n", - "\n", - "# Initialize network parameters using the defined layer sizes and a random seed\n", - "params = init_network_params(layer_sizes, random.PRNGKey(0))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "6Ci_CqW7q6XM" - }, - "source": [ - "## Model Prediction with Auto-Batching\n", - "\n", - "In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.\n", - "\n", - "To efficiently process multiple images simultaneously, you'll use [`vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap), which allows you to vectorize the `predict` function and apply it across a batch of inputs. This technique, called auto-batching, improves computational efficiency by leveraging hardware acceleration." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bKIYPSkvD1QV" - }, - "outputs": [], - "source": [ - "from jax.scipy.special import logsumexp\n", - "\n", - "def relu(x):\n", - " return jnp.maximum(0, x)\n", - "\n", - "def predict(params, image):\n", - " # per-example prediction\n", - " activations = image\n", - " for w, b in params[:-1]:\n", - " outputs = jnp.dot(w, activations) + b\n", - " activations = relu(outputs)\n", - "\n", - " final_w, final_b = params[-1]\n", - " logits = jnp.dot(final_w, activations) + final_b\n", - " return logits - logsumexp(logits)\n", - "\n", - "# Make a batched version of the `predict` function\n", - "batched_predict = vmap(predict, in_axes=(None, 0))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "niTSr34_sDZi" - }, - "source": [ - "## Utility and Loss Functions\n", - "\n", - "You'll now define utility functions for:\n", - "\n", - "- One-hot encoding: Converts class indices to binary vectors.\n", - "- Accuracy calculation: Measures the performance of the model on the dataset.\n", - "- Loss computation: Calculates the difference between predictions and targets.\n", - "\n", - "To optimize performance:\n", - "\n", - "- [`grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) is used to compute gradients of the loss function with respect to network parameters.\n", - "- [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) compiles the update function, enabling faster execution by leveraging JAX's [XLA](https://openxla.org/xla) compilation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sA0a06raEQfS" - }, - "outputs": [], - "source": [ - "import time\n", - "\n", - "def one_hot(x, k, dtype=jnp.float32):\n", - " \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n", - " return jnp.array(x[:, None] == jnp.arange(k), dtype)\n", - "\n", - "def accuracy(params, images, targets):\n", - " \"\"\"Calculate the accuracy of predictions.\"\"\"\n", - " target_class = jnp.argmax(targets, axis=1)\n", - " predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n", - " return jnp.mean(predicted_class == target_class)\n", - "\n", - "def loss(params, images, targets):\n", - " \"\"\"Calculate the loss between predictions and targets.\"\"\"\n", - " preds = batched_predict(params, images)\n", - " return -jnp.mean(preds * targets)\n", - "\n", - "@jit\n", - "def update(params, x, y):\n", - " \"\"\"Update the network parameters using gradient descent.\"\"\"\n", - " grads = grad(loss)(params, x, y)\n", - " return [(w - step_size * dw, b - step_size * db)\n", - " for (w, b), (dw, db) in zip(params, grads)]\n", - "\n", - "def reshape_and_one_hot(x, y):\n", - " \"\"\"Reshape and one-hot encode the inputs.\"\"\"\n", - " x = jnp.reshape(x, (len(x), num_pixels))\n", - " y = one_hot(y, n_targets)\n", - " return x, y\n", - "\n", - "def train_model(num_epochs, params, training_generator, data_loader_type='streamed'):\n", - " \"\"\"Train the model for a given number of epochs.\"\"\"\n", - " for epoch in range(num_epochs):\n", - " start_time = time.time()\n", - " for x, y in training_generator() if data_loader_type == 'streamed' else training_generator:\n", - " x, y = reshape_and_one_hot(x, y)\n", - " params = update(params, x, y)\n", - "\n", - " print(f\"Epoch {epoch + 1} in {time.time() - start_time:.2f} sec: \"\n", - " f\"Train Accuracy: {accuracy(params, train_images, train_labels):.4f}, \"\n", - " f\"Test Accuracy: {accuracy(params, test_images, test_labels):.4f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Hsionp5IYsQ9" - }, - "source": [ - "## Loading Data with PyTorch DataLoader\n", - "\n", - "This section shows how to load the MNIST dataset using PyTorch's DataLoader, convert the data to NumPy arrays, and apply transformations to flatten and cast images." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "jmsfrWrHxIhC", - "outputId": "33dfeada-a763-4d26-f778-a27966e34d55" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.5.1+cu121)\n", - "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.20.1+cu121)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1)\n", - "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", - "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.4.2)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", - "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.10.0)\n", - "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n", - "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.26.4)\n", - "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (11.0.0)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (3.0.2)\n" - ] - } - ], - "source": [ - "!pip install torch torchvision" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kO5_WzwY59gE" - }, - "outputs": [], - "source": [ - "import numpy as np\n", - "from jax.tree_util import tree_map\n", - "from torch.utils import data\n", - "from torchvision.datasets import MNIST" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "6f6qU8PCc143" - }, - "outputs": [], - "source": [ - "def numpy_collate(batch):\n", - " \"\"\"Convert a batch of PyTorch data to NumPy arrays.\"\"\"\n", - " return tree_map(np.asarray, data.default_collate(batch))\n", - "\n", - "class NumpyLoader(data.DataLoader):\n", - " \"\"\"Custom DataLoader to return NumPy arrays from a PyTorch Dataset.\"\"\"\n", - " def __init__(self, dataset, batch_size=1, shuffle=False, **kwargs):\n", - " super().__init__(dataset, batch_size=batch_size, shuffle=shuffle, collate_fn=numpy_collate, **kwargs)\n", - "\n", - "class FlattenAndCast(object):\n", - " \"\"\"Transform class to flatten and cast images to float32.\"\"\"\n", - " def __call__(self, pic):\n", - " return np.ravel(np.array(pic, dtype=jnp.float32))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mfSnfJND6I8G" - }, - "source": [ - "### Load Dataset with Transformations\n", - "\n", - "Standardize the data by flattening the images, casting them to `float32`, and ensuring consistent data types." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Kxbl6bcx6crv", - "outputId": "372bbf4c-3ad5-4fd8-cc5d-27b50f5e4f38" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n", - "Failed to download (trying next):\n", - "HTTP Error 403: Forbidden\n", - "\n", - "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz\n", - "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to /tmp/mnist_dataset/MNIST/raw/train-images-idx3-ubyte.gz\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 9.91M/9.91M [00:00<00:00, 49.4MB/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extracting /tmp/mnist_dataset/MNIST/raw/train-images-idx3-ubyte.gz to /tmp/mnist_dataset/MNIST/raw\n", - "\n", - "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n", - "Failed to download (trying next):\n", - "HTTP Error 403: Forbidden\n", - "\n", - "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz\n", - "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to /tmp/mnist_dataset/MNIST/raw/train-labels-idx1-ubyte.gz\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 28.9k/28.9k [00:00<00:00, 2.09MB/s]" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extracting /tmp/mnist_dataset/MNIST/raw/train-labels-idx1-ubyte.gz to /tmp/mnist_dataset/MNIST/raw\n", - "\n", - "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Failed to download (trying next):\n", - "HTTP Error 403: Forbidden\n", - "\n", - "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz\n", - "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to /tmp/mnist_dataset/MNIST/raw/t10k-images-idx3-ubyte.gz\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 1.65M/1.65M [00:00<00:00, 13.3MB/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extracting /tmp/mnist_dataset/MNIST/raw/t10k-images-idx3-ubyte.gz to /tmp/mnist_dataset/MNIST/raw\n", - "\n", - "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n", - "Failed to download (trying next):\n", - "HTTP Error 403: Forbidden\n", - "\n", - "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz\n", - "Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to /tmp/mnist_dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "100%|██████████| 4.54k/4.54k [00:00<00:00, 8.81MB/s]\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Extracting /tmp/mnist_dataset/MNIST/raw/t10k-labels-idx1-ubyte.gz to /tmp/mnist_dataset/MNIST/raw\n", - "\n" - ] - } - ], - "source": [ - "mnist_dataset = MNIST(data_dir, download=True, transform=FlattenAndCast())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kbdsqvPZGrsa" - }, - "source": [ - "### Full Training Dataset for Accuracy Checks\n", - "\n", - "Convert the entire training dataset to JAX arrays." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "c9ZCJq_rzPck" - }, - "outputs": [], - "source": [ - "train_images = jnp.array(mnist_dataset.data.numpy().reshape(len(mnist_dataset.data), -1), dtype=jnp.float32)\n", - "train_labels = one_hot(np.array(mnist_dataset.targets), n_targets)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WXUh0BwvG8Ko" - }, - "source": [ - "### Get Full Test Dataset\n", - "\n", - "Load and process the full test dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "brlLG4SqGphm" - }, - "outputs": [], - "source": [ - "mnist_dataset_test = MNIST(data_dir, download=True, train=False)\n", - "test_images = jnp.array(mnist_dataset_test.data.numpy().reshape(len(mnist_dataset_test.data), -1), dtype=jnp.float32)\n", - "test_labels = one_hot(np.array(mnist_dataset_test.targets), n_targets)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Oz-UVnCxG5E8", - "outputId": "abbaa26d-491a-4e63-e8c9-d3c571f53a28" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train: (60000, 784) (60000, 10)\n", - "Test: (10000, 784) (10000, 10)\n" - ] - } - ], - "source": [ - "print('Train:', train_images.shape, train_labels.shape)\n", - "print('Test:', test_images.shape, test_labels.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "m3zfxqnMiCbm" - }, - "source": [ - "### Training Data Generator\n", - "\n", - "Define a generator function using PyTorch's DataLoader for batch training. Setting `num_workers > 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload.\n", - "\n", - "Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.` This warning can be safely ignored since data loaders do not use JAX within the forked processes." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "B-fES82EiL6Z" - }, - "outputs": [], - "source": [ - "def pytorch_training_generator(mnist_dataset):\n", - " return NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Xzt2x9S1HC3T" - }, - "source": [ - "### Training Loop (PyTorch DataLoader)\n", - "\n", - "The training loop uses the PyTorch DataLoader to iterate through batches and update model parameters." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "vtUjHsh-rJs8", - "outputId": "4766333e-4366-493b-995a-102778d1345a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1 in 28.93 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9196\n", - "Epoch 2 in 8.33 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9384\n", - "Epoch 3 in 6.99 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9468\n", - "Epoch 4 in 7.01 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532\n", - "Epoch 5 in 8.17 sec: Train Accuracy: 0.9630, Test Accuracy: 0.9579\n", - "Epoch 6 in 8.27 sec: Train Accuracy: 0.9674, Test Accuracy: 0.9615\n", - "Epoch 7 in 8.32 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9650\n", - "Epoch 8 in 8.07 sec: Train Accuracy: 0.9737, Test Accuracy: 0.9671\n" - ] - } - ], - "source": [ - "train_model(num_epochs, params, pytorch_training_generator(mnist_dataset), data_loader_type='iterable')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Nm45ZTo6yrf5" - }, - "source": [ - "## Loading Data with TensorFlow Datasets (TFDS)\n", - "\n", - "This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlow." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sGaQAk1DHMUx" - }, - "outputs": [], - "source": [ - "import tensorflow_datasets as tfds\n", - "import tensorflow as tf\n", - "\n", - "# Ensuring CPU-Only Execution, disable any GPU usage(if applicable) for TF\n", - "tf.config.set_visible_devices([], device_type='GPU')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3xdQY7H6wr3n" - }, - "source": [ - "### Fetch Full Dataset for Evaluation\n", - "\n", - "Load the dataset with `tfds.load`, convert it to NumPy arrays, and process it for evaluation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 104, - "referenced_widgets": [ - "b8cdabf5c05848f38f03850cab08b56f", - "a8b76d5f93004c089676e5a2a9b3336c", - "119ac8428f9441e7a25eb0afef2fbb2a", - "76a9815e5c2b4764a13409cebaf66821", - "45ce8dd5c4b949afa957ec8ffb926060", - "05b7145fd62d4581b2123c7680f11cdd", - "b96267f014814ec5b96ad7e6165104b1", - "bce34bdbfbd64f1f8353a4e8515cee0b", - "93b8206f8c5841a692cdce985ae301d8", - "c95f592620c64da595cc787567b2c4db", - "8a97071f862c4ec3b4b4140d2e34eda2" - ] - }, - "id": "1hOamw_7C8Pb", - "outputId": "ca166490-22db-4732-b29f-866b7593e489" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Downloading and preparing dataset 11.06 MiB (download: 11.06 MiB, generated: 21.00 MiB, total: 32.06 MiB) to /tmp/mnist_dataset/mnist/3.0.1...\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "b8cdabf5c05848f38f03850cab08b56f", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Dl Completed...: 0%| | 0/5 [00:00=9.1.0 in /usr/local/lib/python3.10/dist-packages (from grain) (10.5.0)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from grain) (1.26.4)\n", - "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (4.12.2)\n", - "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (2024.10.0)\n", - "Requirement already satisfied: importlib_resources in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (6.4.5)\n", - "Requirement already satisfied: zipp in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (3.21.0)\n", - "Downloading grain-0.2.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (418 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m419.0/419.0 kB\u001b[0m \u001b[31m7.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading jaxtyping-0.2.36-py3-none-any.whl (55 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m55.8/55.8 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hInstalling collected packages: jaxtyping, grain\n", - "Successfully installed grain-0.2.2 jaxtyping-0.2.36\n" - ] - } - ], - "source": [ - "!pip install grain" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "66bH3ZDJ7Iat" - }, - "source": [ - "Import Required Libraries (import MNIST dataset from torchvision)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "mS62eVL9Ifmz" - }, - "outputs": [], - "source": [ - "import numpy as np\n", - "import grain.python as pygrain\n", - "from torchvision.datasets import MNIST" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0h6mwVrspPA-" - }, - "source": [ - "### Define Dataset Class\n", - "\n", - "Create a custom dataset class to load MNIST data for Grain." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bnrhac5Hh7y1" - }, - "outputs": [], - "source": [ - "class Dataset:\n", - " def __init__(self, data_dir, train=True):\n", - " self.data_dir = data_dir\n", - " self.train = train\n", - " self.load_data()\n", - "\n", - " def load_data(self):\n", - " self.dataset = MNIST(self.data_dir, download=True, train=self.train)\n", - "\n", - " def __len__(self):\n", - " return len(self.dataset)\n", - "\n", - " def __getitem__(self, index):\n", - " img, label = self.dataset[index]\n", - " return np.ravel(np.array(img, dtype=np.float32)), label" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "53mf8bWEsyTr" - }, - "source": [ - "### Initialize the Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "pN3oF7-ostGE" - }, - "outputs": [], - "source": [ - "mnist_dataset = Dataset(data_dir)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GqD-ycgBuwv9" - }, - "source": [ - "### Get the full train and test dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "f1VnTuX3u_kL" - }, - "outputs": [], - "source": [ - "# Convert training data to JAX arrays and encode labels as one-hot vectors\n", - "train_images = jnp.array([mnist_dataset[i][0] for i in range(len(mnist_dataset))], dtype=jnp.float32)\n", - "train_labels = one_hot(np.array([mnist_dataset[i][1] for i in range(len(mnist_dataset))]), n_targets)\n", - "\n", - "# Load test dataset and process it\n", - "mnist_dataset_test = MNIST(data_dir, download=True, train=False)\n", - "test_images = jnp.array([np.ravel(np.array(mnist_dataset_test[i][0], dtype=np.float32)) for i in range(len(mnist_dataset_test))], dtype=jnp.float32)\n", - "test_labels = one_hot(np.array([mnist_dataset_test[i][1] for i in range(len(mnist_dataset_test))]), n_targets)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "a2NHlp9klrQL", - "outputId": "14be58c0-851e-4a44-dfcc-d02f0718dab5" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train: (60000, 784) (60000, 10)\n", - "Test: (10000, 784) (10000, 10)\n" - ] - } - ], - "source": [ - "print(\"Train:\", train_images.shape, train_labels.shape)\n", - "print(\"Test:\", test_images.shape, test_labels.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "fETnWRo2crhf" - }, - "source": [ - "### Initialize PyGrain DataLoader\n", - "\n", - "Set up a PyGrain DataLoader for sequential batch sampling." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "9RuFTcsCs2Ac" - }, - "outputs": [], - "source": [ - "sampler = pygrain.SequentialSampler(\n", - " num_records=len(mnist_dataset),\n", - " shard_options=pygrain.NoSharding()) # Single-device, no sharding\n", - "\n", - "def pygrain_training_generator():\n", - " \"\"\"Grain DataLoader generator for training.\"\"\"\n", - " return pygrain.DataLoader(\n", - " data_source=mnist_dataset,\n", - " sampler=sampler,\n", - " operations=[pygrain.Batch(batch_size)],\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GvpJPHAbeuHW" - }, - "source": [ - "### Training Loop (Grain)\n", - "\n", - "Run the training loop using the Grain DataLoader." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "cjxJRtiTadEI", - "outputId": "3f624366-b683-4d20-9d0a-777d345b0e21" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1 in 15.39 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9196\n", - "Epoch 2 in 15.27 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9384\n", - "Epoch 3 in 12.61 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9468\n", - "Epoch 4 in 12.62 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532\n", - "Epoch 5 in 12.39 sec: Train Accuracy: 0.9630, Test Accuracy: 0.9579\n", - "Epoch 6 in 12.19 sec: Train Accuracy: 0.9674, Test Accuracy: 0.9615\n", - "Epoch 7 in 12.56 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9650\n", - "Epoch 8 in 13.04 sec: Train Accuracy: 0.9737, Test Accuracy: 0.9671\n" - ] - } - ], - "source": [ - "train_model(num_epochs, params, pygrain_training_generator)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "oixvOI816qUn" - }, - "source": [ - "## Loading Data with Hugging Face\n", - "\n", - "This section demonstrates loading MNIST data using the Hugging Face `datasets` library. You'll format the dataset for JAX compatibility, prepare flattened images and one-hot-encoded labels, and define a training generator." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "o51P6lr86wz-" - }, - "source": [ - "Install the Hugging Face `datasets` library." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "19ipxPhI6oSN", - "outputId": "684e445f-d23e-4924-9e76-2c2c9359f0be" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Collecting datasets\n", - " Downloading datasets-3.1.0-py3-none-any.whl.metadata (20 kB)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.1)\n", - "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n", - "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n", - "Collecting dill<0.3.9,>=0.3.0 (from datasets)\n", - " Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)\n", - "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n", - "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n", - "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.6)\n", - "Collecting xxhash (from datasets)\n", - " Downloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)\n", - "Collecting multiprocess<0.70.17 (from datasets)\n", - " Downloading multiprocess-0.70.16-py310-none-any.whl.metadata (7.2 kB)\n", - "Collecting fsspec<=2024.9.0,>=2023.1.0 (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets)\n", - " Downloading fsspec-2024.9.0-py3-none-any.whl.metadata (11 kB)\n", - "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.11.2)\n", - "Requirement already satisfied: huggingface-hub>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.26.2)\n", - "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.2)\n", - "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n", - "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.3)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", - "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.5.0)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n", - "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (0.2.0)\n", - "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.17.2)\n", - "Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", - "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.23.0->datasets) (4.12.2)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.4.0)\n", - "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.10)\n", - "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.2.3)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.8.30)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", - "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", - "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", - "Downloading datasets-3.1.0-py3-none-any.whl (480 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m480.6/480.6 kB\u001b[0m \u001b[31m8.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading dill-0.3.8-py3-none-any.whl (116 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m116.3/116.3 kB\u001b[0m \u001b[31m9.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading fsspec-2024.9.0-py3-none-any.whl (179 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m179.3/179.3 kB\u001b[0m \u001b[31m13.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading multiprocess-0.70.16-py310-none-any.whl (134 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m134.8/134.8 kB\u001b[0m \u001b[31m9.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hDownloading xxhash-3.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (194 kB)\n", - "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m194.1/194.1 kB\u001b[0m \u001b[31m15.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", - "\u001b[?25hInstalling collected packages: xxhash, fsspec, dill, multiprocess, datasets\n", - " Attempting uninstall: fsspec\n", - " Found existing installation: fsspec 2024.10.0\n", - " Uninstalling fsspec-2024.10.0:\n", - " Successfully uninstalled fsspec-2024.10.0\n", - "\u001b[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n", - "gcsfs 2024.10.0 requires fsspec==2024.10.0, but you have fsspec 2024.9.0 which is incompatible.\u001b[0m\u001b[31m\n", - "\u001b[0mSuccessfully installed datasets-3.1.0 dill-0.3.8 fsspec-2024.9.0 multiprocess-0.70.16 xxhash-3.5.0\n" - ] - } - ], - "source": [ - "!pip install datasets" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "be0h_dZv0593" - }, - "source": [ - "Import Library" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "8v1N59p76zn0" - }, - "outputs": [], - "source": [ - "from datasets import load_dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8Gaj11tO7C86" - }, - "source": [ - "### Load and Format MNIST Dataset\n", - "\n", - "Load the MNIST dataset from Hugging Face and format it as `numpy` arrays for quick access or `jax` to get JAX arrays." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/", - "height": 301, - "referenced_widgets": [ - "32f6132a31aa4c508d3c3c5ef70348bb", - "d7c2ffa6b143463c91cbf8befca6ca01", - "fd964ecd3926419d92927c67f955d5d0", - "60feca3fde7c4447ad8393b0542eb999", - "3354a0baeca94d18bc6b2a8b8b465b58", - "a0d0d052772b46deac7657ad052991a4", - "fb34783b9cba462e9b690e0979c4b07a", - "8d8170c1ed99490589969cd753c40748", - "f1ecb6db00a54e088f1e09164222d637", - "3cf5dd8d29aa4619b39dc2542df7e42e", - "2e5d42ca710441b389895f2d3b611d0a", - "5d8202da24244dc896e9a8cba6a4ed4f", - "a6d64c953631412b8bd8f0ba53ae4d32", - "69240c5cbfbb4e91961f5b49812a26f0", - "865f38532b784a7c971f5d33b87b443e", - "ceb1c004191947cdaa10af9b9c03c80d", - "64c6041037914779b5e8e9cf5a80ad04", - "562fa6a0e7b846a180ac4b423c5511c5", - "b3b922288f9c4df2a4088279ff6d1531", - "75a1a8ffda554318890cf74c345ed9a9", - "3bae06cacf394a5998c2326199da94f5", - "ff6428a3daa5496c81d5e664aba01f97", - "1ba3f86870724f55b94a35cb6b4173af", - "b3e163fd8b8a4f289d5a25611cb66d23", - "abd2daba215e4f7c9ddabde04d6eb382", - "e22ee019049144d5aba573cdf4dbe4fc", - "6ac765dac67841a69218140785f024c6", - "7b057411a54e434fb74804b90daa8d44", - "563f71b3c67d47c3ab1100f5dc1b98f3", - "d81a657361ab4bba8bcc0cf309d2ff64", - "20316312ab88471ba90cbb954be3e964", - "698fda742f834473a23fb7e5e4cf239c", - "289b52c5a38146b8b467a5f4678f6271", - "d07c2f37cf914894b1551a8104e6cb70", - "5b55c73d551d483baaa6a1411c2597b1", - "2308f77723f54ac898588f48d1853b65", - "54d2589714d04b2e928b816258cb0df4", - "f84b795348c04c7a950165301a643671", - "bc853a4a8d3c4dbda23d183f0a3b4f27", - "1012ddc0343842d8b913a7d85df8ab8f", - "771a73a8f5084a57afc5654d72e022f0", - "311a43449f074841b6df4130b0871ac9", - "cd4d29cb01134469b52d6936c35eb943", - "013cf89ee6174d29bb3f4fdff7b36049", - "9237d877d84e4b3ab69698ecf56915bb", - "337ef4d37e6b4ff6bf6e8bd4ca93383f", - "b4096d3837b84ccdb8f1186435c87281", - "7259d3b7e11b4736b4d2aa8e9c55e994", - "1ad1f8e99a864fc4a2bc532d9a4ff110", - "b2b50451eabd40978ef46db5e7dd08c4", - "2dad5c5541e243128e23c3dd3e420ac2", - "a3de458b61e5493081d6bb9cf7e923db", - "37760f8a7b164e6f9c1a23d621e9fe6b", - "745a2aedcfab491fb9cffba19958b0c5", - "2f6c670640d048d2af453638cfde3a1e" - ] - }, - "id": "a22kTvgk6_fJ", - "outputId": "35fc38b9-a6ab-4b02-ffa4-ab27fac69df4" - }, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_auth.py:94: UserWarning: \n", - "The secret `HF_TOKEN` does not exist in your Colab secrets.\n", - "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n", - "You will be able to reuse this secret in all of your notebooks.\n", - "Please note that authentication is recommended but still optional to access public models or datasets.\n", - " warnings.warn(\n" - ] - }, - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "32f6132a31aa4c508d3c3c5ef70348bb", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "README.md: 0%| | 0.00/6.97k [00:00 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload. - -Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.` This warning can be safely ignored since data loaders do not use JAX within the forked processes. - -```{code-cell} -:id: B-fES82EiL6Z - -def pytorch_training_generator(mnist_dataset): - return NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0) -``` - -+++ {"id": "Xzt2x9S1HC3T"} - -### Training Loop (PyTorch DataLoader) - -The training loop uses the PyTorch DataLoader to iterate through batches and update model parameters. - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: vtUjHsh-rJs8 -outputId: 4766333e-4366-493b-995a-102778d1345a ---- -train_model(num_epochs, params, pytorch_training_generator(mnist_dataset), data_loader_type='iterable') -``` - -+++ {"id": "Nm45ZTo6yrf5"} - -## Loading Data with TensorFlow Datasets (TFDS) - -This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlow. - -```{code-cell} -:id: sGaQAk1DHMUx - -import tensorflow_datasets as tfds -import tensorflow as tf - -# Ensuring CPU-Only Execution, disable any GPU usage(if applicable) for TF -tf.config.set_visible_devices([], device_type='GPU') -``` - -+++ {"id": "3xdQY7H6wr3n"} - -### Fetch Full Dataset for Evaluation - -Load the dataset with `tfds.load`, convert it to NumPy arrays, and process it for evaluation. - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ - height: 104 - referenced_widgets: [b8cdabf5c05848f38f03850cab08b56f, a8b76d5f93004c089676e5a2a9b3336c, - 119ac8428f9441e7a25eb0afef2fbb2a, 76a9815e5c2b4764a13409cebaf66821, 45ce8dd5c4b949afa957ec8ffb926060, - 05b7145fd62d4581b2123c7680f11cdd, b96267f014814ec5b96ad7e6165104b1, bce34bdbfbd64f1f8353a4e8515cee0b, - 93b8206f8c5841a692cdce985ae301d8, c95f592620c64da595cc787567b2c4db, 8a97071f862c4ec3b4b4140d2e34eda2] -id: 1hOamw_7C8Pb -outputId: ca166490-22db-4732-b29f-866b7593e489 ---- -# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1) -mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True) -mnist_data = tfds.as_numpy(mnist_data) -train_data, test_data = mnist_data['train'], mnist_data['test'] - -# Full train set -train_images, train_labels = train_data['image'], train_data['label'] -train_images = jnp.reshape(train_images, (len(train_images), num_pixels)) -train_labels = one_hot(train_labels, n_targets) - -# Full test set -test_images, test_labels = test_data['image'], test_data['label'] -test_images = jnp.reshape(test_images, (len(test_images), num_pixels)) -test_labels = one_hot(test_labels, n_targets) -``` - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: Td3PiLdmEf7z -outputId: 96403b0f-6079-43ce-df16-d4583f09906b ---- -print('Train:', train_images.shape, train_labels.shape) -print('Test:', test_images.shape, test_labels.shape) -``` - -+++ {"id": "UWRSaalfdyDX"} - -### Define the Training Generator - -Create a generator function to yield batches of data for training. - -```{code-cell} -:id: vX59u8CqEf4J - -def training_generator(): - # as_supervised=True gives us the (image, label) as a tuple instead of a dict - ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir) - # You can build up an arbitrary tf.data input pipeline - ds = ds.batch(batch_size).prefetch(1) - # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays - return tfds.as_numpy(ds) -``` - -+++ {"id": "EAWeUdnuFNBY"} - -### Training Loop (TFDS) - -Use the training generator in a custom training loop. - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: h2sO13XDGvq1 -outputId: a150246e-ceb5-46ac-db71-2a8177a9d04d ---- -train_model(num_epochs, params, training_generator) -``` - -+++ {"id": "-ryVkrAITS9Z"} - -## Loading Data with Grain - -This section demonstrates how to load MNIST data using Grain, a data-loading library. You'll define a custom dataset class for Grain and set up a Grain DataLoader for efficient training. - -+++ {"id": "waYhUMUGmhH-"} - -Install Grain - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: L78o7eeyGvn5 -outputId: 76d16565-0d9e-4f5f-c6b1-4cf4a683d0e7 ---- -!pip install grain -``` - -+++ {"id": "66bH3ZDJ7Iat"} - -Import Required Libraries (import MNIST dataset from torchvision) - -```{code-cell} -:id: mS62eVL9Ifmz - -import numpy as np -import grain.python as pygrain -from torchvision.datasets import MNIST -``` - -+++ {"id": "0h6mwVrspPA-"} - -### Define Dataset Class - -Create a custom dataset class to load MNIST data for Grain. - -```{code-cell} -:id: bnrhac5Hh7y1 - -class Dataset: - def __init__(self, data_dir, train=True): - self.data_dir = data_dir - self.train = train - self.load_data() - - def load_data(self): - self.dataset = MNIST(self.data_dir, download=True, train=self.train) - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, index): - img, label = self.dataset[index] - return np.ravel(np.array(img, dtype=np.float32)), label -``` - -+++ {"id": "53mf8bWEsyTr"} - -### Initialize the Dataset - -```{code-cell} -:id: pN3oF7-ostGE - -mnist_dataset = Dataset(data_dir) -``` - -+++ {"id": "GqD-ycgBuwv9"} - -### Get the full train and test dataset - -```{code-cell} -:id: f1VnTuX3u_kL - -# Convert training data to JAX arrays and encode labels as one-hot vectors -train_images = jnp.array([mnist_dataset[i][0] for i in range(len(mnist_dataset))], dtype=jnp.float32) -train_labels = one_hot(np.array([mnist_dataset[i][1] for i in range(len(mnist_dataset))]), n_targets) - -# Load test dataset and process it -mnist_dataset_test = MNIST(data_dir, download=True, train=False) -test_images = jnp.array([np.ravel(np.array(mnist_dataset_test[i][0], dtype=np.float32)) for i in range(len(mnist_dataset_test))], dtype=jnp.float32) -test_labels = one_hot(np.array([mnist_dataset_test[i][1] for i in range(len(mnist_dataset_test))]), n_targets) -``` - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: a2NHlp9klrQL -outputId: 14be58c0-851e-4a44-dfcc-d02f0718dab5 ---- -print("Train:", train_images.shape, train_labels.shape) -print("Test:", test_images.shape, test_labels.shape) -``` - -+++ {"id": "fETnWRo2crhf"} - -### Initialize PyGrain DataLoader - -Set up a PyGrain DataLoader for sequential batch sampling. - -```{code-cell} -:id: 9RuFTcsCs2Ac - -sampler = pygrain.SequentialSampler( - num_records=len(mnist_dataset), - shard_options=pygrain.NoSharding()) # Single-device, no sharding - -def pygrain_training_generator(): - """Grain DataLoader generator for training.""" - return pygrain.DataLoader( - data_source=mnist_dataset, - sampler=sampler, - operations=[pygrain.Batch(batch_size)], - ) -``` - -+++ {"id": "GvpJPHAbeuHW"} - -### Training Loop (Grain) - -Run the training loop using the Grain DataLoader. - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: cjxJRtiTadEI -outputId: 3f624366-b683-4d20-9d0a-777d345b0e21 ---- -train_model(num_epochs, params, pygrain_training_generator) -``` - -+++ {"id": "oixvOI816qUn"} - -## Loading Data with Hugging Face - -This section demonstrates loading MNIST data using the Hugging Face `datasets` library. You'll format the dataset for JAX compatibility, prepare flattened images and one-hot-encoded labels, and define a training generator. - -+++ {"id": "o51P6lr86wz-"} - -Install the Hugging Face `datasets` library. - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: 19ipxPhI6oSN -outputId: 684e445f-d23e-4924-9e76-2c2c9359f0be ---- -!pip install datasets -``` - -+++ {"id": "be0h_dZv0593"} - -Import Library - -```{code-cell} -:id: 8v1N59p76zn0 - -from datasets import load_dataset -``` - -+++ {"id": "8Gaj11tO7C86"} - -### Load and Format MNIST Dataset - -Load the MNIST dataset from Hugging Face and format it as `numpy` arrays for quick access or `jax` to get JAX arrays. - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ - height: 301 - referenced_widgets: [32f6132a31aa4c508d3c3c5ef70348bb, d7c2ffa6b143463c91cbf8befca6ca01, - fd964ecd3926419d92927c67f955d5d0, 60feca3fde7c4447ad8393b0542eb999, 3354a0baeca94d18bc6b2a8b8b465b58, - a0d0d052772b46deac7657ad052991a4, fb34783b9cba462e9b690e0979c4b07a, 8d8170c1ed99490589969cd753c40748, - f1ecb6db00a54e088f1e09164222d637, 3cf5dd8d29aa4619b39dc2542df7e42e, 2e5d42ca710441b389895f2d3b611d0a, - 5d8202da24244dc896e9a8cba6a4ed4f, a6d64c953631412b8bd8f0ba53ae4d32, 69240c5cbfbb4e91961f5b49812a26f0, - 865f38532b784a7c971f5d33b87b443e, ceb1c004191947cdaa10af9b9c03c80d, 64c6041037914779b5e8e9cf5a80ad04, - 562fa6a0e7b846a180ac4b423c5511c5, b3b922288f9c4df2a4088279ff6d1531, 75a1a8ffda554318890cf74c345ed9a9, - 3bae06cacf394a5998c2326199da94f5, ff6428a3daa5496c81d5e664aba01f97, 1ba3f86870724f55b94a35cb6b4173af, - b3e163fd8b8a4f289d5a25611cb66d23, abd2daba215e4f7c9ddabde04d6eb382, e22ee019049144d5aba573cdf4dbe4fc, - 6ac765dac67841a69218140785f024c6, 7b057411a54e434fb74804b90daa8d44, 563f71b3c67d47c3ab1100f5dc1b98f3, - d81a657361ab4bba8bcc0cf309d2ff64, 20316312ab88471ba90cbb954be3e964, 698fda742f834473a23fb7e5e4cf239c, - 289b52c5a38146b8b467a5f4678f6271, d07c2f37cf914894b1551a8104e6cb70, 5b55c73d551d483baaa6a1411c2597b1, - 2308f77723f54ac898588f48d1853b65, 54d2589714d04b2e928b816258cb0df4, f84b795348c04c7a950165301a643671, - bc853a4a8d3c4dbda23d183f0a3b4f27, 1012ddc0343842d8b913a7d85df8ab8f, 771a73a8f5084a57afc5654d72e022f0, - 311a43449f074841b6df4130b0871ac9, cd4d29cb01134469b52d6936c35eb943, 013cf89ee6174d29bb3f4fdff7b36049, - 9237d877d84e4b3ab69698ecf56915bb, 337ef4d37e6b4ff6bf6e8bd4ca93383f, b4096d3837b84ccdb8f1186435c87281, - 7259d3b7e11b4736b4d2aa8e9c55e994, 1ad1f8e99a864fc4a2bc532d9a4ff110, b2b50451eabd40978ef46db5e7dd08c4, - 2dad5c5541e243128e23c3dd3e420ac2, a3de458b61e5493081d6bb9cf7e923db, 37760f8a7b164e6f9c1a23d621e9fe6b, - 745a2aedcfab491fb9cffba19958b0c5, 2f6c670640d048d2af453638cfde3a1e] -id: a22kTvgk6_fJ -outputId: 35fc38b9-a6ab-4b02-ffa4-ab27fac69df4 ---- -mnist_dataset = load_dataset("mnist").with_format("numpy") -``` - -+++ {"id": "IFjTyGxY19b0"} - -### Extract images and labels - -Get image shape and flatten for model input - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: NHrKatD_7HbH -outputId: deec1739-2fc0-4e71-8567-f2e0c9db198b ---- -train_images = mnist_dataset["train"]["image"] -train_labels = mnist_dataset["train"]["label"] -test_images = mnist_dataset["test"]["image"] -test_labels = mnist_dataset["test"]["label"] - -# Flatten images and one-hot encode labels -image_shape = train_images.shape[1:] -num_features = image_shape[0] * image_shape[1] - -train_images = train_images.reshape(-1, num_features) -test_images = test_images.reshape(-1, num_features) - -train_labels = one_hot(train_labels, n_targets) -test_labels = one_hot(test_labels, n_targets) - -print('Train:', train_images.shape, train_labels.shape) -print('Test:', test_images.shape, test_labels.shape) -``` - -+++ {"id": "kk_4zJlz7T1E"} - -### Define Training Generator - -Set up a generator to yield batches of images and labels for training. - -```{code-cell} -:id: -zLJhogj7RL- - -def hf_training_generator(): - """Yield batches for training.""" - for batch in mnist_dataset["train"].iter(batch_size): - x, y = batch["image"], batch["label"] - yield x, y -``` - -+++ {"id": "HIsGfkLI7dvZ"} - -### Training Loop (Hugging Face Datasets) - -Run the training loop using the Hugging Face training generator. - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: RhloYGsw6nPf -outputId: d49c1cd2-a546-46a6-84fb-d9507c38f4ca ---- -train_model(num_epochs, params, hf_training_generator) -``` - -+++ {"id": "qXylIOwidWI3"} - -## Summary - -This notebook has introduced efficient strategies for data loading on a CPU with JAX, demonstrating how to integrate popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. Each library offers distinct advantages, enabling you to streamline the data loading process for machine learning tasks. By understanding the strengths of these methods, you can select the approach that best suits your project's specific requirements. diff --git a/docs/data_loaders_on_gpu_with_jax.ipynb b/docs/data_loaders_on_gpu_with_jax.ipynb deleted file mode 100644 index 40c8ddc..0000000 --- a/docs/data_loaders_on_gpu_with_jax.ipynb +++ /dev/null @@ -1,1176 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "metadata": { - "id": "PUFGZggH49zp" - }, - "source": [ - "# Introduction to Data Loaders on GPU with JAX" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "3ia4PKEV5Dr8" - }, - "source": [ - "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/data_loaders_on_gpu_with_jax.ipynb)\n", - "\n", - "This tutorial explores different data loading strategies for using **JAX** on a single [**GPU**](https://jax.readthedocs.io/en/latest/glossary.html#term-GPU). While JAX doesn't include a built-in data loader, it seamlessly integrates with popular data loading libraries, including:\n", - "* [**PyTorch DataLoader**](https://github.com/pytorch/data)\n", - "* [**TensorFlow Datasets (TFDS)**](https://github.com/tensorflow/datasets)\n", - "* [**Grain**](https://github.com/google/grain)\n", - "* [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)\n", - "\n", - "You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset.\n", - "\n", - "Compared to [CPU-based loading](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html), working with a GPU introduces specific considerations like transferring data to the GPU using `device_put`, managing larger batch sizes for faster processing, and efficiently utilizing GPU memory. Unlike multi-device setups, this guide focuses on optimizing data handling for a single GPU.\n", - "\n", - "\n", - "If you're looking for CPU-specific data loading advice, see [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html).\n", - "\n", - "If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html)." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-rsMgVtO6asW" - }, - "source": [ - "## Import JAX API" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "tDJNQ6V-Dg5g" - }, - "outputs": [], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "from jax import grad, jit, vmap, random, device_put" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "TsFdlkSZKp9S" - }, - "source": [ - "## Checking GPU Availability for JAX" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "N3sqvaF3KJw1", - "outputId": "ab40f542-b8c0-422c-ca68-4ce292817889" - }, - "outputs": [ - { - "data": { - "text/plain": [ - "[CudaDevice(id=0)]" - ] - }, - "execution_count": 36, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "jax.devices()" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "qyJ_WTghDnIc" - }, - "source": [ - "## Setting Hyperparameters and Initializing Parameters\n", - "\n", - "You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "qLNOSloFDka_" - }, - "outputs": [], - "source": [ - "# A helper function to randomly initialize weights and biases\n", - "# for a dense neural network layer\n", - "def random_layer_params(m, n, key, scale=1e-2):\n", - " w_key, b_key = random.split(key)\n", - " return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,))\n", - "\n", - "# Function to initialize network parameters for all layers based on defined sizes\n", - "def init_network_params(sizes, key):\n", - " keys = random.split(key, len(sizes))\n", - " return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)]\n", - "\n", - "layer_sizes = [784, 512, 512, 10] # Layers of the network\n", - "step_size = 0.01 # Learning rate\n", - "num_epochs = 8 # Number of training epochs\n", - "batch_size = 128 # Batch size for training\n", - "n_targets = 10 # Number of classes (digits 0-9)\n", - "num_pixels = 28 * 28 # Each MNIST image is 28x28 pixels\n", - "data_dir = '/tmp/mnist_dataset' # Directory for storing the dataset\n", - "\n", - "# Initialize network parameters using the defined layer sizes and a random seed\n", - "params = init_network_params(layer_sizes, random.PRNGKey(0))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rHLdqeI7D2WZ" - }, - "source": [ - "## Model Prediction with Auto-Batching\n", - "\n", - "In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.\n", - "\n", - "To efficiently process multiple images simultaneously, you'll use [`vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap), which allows you to vectorize the `predict` function and apply it across a batch of inputs. This technique, called auto-batching, improves computational efficiency by leveraging hardware acceleration." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bKIYPSkvD1QV" - }, - "outputs": [], - "source": [ - "from jax.scipy.special import logsumexp\n", - "\n", - "def relu(x):\n", - " return jnp.maximum(0, x)\n", - "\n", - "def predict(params, image):\n", - " # per-example predictions\n", - " activations = image\n", - " for w, b in params[:-1]:\n", - " outputs = jnp.dot(w, activations) + b\n", - " activations = relu(outputs)\n", - "\n", - " final_w, final_b = params[-1]\n", - " logits = jnp.dot(final_w, activations) + final_b\n", - " return logits - logsumexp(logits)\n", - "\n", - "# Make a batched version of the `predict` function\n", - "batched_predict = vmap(predict, in_axes=(None, 0))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rLqfeORsERek" - }, - "source": [ - "## Utility and Loss Functions\n", - "\n", - "You'll now define utility functions for:\n", - "- One-hot encoding: Converts class indices to binary vectors.\n", - "- Accuracy calculation: Measures the performance of the model on the dataset.\n", - "- Loss computation: Calculates the difference between predictions and targets.\n", - "\n", - "To optimize performance:\n", - "- [`grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) is used to compute gradients of the loss function with respect to network parameters.\n", - "- [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) compiles the update function, enabling faster execution by leveraging JAX's [XLA](https://openxla.org/xla) compilation.\n", - "\n", - "- [`device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html) to transfer the dataset to the GPU." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sA0a06raEQfS" - }, - "outputs": [], - "source": [ - "import time\n", - "\n", - "def one_hot(x, k, dtype=jnp.float32):\n", - " \"\"\"Create a one-hot encoding of x of size k.\"\"\"\n", - " return jnp.array(x[:, None] == jnp.arange(k), dtype)\n", - "\n", - "def accuracy(params, images, targets):\n", - " \"\"\"Calculate the accuracy of predictions.\"\"\"\n", - " target_class = jnp.argmax(targets, axis=1)\n", - " predicted_class = jnp.argmax(batched_predict(params, images), axis=1)\n", - " return jnp.mean(predicted_class == target_class)\n", - "\n", - "def loss(params, images, targets):\n", - " \"\"\"Calculate the loss between predictions and targets.\"\"\"\n", - " preds = batched_predict(params, images)\n", - " return -jnp.mean(preds * targets)\n", - "\n", - "@jit\n", - "def update(params, x, y):\n", - " \"\"\"Update the network parameters using gradient descent.\"\"\"\n", - " grads = grad(loss)(params, x, y)\n", - " return [(w - step_size * dw, b - step_size * db)\n", - " for (w, b), (dw, db) in zip(params, grads)]\n", - "\n", - "def reshape_and_one_hot(x, y):\n", - " \"\"\"Reshape and one-hot encode the inputs.\"\"\"\n", - " x = jnp.reshape(x, (len(x), num_pixels))\n", - " y = one_hot(y, n_targets)\n", - " return x, y\n", - "\n", - "def train_model(num_epochs, params, training_generator, data_loader_type='streamed'):\n", - " \"\"\"Train the model for a given number of epochs and device_put for GPU transfer.\"\"\"\n", - " for epoch in range(num_epochs):\n", - " start_time = time.time()\n", - " for x, y in training_generator() if data_loader_type == 'streamed' else training_generator:\n", - " x, y = reshape_and_one_hot(x, y)\n", - " x, y = device_put(x), device_put(y)\n", - " params = update(params, x, y)\n", - "\n", - " print(f\"Epoch {epoch + 1} in {time.time() - start_time:.2f} sec: \"\n", - " f\"Train Accuracy: {accuracy(params, train_images, train_labels):.4f}, \"\n", - " f\"Test Accuracy: {accuracy(params, test_images, test_labels):.4f}\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Hsionp5IYsQ9" - }, - "source": [ - "## Loading Data with PyTorch DataLoader\n", - "\n", - "This section shows how to load the MNIST dataset using PyTorch's DataLoader, convert the data to NumPy arrays, and apply transformations to flatten and cast images." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "uA7XY0OezHse", - "outputId": "4c86f455-ff1d-474e-f8e3-7111d9b56996" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: torch in /usr/local/lib/python3.10/dist-packages (2.5.1+cu121)\n", - "Requirement already satisfied: torchvision in /usr/local/lib/python3.10/dist-packages (0.20.1+cu121)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from torch) (3.16.1)\n", - "Requirement already satisfied: typing-extensions>=4.8.0 in /usr/local/lib/python3.10/dist-packages (from torch) (4.12.2)\n", - "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch) (3.4.2)\n", - "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch) (3.1.4)\n", - "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from torch) (2024.9.0)\n", - "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.10/dist-packages (from torch) (1.13.1)\n", - "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.10/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from torchvision) (1.26.4)\n", - "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.10/dist-packages (from torchvision) (11.0.0)\n", - "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch) (3.0.2)\n" - ] - } - ], - "source": [ - "!pip install torch torchvision" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "kO5_WzwY59gE" - }, - "outputs": [], - "source": [ - "import numpy as np\n", - "from jax.tree_util import tree_map\n", - "from torch.utils import data\n", - "from torchvision.datasets import MNIST" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "6f6qU8PCc143" - }, - "outputs": [], - "source": [ - "def numpy_collate(batch):\n", - " \"\"\"Collate function to convert a batch of PyTorch data into NumPy arrays.\"\"\"\n", - " return tree_map(np.asarray, data.default_collate(batch))\n", - "\n", - "class NumpyLoader(data.DataLoader):\n", - " \"\"\"Custom DataLoader to return NumPy arrays from a PyTorch Dataset.\"\"\"\n", - " def __init__(self, dataset, batch_size=1,\n", - " shuffle=False, sampler=None,\n", - " batch_sampler=None, num_workers=0,\n", - " pin_memory=False, drop_last=False,\n", - " timeout=0, worker_init_fn=None):\n", - " super(self.__class__, self).__init__(dataset,\n", - " batch_size=batch_size,\n", - " shuffle=shuffle,\n", - " sampler=sampler,\n", - " batch_sampler=batch_sampler,\n", - " num_workers=num_workers,\n", - " collate_fn=numpy_collate,\n", - " pin_memory=pin_memory,\n", - " drop_last=drop_last,\n", - " timeout=timeout,\n", - " worker_init_fn=worker_init_fn)\n", - "class FlattenAndCast(object):\n", - " \"\"\"Transform class to flatten and cast images to float32.\"\"\"\n", - " def __call__(self, pic):\n", - " return np.ravel(np.array(pic, dtype=jnp.float32))" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mfSnfJND6I8G" - }, - "source": [ - "### Load Dataset with Transformations\n", - "\n", - "Standardize the data by flattening the images, casting them to `float32`, and ensuring consistent data types." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "Kxbl6bcx6crv" - }, - "outputs": [], - "source": [ - "mnist_dataset = MNIST(data_dir, download=True, transform=FlattenAndCast())" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kbdsqvPZGrsa" - }, - "source": [ - "### Full Training Dataset for Accuracy Checks\n", - "\n", - "Convert the entire training dataset to JAX arrays." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "c9ZCJq_rzPck" - }, - "outputs": [], - "source": [ - "train_images = np.array(mnist_dataset.data).reshape(len(mnist_dataset.data), -1)\n", - "train_labels = one_hot(np.array(mnist_dataset.targets), n_targets)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "WXUh0BwvG8Ko" - }, - "source": [ - "### Get Full Test Dataset\n", - "\n", - "Load and process the full test dataset." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "brlLG4SqGphm" - }, - "outputs": [], - "source": [ - "mnist_dataset_test = MNIST(data_dir, download=True, train=False)\n", - "test_images = jnp.array(mnist_dataset_test.data.numpy().reshape(len(mnist_dataset_test.data), -1), dtype=jnp.float32)\n", - "test_labels = one_hot(np.array(mnist_dataset_test.targets), n_targets)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Oz-UVnCxG5E8", - "outputId": "53f3fb32-5096-4862-e022-3c3a1d82137a" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train: (60000, 784) (60000, 10)\n", - "Test: (10000, 784) (10000, 10)\n" - ] - } - ], - "source": [ - "print('Train:', train_images.shape, train_labels.shape)\n", - "print('Test:', test_images.shape, test_labels.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mNjn9dMPitKL" - }, - "source": [ - "### Training Data Generator\n", - "\n", - "Define a generator function using PyTorch's DataLoader for batch training.\n", - "Setting `num_workers > 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload.\n", - "\n", - "Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.`\n", - "This warning can be safely ignored since data loaders do not use JAX within the forked processes." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "0LdT8P8aisWF" - }, - "outputs": [], - "source": [ - "def pytorch_training_generator(mnist_dataset):\n", - " return NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Xzt2x9S1HC3T" - }, - "source": [ - "### Training Loop (PyTorch DataLoader)\n", - "\n", - "The training loop uses the PyTorch DataLoader to iterate through batches and update model parameters." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "SqweRz_98sN8", - "outputId": "bdd45256-3f5a-48f7-e45c-378078ac4279" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1 in 20.23 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9195\n", - "Epoch 2 in 14.64 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9385\n", - "Epoch 3 in 3.91 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9467\n", - "Epoch 4 in 3.85 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532\n", - "Epoch 5 in 4.48 sec: Train Accuracy: 0.9631, Test Accuracy: 0.9577\n", - "Epoch 6 in 4.03 sec: Train Accuracy: 0.9675, Test Accuracy: 0.9617\n", - "Epoch 7 in 3.86 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9652\n", - "Epoch 8 in 4.57 sec: Train Accuracy: 0.9736, Test Accuracy: 0.9671\n" - ] - } - ], - "source": [ - "train_model(num_epochs, params, pytorch_training_generator(mnist_dataset), data_loader_type='iterable')" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "Nm45ZTo6yrf5" - }, - "source": [ - "## Loading Data with TensorFlow Datasets (TFDS)\n", - "\n", - "This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlow." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "sGaQAk1DHMUx" - }, - "outputs": [], - "source": [ - "import tensorflow_datasets as tfds" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "ZSc5K0Eiwm4L" - }, - "source": [ - "### Fetch Full Dataset for Evaluation\n", - "\n", - "Load the dataset with `tfds.load`, convert it to NumPy arrays, and process it for evaluation." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "1hOamw_7C8Pb" - }, - "outputs": [], - "source": [ - "# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1)\n", - "mnist_data, info = tfds.load(name=\"mnist\", batch_size=-1, data_dir=data_dir, with_info=True)\n", - "mnist_data = tfds.as_numpy(mnist_data)\n", - "train_data, test_data = mnist_data['train'], mnist_data['test']\n", - "\n", - "# Full train set\n", - "train_images, train_labels = train_data['image'], train_data['label']\n", - "train_images = jnp.reshape(train_images, (len(train_images), num_pixels))\n", - "train_labels = one_hot(train_labels, n_targets)\n", - "\n", - "# Full test set\n", - "test_images, test_labels = test_data['image'], test_data['label']\n", - "test_images = jnp.reshape(test_images, (len(test_images), num_pixels))\n", - "test_labels = one_hot(test_labels, n_targets)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Td3PiLdmEf7z", - "outputId": "b8c9a32a-9cf0-4dc3-cb51-db21d32c6545" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train: (60000, 784) (60000, 10)\n", - "Test: (10000, 784) (10000, 10)\n" - ] - } - ], - "source": [ - "print('Train:', train_images.shape, train_labels.shape)\n", - "print('Test:', test_images.shape, test_labels.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "dXMvgk6sdq4j" - }, - "source": [ - "### Define the Training Generator\n", - "\n", - "Create a generator function to yield batches of data for training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "vX59u8CqEf4J" - }, - "outputs": [], - "source": [ - "def training_generator():\n", - " # as_supervised=True gives us the (image, label) as a tuple instead of a dict\n", - " ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir)\n", - " # You can build up an arbitrary tf.data input pipeline\n", - " ds = ds.batch(batch_size).prefetch(1)\n", - " # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays\n", - " return tfds.as_numpy(ds)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "EAWeUdnuFNBY" - }, - "source": [ - "### Training Loop (TFDS)\n", - "\n", - "Use the training generator in a custom training loop." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "h2sO13XDGvq1", - "outputId": "f30805bb-e725-46ee-e053-6e97f2af81c5" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1 in 20.86 sec: Train Accuracy: 0.9253, Test Accuracy: 0.9268\n", - "Epoch 2 in 8.56 sec: Train Accuracy: 0.9428, Test Accuracy: 0.9413\n", - "Epoch 3 in 5.40 sec: Train Accuracy: 0.9532, Test Accuracy: 0.9511\n", - "Epoch 4 in 3.86 sec: Train Accuracy: 0.9598, Test Accuracy: 0.9555\n", - "Epoch 5 in 3.88 sec: Train Accuracy: 0.9652, Test Accuracy: 0.9601\n", - "Epoch 6 in 10.35 sec: Train Accuracy: 0.9692, Test Accuracy: 0.9631\n", - "Epoch 7 in 4.39 sec: Train Accuracy: 0.9726, Test Accuracy: 0.9650\n", - "Epoch 8 in 4.77 sec: Train Accuracy: 0.9753, Test Accuracy: 0.9669\n" - ] - } - ], - "source": [ - "train_model(num_epochs, params, training_generator)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "-ryVkrAITS9Z" - }, - "source": [ - "## Loading Data with Grain\n", - "\n", - "This section demonstrates how to load MNIST data using Grain, a data-loading library. You'll define a custom dataset class for Grain and set up a Grain DataLoader for efficient training." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "waYhUMUGmhH-" - }, - "source": [ - "Install Grain" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "L78o7eeyGvn5", - "outputId": "cb0ce6cf-243b-4183-8f63-646e00232caa" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: grain in /usr/local/lib/python3.10/dist-packages (0.2.2)\n", - "Requirement already satisfied: absl-py in /usr/local/lib/python3.10/dist-packages (from grain) (1.4.0)\n", - "Requirement already satisfied: array-record in /usr/local/lib/python3.10/dist-packages (from grain) (0.5.1)\n", - "Requirement already satisfied: cloudpickle in /usr/local/lib/python3.10/dist-packages (from grain) (3.1.0)\n", - "Requirement already satisfied: dm-tree in /usr/local/lib/python3.10/dist-packages (from grain) (0.1.8)\n", - "Requirement already satisfied: etils[epath,epy] in /usr/local/lib/python3.10/dist-packages (from grain) (1.10.0)\n", - "Requirement already satisfied: jaxtyping in /usr/local/lib/python3.10/dist-packages (from grain) (0.2.36)\n", - "Requirement already satisfied: more-itertools>=9.1.0 in /usr/local/lib/python3.10/dist-packages (from grain) (10.5.0)\n", - "Requirement already satisfied: numpy in /usr/local/lib/python3.10/dist-packages (from grain) (1.26.4)\n", - "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (2024.9.0)\n", - "Requirement already satisfied: importlib_resources in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (6.4.5)\n", - "Requirement already satisfied: typing_extensions in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (4.12.2)\n", - "Requirement already satisfied: zipp in /usr/local/lib/python3.10/dist-packages (from etils[epath,epy]->grain) (3.21.0)\n" - ] - } - ], - "source": [ - "!pip install grain" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "66bH3ZDJ7Iat" - }, - "source": [ - "Import Required Libraries (import MNIST dataset from torchvision)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "mS62eVL9Ifmz" - }, - "outputs": [], - "source": [ - "import numpy as np\n", - "import grain.python as pygrain\n", - "from torchvision.datasets import MNIST" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "0h6mwVrspPA-" - }, - "source": [ - "### Define Dataset Class\n", - "\n", - "Create a custom dataset class to load MNIST data for Grain." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "bnrhac5Hh7y1" - }, - "outputs": [], - "source": [ - "class Dataset:\n", - " def __init__(self, data_dir, train=True):\n", - " self.data_dir = data_dir\n", - " self.train = train\n", - " self.load_data()\n", - "\n", - " def load_data(self):\n", - " self.dataset = MNIST(self.data_dir, download=True, train=self.train)\n", - "\n", - " def __len__(self):\n", - " return len(self.dataset)\n", - "\n", - " def __getitem__(self, index):\n", - " img, label = self.dataset[index]\n", - " return np.ravel(np.array(img, dtype=np.float32)), label" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "53mf8bWEsyTr" - }, - "source": [ - "### Initialize the Dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "pN3oF7-ostGE" - }, - "outputs": [], - "source": [ - "mnist_dataset = Dataset(data_dir)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "GqD-ycgBuwv9" - }, - "source": [ - "### Get the full train and test dataset" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "f1VnTuX3u_kL" - }, - "outputs": [], - "source": [ - "# Convert training data to JAX arrays and encode labels as one-hot vectors\n", - "train_images = jnp.array([mnist_dataset[i][0] for i in range(len(mnist_dataset))], dtype=jnp.float32)\n", - "train_labels = one_hot(np.array([mnist_dataset[i][1] for i in range(len(mnist_dataset))]), n_targets)\n", - "\n", - "# Load test dataset and process it\n", - "mnist_dataset_test = MNIST(data_dir, download=True, train=False)\n", - "test_images = jnp.array([np.ravel(np.array(mnist_dataset_test[i][0], dtype=np.float32)) for i in range(len(mnist_dataset_test))], dtype=jnp.float32)\n", - "test_labels = one_hot(np.array([mnist_dataset_test[i][1] for i in range(len(mnist_dataset_test))]), n_targets)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "a2NHlp9klrQL", - "outputId": "c9422190-55e9-400b-bd4e-0e7bf23dc6a1" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train: (60000, 784) (60000, 10)\n", - "Test: (10000, 784) (10000, 10)\n" - ] - } - ], - "source": [ - "print(\"Train:\", train_images.shape, train_labels.shape)\n", - "print(\"Test:\", test_images.shape, test_labels.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "1QPbXt7O0JN-" - }, - "source": [ - "### Initialize PyGrain DataLoader" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "2jqd1jJt25Bj" - }, - "outputs": [], - "source": [ - "sampler = pygrain.SequentialSampler(\n", - " num_records=len(mnist_dataset),\n", - " shard_options=pygrain.NoSharding()) # Single-device, no sharding\n", - "\n", - "def pygrain_training_generator():\n", - " return pygrain.DataLoader(\n", - " data_source=mnist_dataset,\n", - " sampler=sampler,\n", - " operations=[pygrain.Batch(batch_size)],\n", - " )" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "mV5z4GLCGKlx" - }, - "source": [ - "### Training Loop (Grain)\n", - "\n", - "Run the training loop using the Grain DataLoader." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "9-iANQ-9CcW_", - "outputId": "b0e19da2-9e34-4183-c5d8-af66de5efa5c" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1 in 15.65 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9195\n", - "Epoch 2 in 15.03 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9385\n", - "Epoch 3 in 14.93 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9467\n", - "Epoch 4 in 11.56 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532\n", - "Epoch 5 in 9.33 sec: Train Accuracy: 0.9631, Test Accuracy: 0.9577\n", - "Epoch 6 in 9.31 sec: Train Accuracy: 0.9675, Test Accuracy: 0.9617\n", - "Epoch 7 in 9.78 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9652\n", - "Epoch 8 in 9.80 sec: Train Accuracy: 0.9736, Test Accuracy: 0.9671\n" - ] - } - ], - "source": [ - "train_model(num_epochs, params, pygrain_training_generator)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "o51P6lr86wz-" - }, - "source": [ - "## Loading Data with Hugging Face\n", - "\n", - "This section demonstrates loading MNIST data using the Hugging Face `datasets` library. You'll format the dataset for JAX compatibility, prepare flattened images and one-hot-encoded labels, and define a training generator." - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "69vrihaOi4Oz" - }, - "source": [ - "Install the Hugging Face `datasets` library." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "19ipxPhI6oSN", - "outputId": "b80b80cd-fc14-4a43-f8a8-2802de4faade" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (3.1.0)\n", - "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.16.1)\n", - "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.26.4)\n", - "Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (17.0.0)\n", - "Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8)\n", - "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.2.2)\n", - "Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3)\n", - "Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.6)\n", - "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.5.0)\n", - "Requirement already satisfied: multiprocess<0.70.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16)\n", - "Requirement already satisfied: fsspec<=2024.9.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from fsspec[http]<=2024.9.0,>=2023.1.0->datasets) (2024.9.0)\n", - "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.11.2)\n", - "Requirement already satisfied: huggingface-hub>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.26.2)\n", - "Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.2)\n", - "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.2)\n", - "Requirement already satisfied: aiohappyeyeballs>=2.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (2.4.3)\n", - "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n", - "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (24.2.0)\n", - "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.5.0)\n", - "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.1.0)\n", - "Requirement already satisfied: propcache>=0.2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (0.2.0)\n", - "Requirement already satisfied: yarl<2.0,>=1.17.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.17.2)\n", - "Requirement already satisfied: async-timeout<6.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n", - "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.23.0->datasets) (4.12.2)\n", - "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.4.0)\n", - "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.10)\n", - "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.2.3)\n", - "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.8.30)\n", - "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n", - "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", - "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.2)\n", - "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n" - ] - } - ], - "source": [ - "!pip install datasets" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "8v1N59p76zn0" - }, - "outputs": [], - "source": [ - "from datasets import load_dataset" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "8Gaj11tO7C86" - }, - "source": [ - "Load the MNIST dataset from Hugging Face and format it as `numpy` arrays for quick access or `jax` to get JAX arrays." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "a22kTvgk6_fJ" - }, - "outputs": [], - "source": [ - "mnist_dataset = load_dataset(\"mnist\", cache_dir=data_dir).with_format(\"numpy\")" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "tgI7dIaX7JzM" - }, - "source": [ - "### Extract images and labels\n", - "\n", - "Get image shape and flatten for model input." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "NHrKatD_7HbH" - }, - "outputs": [], - "source": [ - "train_images = mnist_dataset[\"train\"][\"image\"]\n", - "train_labels = mnist_dataset[\"train\"][\"label\"]\n", - "test_images = mnist_dataset[\"test\"][\"image\"]\n", - "test_labels = mnist_dataset[\"test\"][\"label\"]\n", - "\n", - "# Extract image shape\n", - "image_shape = train_images.shape[1:]\n", - "num_features = image_shape[0] * image_shape[1]\n", - "\n", - "# Flatten the images\n", - "train_images = train_images.reshape(-1, num_features)\n", - "test_images = test_images.reshape(-1, num_features)\n", - "\n", - "# One-hot encode the labels\n", - "train_labels = one_hot(train_labels, n_targets)\n", - "test_labels = one_hot(test_labels, n_targets)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "dITh435Z7Nwb", - "outputId": "cc89c1ec-6987-4f1c-90a4-c3b355ea7225" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Train: (60000, 784) (60000, 10)\n", - "Test: (10000, 784) (10000, 10)\n" - ] - } - ], - "source": [ - "print('Train:', train_images.shape, train_labels.shape)\n", - "print('Test:', test_images.shape, test_labels.shape)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "kk_4zJlz7T1E" - }, - "source": [ - "### Define Training Generator\n", - "\n", - "Set up a generator to yield batches of images and labels for training." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "-zLJhogj7RL-" - }, - "outputs": [], - "source": [ - "def hf_training_generator():\n", - " \"\"\"Yield batches for training.\"\"\"\n", - " for batch in mnist_dataset[\"train\"].iter(batch_size):\n", - " x, y = batch[\"image\"], batch[\"label\"]\n", - " yield x, y" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "HIsGfkLI7dvZ" - }, - "source": [ - "### Training Loop (Hugging Face Datasets)\n", - "\n", - "Run the training loop using the Hugging Face training generator." - ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "colab": { - "base_uri": "https://localhost:8080/" - }, - "id": "Ui6aLiZP7aLe", - "outputId": "c51529e0-563f-4af0-9793-76b5e6f323db" - }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Epoch 1 in 19.06 sec: Train Accuracy: 0.9158, Test Accuracy: 0.9195\n", - "Epoch 2 in 8.94 sec: Train Accuracy: 0.9372, Test Accuracy: 0.9385\n", - "Epoch 3 in 5.43 sec: Train Accuracy: 0.9492, Test Accuracy: 0.9467\n", - "Epoch 4 in 6.41 sec: Train Accuracy: 0.9569, Test Accuracy: 0.9532\n", - "Epoch 5 in 5.80 sec: Train Accuracy: 0.9631, Test Accuracy: 0.9577\n", - "Epoch 6 in 6.61 sec: Train Accuracy: 0.9675, Test Accuracy: 0.9617\n", - "Epoch 7 in 5.49 sec: Train Accuracy: 0.9708, Test Accuracy: 0.9652\n", - "Epoch 8 in 6.64 sec: Train Accuracy: 0.9736, Test Accuracy: 0.9671\n" - ] - } - ], - "source": [ - "train_model(num_epochs, params, hf_training_generator)" - ] - }, - { - "cell_type": "markdown", - "metadata": { - "id": "rCJq2rvKlKWX" - }, - "source": [ - "## Summary\n", - "\n", - "This notebook explored efficient methods for loading data on a GPU with JAX, using libraries such as PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. You also learned GPU-specific optimizations, including using `device_put` for data transfer and managing GPU memory, to enhance training efficiency. Each method offers unique benefits, allowing you to choose the best approach based on your project requirements." - ] - } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "gpuType": "T4", - "provenance": [] - }, - "jupytext": { - "formats": "ipynb,md:myst" - }, - "kernelspec": { - "display_name": "Python 3", - "name": "python3" - }, - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} diff --git a/docs/data_loaders_on_gpu_with_jax.md b/docs/data_loaders_on_gpu_with_jax.md deleted file mode 100644 index a83ec4c..0000000 --- a/docs/data_loaders_on_gpu_with_jax.md +++ /dev/null @@ -1,650 +0,0 @@ ---- -jupytext: - formats: ipynb,md:myst - text_representation: - extension: .md - format_name: myst - format_version: 0.13 - jupytext_version: 1.15.2 -kernelspec: - display_name: Python 3 - name: python3 ---- - -+++ {"id": "PUFGZggH49zp"} - -# Introduction to Data Loaders on GPU with JAX - -+++ {"id": "3ia4PKEV5Dr8"} - -[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/data_loaders_on_gpu_with_jax.ipynb) - -This tutorial explores different data loading strategies for using **JAX** on a single [**GPU**](https://jax.readthedocs.io/en/latest/glossary.html#term-GPU). While JAX doesn't include a built-in data loader, it seamlessly integrates with popular data loading libraries, including: -* [**PyTorch DataLoader**](https://github.com/pytorch/data) -* [**TensorFlow Datasets (TFDS)**](https://github.com/tensorflow/datasets) -* [**Grain**](https://github.com/google/grain) -* [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading) - -You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset. - -Compared to [CPU-based loading](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html), working with a GPU introduces specific considerations like transferring data to the GPU using `device_put`, managing larger batch sizes for faster processing, and efficiently utilizing GPU memory. Unlike multi-device setups, this guide focuses on optimizing data handling for a single GPU. - - -If you're looking for CPU-specific data loading advice, see [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html). - -If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html). - -+++ {"id": "-rsMgVtO6asW"} - -## Import JAX API - -```{code-cell} -:id: tDJNQ6V-Dg5g - -import jax -import jax.numpy as jnp -from jax import grad, jit, vmap, random, device_put -``` - -+++ {"id": "TsFdlkSZKp9S"} - -## Checking GPU Availability for JAX - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: N3sqvaF3KJw1 -outputId: ab40f542-b8c0-422c-ca68-4ce292817889 ---- -jax.devices() -``` - -+++ {"id": "qyJ_WTghDnIc"} - -## Setting Hyperparameters and Initializing Parameters - -You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network. - -```{code-cell} -:id: qLNOSloFDka_ - -# A helper function to randomly initialize weights and biases -# for a dense neural network layer -def random_layer_params(m, n, key, scale=1e-2): - w_key, b_key = random.split(key) - return scale * random.normal(w_key, (n, m)), scale * random.normal(b_key, (n,)) - -# Function to initialize network parameters for all layers based on defined sizes -def init_network_params(sizes, key): - keys = random.split(key, len(sizes)) - return [random_layer_params(m, n, k) for m, n, k in zip(sizes[:-1], sizes[1:], keys)] - -layer_sizes = [784, 512, 512, 10] # Layers of the network -step_size = 0.01 # Learning rate -num_epochs = 8 # Number of training epochs -batch_size = 128 # Batch size for training -n_targets = 10 # Number of classes (digits 0-9) -num_pixels = 28 * 28 # Each MNIST image is 28x28 pixels -data_dir = '/tmp/mnist_dataset' # Directory for storing the dataset - -# Initialize network parameters using the defined layer sizes and a random seed -params = init_network_params(layer_sizes, random.PRNGKey(0)) -``` - -+++ {"id": "rHLdqeI7D2WZ"} - -## Model Prediction with Auto-Batching - -In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image. - -To efficiently process multiple images simultaneously, you'll use [`vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap), which allows you to vectorize the `predict` function and apply it across a batch of inputs. This technique, called auto-batching, improves computational efficiency by leveraging hardware acceleration. - -```{code-cell} -:id: bKIYPSkvD1QV - -from jax.scipy.special import logsumexp - -def relu(x): - return jnp.maximum(0, x) - -def predict(params, image): - # per-example predictions - activations = image - for w, b in params[:-1]: - outputs = jnp.dot(w, activations) + b - activations = relu(outputs) - - final_w, final_b = params[-1] - logits = jnp.dot(final_w, activations) + final_b - return logits - logsumexp(logits) - -# Make a batched version of the `predict` function -batched_predict = vmap(predict, in_axes=(None, 0)) -``` - -+++ {"id": "rLqfeORsERek"} - -## Utility and Loss Functions - -You'll now define utility functions for: -- One-hot encoding: Converts class indices to binary vectors. -- Accuracy calculation: Measures the performance of the model on the dataset. -- Loss computation: Calculates the difference between predictions and targets. - -To optimize performance: -- [`grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html#jax.grad) is used to compute gradients of the loss function with respect to network parameters. -- [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html#jax.jit) compiles the update function, enabling faster execution by leveraging JAX's [XLA](https://openxla.org/xla) compilation. - -- [`device_put`](https://jax.readthedocs.io/en/latest/_autosummary/jax.device_put.html) to transfer the dataset to the GPU. - -```{code-cell} -:id: sA0a06raEQfS - -import time - -def one_hot(x, k, dtype=jnp.float32): - """Create a one-hot encoding of x of size k.""" - return jnp.array(x[:, None] == jnp.arange(k), dtype) - -def accuracy(params, images, targets): - """Calculate the accuracy of predictions.""" - target_class = jnp.argmax(targets, axis=1) - predicted_class = jnp.argmax(batched_predict(params, images), axis=1) - return jnp.mean(predicted_class == target_class) - -def loss(params, images, targets): - """Calculate the loss between predictions and targets.""" - preds = batched_predict(params, images) - return -jnp.mean(preds * targets) - -@jit -def update(params, x, y): - """Update the network parameters using gradient descent.""" - grads = grad(loss)(params, x, y) - return [(w - step_size * dw, b - step_size * db) - for (w, b), (dw, db) in zip(params, grads)] - -def reshape_and_one_hot(x, y): - """Reshape and one-hot encode the inputs.""" - x = jnp.reshape(x, (len(x), num_pixels)) - y = one_hot(y, n_targets) - return x, y - -def train_model(num_epochs, params, training_generator, data_loader_type='streamed'): - """Train the model for a given number of epochs and device_put for GPU transfer.""" - for epoch in range(num_epochs): - start_time = time.time() - for x, y in training_generator() if data_loader_type == 'streamed' else training_generator: - x, y = reshape_and_one_hot(x, y) - x, y = device_put(x), device_put(y) - params = update(params, x, y) - - print(f"Epoch {epoch + 1} in {time.time() - start_time:.2f} sec: " - f"Train Accuracy: {accuracy(params, train_images, train_labels):.4f}, " - f"Test Accuracy: {accuracy(params, test_images, test_labels):.4f}") -``` - -+++ {"id": "Hsionp5IYsQ9"} - -## Loading Data with PyTorch DataLoader - -This section shows how to load the MNIST dataset using PyTorch's DataLoader, convert the data to NumPy arrays, and apply transformations to flatten and cast images. - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: uA7XY0OezHse -outputId: 4c86f455-ff1d-474e-f8e3-7111d9b56996 ---- -!pip install torch torchvision -``` - -```{code-cell} -:id: kO5_WzwY59gE - -import numpy as np -from jax.tree_util import tree_map -from torch.utils import data -from torchvision.datasets import MNIST -``` - -```{code-cell} -:id: 6f6qU8PCc143 - -def numpy_collate(batch): - """Collate function to convert a batch of PyTorch data into NumPy arrays.""" - return tree_map(np.asarray, data.default_collate(batch)) - -class NumpyLoader(data.DataLoader): - """Custom DataLoader to return NumPy arrays from a PyTorch Dataset.""" - def __init__(self, dataset, batch_size=1, - shuffle=False, sampler=None, - batch_sampler=None, num_workers=0, - pin_memory=False, drop_last=False, - timeout=0, worker_init_fn=None): - super(self.__class__, self).__init__(dataset, - batch_size=batch_size, - shuffle=shuffle, - sampler=sampler, - batch_sampler=batch_sampler, - num_workers=num_workers, - collate_fn=numpy_collate, - pin_memory=pin_memory, - drop_last=drop_last, - timeout=timeout, - worker_init_fn=worker_init_fn) -class FlattenAndCast(object): - """Transform class to flatten and cast images to float32.""" - def __call__(self, pic): - return np.ravel(np.array(pic, dtype=jnp.float32)) -``` - -+++ {"id": "mfSnfJND6I8G"} - -### Load Dataset with Transformations - -Standardize the data by flattening the images, casting them to `float32`, and ensuring consistent data types. - -```{code-cell} -:id: Kxbl6bcx6crv - -mnist_dataset = MNIST(data_dir, download=True, transform=FlattenAndCast()) -``` - -+++ {"id": "kbdsqvPZGrsa"} - -### Full Training Dataset for Accuracy Checks - -Convert the entire training dataset to JAX arrays. - -```{code-cell} -:id: c9ZCJq_rzPck - -train_images = np.array(mnist_dataset.data).reshape(len(mnist_dataset.data), -1) -train_labels = one_hot(np.array(mnist_dataset.targets), n_targets) -``` - -+++ {"id": "WXUh0BwvG8Ko"} - -### Get Full Test Dataset - -Load and process the full test dataset. - -```{code-cell} -:id: brlLG4SqGphm - -mnist_dataset_test = MNIST(data_dir, download=True, train=False) -test_images = jnp.array(mnist_dataset_test.data.numpy().reshape(len(mnist_dataset_test.data), -1), dtype=jnp.float32) -test_labels = one_hot(np.array(mnist_dataset_test.targets), n_targets) -``` - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: Oz-UVnCxG5E8 -outputId: 53f3fb32-5096-4862-e022-3c3a1d82137a ---- -print('Train:', train_images.shape, train_labels.shape) -print('Test:', test_images.shape, test_labels.shape) -``` - -+++ {"id": "mNjn9dMPitKL"} - -### Training Data Generator - -Define a generator function using PyTorch's DataLoader for batch training. -Setting `num_workers > 0` enables multi-process data loading, which can accelerate data loading for larger datasets or intensive preprocessing tasks. Experiment with different values to find the optimal setting for your hardware and workload. - -Note: When setting `num_workers > 0`, you may see the following `RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.` -This warning can be safely ignored since data loaders do not use JAX within the forked processes. - -```{code-cell} -:id: 0LdT8P8aisWF - -def pytorch_training_generator(mnist_dataset): - return NumpyLoader(mnist_dataset, batch_size=batch_size, num_workers=0) -``` - -+++ {"id": "Xzt2x9S1HC3T"} - -### Training Loop (PyTorch DataLoader) - -The training loop uses the PyTorch DataLoader to iterate through batches and update model parameters. - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: SqweRz_98sN8 -outputId: bdd45256-3f5a-48f7-e45c-378078ac4279 ---- -train_model(num_epochs, params, pytorch_training_generator(mnist_dataset), data_loader_type='iterable') -``` - -+++ {"id": "Nm45ZTo6yrf5"} - -## Loading Data with TensorFlow Datasets (TFDS) - -This section demonstrates how to load the MNIST dataset using TFDS, fetch the full dataset for evaluation, and define a training generator for batch processing. GPU usage is explicitly disabled for TensorFlow. - -```{code-cell} -:id: sGaQAk1DHMUx - -import tensorflow_datasets as tfds -``` - -+++ {"id": "ZSc5K0Eiwm4L"} - -### Fetch Full Dataset for Evaluation - -Load the dataset with `tfds.load`, convert it to NumPy arrays, and process it for evaluation. - -```{code-cell} -:id: 1hOamw_7C8Pb - -# tfds.load returns tf.Tensors (or tf.data.Datasets if batch_size != -1) -mnist_data, info = tfds.load(name="mnist", batch_size=-1, data_dir=data_dir, with_info=True) -mnist_data = tfds.as_numpy(mnist_data) -train_data, test_data = mnist_data['train'], mnist_data['test'] - -# Full train set -train_images, train_labels = train_data['image'], train_data['label'] -train_images = jnp.reshape(train_images, (len(train_images), num_pixels)) -train_labels = one_hot(train_labels, n_targets) - -# Full test set -test_images, test_labels = test_data['image'], test_data['label'] -test_images = jnp.reshape(test_images, (len(test_images), num_pixels)) -test_labels = one_hot(test_labels, n_targets) -``` - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: Td3PiLdmEf7z -outputId: b8c9a32a-9cf0-4dc3-cb51-db21d32c6545 ---- -print('Train:', train_images.shape, train_labels.shape) -print('Test:', test_images.shape, test_labels.shape) -``` - -+++ {"id": "dXMvgk6sdq4j"} - -### Define the Training Generator - -Create a generator function to yield batches of data for training. - -```{code-cell} -:id: vX59u8CqEf4J - -def training_generator(): - # as_supervised=True gives us the (image, label) as a tuple instead of a dict - ds = tfds.load(name='mnist', split='train', as_supervised=True, data_dir=data_dir) - # You can build up an arbitrary tf.data input pipeline - ds = ds.batch(batch_size).prefetch(1) - # tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays - return tfds.as_numpy(ds) -``` - -+++ {"id": "EAWeUdnuFNBY"} - -### Training Loop (TFDS) - -Use the training generator in a custom training loop. - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: h2sO13XDGvq1 -outputId: f30805bb-e725-46ee-e053-6e97f2af81c5 ---- -train_model(num_epochs, params, training_generator) -``` - -+++ {"id": "-ryVkrAITS9Z"} - -## Loading Data with Grain - -This section demonstrates how to load MNIST data using Grain, a data-loading library. You'll define a custom dataset class for Grain and set up a Grain DataLoader for efficient training. - -+++ {"id": "waYhUMUGmhH-"} - -Install Grain - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: L78o7eeyGvn5 -outputId: cb0ce6cf-243b-4183-8f63-646e00232caa ---- -!pip install grain -``` - -+++ {"id": "66bH3ZDJ7Iat"} - -Import Required Libraries (import MNIST dataset from torchvision) - -```{code-cell} -:id: mS62eVL9Ifmz - -import numpy as np -import grain.python as pygrain -from torchvision.datasets import MNIST -``` - -+++ {"id": "0h6mwVrspPA-"} - -### Define Dataset Class - -Create a custom dataset class to load MNIST data for Grain. - -```{code-cell} -:id: bnrhac5Hh7y1 - -class Dataset: - def __init__(self, data_dir, train=True): - self.data_dir = data_dir - self.train = train - self.load_data() - - def load_data(self): - self.dataset = MNIST(self.data_dir, download=True, train=self.train) - - def __len__(self): - return len(self.dataset) - - def __getitem__(self, index): - img, label = self.dataset[index] - return np.ravel(np.array(img, dtype=np.float32)), label -``` - -+++ {"id": "53mf8bWEsyTr"} - -### Initialize the Dataset - -```{code-cell} -:id: pN3oF7-ostGE - -mnist_dataset = Dataset(data_dir) -``` - -+++ {"id": "GqD-ycgBuwv9"} - -### Get the full train and test dataset - -```{code-cell} -:id: f1VnTuX3u_kL - -# Convert training data to JAX arrays and encode labels as one-hot vectors -train_images = jnp.array([mnist_dataset[i][0] for i in range(len(mnist_dataset))], dtype=jnp.float32) -train_labels = one_hot(np.array([mnist_dataset[i][1] for i in range(len(mnist_dataset))]), n_targets) - -# Load test dataset and process it -mnist_dataset_test = MNIST(data_dir, download=True, train=False) -test_images = jnp.array([np.ravel(np.array(mnist_dataset_test[i][0], dtype=np.float32)) for i in range(len(mnist_dataset_test))], dtype=jnp.float32) -test_labels = one_hot(np.array([mnist_dataset_test[i][1] for i in range(len(mnist_dataset_test))]), n_targets) -``` - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: a2NHlp9klrQL -outputId: c9422190-55e9-400b-bd4e-0e7bf23dc6a1 ---- -print("Train:", train_images.shape, train_labels.shape) -print("Test:", test_images.shape, test_labels.shape) -``` - -+++ {"id": "1QPbXt7O0JN-"} - -### Initialize PyGrain DataLoader - -```{code-cell} -:id: 2jqd1jJt25Bj - -sampler = pygrain.SequentialSampler( - num_records=len(mnist_dataset), - shard_options=pygrain.NoSharding()) # Single-device, no sharding - -def pygrain_training_generator(): - return pygrain.DataLoader( - data_source=mnist_dataset, - sampler=sampler, - operations=[pygrain.Batch(batch_size)], - ) -``` - -+++ {"id": "mV5z4GLCGKlx"} - -### Training Loop (Grain) - -Run the training loop using the Grain DataLoader. - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: 9-iANQ-9CcW_ -outputId: b0e19da2-9e34-4183-c5d8-af66de5efa5c ---- -train_model(num_epochs, params, pygrain_training_generator) -``` - -+++ {"id": "o51P6lr86wz-"} - -## Loading Data with Hugging Face - -This section demonstrates loading MNIST data using the Hugging Face `datasets` library. You'll format the dataset for JAX compatibility, prepare flattened images and one-hot-encoded labels, and define a training generator. - -+++ {"id": "69vrihaOi4Oz"} - -Install the Hugging Face `datasets` library. - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: 19ipxPhI6oSN -outputId: b80b80cd-fc14-4a43-f8a8-2802de4faade ---- -!pip install datasets -``` - -```{code-cell} -:id: 8v1N59p76zn0 - -from datasets import load_dataset -``` - -+++ {"id": "8Gaj11tO7C86"} - -Load the MNIST dataset from Hugging Face and format it as `numpy` arrays for quick access or `jax` to get JAX arrays. - -```{code-cell} -:id: a22kTvgk6_fJ - -mnist_dataset = load_dataset("mnist", cache_dir=data_dir).with_format("numpy") -``` - -+++ {"id": "tgI7dIaX7JzM"} - -### Extract images and labels - -Get image shape and flatten for model input. - -```{code-cell} -:id: NHrKatD_7HbH - -train_images = mnist_dataset["train"]["image"] -train_labels = mnist_dataset["train"]["label"] -test_images = mnist_dataset["test"]["image"] -test_labels = mnist_dataset["test"]["label"] - -# Extract image shape -image_shape = train_images.shape[1:] -num_features = image_shape[0] * image_shape[1] - -# Flatten the images -train_images = train_images.reshape(-1, num_features) -test_images = test_images.reshape(-1, num_features) - -# One-hot encode the labels -train_labels = one_hot(train_labels, n_targets) -test_labels = one_hot(test_labels, n_targets) -``` - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: dITh435Z7Nwb -outputId: cc89c1ec-6987-4f1c-90a4-c3b355ea7225 ---- -print('Train:', train_images.shape, train_labels.shape) -print('Test:', test_images.shape, test_labels.shape) -``` - -+++ {"id": "kk_4zJlz7T1E"} - -### Define Training Generator - -Set up a generator to yield batches of images and labels for training. - -```{code-cell} -:id: -zLJhogj7RL- - -def hf_training_generator(): - """Yield batches for training.""" - for batch in mnist_dataset["train"].iter(batch_size): - x, y = batch["image"], batch["label"] - yield x, y -``` - -+++ {"id": "HIsGfkLI7dvZ"} - -### Training Loop (Hugging Face Datasets) - -Run the training loop using the Hugging Face training generator. - -```{code-cell} ---- -colab: - base_uri: https://localhost:8080/ -id: Ui6aLiZP7aLe -outputId: c51529e0-563f-4af0-9793-76b5e6f323db ---- -train_model(num_epochs, params, hf_training_generator) -``` - -+++ {"id": "rCJq2rvKlKWX"} - -## Summary - -This notebook explored efficient methods for loading data on a GPU with JAX, using libraries such as PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. You also learned GPU-specific optimizations, including using `device_put` for data transfer and managing GPU memory, to enhance training efficiency. Each method offers unique benefits, allowing you to choose the best approach based on your project requirements. diff --git a/docs/source/tutorials.md b/docs/source/tutorials.md index dc82a13..23cdd58 100644 --- a/docs/source/tutorials.md +++ b/docs/source/tutorials.md @@ -19,7 +19,6 @@ JAX_basic_text_classification JAX_examples_image_segmentation JAX_Vision_transformer JAX_machine_translation -<<<<<<< HEAD JAX_visualizing_models_metrics JAX_image_captioning JAX_time_series_classification