Skip to content

Commit

Permalink
Fix headings hierarchy
Browse files Browse the repository at this point in the history
  • Loading branch information
trallard committed Nov 27, 2024
1 parent d4d7cf4 commit 09cbb8e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
10 changes: 5 additions & 5 deletions docs/data_loaders_on_gpu_with_jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"id": "-rsMgVtO6asW"
},
"source": [
"### Import JAX API"
"## Import JAX API"
]
},
{
Expand All @@ -56,7 +56,7 @@
"id": "TsFdlkSZKp9S"
},
"source": [
"### Checking GPU Availability for JAX"
"## Checking GPU Availability for JAX"
]
},
{
Expand Down Expand Up @@ -91,7 +91,7 @@
"id": "qyJ_WTghDnIc"
},
"source": [
"### Setting Hyperparameters and Initializing Parameters\n",
"## Setting Hyperparameters and Initializing Parameters\n",
"\n",
"You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network."
]
Expand Down Expand Up @@ -133,7 +133,7 @@
"id": "rHLdqeI7D2WZ"
},
"source": [
"### Model Prediction with Auto-Batching\n",
"## Model Prediction with Auto-Batching\n",
"\n",
"In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.\n",
"\n",
Expand Down Expand Up @@ -174,7 +174,7 @@
"id": "rLqfeORsERek"
},
"source": [
"### Utility and Loss Functions\n",
"## Utility and Loss Functions\n",
"\n",
"You'll now define utility functions for:\n",
"- One-hot encoding: Converts class indices to binary vectors.\n",
Expand Down
10 changes: 5 additions & 5 deletions docs/data_loaders_on_gpu_with_jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Compared to the [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/lat

+++ {"id": "-rsMgVtO6asW"}

### Import JAX API
## Import JAX API

```{code-cell}
:id: tDJNQ6V-Dg5g
Expand All @@ -43,7 +43,7 @@ from jax import grad, jit, vmap, random, device_put

+++ {"id": "TsFdlkSZKp9S"}

### Checking GPU Availability for JAX
## Checking GPU Availability for JAX

```{code-cell}
---
Expand All @@ -57,7 +57,7 @@ jax.devices()

+++ {"id": "qyJ_WTghDnIc"}

### Setting Hyperparameters and Initializing Parameters
## Setting Hyperparameters and Initializing Parameters

You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network.

Expand Down Expand Up @@ -89,7 +89,7 @@ params = init_network_params(layer_sizes, random.PRNGKey(0))

+++ {"id": "rHLdqeI7D2WZ"}

### Model Prediction with Auto-Batching
## Model Prediction with Auto-Batching

In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.

Expand Down Expand Up @@ -120,7 +120,7 @@ batched_predict = vmap(predict, in_axes=(None, 0))

+++ {"id": "rLqfeORsERek"}

### Utility and Loss Functions
## Utility and Loss Functions

You'll now define utility functions for:
- One-hot encoding: Converts class indices to binary vectors.
Expand Down

0 comments on commit 09cbb8e

Please sign in to comment.