From 78135a92c7f40b49d49b270eae42bf03d3b5c388 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Mateusz=20Sok=C3=B3=C5=82?= Date: Mon, 25 Nov 2024 16:22:29 +0000 Subject: [PATCH] Apply review comments --- docs/JAX_time_series_classification.ipynb | 193 +++++++++++++++------- docs/JAX_time_series_classification.md | 103 ++++++++---- 2 files changed, 205 insertions(+), 91 deletions(-) diff --git a/docs/JAX_time_series_classification.ipynb b/docs/JAX_time_series_classification.ipynb index 4d65164..db437f0 100644 --- a/docs/JAX_time_series_classification.ipynb +++ b/docs/JAX_time_series_classification.ipynb @@ -7,12 +7,32 @@ "# Time series classification with JAX\n", "\n", "In this tutorial, we're going to perform time series classification with a Convolutional Neural Network.\n", - "We're going to use FordA dataset from the [UCR archive](https://www.cs.ucr.edu/%7Eeamonn/time_series_data_2018/).\n", + "We will use the FordA dataset from the [UCR archive](https://www.cs.ucr.edu/%7Eeamonn/time_series_data_2018/),\n", + "which contains measurements of engine noise captured by a motor sensor.\n", "\n", - "The problem we're facing is to assess if an engine is malfunctioning based on recorded noises it generates.\n", - "Each sample is comprised of noise measurements across time, together with a \"yes/no\" label, so it's a binary classification problem.\n", + "We need to assess if an engine is malfunctioning based on the recorded noises it generates.\n", + "Each sample comprises of noise measurements across time, together with a \"yes/no\" label,\n", + "so this is a binary classification problem.\n", "\n", - "Although convolution models are mainly associated with image processing, they are useful also for time series data as they're able to extract temporal structures." + "Although convolution models are mainly associated with image processing, they are also useful\n", + "for time series data because they can extract temporal structures." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tools overview and setup\n", + "\n", + "Here's a list of key packages that belong to the JAX AI stack required for this tutorial:\n", + "\n", + "- [JAX](https://github.com/jax-ml/jax) for array computations.\n", + "- [Flax](https://github.com/google/flax) for constructing neural networks.\n", + "- [Optax](https://github.com/google-deepmind/optax) for gradient processing and optimization.\n", + "- [Grain](https://github.com/google/grain/) to define data sources.\n", + "- [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress.\n", + "\n", + "We'll start by installing and importing these packages." ] }, { @@ -26,21 +46,6 @@ "# !pip install -U grain tqdm requests matplotlib" ] }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Tools overview\n", - "\n", - "Here's a list of key packages that belong to JAX AI stack:\n", - "\n", - "- [JAX](https://github.com/jax-ml/jax) will be used for array computations.\n", - "- [Flax](https://github.com/google/flax) for constructing neural networks.\n", - "- [Optax](https://github.com/google-deepmind/optax) for gradient processing and optimization.\n", - "- [Grain](https://github.com/google/grain/) will be be used to define data sources.\n", - "- [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress." - ] - }, { "cell_type": "code", "execution_count": 2, @@ -62,15 +67,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Dataset\n", + "## Load the dataset\n", "\n", - "We load dataset files into NumPy arrays, add singleton dimention to take into\n", - "the account convolution features, and change `-1` label to `0` value:" + "We load dataset files into NumPy arrays, add singleton dimension to take convolution features\n", + "into account, and change `-1` label to `0` (so that the expected values are `0` and `1`):" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 3, "metadata": {}, "outputs": [], "source": [ @@ -106,9 +111,16 @@ "(x_train, y_train), (x_test, y_test) = prepare_ucr_dataset()" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's visualize example samples from each class." + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 5, "metadata": {}, "outputs": [ { @@ -123,7 +135,6 @@ } ], "source": [ - "# Here are exemplary samples from each class\n", "classes = np.unique(np.concatenate((y_train, y_test), axis=0))\n", "for c in classes:\n", " c_x_train = x_train[y_train == c]\n", @@ -136,13 +147,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "For handling input data we're going to use Grain, a pure Python package developed\n", - "for JAX and Flax models. Grain supports custom setups where data sources might come\n", - "in different forms, but they all need to implement the `grain.RandomAccessDataSource`\n", + "### Create a Data Loader using Grain\n", + "\n", + "For handling input data we're going to use Grain, a pure Python package developed for JAX and\n", + "Flax models.\n", + "\n", + "Grain follows the source-sampler-loader paradigm. Grain supports custom setups where data sources\n", + "might come in different forms, but they all need to implement the `grain.RandomAccessDataSource`\n", "interface. See [PyGrain Data Sources](https://github.com/google/grain/blob/main/docs/data_sources.md)\n", "for more details.\n", "\n", - "Our dataset is comprised of relatively small NumPy arrays so our DataSource is uncomplicated:" + "Our dataset is comprised of relatively small NumPy arrays so our `DataSource` is uncomplicated:" ] }, { @@ -173,6 +188,18 @@ "test_source = DataSource(x_test, y_test)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Samplers determine the order in which records are processed, and we'll use the\n", + "[`IndexSmapler`](https://github.com/google/grain/blob/main/docs/data_loader/samplers.md#index-sampler)\n", + "recommended by Grain.\n", + "\n", + "Finally, we'll create `DataLoader`s that handle orchestration of loading.\n", + "We'll leverage Grain's multiprocessing capabilities to scale processing up to 4 workers." + ] + }, { "cell_type": "code", "execution_count": 8, @@ -225,10 +252,13 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Model\n", + "## Define the Model\n", "\n", - "Here we construct the model with three convolution and dense layers. We use ReLU activation\n", - "function for middle layers and softmax in the final layer for binary classification output:" + "Let's now construct the Convolutional Neural Network with Flax by subclassing `nnx.Module`.\n", + "You can learn more about the [Flax NNX module system in the Flax documentation](https://flax.readthedocs.io/en/latest/nnx_basics.html#the-flax-nnx-module-system).\n", + "\n", + "Let's have three convolution and dense layers, and use ReLU activation function for middle\n", + "layers and softmax in the final layer for binary classification output." ] }, { @@ -313,9 +343,9 @@ " dtype=None,\n", " param_dtype=,\n", " precision=None,\n", - " kernel_init=.init at 0x7fcf74afdbd0>,\n", - " bias_init=,\n", - " conv_general_dilated=\n", + " kernel_init=.init at 0x7fec9a939bd0>,\n", + " bias_init=,\n", + " conv_general_dilated=\n", " ),\n", " layer_norm_1=LayerNorm(\n", " scale=Param(\n", @@ -330,8 +360,8 @@ " param_dtype=,\n", " use_bias=True,\n", " use_scale=True,\n", - " bias_init=,\n", - " scale_init=,\n", + " bias_init=,\n", + " scale_init=,\n", " reduction_axes=-1,\n", " feature_axes=-1,\n", " axis_name=None,\n", @@ -359,9 +389,9 @@ " dtype=None,\n", " param_dtype=,\n", " precision=None,\n", - " kernel_init=.init at 0x7fcf74afdbd0>,\n", - " bias_init=,\n", - " conv_general_dilated=\n", + " kernel_init=.init at 0x7fec9a939bd0>,\n", + " bias_init=,\n", + " conv_general_dilated=\n", " ),\n", " layer_norm_2=LayerNorm(\n", " scale=Param(\n", @@ -376,8 +406,8 @@ " param_dtype=,\n", " use_bias=True,\n", " use_scale=True,\n", - " bias_init=,\n", - " scale_init=,\n", + " bias_init=,\n", + " scale_init=,\n", " reduction_axes=-1,\n", " feature_axes=-1,\n", " axis_name=None,\n", @@ -405,9 +435,9 @@ " dtype=None,\n", " param_dtype=,\n", " precision=None,\n", - " kernel_init=.init at 0x7fcf74afdbd0>,\n", - " bias_init=,\n", - " conv_general_dilated=\n", + " kernel_init=.init at 0x7fec9a939bd0>,\n", + " bias_init=,\n", + " conv_general_dilated=\n", " ),\n", " layer_norm_3=LayerNorm(\n", " scale=Param(\n", @@ -422,8 +452,8 @@ " param_dtype=,\n", " use_bias=True,\n", " use_scale=True,\n", - " bias_init=,\n", - " scale_init=,\n", + " bias_init=,\n", + " scale_init=,\n", " reduction_axes=-1,\n", " feature_axes=-1,\n", " axis_name=None,\n", @@ -443,9 +473,9 @@ " dtype=None,\n", " param_dtype=,\n", " precision=None,\n", - " kernel_init=.init at 0x7fcf74afdbd0>,\n", - " bias_init=,\n", - " dot_general=\n", + " kernel_init=.init at 0x7fec9a939bd0>,\n", + " bias_init=,\n", + " dot_general=\n", " )\n", ")\n" ] @@ -460,11 +490,15 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## Training\n", + "## Train the Model\n", + "\n", + "To train our Flax model we need to construct an `nnx.Optimizer` object with our model and\n", + "a selected optimization algorithm. The optimizer object manages the model’s parameters and\n", + "applies gradients during training.\n", "\n", - "To train our model we construct an `nnx.Optimizer` object with our model and a selected\n", - "optimization algorithm. We're going to use Adam optimizer, which is a popular choice\n", - "for Deep Learning models:" + "We're going to use [Adam optimizer](https://optax.readthedocs.io/en/latest/api/optimizers.html#adam),\n", + "a popular choice for Deep Learning models. We'll use it through\n", + "[Optax](https://optax.readthedocs.io/en/latest/index.html), an optimization library developed for JAX." ] }, { @@ -480,6 +514,14 @@ "optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum))" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll define a loss and logits computation function using Optax's\n", + "[`losses.softmax_cross_entropy_with_integer_labels`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.losses.softmax_cross_entropy_with_integer_labels)." + ] + }, { "cell_type": "code", "execution_count": 12, @@ -495,6 +537,20 @@ " return loss, logits" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We'll now define the training and evaluation step functions. The loss and logits from both\n", + "functions will be used for calculating accuracy metrics.\n", + "\n", + "For training, we'll use `nnx.value_and_grad` to compute the gradients, and then update\n", + "the model’s parameters using our optimizer.\n", + "\n", + "Notice the use of [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit). This sets up the functions for just-in-time (JIT) compilation with [XLA](https://openxla.org/xla)\n", + "for performant execution across different hardware accelerators like GPUs and TPUs." + ] + }, { "cell_type": "code", "execution_count": 13, @@ -551,6 +607,14 @@ "}" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now train the CNN model. We'll evaluate the model’s performance on the test set\n", + "after each epoch, and print the metrics: total loss and accuracy." + ] + }, { "cell_type": "code", "execution_count": 15, @@ -948,7 +1012,7 @@ "[train] epoch: 134/300, [28/28], loss=0.504 [00:03<00:00]\n", "[train] epoch: 135/300, [28/28], loss=0.505 [00:03<00:00]\n", "[train] epoch: 136/300, [28/28], loss=0.504 [00:03<00:00]\n", - "[train] epoch: 137/300, [28/28], loss=0.505 [00:04<00:00]\n", + "[train] epoch: 137/300, [28/28], loss=0.505 [00:03<00:00]\n", "[train] epoch: 138/300, [28/28], loss=0.504 [00:03<00:00]\n", "[train] epoch: 139/300, [28/28], loss=0.503 [00:03<00:00]\n", "[train] epoch: 140/300, [28/28], loss=0.502 [00:03<00:00]\n" @@ -1357,8 +1421,8 @@ "name": "stdout", "output_type": "stream", "text": [ - "CPU times: user 3h 38min 30s, sys: 47min 32s, total: 4h 26min 3s\n", - "Wall time: 23min\n" + "CPU times: user 3h 32min 15s, sys: 44min 47s, total: 4h 17min 2s\n", + "Wall time: 22min 33s\n" ] } ], @@ -1369,6 +1433,13 @@ " evaluate_model(epoch)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Finally, let's visualize the loss and accuracy with Matplotlib." + ] + }, { "cell_type": "code", "execution_count": 17, @@ -1377,7 +1448,7 @@ { "data": { "text/plain": [ - "" + "" ] }, "execution_count": 17, @@ -1408,7 +1479,7 @@ { "data": { "text/plain": [ - "[]" + "[]" ] }, "execution_count": 18, @@ -1443,8 +1514,8 @@ "but we also need to pay attention to validation accuracy so as to spot when the model starts\n", "overfitting.\n", "\n", - "For model early stopping and selecting best model there's [Orbax](https://github.com/google/orbax)\n", - "library which provides checkpointing and persistence utilities." + "For model early stopping and selecting best model, you can check out [Orbax](https://github.com/google/orbax),\n", + "a library which provides checkpointing and persistence utilities." ] } ], diff --git a/docs/JAX_time_series_classification.md b/docs/JAX_time_series_classification.md index 1ba7bb0..4fbd21b 100644 --- a/docs/JAX_time_series_classification.md +++ b/docs/JAX_time_series_classification.md @@ -15,29 +15,36 @@ kernelspec: # Time series classification with JAX In this tutorial, we're going to perform time series classification with a Convolutional Neural Network. -We're going to use FordA dataset from the [UCR archive](https://www.cs.ucr.edu/%7Eeamonn/time_series_data_2018/). +We will use the FordA dataset from the [UCR archive](https://www.cs.ucr.edu/%7Eeamonn/time_series_data_2018/), +which contains measurements of engine noise captured by a motor sensor. -The problem we're facing is to assess if an engine is malfunctioning based on recorded noises it generates. -Each sample is comprised of noise measurements across time, together with a "yes/no" label, so it's a binary classification problem. +We need to assess if an engine is malfunctioning based on the recorded noises it generates. +Each sample comprises of noise measurements across time, together with a "yes/no" label, +so this is a binary classification problem. -Although convolution models are mainly associated with image processing, they are useful also for time series data as they're able to extract temporal structures. +Although convolution models are mainly associated with image processing, they are also useful +for time series data because they can extract temporal structures. -```{code-cell} ipython3 -# Required packages -# !pip install -U jax flax optax -# !pip install -U grain tqdm requests matplotlib -``` ++++ -## Tools overview +## Tools overview and setup -Here's a list of key packages that belong to JAX AI stack: +Here's a list of key packages that belong to the JAX AI stack required for this tutorial: -- [JAX](https://github.com/jax-ml/jax) will be used for array computations. +- [JAX](https://github.com/jax-ml/jax) for array computations. - [Flax](https://github.com/google/flax) for constructing neural networks. - [Optax](https://github.com/google-deepmind/optax) for gradient processing and optimization. -- [Grain](https://github.com/google/grain/) will be be used to define data sources. +- [Grain](https://github.com/google/grain/) to define data sources. - [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress. +We'll start by installing and importing these packages. + +```{code-cell} ipython3 +# Required packages +# !pip install -U jax flax optax +# !pip install -U grain tqdm requests matplotlib +``` + ```{code-cell} ipython3 import jax import jax.numpy as jnp @@ -50,10 +57,10 @@ import grain.python as grain import tqdm ``` -## Dataset +## Load the dataset -We load dataset files into NumPy arrays, add singleton dimention to take into -the account convolution features, and change `-1` label to `0` value: +We load dataset files into NumPy arrays, add singleton dimension to take convolution features +into account, and change `-1` label to `0` (so that the expected values are `0` and `1`): ```{code-cell} ipython3 def prepare_ucr_dataset() -> tuple: @@ -83,8 +90,9 @@ def prepare_ucr_dataset() -> tuple: (x_train, y_train), (x_test, y_test) = prepare_ucr_dataset() ``` +Let's visualize example samples from each class. + ```{code-cell} ipython3 -# Here are exemplary samples from each class classes = np.unique(np.concatenate((y_train, y_test), axis=0)) for c in classes: c_x_train = x_train[y_train == c] @@ -93,13 +101,17 @@ plt.legend() plt.show() ``` -For handling input data we're going to use Grain, a pure Python package developed -for JAX and Flax models. Grain supports custom setups where data sources might come -in different forms, but they all need to implement the `grain.RandomAccessDataSource` +### Create a Data Loader using Grain + +For handling input data we're going to use Grain, a pure Python package developed for JAX and +Flax models. + +Grain follows the source-sampler-loader paradigm. Grain supports custom setups where data sources +might come in different forms, but they all need to implement the `grain.RandomAccessDataSource` interface. See [PyGrain Data Sources](https://github.com/google/grain/blob/main/docs/data_sources.md) for more details. -Our dataset is comprised of relatively small NumPy arrays so our DataSource is uncomplicated: +Our dataset is comprised of relatively small NumPy arrays so our `DataSource` is uncomplicated: ```{code-cell} ipython3 class DataSource(grain.RandomAccessDataSource): @@ -119,6 +131,13 @@ train_source = DataSource(x_train, y_train) test_source = DataSource(x_test, y_test) ``` +Samplers determine the order in which records are processed, and we'll use the +[`IndexSmapler`](https://github.com/google/grain/blob/main/docs/data_loader/samplers.md#index-sampler) +recommended by Grain. + +Finally, we'll create `DataLoader`s that handle orchestration of loading. +We'll leverage Grain's multiprocessing capabilities to scale processing up to 4 workers. + ```{code-cell} ipython3 seed = 12 train_batch_size = 128 @@ -162,10 +181,13 @@ test_loader = grain.DataLoader( ) ``` -## Model +## Define the Model + +Let's now construct the Convolutional Neural Network with Flax by subclassing `nnx.Module`. +You can learn more about the [Flax NNX module system in the Flax documentation](https://flax.readthedocs.io/en/latest/nnx_basics.html#the-flax-nnx-module-system). -Here we construct the model with three convolution and dense layers. We use ReLU activation -function for middle layers and softmax in the final layer for binary classification output: +Let's have three convolution and dense layers, and use ReLU activation function for middle +layers and softmax in the final layer for binary classification output. ```{code-cell} ipython3 class MyModel(nnx.Module): @@ -211,11 +233,15 @@ model = MyModel(rngs=nnx.Rngs(0)) nnx.display(model) ``` -## Training +## Train the Model + +To train our Flax model we need to construct an `nnx.Optimizer` object with our model and +a selected optimization algorithm. The optimizer object manages the model’s parameters and +applies gradients during training. -To train our model we construct an `nnx.Optimizer` object with our model and a selected -optimization algorithm. We're going to use Adam optimizer, which is a popular choice -for Deep Learning models: +We're going to use [Adam optimizer](https://optax.readthedocs.io/en/latest/api/optimizers.html#adam), +a popular choice for Deep Learning models. We'll use it through +[Optax](https://optax.readthedocs.io/en/latest/index.html), an optimization library developed for JAX. ```{code-cell} ipython3 num_epochs = 300 @@ -225,6 +251,9 @@ momentum = 0.9 optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum)) ``` +We'll define a loss and logits computation function using Optax's +[`losses.softmax_cross_entropy_with_integer_labels`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.losses.softmax_cross_entropy_with_integer_labels). + ```{code-cell} ipython3 def compute_losses_and_logits(model: nnx.Module, batch_tokens: jax.Array, labels: jax.Array): logits = model(batch_tokens) @@ -235,6 +264,15 @@ def compute_losses_and_logits(model: nnx.Module, batch_tokens: jax.Array, labels return loss, logits ``` +We'll now define the training and evaluation step functions. The loss and logits from both +functions will be used for calculating accuracy metrics. + +For training, we'll use `nnx.value_and_grad` to compute the gradients, and then update +the model’s parameters using our optimizer. + +Notice the use of [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit). This sets up the functions for just-in-time (JIT) compilation with [XLA](https://openxla.org/xla) +for performant execution across different hardware accelerators like GPUs and TPUs. + ```{code-cell} ipython3 @nnx.jit def train_step( @@ -281,6 +319,9 @@ eval_metrics_history = { } ``` +We can now train the CNN model. We'll evaluate the model’s performance on the test set +after each epoch, and print the metrics: total loss and accuracy. + ```{code-cell} ipython3 bar_format = "{desc}[{n_fmt}/{total_fmt}]{postfix} [{elapsed}<{remaining}]" train_total_steps = len(x_train) // train_batch_size @@ -324,6 +365,8 @@ for epoch in range(num_epochs): evaluate_model(epoch) ``` +Finally, let's visualize the loss and accuracy with Matplotlib. + ```{code-cell} ipython3 plt.plot(train_metrics_history["train_loss"], label="Loss value during the training") plt.legend() @@ -342,5 +385,5 @@ that the loss function isn't completely flat yet. We could continue until the cu but we also need to pay attention to validation accuracy so as to spot when the model starts overfitting. -For model early stopping and selecting best model there's [Orbax](https://github.com/google/orbax) -library which provides checkpointing and persistence utilities. +For model early stopping and selecting best model, you can check out [Orbax](https://github.com/google/orbax), +a library which provides checkpointing and persistence utilities.