Skip to content

Commit

Permalink
Adding tutorial for data loaders on gpu with jax (jax-ml#109)
Browse files Browse the repository at this point in the history
  • Loading branch information
selamw1 committed Dec 5, 2024
1 parent 2bb53a7 commit 71614bd
Show file tree
Hide file tree
Showing 6 changed files with 1,845 additions and 4 deletions.
2 changes: 2 additions & 0 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
'JAX_transformer_text_classification.md',
'data_loaders_on_cpu_with_jax.md',
'data_loaders_for_multi_device_setups_with_jax.md',
'data_loaders_on_gpu_with_jax.md',
]

suppress_warnings = [
Expand Down Expand Up @@ -105,4 +106,5 @@
'JAX_transformer_text_classification.ipynb',
'data_loaders_on_cpu_with_jax.ipynb',
'data_loaders_for_multi_device_setups_with_jax.ipynb',
'data_loaders_on_gpu_with_jax.ipynb',
]
10 changes: 8 additions & 2 deletions docs/source/data_loaders_on_cpu_with_jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,13 @@
"- [**Grain**](https://github.com/google/grain)\n",
"- [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)\n",
"\n",
"You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset."
"In this tutorial, you'll learn how to efficiently load data using these libraries for a simple image classification task based on the MNIST dataset.\n",
"\n",
"Compared to GPU or multi-device setups, CPU-based data loading is straightforward as it avoids challenges like GPU memory management and data synchronization across devices. This makes it ideal for smaller-scale tasks or scenarios where data resides exclusively on the CPU.\n",
"\n",
"If you're looking for GPU-specific data loading advice, see [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html).\n",
"\n",
"If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html)."
]
},
{
Expand Down Expand Up @@ -1489,7 +1495,7 @@
"source": [
"## Summary\n",
"\n",
"This notebook has guided you through efficient methods for loading data on a CPU when using JAX. You’ve learned how to leverage popular libraries such as PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to streamline the data loading process for your machine learning tasks. Each of these methods offers unique advantages and considerations, allowing you to choose the best approach based on the specific needs of your project."
"This notebook has introduced efficient strategies for data loading on a CPU with JAX, demonstrating how to integrate popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. Each library offers distinct advantages, enabling you to streamline the data loading process for machine learning tasks. By understanding the strengths of these methods, you can select the approach that best suits your project's specific requirements."
]
}
],
Expand Down
10 changes: 8 additions & 2 deletions docs/source/data_loaders_on_cpu_with_jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,13 @@ This tutorial explores different data loading strategies for using **JAX** on a
- [**Grain**](https://github.com/google/grain)
- [**Hugging Face**](https://huggingface.co/docs/datasets/en/use_with_jax#data-loading)

You'll see how to use each of these libraries to efficiently load data for a simple image classification task using the MNIST dataset.
In this tutorial, you'll learn how to efficiently load data using these libraries for a simple image classification task based on the MNIST dataset.

Compared to GPU or multi-device setups, CPU-based data loading is straightforward as it avoids challenges like GPU memory management and data synchronization across devices. This makes it ideal for smaller-scale tasks or scenarios where data resides exclusively on the CPU.

If you're looking for GPU-specific data loading advice, see [Data Loaders on GPU](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_on_gpu_with_jax.html).

If you're looking for a multi-device data loading strategy, see [Data Loaders on Multi-Device Setups](https://jax-ai-stack.readthedocs.io/en/latest/data_loaders_for_multi_device_setups_with_jax.html).

+++ {"id": "pEsb135zE-Jo"}

Expand Down Expand Up @@ -682,4 +688,4 @@ train_model(num_epochs, params, hf_training_generator)

## Summary

This notebook has guided you through efficient methods for loading data on a CPU when using JAX. You’ve learned how to leverage popular libraries such as PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets to streamline the data loading process for your machine learning tasks. Each of these methods offers unique advantages and considerations, allowing you to choose the best approach based on the specific needs of your project.
This notebook has introduced efficient strategies for data loading on a CPU with JAX, demonstrating how to integrate popular libraries like PyTorch DataLoader, TensorFlow Datasets, Grain, and Hugging Face Datasets. Each library offers distinct advantages, enabling you to streamline the data loading process for machine learning tasks. By understanding the strengths of these methods, you can select the approach that best suits your project's specific requirements.
Loading

0 comments on commit 71614bd

Please sign in to comment.