Skip to content

Commit

Permalink
included more graphs in the example
Browse files Browse the repository at this point in the history
  • Loading branch information
Artur-Galstyan committed Feb 25, 2024
1 parent 572fda4 commit b7edd57
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 20 deletions.
79 changes: 59 additions & 20 deletions examples/mamba.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,30 +52,19 @@
"metadata": {},
"outputs": [],
"source": [
"from typing import Optional\n",
"\n",
"import equinox as eqx\n",
"import jax\n",
"from jaxtyping import Array, Float, Int, PRNGKeyArray"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "dae7f1eb-ad21-4e6f-a6a1-352eea03d414",
"execution_count": 3,
"id": "91463596-114f-45f7-a2f0-b1ae8d16f25f",
"metadata": {},
"outputs": [
{
"ename": "AttributeError",
"evalue": "module 'equinox.nn' has no attribute 'RMSNorm'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;43;01mclass\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;21;43;01mMamba\u001b[39;49;00m\u001b[43m(\u001b[49m\u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mModule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mlayers\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mSequential\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mnormalization\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mRMSNorm\u001b[49m\n",
"Cell \u001b[0;32mIn[4], line 3\u001b[0m, in \u001b[0;36mMamba\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mMamba\u001b[39;00m(eqx\u001b[38;5;241m.\u001b[39mModule, strict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[1;32m 2\u001b[0m layers: eqx\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mSequential\n\u001b[0;32m----> 3\u001b[0m normalization: \u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mRMSNorm\u001b[49m\n\u001b[1;32m 4\u001b[0m shared_emb_lm_head: eqx\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mShared\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, n_layers: \u001b[38;5;28mint\u001b[39m, n_dims: \u001b[38;5;28mint\u001b[39m, n_embd: \u001b[38;5;28mint\u001b[39m, \u001b[38;5;241m*\u001b[39m, key: PRNGKeyArray):\n",
"\u001b[0;31mAttributeError\u001b[0m: module 'equinox.nn' has no attribute 'RMSNorm'"
]
}
],
"outputs": [],
"source": [
"class Mamba(eqx.Module, strict=True):\n",
" layers: eqx.nn.Sequential\n",
Expand Down Expand Up @@ -106,7 +95,7 @@
" self,\n",
" x: Int[Array, \"seq_len\"], # noqa\n",
" *,\n",
" key: PRNGKeyArray = None,\n",
" key: Optional[PRNGKeyArray] = None,\n",
" ) -> Float[Array, \"seq_len n_dims\"]: # noqa\n",
" embedding, linear = self.shared_emb_lm_head()\n",
" x = jax.vmap(embedding)(x)\n",
Expand Down Expand Up @@ -159,21 +148,71 @@
"Let's continue with the `ResidualBlock`."
]
},
{
"cell_type": "markdown",
"id": "6b4b5de1-8300-4fc4-934e-3b0194bf8372",
"metadata": {},
"source": [
"Here's an overview of what the components of the `ResidualBlock` will look like.\n",
"\n",
"<div style=\"display: flex; justify-content: center; margin-left: auto; width: 100%\">\n",
" <img src=\"../imgs/Mamba2.drawio.svg\" width=\"30%\">\n",
"</div>\n",
"\n",
"As you can see, we keep diving further into the model. Let's implement this `ResidualBlock` now."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"id": "d1ed52a9-abea-43e8-822a-9d75ac8ae480",
"metadata": {},
"outputs": [],
"source": [
"class ResidualBlock(eqx.Module):\n",
"class ResidualBlock(eqx.Module, strict=True):\n",
" mamba_block: MambaBlock\n",
" rns_norm: eqx.nn.RMSNorm\n",
"\n",
" def __init__(self, n_embd: int, *, key: PRNGKeyArray):\n",
" \n",
" self.mamba_block = MambaBlock(\n",
" key=key,\n",
" )\n",
" self.rns_norm = eqx.nn.RMSNorm(n_embd)\n",
"\n",
" def __call__(\n",
" self, x: Float[Array, \"seq_len n_embd\"], *, key: Optional[PRNGKeyArray] = None\n",
" ) -> Array:\n",
" return self.mamba_block(jax.vmap(self.rns_norm)(x)) + x"
]
},
{
"cell_type": "markdown",
"id": "d8562a8b-3bea-4997-8c53-be33d66893f1",
"metadata": {},
"source": [
"We're getting closer and closer to the heart of the Mamba model. Let's look at what the `MambaBlock` looks like. This time, I've included the shapes of the matrices as they traverse through all kinds of transformations. \n",
"\n",
"<div style=\"display: flex; justify-content: center; margin-left: auto; width: 100%\">\n",
" <img src=\"../imgs/Mamba3.drawio.svg\" width=\"60%\">\n",
"</div>"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "6b57d346-18d0-4d29-b35d-e1a6e99c6a75",
"metadata": {},
"outputs": [],
"source": [
"class MambaBlock(eqx.Module):\n",
" pass"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8aa05c2c-d8d0-484f-9d29-430b41bf4181",
"id": "e84076b2-9808-4916-8580-d789cdc49c05",
"metadata": {},
"outputs": [],
"source": []
Expand Down
4 changes: 4 additions & 0 deletions imgs/Mamba2.drawio.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 4 additions & 0 deletions imgs/Mamba3.drawio.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit b7edd57

Please sign in to comment.