Skip to content

Commit

Permalink
Apply review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
mtsokol committed Nov 25, 2024
1 parent fcc6b33 commit 78135a9
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 91 deletions.
193 changes: 132 additions & 61 deletions docs/JAX_time_series_classification.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,32 @@
"# Time series classification with JAX\n",
"\n",
"In this tutorial, we're going to perform time series classification with a Convolutional Neural Network.\n",
"We're going to use FordA dataset from the [UCR archive](https://www.cs.ucr.edu/%7Eeamonn/time_series_data_2018/).\n",
"We will use the FordA dataset from the [UCR archive](https://www.cs.ucr.edu/%7Eeamonn/time_series_data_2018/),\n",
"which contains measurements of engine noise captured by a motor sensor.\n",
"\n",
"The problem we're facing is to assess if an engine is malfunctioning based on recorded noises it generates.\n",
"Each sample is comprised of noise measurements across time, together with a \"yes/no\" label, so it's a binary classification problem.\n",
"We need to assess if an engine is malfunctioning based on the recorded noises it generates.\n",
"Each sample comprises of noise measurements across time, together with a \"yes/no\" label,\n",
"so this is a binary classification problem.\n",
"\n",
"Although convolution models are mainly associated with image processing, they are useful also for time series data as they're able to extract temporal structures."
"Although convolution models are mainly associated with image processing, they are also useful\n",
"for time series data because they can extract temporal structures."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tools overview and setup\n",
"\n",
"Here's a list of key packages that belong to the JAX AI stack required for this tutorial:\n",
"\n",
"- [JAX](https://github.com/jax-ml/jax) for array computations.\n",
"- [Flax](https://github.com/google/flax) for constructing neural networks.\n",
"- [Optax](https://github.com/google-deepmind/optax) for gradient processing and optimization.\n",
"- [Grain](https://github.com/google/grain/) to define data sources.\n",
"- [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress.\n",
"\n",
"We'll start by installing and importing these packages."
]
},
{
Expand All @@ -26,21 +46,6 @@
"# !pip install -U grain tqdm requests matplotlib"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tools overview\n",
"\n",
"Here's a list of key packages that belong to JAX AI stack:\n",
"\n",
"- [JAX](https://github.com/jax-ml/jax) will be used for array computations.\n",
"- [Flax](https://github.com/google/flax) for constructing neural networks.\n",
"- [Optax](https://github.com/google-deepmind/optax) for gradient processing and optimization.\n",
"- [Grain](https://github.com/google/grain/) will be be used to define data sources.\n",
"- [tqdm](https://tqdm.github.io/) for a progress bar to monitor the training progress."
]
},
{
"cell_type": "code",
"execution_count": 2,
Expand All @@ -62,15 +67,15 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Dataset\n",
"## Load the dataset\n",
"\n",
"We load dataset files into NumPy arrays, add singleton dimention to take into\n",
"the account convolution features, and change `-1` label to `0` value:"
"We load dataset files into NumPy arrays, add singleton dimension to take convolution features\n",
"into account, and change `-1` label to `0` (so that the expected values are `0` and `1`):"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
Expand Down Expand Up @@ -106,9 +111,16 @@
"(x_train, y_train), (x_test, y_test) = prepare_ucr_dataset()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's visualize example samples from each class."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 5,
"metadata": {},
"outputs": [
{
Expand All @@ -123,7 +135,6 @@
}
],
"source": [
"# Here are exemplary samples from each class\n",
"classes = np.unique(np.concatenate((y_train, y_test), axis=0))\n",
"for c in classes:\n",
" c_x_train = x_train[y_train == c]\n",
Expand All @@ -136,13 +147,17 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"For handling input data we're going to use Grain, a pure Python package developed\n",
"for JAX and Flax models. Grain supports custom setups where data sources might come\n",
"in different forms, but they all need to implement the `grain.RandomAccessDataSource`\n",
"### Create a Data Loader using Grain\n",
"\n",
"For handling input data we're going to use Grain, a pure Python package developed for JAX and\n",
"Flax models.\n",
"\n",
"Grain follows the source-sampler-loader paradigm. Grain supports custom setups where data sources\n",
"might come in different forms, but they all need to implement the `grain.RandomAccessDataSource`\n",
"interface. See [PyGrain Data Sources](https://github.com/google/grain/blob/main/docs/data_sources.md)\n",
"for more details.\n",
"\n",
"Our dataset is comprised of relatively small NumPy arrays so our DataSource is uncomplicated:"
"Our dataset is comprised of relatively small NumPy arrays so our `DataSource` is uncomplicated:"
]
},
{
Expand Down Expand Up @@ -173,6 +188,18 @@
"test_source = DataSource(x_test, y_test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Samplers determine the order in which records are processed, and we'll use the\n",
"[`IndexSmapler`](https://github.com/google/grain/blob/main/docs/data_loader/samplers.md#index-sampler)\n",
"recommended by Grain.\n",
"\n",
"Finally, we'll create `DataLoader`s that handle orchestration of loading.\n",
"We'll leverage Grain's multiprocessing capabilities to scale processing up to 4 workers."
]
},
{
"cell_type": "code",
"execution_count": 8,
Expand Down Expand Up @@ -225,10 +252,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model\n",
"## Define the Model\n",
"\n",
"Here we construct the model with three convolution and dense layers. We use ReLU activation\n",
"function for middle layers and softmax in the final layer for binary classification output:"
"Let's now construct the Convolutional Neural Network with Flax by subclassing `nnx.Module`.\n",
"You can learn more about the [Flax NNX module system in the Flax documentation](https://flax.readthedocs.io/en/latest/nnx_basics.html#the-flax-nnx-module-system).\n",
"\n",
"Let's have three convolution and dense layers, and use ReLU activation function for middle\n",
"layers and softmax in the final layer for binary classification output."
]
},
{
Expand Down Expand Up @@ -313,9 +343,9 @@
" dtype=None,\n",
" param_dtype=<class 'jax.numpy.float32'>,\n",
" precision=None,\n",
" kernel_init=<function variance_scaling.<locals>.init at 0x7fcf74afdbd0>,\n",
" bias_init=<function zeros at 0x7fcf755c2cb0>,\n",
" conv_general_dilated=<function conv_general_dilated at 0x7fcf75ab3be0>\n",
" kernel_init=<function variance_scaling.<locals>.init at 0x7fec9a939bd0>,\n",
" bias_init=<function zeros at 0x7fec9b3cecb0>,\n",
" conv_general_dilated=<function conv_general_dilated at 0x7fec9b897be0>\n",
" ),\n",
" layer_norm_1=LayerNorm(\n",
" scale=Param(\n",
Expand All @@ -330,8 +360,8 @@
" param_dtype=<class 'jax.numpy.float32'>,\n",
" use_bias=True,\n",
" use_scale=True,\n",
" bias_init=<function zeros at 0x7fcf755c2cb0>,\n",
" scale_init=<function ones at 0x7fcf755c2e60>,\n",
" bias_init=<function zeros at 0x7fec9b3cecb0>,\n",
" scale_init=<function ones at 0x7fec9b3cee60>,\n",
" reduction_axes=-1,\n",
" feature_axes=-1,\n",
" axis_name=None,\n",
Expand Down Expand Up @@ -359,9 +389,9 @@
" dtype=None,\n",
" param_dtype=<class 'jax.numpy.float32'>,\n",
" precision=None,\n",
" kernel_init=<function variance_scaling.<locals>.init at 0x7fcf74afdbd0>,\n",
" bias_init=<function zeros at 0x7fcf755c2cb0>,\n",
" conv_general_dilated=<function conv_general_dilated at 0x7fcf75ab3be0>\n",
" kernel_init=<function variance_scaling.<locals>.init at 0x7fec9a939bd0>,\n",
" bias_init=<function zeros at 0x7fec9b3cecb0>,\n",
" conv_general_dilated=<function conv_general_dilated at 0x7fec9b897be0>\n",
" ),\n",
" layer_norm_2=LayerNorm(\n",
" scale=Param(\n",
Expand All @@ -376,8 +406,8 @@
" param_dtype=<class 'jax.numpy.float32'>,\n",
" use_bias=True,\n",
" use_scale=True,\n",
" bias_init=<function zeros at 0x7fcf755c2cb0>,\n",
" scale_init=<function ones at 0x7fcf755c2e60>,\n",
" bias_init=<function zeros at 0x7fec9b3cecb0>,\n",
" scale_init=<function ones at 0x7fec9b3cee60>,\n",
" reduction_axes=-1,\n",
" feature_axes=-1,\n",
" axis_name=None,\n",
Expand Down Expand Up @@ -405,9 +435,9 @@
" dtype=None,\n",
" param_dtype=<class 'jax.numpy.float32'>,\n",
" precision=None,\n",
" kernel_init=<function variance_scaling.<locals>.init at 0x7fcf74afdbd0>,\n",
" bias_init=<function zeros at 0x7fcf755c2cb0>,\n",
" conv_general_dilated=<function conv_general_dilated at 0x7fcf75ab3be0>\n",
" kernel_init=<function variance_scaling.<locals>.init at 0x7fec9a939bd0>,\n",
" bias_init=<function zeros at 0x7fec9b3cecb0>,\n",
" conv_general_dilated=<function conv_general_dilated at 0x7fec9b897be0>\n",
" ),\n",
" layer_norm_3=LayerNorm(\n",
" scale=Param(\n",
Expand All @@ -422,8 +452,8 @@
" param_dtype=<class 'jax.numpy.float32'>,\n",
" use_bias=True,\n",
" use_scale=True,\n",
" bias_init=<function zeros at 0x7fcf755c2cb0>,\n",
" scale_init=<function ones at 0x7fcf755c2e60>,\n",
" bias_init=<function zeros at 0x7fec9b3cecb0>,\n",
" scale_init=<function ones at 0x7fec9b3cee60>,\n",
" reduction_axes=-1,\n",
" feature_axes=-1,\n",
" axis_name=None,\n",
Expand All @@ -443,9 +473,9 @@
" dtype=None,\n",
" param_dtype=<class 'jax.numpy.float32'>,\n",
" precision=None,\n",
" kernel_init=<function variance_scaling.<locals>.init at 0x7fcf74afdbd0>,\n",
" bias_init=<function zeros at 0x7fcf755c2cb0>,\n",
" dot_general=<function dot_general at 0x7fcf75c30820>\n",
" kernel_init=<function variance_scaling.<locals>.init at 0x7fec9a939bd0>,\n",
" bias_init=<function zeros at 0x7fec9b3cecb0>,\n",
" dot_general=<function dot_general at 0x7fec9ba14820>\n",
" )\n",
")\n"
]
Expand All @@ -460,11 +490,15 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"## Training\n",
"## Train the Model\n",
"\n",
"To train our Flax model we need to construct an `nnx.Optimizer` object with our model and\n",
"a selected optimization algorithm. The optimizer object manages the model’s parameters and\n",
"applies gradients during training.\n",
"\n",
"To train our model we construct an `nnx.Optimizer` object with our model and a selected\n",
"optimization algorithm. We're going to use Adam optimizer, which is a popular choice\n",
"for Deep Learning models:"
"We're going to use [Adam optimizer](https://optax.readthedocs.io/en/latest/api/optimizers.html#adam),\n",
"a popular choice for Deep Learning models. We'll use it through\n",
"[Optax](https://optax.readthedocs.io/en/latest/index.html), an optimization library developed for JAX."
]
},
{
Expand All @@ -480,6 +514,14 @@
"optimizer = nnx.Optimizer(model, optax.adam(learning_rate, momentum))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll define a loss and logits computation function using Optax's\n",
"[`losses.softmax_cross_entropy_with_integer_labels`](https://optax.readthedocs.io/en/latest/api/losses.html#optax.losses.softmax_cross_entropy_with_integer_labels)."
]
},
{
"cell_type": "code",
"execution_count": 12,
Expand All @@ -495,6 +537,20 @@
" return loss, logits"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We'll now define the training and evaluation step functions. The loss and logits from both\n",
"functions will be used for calculating accuracy metrics.\n",
"\n",
"For training, we'll use `nnx.value_and_grad` to compute the gradients, and then update\n",
"the model’s parameters using our optimizer.\n",
"\n",
"Notice the use of [`nnx.jit`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/transforms.html#flax.nnx.jit). This sets up the functions for just-in-time (JIT) compilation with [XLA](https://openxla.org/xla)\n",
"for performant execution across different hardware accelerators like GPUs and TPUs."
]
},
{
"cell_type": "code",
"execution_count": 13,
Expand Down Expand Up @@ -551,6 +607,14 @@
"}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We can now train the CNN model. We'll evaluate the model’s performance on the test set\n",
"after each epoch, and print the metrics: total loss and accuracy."
]
},
{
"cell_type": "code",
"execution_count": 15,
Expand Down Expand Up @@ -948,7 +1012,7 @@
"[train] epoch: 134/300, [28/28], loss=0.504 [00:03<00:00]\n",
"[train] epoch: 135/300, [28/28], loss=0.505 [00:03<00:00]\n",
"[train] epoch: 136/300, [28/28], loss=0.504 [00:03<00:00]\n",
"[train] epoch: 137/300, [28/28], loss=0.505 [00:04<00:00]\n",
"[train] epoch: 137/300, [28/28], loss=0.505 [00:03<00:00]\n",
"[train] epoch: 138/300, [28/28], loss=0.504 [00:03<00:00]\n",
"[train] epoch: 139/300, [28/28], loss=0.503 [00:03<00:00]\n",
"[train] epoch: 140/300, [28/28], loss=0.502 [00:03<00:00]\n"
Expand Down Expand Up @@ -1357,8 +1421,8 @@
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 3h 38min 30s, sys: 47min 32s, total: 4h 26min 3s\n",
"Wall time: 23min\n"
"CPU times: user 3h 32min 15s, sys: 44min 47s, total: 4h 17min 2s\n",
"Wall time: 22min 33s\n"
]
}
],
Expand All @@ -1369,6 +1433,13 @@
" evaluate_model(epoch)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Finally, let's visualize the loss and accuracy with Matplotlib."
]
},
{
"cell_type": "code",
"execution_count": 17,
Expand All @@ -1377,7 +1448,7 @@
{
"data": {
"text/plain": [
"<matplotlib.legend.Legend at 0x7fcf6830cbb0>"
"<matplotlib.legend.Legend at 0x7fec902d3730>"
]
},
"execution_count": 17,
Expand Down Expand Up @@ -1408,7 +1479,7 @@
{
"data": {
"text/plain": [
"[<matplotlib.lines.Line2D at 0x7fcf103e5180>]"
"[<matplotlib.lines.Line2D at 0x7fec681dd060>]"
]
},
"execution_count": 18,
Expand Down Expand Up @@ -1443,8 +1514,8 @@
"but we also need to pay attention to validation accuracy so as to spot when the model starts\n",
"overfitting.\n",
"\n",
"For model early stopping and selecting best model there's [Orbax](https://github.com/google/orbax)\n",
"library which provides checkpointing and persistence utilities."
"For model early stopping and selecting best model, you can check out [Orbax](https://github.com/google/orbax),\n",
"a library which provides checkpointing and persistence utilities."
]
}
],
Expand Down
Loading

0 comments on commit 78135a9

Please sign in to comment.