From f97b3d952662a2b5c40960cb383bfcb46d18aa3c Mon Sep 17 00:00:00 2001 From: selamw1 Date: Tue, 3 Dec 2024 15:22:02 -0800 Subject: [PATCH] =?UTF-8?q?=E2=80=9Creferece=5Ftutorial=5Flinks=5Fadded?= =?UTF-8?q?=E2=80=9D?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/data_loaders_on_gpu_with_jax.ipynb | 88 +++++++++++++------------ docs/data_loaders_on_gpu_with_jax.md | 19 ++++-- 2 files changed, 58 insertions(+), 49 deletions(-) diff --git a/docs/data_loaders_on_gpu_with_jax.ipynb b/docs/data_loaders_on_gpu_with_jax.ipynb index f726297..40c8ddc 100644 --- a/docs/data_loaders_on_gpu_with_jax.ipynb +++ b/docs/data_loaders_on_gpu_with_jax.ipynb @@ -25,7 +25,12 @@ "\n", "You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset.\n", "\n", - "Compared to the [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html), working with GPUs introduces opportunities for further optimization, such as transferring data to the GPU using `device_put`, leveraging larger batch sizes for faster processing, and addressing considerations like memory management." + "Compared to [CPU-based loading](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html), working with a GPU introduces specific considerations like transferring data to the GPU using `device_put`, managing larger batch sizes for faster processing, and efficiently utilizing GPU memory. Unlike multi-device setups, this guide focuses on optimizing data handling for a single GPU.\n", + "\n", + "\n", + "If you're looking for CPU-specific data loading advice, see [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html).\n", + "\n", + "If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html)." ] }, { @@ -34,12 +39,12 @@ "id": "-rsMgVtO6asW" }, "source": [ - "### Import JAX API" + "## Import JAX API" ] }, { "cell_type": "code", - "execution_count": 35, + "execution_count": null, "metadata": { "id": "tDJNQ6V-Dg5g" }, @@ -56,12 +61,12 @@ "id": "TsFdlkSZKp9S" }, "source": [ - "### Checking GPU Availability for JAX" + "## Checking GPU Availability for JAX" ] }, { "cell_type": "code", - "execution_count": 36, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -91,14 +96,14 @@ "id": "qyJ_WTghDnIc" }, "source": [ - "### Setting Hyperparameters and Initializing Parameters\n", + "## Setting Hyperparameters and Initializing Parameters\n", "\n", "You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network." ] }, { "cell_type": "code", - "execution_count": 37, + "execution_count": null, "metadata": { "id": "qLNOSloFDka_" }, @@ -133,7 +138,7 @@ "id": "rHLdqeI7D2WZ" }, "source": [ - "### Model Prediction with Auto-Batching\n", + "## Model Prediction with Auto-Batching\n", "\n", "In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.\n", "\n", @@ -142,7 +147,7 @@ }, { "cell_type": "code", - "execution_count": 38, + "execution_count": null, "metadata": { "id": "bKIYPSkvD1QV" }, @@ -174,7 +179,7 @@ "id": "rLqfeORsERek" }, "source": [ - "### Utility and Loss Functions\n", + "## Utility and Loss Functions\n", "\n", "You'll now define utility functions for:\n", "- One-hot encoding: Converts class indices to binary vectors.\n", @@ -190,7 +195,7 @@ }, { "cell_type": "code", - "execution_count": 39, + "execution_count": null, "metadata": { "id": "sA0a06raEQfS" }, @@ -253,7 +258,7 @@ }, { "cell_type": "code", - "execution_count": 40, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -287,7 +292,7 @@ }, { "cell_type": "code", - "execution_count": 41, + "execution_count": null, "metadata": { "id": "kO5_WzwY59gE" }, @@ -301,7 +306,7 @@ }, { "cell_type": "code", - "execution_count": 42, + "execution_count": null, "metadata": { "id": "6f6qU8PCc143" }, @@ -348,7 +353,7 @@ }, { "cell_type": "code", - "execution_count": 43, + "execution_count": null, "metadata": { "id": "Kxbl6bcx6crv" }, @@ -370,7 +375,7 @@ }, { "cell_type": "code", - "execution_count": 44, + "execution_count": null, "metadata": { "id": "c9ZCJq_rzPck" }, @@ -393,7 +398,7 @@ }, { "cell_type": "code", - "execution_count": 45, + "execution_count": null, "metadata": { "id": "brlLG4SqGphm" }, @@ -406,7 +411,7 @@ }, { "cell_type": "code", - "execution_count": 46, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -446,7 +451,7 @@ }, { "cell_type": "code", - "execution_count": 47, + "execution_count": null, "metadata": { "id": "0LdT8P8aisWF" }, @@ -469,7 +474,7 @@ }, { "cell_type": "code", - "execution_count": 48, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -510,7 +515,7 @@ }, { "cell_type": "code", - "execution_count": 49, + "execution_count": null, "metadata": { "id": "sGaQAk1DHMUx" }, @@ -532,7 +537,7 @@ }, { "cell_type": "code", - "execution_count": 50, + "execution_count": null, "metadata": { "id": "1hOamw_7C8Pb" }, @@ -556,7 +561,7 @@ }, { "cell_type": "code", - "execution_count": 51, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -592,7 +597,7 @@ }, { "cell_type": "code", - "execution_count": 52, + "execution_count": null, "metadata": { "id": "vX59u8CqEf4J" }, @@ -620,7 +625,7 @@ }, { "cell_type": "code", - "execution_count": 53, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -670,7 +675,7 @@ }, { "cell_type": "code", - "execution_count": 54, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -714,7 +719,7 @@ }, { "cell_type": "code", - "execution_count": 55, + "execution_count": null, "metadata": { "id": "mS62eVL9Ifmz" }, @@ -738,7 +743,7 @@ }, { "cell_type": "code", - "execution_count": 56, + "execution_count": null, "metadata": { "id": "bnrhac5Hh7y1" }, @@ -772,7 +777,7 @@ }, { "cell_type": "code", - "execution_count": 57, + "execution_count": null, "metadata": { "id": "pN3oF7-ostGE" }, @@ -792,7 +797,7 @@ }, { "cell_type": "code", - "execution_count": 58, + "execution_count": null, "metadata": { "id": "f1VnTuX3u_kL" }, @@ -810,7 +815,7 @@ }, { "cell_type": "code", - "execution_count": 59, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -844,7 +849,7 @@ }, { "cell_type": "code", - "execution_count": 60, + "execution_count": null, "metadata": { "id": "2jqd1jJt25Bj" }, @@ -875,7 +880,7 @@ }, { "cell_type": "code", - "execution_count": 61, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -925,7 +930,7 @@ }, { "cell_type": "code", - "execution_count": 62, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -979,7 +984,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": null, "metadata": { "id": "8v1N59p76zn0" }, @@ -999,7 +1004,7 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": null, "metadata": { "id": "a22kTvgk6_fJ" }, @@ -1021,7 +1026,7 @@ }, { "cell_type": "code", - "execution_count": 65, + "execution_count": null, "metadata": { "id": "NHrKatD_7HbH" }, @@ -1047,7 +1052,7 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -1083,7 +1088,7 @@ }, { "cell_type": "code", - "execution_count": 67, + "execution_count": null, "metadata": { "id": "-zLJhogj7RL-" }, @@ -1109,7 +1114,7 @@ }, { "cell_type": "code", - "execution_count": 68, + "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" @@ -1145,7 +1150,7 @@ "source": [ "## Summary\n", "\n", - "This notebook explored efficient methods for loading data on a GPU with JAX, using libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. You also learned GPU-specific optimizations, such as `device_put` for data transfer and memory management, to enhance training efficiency. Each methods offers unique benefits, helping you choose the best fit for your project needs." + "This notebook explored efficient methods for loading data on a GPU with JAX, using libraries such as PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. You also learned GPU-specific optimizations, including using `device_put` for data transfer and managing GPU memory, to enhance training efficiency. Each method offers unique benefits, allowing you to choose the best approach based on your project requirements." ] } ], @@ -1153,7 +1158,6 @@ "accelerator": "GPU", "colab": { "gpuType": "T4", - "name": "data_loaders_on_gpu_with_jax.ipynb", "provenance": [] }, "jupytext": { diff --git a/docs/data_loaders_on_gpu_with_jax.md b/docs/data_loaders_on_gpu_with_jax.md index 4ec7487..a83ec4c 100644 --- a/docs/data_loaders_on_gpu_with_jax.md +++ b/docs/data_loaders_on_gpu_with_jax.md @@ -27,11 +27,16 @@ This tutorial explores different data loading strategies for using **JAX** on a You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset. -Compared to the [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html), working with GPUs introduces opportunities for further optimization, such as transferring data to the GPU using `device_put`, leveraging larger batch sizes for faster processing, and addressing considerations like memory management. +Compared to [CPU-based loading](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html), working with a GPU introduces specific considerations like transferring data to the GPU using `device_put`, managing larger batch sizes for faster processing, and efficiently utilizing GPU memory. Unlike multi-device setups, this guide focuses on optimizing data handling for a single GPU. + + +If you're looking for CPU-specific data loading advice, see [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_cpu_with_jax.html). + +If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html). +++ {"id": "-rsMgVtO6asW"} -### Import JAX API +## Import JAX API ```{code-cell} :id: tDJNQ6V-Dg5g @@ -43,7 +48,7 @@ from jax import grad, jit, vmap, random, device_put +++ {"id": "TsFdlkSZKp9S"} -### Checking GPU Availability for JAX +## Checking GPU Availability for JAX ```{code-cell} --- @@ -57,7 +62,7 @@ jax.devices() +++ {"id": "qyJ_WTghDnIc"} -### Setting Hyperparameters and Initializing Parameters +## Setting Hyperparameters and Initializing Parameters You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network. @@ -89,7 +94,7 @@ params = init_network_params(layer_sizes, random.PRNGKey(0)) +++ {"id": "rHLdqeI7D2WZ"} -### Model Prediction with Auto-Batching +## Model Prediction with Auto-Batching In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image. @@ -120,7 +125,7 @@ batched_predict = vmap(predict, in_axes=(None, 0)) +++ {"id": "rLqfeORsERek"} -### Utility and Loss Functions +## Utility and Loss Functions You'll now define utility functions for: - One-hot encoding: Converts class indices to binary vectors. @@ -642,4 +647,4 @@ train_model(num_epochs, params, hf_training_generator) ## Summary -This notebook explored efficient methods for loading data on a GPU with JAX, using libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. You also learned GPU-specific optimizations, such as `device_put` for data transfer and memory management, to enhance training efficiency. Each methods offers unique benefits, helping you choose the best fit for your project needs. +This notebook explored efficient methods for loading data on a GPU with JAX, using libraries such as PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. You also learned GPU-specific optimizations, including using `device_put` for data transfer and managing GPU memory, to enhance training efficiency. Each method offers unique benefits, allowing you to choose the best approach based on your project requirements.