Skip to content

Commit

Permalink
Merge pull request #3 from trallard/trallard/patch-pr-109
Browse files Browse the repository at this point in the history
Fix headings hierarchy
  • Loading branch information
selamw1 authored Nov 27, 2024
2 parents d4d7cf4 + 09cbb8e commit 95f4333
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
10 changes: 5 additions & 5 deletions docs/data_loaders_on_gpu_with_jax.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@
"id": "-rsMgVtO6asW"
},
"source": [
"### Import JAX API"
"## Import JAX API"
]
},
{
Expand All @@ -56,7 +56,7 @@
"id": "TsFdlkSZKp9S"
},
"source": [
"### Checking GPU Availability for JAX"
"## Checking GPU Availability for JAX"
]
},
{
Expand Down Expand Up @@ -91,7 +91,7 @@
"id": "qyJ_WTghDnIc"
},
"source": [
"### Setting Hyperparameters and Initializing Parameters\n",
"## Setting Hyperparameters and Initializing Parameters\n",
"\n",
"You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network."
]
Expand Down Expand Up @@ -133,7 +133,7 @@
"id": "rHLdqeI7D2WZ"
},
"source": [
"### Model Prediction with Auto-Batching\n",
"## Model Prediction with Auto-Batching\n",
"\n",
"In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.\n",
"\n",
Expand Down Expand Up @@ -174,7 +174,7 @@
"id": "rLqfeORsERek"
},
"source": [
"### Utility and Loss Functions\n",
"## Utility and Loss Functions\n",
"\n",
"You'll now define utility functions for:\n",
"- One-hot encoding: Converts class indices to binary vectors.\n",
Expand Down
10 changes: 5 additions & 5 deletions docs/data_loaders_on_gpu_with_jax.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Compared to the [Data Loaders on CPU](https://jax-ai-stack.readthedocs.io/en/lat

+++ {"id": "-rsMgVtO6asW"}

### Import JAX API
## Import JAX API

```{code-cell}
:id: tDJNQ6V-Dg5g
Expand All @@ -43,7 +43,7 @@ from jax import grad, jit, vmap, random, device_put

+++ {"id": "TsFdlkSZKp9S"}

### Checking GPU Availability for JAX
## Checking GPU Availability for JAX

```{code-cell}
---
Expand All @@ -57,7 +57,7 @@ jax.devices()

+++ {"id": "qyJ_WTghDnIc"}

### Setting Hyperparameters and Initializing Parameters
## Setting Hyperparameters and Initializing Parameters

You'll define hyperparameters for your model and data loading, including layer sizes, learning rate, batch size, and the data directory. You'll also initialize the weights and biases for a fully-connected neural network.

Expand Down Expand Up @@ -89,7 +89,7 @@ params = init_network_params(layer_sizes, random.PRNGKey(0))

+++ {"id": "rHLdqeI7D2WZ"}

### Model Prediction with Auto-Batching
## Model Prediction with Auto-Batching

In this section, you'll define the `predict` function for your neural network. This function computes the output of the network for a single input image.

Expand Down Expand Up @@ -120,7 +120,7 @@ batched_predict = vmap(predict, in_axes=(None, 0))

+++ {"id": "rLqfeORsERek"}

### Utility and Loss Functions
## Utility and Loss Functions

You'll now define utility functions for:
- One-hot encoding: Converts class indices to binary vectors.
Expand Down

0 comments on commit 95f4333

Please sign in to comment.