diff --git a/docs/data_loaders_on_gpu_with_jax.ipynb b/docs/data_loaders_on_gpu_with_jax.ipynb index f726297..95db3d7 100644 --- a/docs/data_loaders_on_gpu_with_jax.ipynb +++ b/docs/data_loaders_on_gpu_with_jax.ipynb @@ -34,7 +34,7 @@ "id": "-rsMgVtO6asW" }, "source": [ - "### Import JAX API" + "## Import JAX API" ] }, { @@ -56,7 +56,7 @@ "id": "TsFdlkSZKp9S" }, "source": [ - "### Checking GPU Availability for JAX" + "## Checking GPU Availability for JAX" ] }, { @@ -91,7 +91,7 @@ "id": "qyJ_WTghDnIc" }, "source": [ - "### Setting Hyperparameters and Initializing Parameters\n", + "## Setting Hyperparameters and Initializing Parameters\n", "\n", "You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network." ] @@ -133,7 +133,7 @@ "id": "rHLdqeI7D2WZ" }, "source": [ - "### Model Prediction with Auto-Batching\n", + "## Model Prediction with Auto-Batching\n", "\n", "In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.\n", "\n", @@ -174,7 +174,7 @@ "id": "rLqfeORsERek" }, "source": [ - "### Utility and Loss Functions\n", + "## Utility and Loss Functions\n", "\n", "You'll now define utility functions for:\n", "- One-hot encoding: Converts class indices to binary vectors.\n", diff --git a/docs/data_loaders_on_gpu_with_jax.md b/docs/data_loaders_on_gpu_with_jax.md index 4ec7487..901aa6a 100644 --- a/docs/data_loaders_on_gpu_with_jax.md +++ b/docs/data_loaders_on_gpu_with_jax.md @@ -31,7 +31,7 @@ Compared to the [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/lat +++ {"id": "-rsMgVtO6asW"} -### Import JAX API +## Import JAX API ```{code-cell} :id: tDJNQ6V-Dg5g @@ -43,7 +43,7 @@ from jax import grad, jit, vmap, random, device_put +++ {"id": "TsFdlkSZKp9S"} -### Checking GPU Availability for JAX +## Checking GPU Availability for JAX ```{code-cell} --- @@ -57,7 +57,7 @@ jax.devices() +++ {"id": "qyJ_WTghDnIc"} -### Setting Hyperparameters and Initializing Parameters +## Setting Hyperparameters and Initializing Parameters You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network. @@ -89,7 +89,7 @@ params = init_network_params(layer_sizes, random.PRNGKey(0)) +++ {"id": "rHLdqeI7D2WZ"} -### Model Prediction with Auto-Batching +## Model Prediction with Auto-Batching In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image. @@ -120,7 +120,7 @@ batched_predict = vmap(predict, in_axes=(None, 0)) +++ {"id": "rLqfeORsERek"} -### Utility and Loss Functions +## Utility and Loss Functions You'll now define utility functions for: - One-hot encoding: Converts class indices to binary vectors.