diff --git a/docs/source/JAX_visualizing_models_metrics.ipynb b/docs/source/JAX_visualizing_models_metrics.ipynb index 6ece7e0..c8589fc 100644 --- a/docs/source/JAX_visualizing_models_metrics.ipynb +++ b/docs/source/JAX_visualizing_models_metrics.ipynb @@ -13,45 +13,68 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "To keep things straightforward and familiar, we reuse the model and data from [Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html) - if you haven't read that yet and want the primer, start there before returning.\n", + "Measuring and visualizing experiment metrics is an essential part of the machine learning workflow.\n", + "In this tutorial, we will measure a JAX model using [TensorBoard](https://www.tensorflow.org/tensorboard) - a visualization tool that allows tracking loss and accuracy, visualizing model graphs, and more.\n", "\n", - "All of the modeling and training code is the same here. What we have added are the tensorboard connections and the discussion around them." + "We'll measure the model defined in [Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html). Go through that tutorial before continuing because we'll use the same modeling and training code, and add TensorBoard connections to it.\n", + "\n", + "## Setup\n", + "\n", + "TensorBoard is a part of the TensorFlow library. We'll install TensorFlow, load the TensorBoard extension for use within Jupyter Notebooks, and import the required libraries." ] }, { "cell_type": "code", - "execution_count": 12, + "execution_count": 1, + "id": "e5c72190", "metadata": {}, "outputs": [], "source": [ - "import tensorflow as tf\n", - "import io\n", - "from datetime import datetime" + "# python -m pip install tensorflow-cpu" ] }, { "cell_type": "code", "execution_count": 2, - "metadata": { - "id": "hKhPLnNxfOHU", - "outputId": "ac3508f0-ccc6-409b-c719-99a4b8f94bd6" - }, + "id": "686c835c", + "metadata": {}, "outputs": [], "source": [ - "from sklearn.datasets import load_digits\n", - "digits = load_digits()" + "%load_ext tensorboard" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "import tensorflow as tf\n", + "import io\n", + "from datetime import datetime" ] }, { "cell_type": "markdown", + "id": "b5808c1d", "metadata": {}, "source": [ - "Here we set the location of the tensorflow writer - the organization is somewhat arbitrary, though keeping a folder for each training run can make later navigation more straightforward." + "In TensorFlow, a `SummaryWriter` object handles outputs and logs. Let's create this object using [`tf.summary.create_file_writer`](https://www.tensorflow.org/api_docs/python/tf/summary/create_file_writer) and set the directory where the outputs should be stored. The following organization structure is arbitrary, but keeping a folder for each training run can make future navigation more straightforward." ] }, { "cell_type": "code", - "execution_count": 13, + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "# Optional - Clear any logs from previous runs\n", + "# !rm -rf ./runs/test/" + ] + }, + { + "cell_type": "code", + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -63,19 +86,47 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Pulled from the official tensorboard examples, this convert function makes it simple to drop matplotlib figures directly into tensorboard" + "## Load the dataset\n", + "\n", + "In the [Getting Started tutorial](https://docs.jaxstack.ai/en/latest/getting_started_with_jax_for_AI.html), we loaded the scikit-learn digits dataset and used matplotlib to display a few images in the notebook.\n", + "\n", + "We can also stash these images in TensorBoard. If a training needs to be repeated, it's more space efficient to stash the training data information and skip this step for subsequent trainings, provided the input is static." ] }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 6, + "id": "367798fb", + "metadata": { + "id": "hKhPLnNxfOHU", + "outputId": "ac3508f0-ccc6-409b-c719-99a4b8f94bd6" + }, + "outputs": [], + "source": [ + "from sklearn.datasets import load_digits\n", + "digits = load_digits()" + ] + }, + { + "cell_type": "markdown", + "id": "3b854512", + "metadata": {}, + "source": [ + "Taken from the [TensorBoard example on displaying image data](https://www.tensorflow.org/tensorboard/image_summaries), the following convert function makes it easier to view matplotlib figures (which are in images) directly in TensorBoard." + ] + }, + { + "cell_type": "code", + "execution_count": 7, "metadata": {}, "outputs": [], "source": [ + "# Source: https://www.tensorflow.org/tensorboard/image_summaries#logging_arbitrary_image_data\n", "def plot_to_image(figure):\n", - " \"\"\"Sourced from https://www.tensorflow.org/tensorboard/image_summaries\n", + " \"\"\"\n", " Converts the matplotlib plot specified by 'figure' to a PNG image and\n", - " returns it. The supplied figure is closed and inaccessible after this call.\"\"\"\n", + " returns it. The supplied figure is closed and inaccessible after this call.\n", + " \"\"\"\n", " # Save the plot to a PNG in memory.\n", " buf = io.BytesIO()\n", " plt.savefig(buf, format='png')\n", @@ -94,12 +145,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Whereas previously the example displays the training data snapshot in the notebook, here we stash it in the tensorboard images. If a given training is to be repeated many, many times it can save space to stash the training data information as its own run and skip this step for each subsequent training, provided the input is static. Note that this pattern uses the writer in a `with` context manager. We are able to step into and out of this type of context through the run without losing the same file/folder experiment." + "We'll use the `SummaryWriter` in a `with` context manager, to step in and out of this type of context through the run.\n", + "\n", + "[tf.summary](https://www.tensorflow.org/api_docs/python/tf/summary) has several functions to log different types of information. Here, use [`tf.summary.image`](https://www.tensorflow.org/api_docs/python/tf/summary/image) to write the image." ] }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 8, "metadata": { "id": "Y8cMntSdfyyT", "outputId": "9343a558-cd8c-473c-c109-aa8015c7ae7e" @@ -115,6 +168,7 @@ "for i, ax in enumerate(axes.flat):\n", " ax.imshow(digits.images[i], cmap='binary', interpolation='gaussian')\n", " ax.text(0.05, 0.05, str(digits.target[i]), transform=ax.transAxes, color='green')\n", + "\n", "with test_summary_writer.as_default():\n", " tf.summary.image(\"Training Data\", plot_to_image(fig), step=0)" ] @@ -123,29 +177,75 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "After running all above and launching `tensorboard --logdir runs/test` from the same folder, you should see the following in the supplied URL:\n", - "\n", - "![image.png](./_static/images/training_data_example.png)" + "We can now launch TensorBoard within the notebook. Notice the stored training data image." ] }, { "cell_type": "code", - "execution_count": 6, - "metadata": { - "id": "6jrYisoPh6TL" - }, - "outputs": [], + "execution_count": 10, + "id": "2214671c", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Reusing TensorBoard on port 6006 (pid 31393), started 0:00:14 ago. (Use '!kill 31393' to kill it.)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], "source": [ - "from sklearn.model_selection import train_test_split\n", - "splits = train_test_split(digits.images, digits.target, random_state=0)" + "%tensorboard --logdir runs/test" + ] + }, + { + "cell_type": "markdown", + "id": "72e12830", + "metadata": {}, + "source": [ + "## Define and train the model" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We can now create a simple neural network using Flax." ] }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 11, "metadata": { - "id": "oMRcwKd4hqOo", - "outputId": "0ad36290-397b-431d-eba2-ef114daf5ea6" + "id": "6jrYisoPh6TL" }, "outputs": [ { @@ -155,27 +255,11 @@ "images_train.shape=(1347, 8, 8) label_train.shape=(1347,)\n", "images_test.shape=(450, 8, 8) label_test.shape=(450,)\n" ] - } - ], - "source": [ - "import jax.numpy as jnp\n", - "images_train, images_test, label_train, label_test = map(jnp.asarray, splits)\n", - "print(f\"{images_train.shape=} {label_train.shape=}\")\n", - "print(f\"{images_test.shape=} {label_test.shape=}\")" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "metadata": { - "id": "U77VMQwRjTfH", - "outputId": "345fed7a-4455-4036-85ed-57e673a4de01" - }, - "outputs": [ + }, { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -187,7 +271,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -198,8 +282,16 @@ } ], "source": [ + "from sklearn.model_selection import train_test_split\n", + "import jax.numpy as jnp\n", "from flax import nnx\n", "\n", + "splits = train_test_split(digits.images, digits.target, random_state=0)\n", + "\n", + "images_train, images_test, label_train, label_test = map(jnp.asarray, splits)\n", + "print(f\"{images_train.shape=} {label_train.shape=}\")\n", + "print(f\"{images_test.shape=} {label_test.shape=}\")\n", + "\n", "class SimpleNN(nnx.Module):\n", "\n", " def __init__(self, n_features: int = 64, n_hidden: int = 100, n_targets: int = 10,\n", @@ -225,21 +317,14 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "We've now created the basic model - the above cell will render an interactive view of the model. Which, when fully expanded, should look something like this:\n", + "To track loss across our training run, we'll calculate loss in the training step.\n", "\n", - "![image.png](./_static/images/nnx_display_example.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "In order to track loss across our training run, we've collected the loss function call inside the training step:" + "Note that in the [Getting Started tutorial](https://docs.jaxstack.ai/en/latest/getting_started_with_jax_for_AI.html), this metric was computed once at the end of training, and called within the `for` loop." ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 12, "metadata": { "id": "QwRvFPkYl5b2" }, @@ -279,14 +364,17 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Now, we've collected the metrics that were previously computed once at the end of training and called them throughout the `for` loop, as you would in an eval stage.\n", - "\n", - "With the summary_writer context in place, we write out the `Loss` scalar every epoch, test the model accuracy every 10, and stash a accuracy test sheet every 500. Any custom metric can be added this way, through the tf.summary API." + "With the summary writer context in place, we can write the following to TensorBoard:\n", + "- the `Loss` scalar every epoch,\n", + "- model accuracy every 10 epochs\n", + "- accuracy test sheet every 500 epochs\n", + " \n", + "Any custom metric can be added this through the `tf.summary` API." ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 13, "metadata": { "id": "l9mukT0eqmsr", "outputId": "c6c7b2d6-8706-4bc3-d5a6-0396d7cfbf56" @@ -326,44 +414,103 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "During the training has run, and after, the added `Loss` and `Accuracy` scalars are available in the tensorboard UI under the run folder we've dynamically created by the datetime.\n", - "\n", - "The output there should look something like the following:\n", - "\n", - "![image.png](./_static/images/loss_acc_example.png)" + "## View metrics on TensorBoard" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Since we've stored the example test sheet every 500 epochs, it's easy to go back and step through the progress. With each training step using all of the training data the steps and epochs are essentially the same here.\n", + "On TensorBoard UI, the added `Loss` and `Accuracy` metrics are available in the `Scalars` tab under the `runs/test/` folder created dynamically using datetime." + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Reusing TensorBoard on port 6006 (pid 31393), started 0:00:49 ago. (Use '!kill 31393' to kill it.)" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "\n", + " \n", + " \n", + " " + ], + "text/plain": [ + "" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "%tensorboard --logdir runs/test" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Since we've stored the example test sheet every 500 epochs, we can go back and step through the progress. With each training step using all of the training data, the steps and epochs are essentially the same here.\n", + "\n", + "Navigate to the `Images` tab. \n", "\n", - "At step 1, we see poor accuracy, as you would expect\n", + "At step 1, we see poor accuracy, as expected:\n", "\n", - "![image.png](./_static/images/testsheet_start_example.png)\n", + "\"TensorBoard\n", "\n", - "By 500, the model is essentially done, but we see the bottom row `7` get lost and recovered at higher epochs as we go far into an overfitting regime. This kind of stored data can be very useful when the training routines become automated and a human is potentially only looking when something has gone wrong.\n", + "By 500, the model is essentially done. However, in the bottom row `7` gets lost and recovered at higher epochs as we go far into an overfitting regime. This kind of stored data can be very useful when the training routines become automated, and a human is potentially only checking when something has gone wrong.\n", "\n", - "![image.png](./_static/images/testsheets_500_3000.png)" + "![Accuracy testsheets at Step 500, 2500, and 3000](./_static/images/testsheets_500_3000.png)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "Finally, it can be useful to use nnx.display's ability to visualize networks and model output. Here we feed the top 35 test images into the model and display the final output vector for each - in the top plot, each row is an individual image prediction result: each column corresponds to a class, in this case the digits (0-9). Since we're calling the highest value in a given row the class prediction (`.argmax(axis=1)`), the final image predictions (bottom plot) simply match the largest value in each row in the upper plot." + "## Visualize model output" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "In addition to the TensorBoard visualization, Flax [`nnx.display`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/visualization.html#flax.nnx.display)'s interactive visualizations of networks and model outputs are also helpful. \n", + " \n", + "We can feed the top 35 test images into the model and display the final output vector for each. In the following plot, each row is an individual image prediction result, and each column corresponds to a class, in this case the digits (0-9)." ] }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -375,7 +522,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -383,11 +530,28 @@ }, "metadata": {}, "output_type": "display_data" - }, + } + ], + "source": [ + "nnx.display(model(images_test[:35]))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The highest value in a given row is the class prediction (`.argmax(axis=1)`). The following plot shows image predictions matching the largest value in each row in the previous(above) plot." + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -399,7 +563,7 @@ { "data": { "text/html": [ - "
" + "
" ], "text/plain": [ "" @@ -407,40 +571,17 @@ }, "metadata": {}, "output_type": "display_data" - }, - { - "data": { - "text/plain": [ - "(None, None)" - ] - }, - "execution_count": 11, - "metadata": {}, - "output_type": "execute_result" } ], "source": [ - "nnx.display(model(images_test[:35])), nnx.display(model(images_test[:35]).argmax(axis=1))" + "nnx.display(model(images_test[:35]).argmax(axis=1))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "The above cell output will give you an interactive plot that looks like this image below, where here we've 'clicked' in the bottom plot for entry `7` and hover over the corresponding value in the top plot.\n", - "\n", - "![image.png](./_static/images/model_display_example.png)" - ] - }, - { - "cell_type": "markdown", - "metadata": {}, - "source": [ - "## Extra Resources\n", - "\n", - "For further information about `TensorBoard` see [https://www.tensorflow.org/tensorboard/get_started](https://www.tensorflow.org/tensorboard/get_started)\n", - "\n", - "For more about `nnx.display()`, which calls Treescope under the hood, see [https://treescope.readthedocs.io/en/stable/](https://treescope.readthedocs.io/en/stable/)" + "For more information about these tools, check out the [TensorBoard documentation](https://www.tensorflow.org/tensorboard/get_started) and [Treescope documentation ](https://treescope.readthedocs.io/en/stable/) (library behind `nnx.display`)." ] } ], @@ -452,7 +593,7 @@ "formats": "ipynb,md:myst" }, "kernelspec": { - "display_name": "Python 3 (ipykernel)", + "display_name": "jax-ai-stack", "language": "python", "name": "python3" }, @@ -466,7 +607,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.10" + "version": "3.11.11" } }, "nbformat": 4, diff --git a/docs/source/JAX_visualizing_models_metrics.md b/docs/source/JAX_visualizing_models_metrics.md index 2e27bc3..e65f911 100644 --- a/docs/source/JAX_visualizing_models_metrics.md +++ b/docs/source/JAX_visualizing_models_metrics.md @@ -7,7 +7,7 @@ jupytext: format_version: 0.13 jupytext_version: 1.15.2 kernelspec: - display_name: Python 3 (ipykernel) + display_name: jax-ai-stack language: python name: python3 --- @@ -18,9 +18,22 @@ kernelspec: +++ -To keep things straightforward and familiar, we reuse the model and data from [Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html) - if you haven't read that yet and want the primer, start there before returning. +Measuring and visualizing experiment metrics is an essential part of the machine learning workflow. +In this tutorial, we will measure a JAX model using [TensorBoard](https://www.tensorflow.org/tensorboard) - a visualization tool that allows tracking loss and accuracy, visualizing model graphs, and more. -All of the modeling and training code is the same here. What we have added are the tensorboard connections and the discussion around them. +We'll measure the model defined in [Getting started with JAX for AI](https://jax-ai-stack.readthedocs.io/en/latest/getting_started_with_jax_for_AI.html). Go through that tutorial before continuing because we'll use the same modeling and training code, and add TensorBoard connections to it. + +## Setup + +TensorBoard is a part of the TensorFlow library. We'll install TensorFlow, load the TensorBoard extension for use within Jupyter Notebooks, and import the required libraries. + +```{code-cell} ipython3 +# python -m pip install tensorflow-cpu +``` + +```{code-cell} ipython3 +%load_ext tensorboard +``` ```{code-cell} ipython3 import tensorflow as tf @@ -28,28 +41,41 @@ import io from datetime import datetime ``` -```{code-cell} ipython3 -:id: hKhPLnNxfOHU -:outputId: ac3508f0-ccc6-409b-c719-99a4b8f94bd6 +In TensorFlow, a `SummaryWriter` object handles outputs and logs. Let's create this object using [`tf.summary.create_file_writer`](https://www.tensorflow.org/api_docs/python/tf/summary/create_file_writer) and set the directory where the outputs should be stored. The following organization structure is arbitrary, but keeping a folder for each training run can make future navigation more straightforward. -from sklearn.datasets import load_digits -digits = load_digits() +```{code-cell} ipython3 +# Optional - Clear any logs from previous runs +# !rm -rf ./runs/test/ ``` -Here we set the location of the tensorflow writer - the organization is somewhat arbitrary, though keeping a folder for each training run can make later navigation more straightforward. - ```{code-cell} ipython3 file_path = "runs/test/" + datetime.now().strftime("%Y%m%d-%H%M%S") test_summary_writer = tf.summary.create_file_writer(file_path) ``` -Pulled from the official tensorboard examples, this convert function makes it simple to drop matplotlib figures directly into tensorboard +## Load the dataset + +In the [Getting Started tutorial](https://docs.jaxstack.ai/en/latest/getting_started_with_jax_for_AI.html), we loaded the scikit-learn digits dataset and used matplotlib to display a few images in the notebook. + +We can also stash these images in TensorBoard. If a training needs to be repeated, it's more space efficient to stash the training data information and skip this step for subsequent trainings, provided the input is static. + +```{code-cell} ipython3 +:id: hKhPLnNxfOHU +:outputId: ac3508f0-ccc6-409b-c719-99a4b8f94bd6 + +from sklearn.datasets import load_digits +digits = load_digits() +``` + +Taken from the [TensorBoard example on displaying image data](https://www.tensorflow.org/tensorboard/image_summaries), the following convert function makes it easier to view matplotlib figures (which are in images) directly in TensorBoard. ```{code-cell} ipython3 +# Source: https://www.tensorflow.org/tensorboard/image_summaries#logging_arbitrary_image_data def plot_to_image(figure): - """Sourced from https://www.tensorflow.org/tensorboard/image_summaries + """ Converts the matplotlib plot specified by 'figure' to a PNG image and - returns it. The supplied figure is closed and inaccessible after this call.""" + returns it. The supplied figure is closed and inaccessible after this call. + """ # Save the plot to a PNG in memory. buf = io.BytesIO() plt.savefig(buf, format='png') @@ -64,7 +90,9 @@ def plot_to_image(figure): return image ``` -Whereas previously the example displays the training data snapshot in the notebook, here we stash it in the tensorboard images. If a given training is to be repeated many, many times it can save space to stash the training data information as its own run and skip this step for each subsequent training, provided the input is static. Note that this pattern uses the writer in a `with` context manager. We are able to step into and out of this type of context through the run without losing the same file/folder experiment. +We'll use the `SummaryWriter` in a `with` context manager, to step in and out of this type of context through the run. + +[tf.summary](https://www.tensorflow.org/api_docs/python/tf/summary) has several functions to log different types of information. Here, use [`tf.summary.image`](https://www.tensorflow.org/api_docs/python/tf/summary/image) to write the image. ```{code-cell} ipython3 :id: Y8cMntSdfyyT @@ -79,36 +107,35 @@ fig, axes = plt.subplots(10, 10, figsize=(6, 6), for i, ax in enumerate(axes.flat): ax.imshow(digits.images[i], cmap='binary', interpolation='gaussian') ax.text(0.05, 0.05, str(digits.target[i]), transform=ax.transAxes, color='green') + with test_summary_writer.as_default(): tf.summary.image("Training Data", plot_to_image(fig), step=0) ``` -After running all above and launching `tensorboard --logdir runs/test` from the same folder, you should see the following in the supplied URL: +We can now launch TensorBoard within the notebook. Notice the stored training data image. + +```{code-cell} ipython3 +%tensorboard --logdir runs/test +``` + +## Define and train the model -![image.png](./_static/images/training_data_example.png) ++++ + +We can now create a simple neural network using Flax. ```{code-cell} ipython3 :id: 6jrYisoPh6TL from sklearn.model_selection import train_test_split -splits = train_test_split(digits.images, digits.target, random_state=0) -``` +import jax.numpy as jnp +from flax import nnx -```{code-cell} ipython3 -:id: oMRcwKd4hqOo -:outputId: 0ad36290-397b-431d-eba2-ef114daf5ea6 +splits = train_test_split(digits.images, digits.target, random_state=0) -import jax.numpy as jnp images_train, images_test, label_train, label_test = map(jnp.asarray, splits) print(f"{images_train.shape=} {label_train.shape=}") print(f"{images_test.shape=} {label_test.shape=}") -``` - -```{code-cell} ipython3 -:id: U77VMQwRjTfH -:outputId: 345fed7a-4455-4036-85ed-57e673a4de01 - -from flax import nnx class SimpleNN(nnx.Module): @@ -131,13 +158,9 @@ model = SimpleNN(rngs=nnx.Rngs(0)) nnx.display(model) # Interactive display if penzai is installed. ``` -We've now created the basic model - the above cell will render an interactive view of the model. Which, when fully expanded, should look something like this: - -![image.png](./_static/images/nnx_display_example.png) - -+++ +To track loss across our training run, we'll calculate loss in the training step. -In order to track loss across our training run, we've collected the loss function call inside the training step: +Note that in the [Getting Started tutorial](https://docs.jaxstack.ai/en/latest/getting_started_with_jax_for_AI.html), this metric was computed once at the end of training, and called within the `for` loop. ```{code-cell} ipython3 :id: QwRvFPkYl5b2 @@ -172,9 +195,12 @@ def train_step( return loss ``` -Now, we've collected the metrics that were previously computed once at the end of training and called them throughout the `for` loop, as you would in an eval stage. - -With the summary_writer context in place, we write out the `Loss` scalar every epoch, test the model accuracy every 10, and stash a accuracy test sheet every 500. Any custom metric can be added this way, through the tf.summary API. +With the summary writer context in place, we can write the following to TensorBoard: +- the `Loss` scalar every epoch, +- model accuracy every 10 epochs +- accuracy test sheet every 500 epochs + +Any custom metric can be added this through the `tf.summary` API. ```{code-cell} ipython3 :id: l9mukT0eqmsr @@ -209,40 +235,46 @@ with test_summary_writer.as_default(): tf.summary.image(f"Step {i+1} Accuracy Testsheet", plot_to_image(fig), step=i+1) ``` -During the training has run, and after, the added `Loss` and `Accuracy` scalars are available in the tensorboard UI under the run folder we've dynamically created by the datetime. - -The output there should look something like the following: - -![image.png](./_static/images/loss_acc_example.png) +## View metrics on TensorBoard +++ -Since we've stored the example test sheet every 500 epochs, it's easy to go back and step through the progress. With each training step using all of the training data the steps and epochs are essentially the same here. +On TensorBoard UI, the added `Loss` and `Accuracy` metrics are available in the `Scalars` tab under the `runs/test/` folder created dynamically using datetime. -At step 1, we see poor accuracy, as you would expect +```{code-cell} ipython3 +%tensorboard --logdir runs/test +``` -![image.png](./_static/images/testsheet_start_example.png) +Since we've stored the example test sheet every 500 epochs, we can go back and step through the progress. With each training step using all of the training data, the steps and epochs are essentially the same here. -By 500, the model is essentially done, but we see the bottom row `7` get lost and recovered at higher epochs as we go far into an overfitting regime. This kind of stored data can be very useful when the training routines become automated and a human is potentially only looking when something has gone wrong. +Navigate to the `Images` tab. -![image.png](./_static/images/testsheets_500_3000.png) +At step 1, we see poor accuracy, as expected: -+++ +TensorBoard UI with Images Tab showing the Accuracy testsheet at Step 1 -Finally, it can be useful to use nnx.display's ability to visualize networks and model output. Here we feed the top 35 test images into the model and display the final output vector for each - in the top plot, each row is an individual image prediction result: each column corresponds to a class, in this case the digits (0-9). Since we're calling the highest value in a given row the class prediction (`.argmax(axis=1)`), the final image predictions (bottom plot) simply match the largest value in each row in the upper plot. +By 500, the model is essentially done. However, in the bottom row `7` gets lost and recovered at higher epochs as we go far into an overfitting regime. This kind of stored data can be very useful when the training routines become automated, and a human is potentially only checking when something has gone wrong. -```{code-cell} ipython3 -nnx.display(model(images_test[:35])), nnx.display(model(images_test[:35]).argmax(axis=1)) -``` +![Accuracy testsheets at Step 500, 2500, and 3000](./_static/images/testsheets_500_3000.png) -The above cell output will give you an interactive plot that looks like this image below, where here we've 'clicked' in the bottom plot for entry `7` and hover over the corresponding value in the top plot. ++++ -![image.png](./_static/images/model_display_example.png) +## Visualize model output +++ -## Extra Resources +In addition to the TensorBoard visualization, Flax [`nnx.display`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/visualization.html#flax.nnx.display)'s interactive visualizations of networks and model outputs are also helpful. + +We can feed the top 35 test images into the model and display the final output vector for each. In the following plot, each row is an individual image prediction result, and each column corresponds to a class, in this case the digits (0-9). + +```{code-cell} ipython3 +nnx.display(model(images_test[:35])) +``` + +The highest value in a given row is the class prediction (`.argmax(axis=1)`). The following plot shows image predictions matching the largest value in each row in the previous(above) plot. -For further information about `TensorBoard` see [https://www.tensorflow.org/tensorboard/get_started](https://www.tensorflow.org/tensorboard/get_started) +```{code-cell} ipython3 +nnx.display(model(images_test[:35]).argmax(axis=1)) +``` -For more about `nnx.display()`, which calls Treescope under the hood, see [https://treescope.readthedocs.io/en/stable/](https://treescope.readthedocs.io/en/stable/) +For more information about these tools, check out the [TensorBoard documentation](https://www.tensorflow.org/tensorboard/get_started) and [Treescope documentation ](https://treescope.readthedocs.io/en/stable/) (library behind `nnx.display`). diff --git a/docs/source/_static/images/loss_acc_example.png b/docs/source/_static/images/loss_acc_example.png deleted file mode 100644 index fd73060..0000000 Binary files a/docs/source/_static/images/loss_acc_example.png and /dev/null differ diff --git a/docs/source/_static/images/nnx_display_example.png b/docs/source/_static/images/nnx_display_example.png deleted file mode 100644 index 6841dfb..0000000 Binary files a/docs/source/_static/images/nnx_display_example.png and /dev/null differ diff --git a/docs/source/_static/images/training_data_example.png b/docs/source/_static/images/training_data_example.png deleted file mode 100644 index 699f7cf..0000000 Binary files a/docs/source/_static/images/training_data_example.png and /dev/null differ