From 3818ac3b2424cb75be0ce8fdad2351f6c3b8f066 Mon Sep 17 00:00:00 2001 From: 8bitmp3 <19637339+8bitmp3@users.noreply.github.com> Date: Fri, 8 Nov 2024 22:05:36 +0000 Subject: [PATCH] Update JAX AI Stack Getting Started (#54) --- docs/getting_started_with_jax_for_AI.ipynb | 152 +++++++++++---------- docs/getting_started_with_jax_for_AI.md | 110 ++++++++------- 2 files changed, 135 insertions(+), 127 deletions(-) diff --git a/docs/getting_started_with_jax_for_AI.ipynb b/docs/getting_started_with_jax_for_AI.ipynb index 0e09ab1..8cf1f72 100644 --- a/docs/getting_started_with_jax_for_AI.ipynb +++ b/docs/getting_started_with_jax_for_AI.ipynb @@ -6,7 +6,7 @@ "id": "AEQPh3NtawWA" }, "source": [ - "# Getting started with JAX\n", + "# Getting started with JAX for AI\n", "\n", "[![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/getting_started_with_jax_for_AI.ipynb)\n", "\n", @@ -21,7 +21,7 @@ "source": [ "## Who is this tutorial for?\n", "\n", - "This tutorial is for those who want to get started using the JAX AI stack to build and train neural network models. It assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models." + "This tutorial is for those who want to get started using JAX and JAX-based AI libraries - the JAX AI stack - to build and train a simple neural network model. [JAX](http://jax.readthedocs.io) is a Python library for hardware accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google, Google DeepMind, and beyond. This tutorial assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models." ] }, { @@ -32,14 +32,12 @@ "source": [ "## What does this tutorial cover?\n", "\n", - "JAX itself focuses on array-based computation, and is at the core of a growing ecosystem of domain-specific tools. This tutorial introduces part of that ecosystem that is useful for defininig and training AI models, including:\n", + "JAX focuses on [array-based](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) computation, and is at the core of a growing ecosystem of domain-specific tools. This tutorial introduces part of that JAX ecosystem designed for AI-related tasks, including:\n", "\n", - "- [flax](http://flax.readthedocs.io): a tool designed for defining and building\n", - " scalable neural networks using JAX.\n", - "- [optax](http://optax.readthedocs.io): a tool designed for high-performance\n", - " optimization of functions in JAX, including the loss functions used in neural network training.\n", + "- [Flax NNX](http://flax.readthedocs.io): A machine learning library designed for defining and building scalable neural networks using JAX.\n", + "- [Optax](http://optax.readthedocs.io): A high-performance function optimization library that comes with built-in optimizers and loss functions.\n", "\n", - "Once you've worked through this content, you may wish to visit http://jax.readthedocs.io/ for a deeper dive into the JAX library itself." + "Once you've worked through this content, you may wish to visit the [JAX documentation site](http://jax.readthedocs.io/) for a deeper dive into the core JAX concepts." ] }, { @@ -61,12 +59,12 @@ "source": [ "### Loading the data\n", "\n", - "JAX can work with a variety of data loaders, but for simplicity here we can use the well-known [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset." + "JAX can work with a variety of data loaders, including [Grain](https://github.com/google/grain), [TensorFlow Datasets](https://github.com/tensorflow/datasets) and [TorchData](https://github.com/pytorch/data), but for simplicity this example uses the well-known [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset." ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": null, "metadata": { "id": "hKhPLnNxfOHU", "outputId": "ac3508f0-ccc6-409b-c719-99a4b8f94bd6" @@ -95,12 +93,12 @@ "id": "lst3E34dgrLc" }, "source": [ - "This dataset consists of 8x8 pixelized images of hand-written digits along with labels, and we can visualize a handful of them this way:" + "This dataset consists of `8x8` pixelated images of hand-written digits and their corresponding labels. Let’s visualize a handful of them with [`matplotlib`](https://matplotlib.org/stable/tutorials/index):" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": null, "metadata": { "id": "Y8cMntSdfyyT", "outputId": "9343a558-cd8c-473c-c109-aa8015c7ae7e" @@ -135,13 +133,13 @@ "id": "Z3l45KgtfUUo" }, "source": [ - "Let's split these into a training and testing set, and convert these splits into JAX arrays which will be ready to feed into our model.\n", - "We'll make use of the `jax.numpy` module, which provides a familiar NumPy-style API around JAX operations:" + "Next, split the dataset into a training and testing set, and convert these splits into [`jax.Array`s](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) before you feed them into the model.\n", + "You’ll use the `jax.numpy` module, which provides a familiar NumPy-style API around JAX operations:" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": { "id": "6jrYisoPh6TL" }, @@ -153,7 +151,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": null, "metadata": { "id": "oMRcwKd4hqOo", "outputId": "0ad36290-397b-431d-eba2-ef114daf5ea6" @@ -181,16 +179,17 @@ "id": "JzrixENjifiq" }, "source": [ - "### Defining the flax model\n", + "### Defining the Flax model\n", "\n", - "We can now use the [Flax](http://flax.readthedocs.io) package to create a simple [Feedforward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network with one hidden layer, and use a *scaled exponential linear unit* (SELU) activation function." + "You can now use [Flax NNX](http://flax.readthedocs.io) to create a simple [feed-forward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network - subclassing [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) - with [`flax.nnx.Linear`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Linear) layers with *scaled exponential linear unit* (SELU) activation function using the built-in [`flax.nnx.selu`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/activations.html#flax.nnx.selu):" ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": null, "metadata": { - "id": "U77VMQwRjTfH" + "id": "U77VMQwRjTfH", + "outputId": "345fed7a-4455-4036-85ed-57e673a4de01" }, "outputs": [ { @@ -231,7 +230,7 @@ " self.layer3 = nnx.Linear(n_hidden, n_targets, rngs=rngs)\n", "\n", " def __call__(self, x):\n", - " x = x.reshape(x.shape[0], self.n_features) # flatten images\n", + " x = x.reshape(x.shape[0], self.n_features) # Flatten images.\n", " x = nnx.selu(self.layer1(x))\n", " x = nnx.selu(self.layer2(x))\n", " x = self.layer3(x)\n", @@ -250,14 +249,15 @@ "source": [ "### Training the model\n", "\n", - "With this model defined, we can now use the [optax](http://optax.readthedocs.io) package to train it.\n", - "First we need to decide on a loss function: since we have an output layer with each node corresponding to an integer label, an appropriate metric is [`optax.softmax_cross_entropy_with_integer_labels`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.losses.softmax_cross_entropy_with_integer_labels).\n", - "We can then define a training step based on this optimizer:" + "With the `SimpleNN` model created and instantiated, you can now choose the loss function and the optimizer with the [Optax](http://optax.readthedocs.io) package, and then define the training step function. Use:\n", + "- [`optax.softmax_cross_entropy_with_integer_labels`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.losses.softmax_cross_entropy_with_integer_labels) as the loss, as the output layer will have nodes corresponding to a handwritten integer label.\n", + "- [`optax.sgd`](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.sgd) for the stochastic gradient descent.\n", + "- [`flax.nnx.Optimizer`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/optimizer.html) to set the train state." ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "metadata": { "id": "QwRvFPkYl5b2" }, @@ -295,15 +295,19 @@ "id": "K2Tp-ym6sXEl" }, "source": [ - "Notice here the use of `nnx.jit` and `nnx.grad`, which are transformations built on `jax.jit` and `jax.grad`: `jit` is a Just In Time compilation transformation, and will cause the function to be passed to the XLA compiler for fast repeated execution. `grad` is a gradient transformation, and uses JAX's automatic differentiation for fast optimization of large networks.\n", - "We'll return to these transformations further below.\n", + "Notice here the use of [`flax.nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) and [`flax.nnx.grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.grad), which are [Flax NNX transformations](https://flax.readthedocs.io/en/latest/guides/transforms.html) built on `jax.jit` and `jax.grad` [transformations](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations).\n", "\n", - "Now we define a training loop to repeatedly perform this train step over our training data, periodically printing the loss against the test set to monitor convergence:" + "- `jax.jit` is a [Just-In-Time compilation transformation](https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation), and will cause the function to be passed to the [XLA](https://openxla.org/xla) compiler for fast repeated execution.\n", + "- `jax.grad` is a [gradient transformation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) that uses JAX's automatic differentiation for fast optimization of large networks.\n", + "\n", + "You will return to these transformations later in the tutorial.\n", + "\n", + "Now that you have a training step function, define a training loop to repeatedly perform this training step over the training data, periodically printing the loss against the test set to monitor convergence:" ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": null, "metadata": { "id": "l9mukT0eqmsr", "outputId": "c6c7b2d6-8706-4bc3-d5a6-0396d7cfbf56" @@ -337,12 +341,12 @@ "id": "3sjOKxLDv8SS" }, "source": [ - "After 300 training epochs, our model appears to have converged to a target loss of `0.10`; lets look at what this implies for the accuracy of the labels for each image:" + "After 300 training epochs, your model should have converged to a target loss of around `0.10`. You can check what this implies for the accuracy of the labels for each image:" ] }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { "id": "6OmW0lVlsvJ1", "outputId": "f8d7849b-4242-48e7-8120-82e5574b18f3" @@ -371,13 +375,13 @@ "id": "vTKF3-CFwY50" }, "source": [ - "Our simple feed-forward network has achieved approximately 98% accuracy on the test set.\n", - "We can do a similar visualization as above to see some examples that the model predicted correctly (green) and incorrectly (red):" + "The simple feed-forward network has achieved approximately 98% accuracy on the test set.\n", + "You can do a similar visualization as above to review some examples that the model predicted correctly (in green) and incorrectly (in red):" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": null, "metadata": { "id": "uinijfm-qXsP", "outputId": "632f6e98-1779-4492-c2f7-125499c5b55f" @@ -411,7 +415,7 @@ "id": "x7IIiVymuTRa" }, "source": [ - "We've just scraped the surface with Flax here; the package includes a number of useful APIs for tracking metrics during training; you can see these in action in the [MNIST tutorial](https://flax.readthedocs.io/en/latest/nnx/mnist_tutorial.html) on the Flax website." + "In this tutorial, you have just scraped the surface with JAX, Flax NNX, and Optax here. The Flax NNX package includes a number of useful APIs for tracking metrics during training, which are features in the [Flax MNIST tutorial](https://flax.readthedocs.io/en/latest/nnx/mnist_tutorial.html) on the Flax website." ] }, { @@ -420,23 +424,23 @@ "id": "5ZfGvXAiy2yr" }, "source": [ - "## JAX key features\n", + "## Key JAX features\n", "\n", - "The Flax neural network API demonstrated above takes advantage of a number of key JAX features, designed into the library from the ground up. In particular:\n", + "The Flax NNX neural network API demonstrated above takes advantage of a number of [key JAX features](https://jax.readthedocs.io/en/latest/key-concepts.html), designed into the library from the ground up. In particular:\n", "\n", - "- **JAX provides a familiar NumPy-like API for array computing**.\n", - " This means that when processing data and outputs, we can reach for APIs like [`jax.numpy.count_nonzero`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.count_nonzero.html), which mirror the familiar APIs of the NumPy package; in this case [`numpy.count_nonzero`](https://numpy.org/doc/stable/reference/generated/numpy.count_nonzero.html).\n", + "- **JAX provides a familiar NumPy-like API for array computing.**\n", + " This means that when processing data and outputs, you can reach for APIs like [`jax.numpy.count_nonzero`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.count_nonzero.html), which mirror the familiar APIs of the NumPy package; in this case [`numpy.count_nonzero`](https://numpy.org/doc/stable/reference/generated/numpy.count_nonzero.html).\n", "\n", "- **JAX provides just-in-time (JIT) compilation.**\n", - " This means that you can implement your code easily in Python, but count on fast compiled execution on CPU, GPU, and TPU backends via the [XLA](https://openxla.org/xla) compiler by wrapping your code with a simple [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) transformation.\n", + " This means that you can implement your code easily in Python, but count on fast compiled execution on CPU, GPU, and TPU backends via the [XLA](https://openxla.org/xla) compiler by wrapping your code with a simple [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) [transformation](https://jax.readthedocs.io/en/latest/jit-compilation.html).\n", "\n", "- **JAX provides automatic differentiation (autodiff).**\n", - " This means that when fitting models, `optax` and `flax` can compute closed-form gradient functions for fast optimization of models, using the [`grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html) transformation.\n", + " This means that when fitting models, `optax` and `flax` can compute closed-form gradient functions for fast optimization of models, using the [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html) [transformation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html).\n", "\n", "- **JAX provides automatic vectorization.**\n", - " We didn't use it directly above, but under the hood flax takes advantage of JAX's vectorized map ([`vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html)) to automatically convert loss and gradient functions to efficient batch-aware functions that are just as fast as hand-written versions. This makes JAX implementations simpler and less error-prone.\n", + " While you didn't get to use this directly in the code before, but under the hood flax takes advantage of [JAX's vectorized map](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) ([`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html)) to automatically convert loss and gradient functions to efficient batch-aware functions that are just as fast as hand-written versions. This makes JAX implementations simpler and less error-prone.\n", "\n", - "We'll briefly show further examples of these in the following sections." + "You will learn more about these features through brief examples in the following sections." ] }, { @@ -447,15 +451,15 @@ "source": [ "### JAX NumPy interface\n", "\n", - "The foundational array computing package in Python is NumPy, and JAX provides a matching API via the `jax.numpy` subpackage.\n", - "Additionally, JAX arrays behave much like NumPy arrays in their attributes, and in terms of [indexing](https://numpy.org/doc/stable/user/basics.indexing.html) and [broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html) semantics.\n", + "The foundational array computing package in Python is NumPy, and [JAX provides](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jax-vs-numpy) [a matching API](https://jax.readthedocs.io/en/latest/quickstart.html#jax-as-numpy) via the [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) subpackage.\n", + "Additionally, [JAX arrays](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) ([`jax.Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array)) behave much like NumPy arrays in their attributes, and in terms of [indexing](https://numpy.org/doc/stable/user/basics.indexing.html) and [broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html) semantics.\n", "\n", - "Above we used the built-in `flax.nnx.selu` implementation, but we could instead implement our own version with JAX's NumPy API:" + "In the previous example, you used Flax's built-in `flax.nnx.selu` implementation. You can also implement SeLU using JAX's NumPy API as follows:" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": null, "metadata": { "id": "2u2femxe2EzA", "outputId": "89b9f9b0-5631-405c-f4d8-2198593d0d50" @@ -485,7 +489,7 @@ "id": "H9o_a859JLY9" }, "source": [ - "Despite the broad similarities, be aware that JAX does have some well-motivated differences from NumPy that you can read about in [JAX – the sharp bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html)." + "Despite the broad similarities, be aware that JAX does have some well-motivated differences from NumPy that you can read about in [🔪 JAX – The Sharp Bits 🔪](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) on the JAX site." ] }, { @@ -496,15 +500,15 @@ "source": [ "### Just-in-time compilation\n", "\n", - "JAX is built on the XLA compiler, and allows sequences of operations to be Just-in-time (JIT) compiled using the `jax.jit` transformation.\n", - "In the example above, we used the similar `nnx.jit` API, which has some special handling for Flax objects, for speed in our neural network training.\n", + "As mentioned before, JAX is built on the [XLA](https://openxla.org/xla) compiler, and allows sequences of operations to be just-in-time (JIT) compiled using the [`jax.jit` transformation](https://jax.readthedocs.io/en/latest/jit-compilation.html).\n", + "In the neural network example above, you used the [similar](https://flax.readthedocs.io/en/latest/guides/transforms.html) [`flax.nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transform, which has some special handling for Flax NNX objects for speed in neural network training.\n", "\n", - "Returning to our `selu` function, we can create a JIT-compiled version this way:" + "Returning to the previously defined `selu` function in JAX, you can create a `jax.jit`-compiled version this way:" ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": null, "metadata": { "id": "-Chp8yCjQaFY" }, @@ -525,7 +529,7 @@ }, { "cell_type": "code", - "execution_count": 12, + "execution_count": null, "metadata": { "id": "uHeJXgKURL6q", "outputId": "dfc5a602-2b28-4863-a852-38f8fe6aaab4" @@ -553,12 +557,12 @@ "id": "WWwD0NmzRLP8" }, "source": [ - "We can use IPython's `%timeit` magic to see the speedup (note the use of `.block_until_ready()`, which we use to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)):" + "You can use IPython's `%timeit` magic to observe the speedup (note the use of [`jax.block_until_ready()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.block_until_ready.html#jax.block_until_ready), which you need to use to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)):" ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": null, "metadata": { "id": "SzU_0NU5Jq_W", "outputId": "dba1ee6b-32f8-4429-a147-b6d4f4e6f0ff" @@ -578,7 +582,7 @@ }, { "cell_type": "code", - "execution_count": 14, + "execution_count": null, "metadata": { "id": "QOu7wo7UQ07v", "outputId": "bd91aaa2-d367-47e0-eb17-a90658de2d14" @@ -602,7 +606,7 @@ "id": "1ST-uLL9JqzB" }, "source": [ - "For this computation, running on CPU, JIT compilation gives an order of magnitude speedup.\n", + "For this computation, running on CPU, `jax.jit` compilation gives an order of magnitude speedup.\n", "JAX's documentation has more discussion of JIT compilation at [Just-in-time compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html)." ] }, @@ -612,16 +616,16 @@ "id": "XFWR0tYjLYcj" }, "source": [ - "### Automatic differentiation\n", + "### Automatic differentiation (autodiff)\n", "\n", - "For efficient optimization of neural network models, fast gradient computations are essential. JAX enables this via its automatic differentiation transformations like [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html), which computes a closed-form gradient of a JAX function. In the example above, we used the similar `nnx.grad` function, which has special handling for `flax.nnx` objects.\n", + "For efficient optimization of neural network models, fast gradient computations are essential. JAX enables this via its [automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) transformations like [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html), which computes a closed-form gradient of a JAX function. In the neural network example, you used the [similar](https://flax.readthedocs.io/en/latest/guides/transforms.html) [`flax.nnx.grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.grad) function, which has special handling for [`flax.nnx`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/index.html) objects.\n", "\n", "Here's how to compute the gradient of a function with `jax.grad`:" ] }, { "cell_type": "code", - "execution_count": 15, + "execution_count": null, "metadata": { "id": "JtsPYnKbOtZt", "outputId": "834c31f8-ed1f-46ae-a827-e0b7faa52181" @@ -649,12 +653,12 @@ "id": "1P-UEh9VO94k" }, "source": [ - "We can briefly check with a finite-difference approximation that this is giving the expected value:" + "You can briefly check with a finite-difference approximation that this is giving the expected value:" ] }, { "cell_type": "code", - "execution_count": 16, + "execution_count": null, "metadata": { "id": "1gOc4FyzPDUC", "outputId": "95053e89-048d-4331-b898-079818e23dae" @@ -682,7 +686,7 @@ "id": "pkQW2Hd_bPSd" }, "source": [ - "Importantly, the automatic differentiation approach is both more accurate and more efficient than computing numerical gradients. JAX's documentation has more discussion of autodiff at [Automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html)." + "Importantly, the automatic differentiation approach is both more accurate and efficient than computing numerical gradients. JAX's documentation has more discussion of autodiff at [Automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) and [Advanced automatic differentiation](https://jax.readthedocs.io/en/latest/advanced-autodiff.html)." ] }, { @@ -693,14 +697,14 @@ "source": [ "### Automatic vectorization\n", "\n", - "In our training loop above, we defined our loss function in terms of a single input data vector of shape `n_features`, but trained our model by passing batches of data (of shape `[n_samples, n_features]`). Rather than requiring a naive and slow loop over batches in Flax & Optax internals, they instead use JAX's automatic vectorization via the `jax.vmap` transformation to construct a batched version of the kernel automatically.\n", + "In the training loop example earlier, you defined the loss function in terms of a single input data vector of shape `n_features` but trained the model by passing batches of data (of shape `[n_samples, n_features]`). Rather than requiring a naive and slow loop over batches in Flax and Optax internals, they instead use JAX's [automatic vectorization](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) via the [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) transformation to construct a batched version of the kernel automatically.\n", "\n", "Consider a simple loss function that looks like this:" ] }, { "cell_type": "code", - "execution_count": 17, + "execution_count": null, "metadata": { "id": "OuSSCpxzdWw_" }, @@ -716,12 +720,12 @@ "id": "lOg9IWlPddfE" }, "source": [ - "We can evaluate it on a single data vector this way:" + "You can evaluate it on a single data vector this way:" ] }, { "cell_type": "code", - "execution_count": 18, + "execution_count": null, "metadata": { "id": "sYlEtbxedngb", "outputId": "39030fb7-feee-4da1-ef5d-54cd86ad8dfb" @@ -750,12 +754,12 @@ "id": "STit-syzk59F" }, "source": [ - "But if we attempt to evaluate it on a batch of vectors, it does not correctly return a batch of 4 losses:" + "But if you attempt to evaluate it on a batch of vectors, it does not correctly return a batch of 4 losses:" ] }, { "cell_type": "code", - "execution_count": 19, + "execution_count": null, "metadata": { "id": "1LFQX3zGlCil", "outputId": "a12c4d75-2d94-4341-e9ca-915a33f1278e" @@ -783,17 +787,17 @@ "id": "Qc3Kwe2HlhpA" }, "source": [ - "The problem is that our loss function is not batch-aware. Without automatic vectorization, there are two ways you can address this:\n", + "The problem is that this loss function is not batch-aware. Without automatic vectorization, there are two ways you can address this:\n", "\n", "1. Re-write your loss function by hand to operate on batched data; however, as functions become more complicated, this becomes difficult and error-prone.\n", "2. Naively loop over unbatched calls to your original function; however, this is easy to code, but can be slow because it doesn't take advantage of vectorized compute.\n", "\n", - "The `jax.vmap` transformation offers a third way: it automatically transforms your original function into a batch-aware version, so you get the speed of option 1 with the ease of option 2:" + "The [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) [transformation](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) offers a third way: it automatically transforms your original function into a batch-aware version, so you get the speed of option 1 with the ease of option 2:" ] }, { "cell_type": "code", - "execution_count": 20, + "execution_count": null, "metadata": { "id": "Y2Sa458OoRVL", "outputId": "d1d8295b-40d3-477a-e5d8-b2d6f28ad803" @@ -821,9 +825,9 @@ "id": "6A8L1QDFogKd" }, "source": [ - "In our neural network example above, both `flax` and `optax` make use of JAX's `vmap` to allow for efficient batched computations over our unbatched loss function.\n", + "In the neural network example earlier, both `flax` and `optax` make use of JAX's `vmap` to allow for efficient batched computations over our unbatched loss function.\n", "\n", - "JAX's documentation has more discussion of autodiff at [Automatic vectorization](https://jax.readthedocs.io/en/latest/automatic-vectorization.html)." + "JAX's documentation has more discussion of automatic vectorization at [Automatic vectorization](https://jax.readthedocs.io/en/latest/automatic-vectorization.html)." ] } ], diff --git a/docs/getting_started_with_jax_for_AI.md b/docs/getting_started_with_jax_for_AI.md index 75ba100..cb08145 100644 --- a/docs/getting_started_with_jax_for_AI.md +++ b/docs/getting_started_with_jax_for_AI.md @@ -13,7 +13,7 @@ kernelspec: +++ {"id": "AEQPh3NtawWA"} -# Getting started with JAX +# Getting started with JAX for AI [![Open in Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/getting_started_with_jax_for_AI.ipynb) @@ -23,20 +23,18 @@ kernelspec: ## Who is this tutorial for? -This tutorial is for those who want to get started using the JAX AI stack to build and train neural network models. It assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models. +This tutorial is for those who want to get started using JAX and JAX-based AI libraries - the JAX AI stack - to build and train a simple neural network model. [JAX](http://jax.readthedocs.io) is a Python library for hardware accelerator-oriented array computation and program transformation, and is the engine behind cutting-edge AI research and production models at Google, Google DeepMind, and beyond. This tutorial assumes some familiarity with numerical computing in Python with [NumPy](http://numpy.org), and assumes some conceptual familiarity with defining, training, and evaluating machine learning models. +++ {"id": "1Y92oUSGeoRz"} ## What does this tutorial cover? -JAX itself focuses on array-based computation, and is at the core of a growing ecosystem of domain-specific tools. This tutorial introduces part of that ecosystem that is useful for defininig and training AI models, including: +JAX focuses on [array-based](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) computation, and is at the core of a growing ecosystem of domain-specific tools. This tutorial introduces part of that JAX ecosystem designed for AI-related tasks, including: -- [flax](http://flax.readthedocs.io): a tool designed for defining and building - scalable neural networks using JAX. -- [optax](http://optax.readthedocs.io): a tool designed for high-performance - optimization of functions in JAX, including the loss functions used in neural network training. +- [Flax NNX](http://flax.readthedocs.io): A machine learning library designed for defining and building scalable neural networks using JAX. +- [Optax](http://optax.readthedocs.io): A high-performance function optimization library that comes with built-in optimizers and loss functions. -Once you've worked through this content, you may wish to visit http://jax.readthedocs.io/ for a deeper dive into the JAX library itself. +Once you've worked through this content, you may wish to visit the [JAX documentation site](http://jax.readthedocs.io/) for a deeper dive into the core JAX concepts. +++ {"id": "z7sAr0sderhh"} @@ -48,7 +46,7 @@ We'll start with a very quick example of what it looks like to use JAX with the ### Loading the data -JAX can work with a variety of data loaders, but for simplicity here we can use the well-known [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset. +JAX can work with a variety of data loaders, including [Grain](https://github.com/google/grain), [TensorFlow Datasets](https://github.com/tensorflow/datasets) and [TorchData](https://github.com/pytorch/data), but for simplicity this example uses the well-known [scikit-learn `digits`](https://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_digits.html) dataset. ```{code-cell} :id: hKhPLnNxfOHU @@ -63,7 +61,7 @@ print(f"{digits.target.shape=}") +++ {"id": "lst3E34dgrLc"} -This dataset consists of 8x8 pixelized images of hand-written digits along with labels, and we can visualize a handful of them this way: +This dataset consists of `8x8` pixelated images of hand-written digits and their corresponding labels. Let’s visualize a handful of them with [`matplotlib`](https://matplotlib.org/stable/tutorials/index): ```{code-cell} :id: Y8cMntSdfyyT @@ -82,8 +80,8 @@ for i, ax in enumerate(axes.flat): +++ {"id": "Z3l45KgtfUUo"} -Let's split these into a training and testing set, and convert these splits into JAX arrays which will be ready to feed into our model. -We'll make use of the `jax.numpy` module, which provides a familiar NumPy-style API around JAX operations: +Next, split the dataset into a training and testing set, and convert these splits into [`jax.Array`s](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) before you feed them into the model. +You’ll use the `jax.numpy` module, which provides a familiar NumPy-style API around JAX operations: ```{code-cell} :id: 6jrYisoPh6TL @@ -104,12 +102,13 @@ print(f"{images_test.shape=} {label_test.shape=}") +++ {"id": "JzrixENjifiq"} -### Defining the flax model +### Defining the Flax model -We can now use the [Flax](http://flax.readthedocs.io) package to create a simple [Feedforward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network with one hidden layer, and use a *scaled exponential linear unit* (SELU) activation function. +You can now use [Flax NNX](http://flax.readthedocs.io) to create a simple [feed-forward](https://en.wikipedia.org/wiki/Feedforward_neural_network) neural network - subclassing [`flax.nnx.Module`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/module.html#flax.nnx.Module) - with [`flax.nnx.Linear`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/linear.html#flax.nnx.Linear) layers with *scaled exponential linear unit* (SELU) activation function using the built-in [`flax.nnx.selu`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/activations.html#flax.nnx.selu): ```{code-cell} :id: U77VMQwRjTfH +:outputId: 345fed7a-4455-4036-85ed-57e673a4de01 from flax import nnx @@ -123,7 +122,7 @@ class SimpleNN(nnx.Module): self.layer3 = nnx.Linear(n_hidden, n_targets, rngs=rngs) def __call__(self, x): - x = x.reshape(x.shape[0], self.n_features) # flatten images + x = x.reshape(x.shape[0], self.n_features) # Flatten images. x = nnx.selu(self.layer1(x)) x = nnx.selu(self.layer2(x)) x = self.layer3(x) @@ -138,9 +137,10 @@ nnx.display(model) # Interactive display if penzai is installed. ### Training the model -With this model defined, we can now use the [optax](http://optax.readthedocs.io) package to train it. -First we need to decide on a loss function: since we have an output layer with each node corresponding to an integer label, an appropriate metric is [`optax.softmax_cross_entropy_with_integer_labels`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.losses.softmax_cross_entropy_with_integer_labels). -We can then define a training step based on this optimizer: +With the `SimpleNN` model created and instantiated, you can now choose the loss function and the optimizer with the [Optax](http://optax.readthedocs.io) package, and then define the training step function. Use: +- [`optax.softmax_cross_entropy_with_integer_labels`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.losses.softmax_cross_entropy_with_integer_labels) as the loss, as the output layer will have nodes corresponding to a handwritten integer label. +- [`optax.sgd`](https://optax.readthedocs.io/en/latest/api/optimizers.html#optax.sgd) for the stochastic gradient descent. +- [`flax.nnx.Optimizer`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/training/optimizer.html) to set the train state. ```{code-cell} :id: QwRvFPkYl5b2 @@ -173,10 +173,14 @@ def train_step( +++ {"id": "K2Tp-ym6sXEl"} -Notice here the use of `nnx.jit` and `nnx.grad`, which are transformations built on `jax.jit` and `jax.grad`: `jit` is a Just In Time compilation transformation, and will cause the function to be passed to the XLA compiler for fast repeated execution. `grad` is a gradient transformation, and uses JAX's automatic differentiation for fast optimization of large networks. -We'll return to these transformations further below. +Notice here the use of [`flax.nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) and [`flax.nnx.grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.grad), which are [Flax NNX transformations](https://flax.readthedocs.io/en/latest/guides/transforms.html) built on `jax.jit` and `jax.grad` [transformations](https://jax.readthedocs.io/en/latest/key-concepts.html#transformations). -Now we define a training loop to repeatedly perform this train step over our training data, periodically printing the loss against the test set to monitor convergence: +- `jax.jit` is a [Just-In-Time compilation transformation](https://jax.readthedocs.io/en/latest/jit-compilation.html#jit-compilation), and will cause the function to be passed to the [XLA](https://openxla.org/xla) compiler for fast repeated execution. +- `jax.grad` is a [gradient transformation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) that uses JAX's automatic differentiation for fast optimization of large networks. + +You will return to these transformations later in the tutorial. + +Now that you have a training step function, define a training loop to repeatedly perform this training step over the training data, periodically printing the loss against the test set to monitor convergence: ```{code-cell} :id: l9mukT0eqmsr @@ -191,7 +195,7 @@ for i in range(301): # 300 training epochs +++ {"id": "3sjOKxLDv8SS"} -After 300 training epochs, our model appears to have converged to a target loss of `0.10`; lets look at what this implies for the accuracy of the labels for each image: +After 300 training epochs, your model should have converged to a target loss of around `0.10`. You can check what this implies for the accuracy of the labels for each image: ```{code-cell} :id: 6OmW0lVlsvJ1 @@ -207,8 +211,8 @@ print(f"{num_matches} labels match out of {num_total}:" +++ {"id": "vTKF3-CFwY50"} -Our simple feed-forward network has achieved approximately 98% accuracy on the test set. -We can do a similar visualization as above to see some examples that the model predicted correctly (green) and incorrectly (red): +The simple feed-forward network has achieved approximately 98% accuracy on the test set. +You can do a similar visualization as above to review some examples that the model predicted correctly (in green) and incorrectly (in red): ```{code-cell} :id: uinijfm-qXsP @@ -226,36 +230,36 @@ for i, ax in enumerate(axes.flat): +++ {"id": "x7IIiVymuTRa"} -We've just scraped the surface with Flax here; the package includes a number of useful APIs for tracking metrics during training; you can see these in action in the [MNIST tutorial](https://flax.readthedocs.io/en/latest/nnx/mnist_tutorial.html) on the Flax website. +In this tutorial, you have just scraped the surface with JAX, Flax NNX, and Optax here. The Flax NNX package includes a number of useful APIs for tracking metrics during training, which are features in the [Flax MNIST tutorial](https://flax.readthedocs.io/en/latest/nnx/mnist_tutorial.html) on the Flax website. +++ {"id": "5ZfGvXAiy2yr"} -## JAX key features +## Key JAX features -The Flax neural network API demonstrated above takes advantage of a number of key JAX features, designed into the library from the ground up. In particular: +The Flax NNX neural network API demonstrated above takes advantage of a number of [key JAX features](https://jax.readthedocs.io/en/latest/key-concepts.html), designed into the library from the ground up. In particular: -- **JAX provides a familiar NumPy-like API for array computing**. - This means that when processing data and outputs, we can reach for APIs like [`jax.numpy.count_nonzero`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.count_nonzero.html), which mirror the familiar APIs of the NumPy package; in this case [`numpy.count_nonzero`](https://numpy.org/doc/stable/reference/generated/numpy.count_nonzero.html). +- **JAX provides a familiar NumPy-like API for array computing.** + This means that when processing data and outputs, you can reach for APIs like [`jax.numpy.count_nonzero`](https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.count_nonzero.html), which mirror the familiar APIs of the NumPy package; in this case [`numpy.count_nonzero`](https://numpy.org/doc/stable/reference/generated/numpy.count_nonzero.html). - **JAX provides just-in-time (JIT) compilation.** - This means that you can implement your code easily in Python, but count on fast compiled execution on CPU, GPU, and TPU backends via the [XLA](https://openxla.org/xla) compiler by wrapping your code with a simple [`jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) transformation. + This means that you can implement your code easily in Python, but count on fast compiled execution on CPU, GPU, and TPU backends via the [XLA](https://openxla.org/xla) compiler by wrapping your code with a simple [`jax.jit`](https://jax.readthedocs.io/en/latest/_autosummary/jax.jit.html) [transformation](https://jax.readthedocs.io/en/latest/jit-compilation.html). - **JAX provides automatic differentiation (autodiff).** - This means that when fitting models, `optax` and `flax` can compute closed-form gradient functions for fast optimization of models, using the [`grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html) transformation. + This means that when fitting models, `optax` and `flax` can compute closed-form gradient functions for fast optimization of models, using the [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html) [transformation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html). - **JAX provides automatic vectorization.** - We didn't use it directly above, but under the hood flax takes advantage of JAX's vectorized map ([`vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html)) to automatically convert loss and gradient functions to efficient batch-aware functions that are just as fast as hand-written versions. This makes JAX implementations simpler and less error-prone. + While you didn't get to use this directly in the code before, but under the hood flax takes advantage of [JAX's vectorized map](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) ([`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html)) to automatically convert loss and gradient functions to efficient batch-aware functions that are just as fast as hand-written versions. This makes JAX implementations simpler and less error-prone. -We'll briefly show further examples of these in the following sections. +You will learn more about these features through brief examples in the following sections. +++ {"id": "ZjneGfjy2Ef1"} ### JAX NumPy interface -The foundational array computing package in Python is NumPy, and JAX provides a matching API via the `jax.numpy` subpackage. -Additionally, JAX arrays behave much like NumPy arrays in their attributes, and in terms of [indexing](https://numpy.org/doc/stable/user/basics.indexing.html) and [broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html) semantics. +The foundational array computing package in Python is NumPy, and [JAX provides](https://jax.readthedocs.io/en/latest/notebooks/thinking_in_jax.html#jax-vs-numpy) [a matching API](https://jax.readthedocs.io/en/latest/quickstart.html#jax-as-numpy) via the [`jax.numpy`](https://jax.readthedocs.io/en/latest/jax.numpy.html) subpackage. +Additionally, [JAX arrays](https://jax.readthedocs.io/en/latest/key-concepts.html#jax-arrays-jax-array) ([`jax.Array`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Array.html#jax.Array)) behave much like NumPy arrays in their attributes, and in terms of [indexing](https://numpy.org/doc/stable/user/basics.indexing.html) and [broadcasting](https://numpy.org/doc/stable/user/basics.broadcasting.html) semantics. -Above we used the built-in `flax.nnx.selu` implementation, but we could instead implement our own version with JAX's NumPy API: +In the previous example, you used Flax's built-in `flax.nnx.selu` implementation. You can also implement SeLU using JAX's NumPy API as follows: ```{code-cell} :id: 2u2femxe2EzA @@ -272,16 +276,16 @@ print(selu(x)) +++ {"id": "H9o_a859JLY9"} -Despite the broad similarities, be aware that JAX does have some well-motivated differences from NumPy that you can read about in [JAX – the sharp bits](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html). +Despite the broad similarities, be aware that JAX does have some well-motivated differences from NumPy that you can read about in [🔪 JAX – The Sharp Bits 🔪](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html) on the JAX site. +++ {"id": "LnDgHRBsJrYL"} ### Just-in-time compilation -JAX is built on the XLA compiler, and allows sequences of operations to be Just-in-time (JIT) compiled using the `jax.jit` transformation. -In the example above, we used the similar `nnx.jit` API, which has some special handling for Flax objects, for speed in our neural network training. +As mentioned before, JAX is built on the [XLA](https://openxla.org/xla) compiler, and allows sequences of operations to be just-in-time (JIT) compiled using the [`jax.jit` transformation](https://jax.readthedocs.io/en/latest/jit-compilation.html). +In the neural network example above, you used the [similar](https://flax.readthedocs.io/en/latest/guides/transforms.html) [`flax.nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit) transform, which has some special handling for Flax NNX objects for speed in neural network training. -Returning to our `selu` function, we can create a JIT-compiled version this way: +Returning to the previously defined `selu` function in JAX, you can create a `jax.jit`-compiled version this way: ```{code-cell} :id: -Chp8yCjQaFY @@ -304,7 +308,7 @@ jnp.allclose(selu(x), selu_jit(x)) # results match +++ {"id": "WWwD0NmzRLP8"} -We can use IPython's `%timeit` magic to see the speedup (note the use of `.block_until_ready()`, which we use to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)): +You can use IPython's `%timeit` magic to observe the speedup (note the use of [`jax.block_until_ready()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.block_until_ready.html#jax.block_until_ready), which you need to use to account for JAX's [asynchronous dispatch](https://jax.readthedocs.io/en/latest/async_dispatch.html)): ```{code-cell} :id: SzU_0NU5Jq_W @@ -322,14 +326,14 @@ We can use IPython's `%timeit` magic to see the speedup (note the use of `.block +++ {"id": "1ST-uLL9JqzB"} -For this computation, running on CPU, JIT compilation gives an order of magnitude speedup. +For this computation, running on CPU, `jax.jit` compilation gives an order of magnitude speedup. JAX's documentation has more discussion of JIT compilation at [Just-in-time compilation](https://jax.readthedocs.io/en/latest/jit-compilation.html). +++ {"id": "XFWR0tYjLYcj"} -### Automatic differentiation +### Automatic differentiation (autodiff) -For efficient optimization of neural network models, fast gradient computations are essential. JAX enables this via its automatic differentiation transformations like [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html), which computes a closed-form gradient of a JAX function. In the example above, we used the similar `nnx.grad` function, which has special handling for `flax.nnx` objects. +For efficient optimization of neural network models, fast gradient computations are essential. JAX enables this via its [automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) transformations like [`jax.grad`](https://jax.readthedocs.io/en/latest/_autosummary/jax.grad.html), which computes a closed-form gradient of a JAX function. In the neural network example, you used the [similar](https://flax.readthedocs.io/en/latest/guides/transforms.html) [`flax.nnx.grad`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.grad) function, which has special handling for [`flax.nnx`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/index.html) objects. Here's how to compute the gradient of a function with `jax.grad`: @@ -343,7 +347,7 @@ jax.grad(selu)(x) +++ {"id": "1P-UEh9VO94k"} -We can briefly check with a finite-difference approximation that this is giving the expected value: +You can briefly check with a finite-difference approximation that this is giving the expected value: ```{code-cell} :id: 1gOc4FyzPDUC @@ -355,13 +359,13 @@ eps = 1E-3 +++ {"id": "pkQW2Hd_bPSd"} -Importantly, the automatic differentiation approach is both more accurate and more efficient than computing numerical gradients. JAX's documentation has more discussion of autodiff at [Automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html). +Importantly, the automatic differentiation approach is both more accurate and efficient than computing numerical gradients. JAX's documentation has more discussion of autodiff at [Automatic differentiation](https://jax.readthedocs.io/en/latest/automatic-differentiation.html) and [Advanced automatic differentiation](https://jax.readthedocs.io/en/latest/advanced-autodiff.html). +++ {"id": "xsKyfRDNbj2y"} ### Automatic vectorization -In our training loop above, we defined our loss function in terms of a single input data vector of shape `n_features`, but trained our model by passing batches of data (of shape `[n_samples, n_features]`). Rather than requiring a naive and slow loop over batches in Flax & Optax internals, they instead use JAX's automatic vectorization via the `jax.vmap` transformation to construct a batched version of the kernel automatically. +In the training loop example earlier, you defined the loss function in terms of a single input data vector of shape `n_features` but trained the model by passing batches of data (of shape `[n_samples, n_features]`). Rather than requiring a naive and slow loop over batches in Flax and Optax internals, they instead use JAX's [automatic vectorization](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) via the [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) transformation to construct a batched version of the kernel automatically. Consider a simple loss function that looks like this: @@ -374,7 +378,7 @@ def loss(x: jax.Array, x0: jax.Array): +++ {"id": "lOg9IWlPddfE"} -We can evaluate it on a single data vector this way: +You can evaluate it on a single data vector this way: ```{code-cell} :id: sYlEtbxedngb @@ -387,7 +391,7 @@ loss(x, x0) +++ {"id": "STit-syzk59F"} -But if we attempt to evaluate it on a batch of vectors, it does not correctly return a batch of 4 losses: +But if you attempt to evaluate it on a batch of vectors, it does not correctly return a batch of 4 losses: ```{code-cell} :id: 1LFQX3zGlCil @@ -399,12 +403,12 @@ loss(batched_x, x0) # wrong! +++ {"id": "Qc3Kwe2HlhpA"} -The problem is that our loss function is not batch-aware. Without automatic vectorization, there are two ways you can address this: +The problem is that this loss function is not batch-aware. Without automatic vectorization, there are two ways you can address this: 1. Re-write your loss function by hand to operate on batched data; however, as functions become more complicated, this becomes difficult and error-prone. 2. Naively loop over unbatched calls to your original function; however, this is easy to code, but can be slow because it doesn't take advantage of vectorized compute. -The `jax.vmap` transformation offers a third way: it automatically transforms your original function into a batch-aware version, so you get the speed of option 1 with the ease of option 2: +The [`jax.vmap`](https://jax.readthedocs.io/en/latest/_autosummary/jax.vmap.html#jax.vmap) [transformation](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) offers a third way: it automatically transforms your original function into a batch-aware version, so you get the speed of option 1 with the ease of option 2: ```{code-cell} :id: Y2Sa458OoRVL @@ -416,6 +420,6 @@ loss_batched(batched_x, x0) +++ {"id": "6A8L1QDFogKd"} -In our neural network example above, both `flax` and `optax` make use of JAX's `vmap` to allow for efficient batched computations over our unbatched loss function. +In the neural network example earlier, both `flax` and `optax` make use of JAX's `vmap` to allow for efficient batched computations over our unbatched loss function. -JAX's documentation has more discussion of autodiff at [Automatic vectorization](https://jax.readthedocs.io/en/latest/automatic-vectorization.html). +JAX's documentation has more discussion of automatic vectorization at [Automatic vectorization](https://jax.readthedocs.io/en/latest/automatic-vectorization.html).