From b7edd5739c823271434e8d7e1466496ac085cdb8 Mon Sep 17 00:00:00 2001 From: "Artur A. Galstyan" Date: Sun, 25 Feb 2024 17:47:00 +0100 Subject: [PATCH] included more graphs in the example --- examples/mamba.ipynb | 79 +++++++++++++++++++++++++++++++----------- imgs/Mamba2.drawio.svg | 4 +++ imgs/Mamba3.drawio.svg | 4 +++ 3 files changed, 67 insertions(+), 20 deletions(-) create mode 100644 imgs/Mamba2.drawio.svg create mode 100644 imgs/Mamba3.drawio.svg diff --git a/examples/mamba.ipynb b/examples/mamba.ipynb index d3bf3826..20e57057 100644 --- a/examples/mamba.ipynb +++ b/examples/mamba.ipynb @@ -52,6 +52,8 @@ "metadata": {}, "outputs": [], "source": [ + "from typing import Optional\n", + "\n", "import equinox as eqx\n", "import jax\n", "from jaxtyping import Array, Float, Int, PRNGKeyArray" @@ -59,23 +61,10 @@ }, { "cell_type": "code", - "execution_count": 4, - "id": "dae7f1eb-ad21-4e6f-a6a1-352eea03d414", + "execution_count": 3, + "id": "91463596-114f-45f7-a2f0-b1ae8d16f25f", "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "module 'equinox.nn' has no attribute 'RMSNorm'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;43;01mclass\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;21;43;01mMamba\u001b[39;49;00m\u001b[43m(\u001b[49m\u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mModule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mlayers\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mSequential\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mnormalization\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mRMSNorm\u001b[49m\n", - "Cell \u001b[0;32mIn[4], line 3\u001b[0m, in \u001b[0;36mMamba\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mMamba\u001b[39;00m(eqx\u001b[38;5;241m.\u001b[39mModule, strict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[1;32m 2\u001b[0m layers: eqx\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mSequential\n\u001b[0;32m----> 3\u001b[0m normalization: \u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mRMSNorm\u001b[49m\n\u001b[1;32m 4\u001b[0m shared_emb_lm_head: eqx\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mShared\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, n_layers: \u001b[38;5;28mint\u001b[39m, n_dims: \u001b[38;5;28mint\u001b[39m, n_embd: \u001b[38;5;28mint\u001b[39m, \u001b[38;5;241m*\u001b[39m, key: PRNGKeyArray):\n", - "\u001b[0;31mAttributeError\u001b[0m: module 'equinox.nn' has no attribute 'RMSNorm'" - ] - } - ], + "outputs": [], "source": [ "class Mamba(eqx.Module, strict=True):\n", " layers: eqx.nn.Sequential\n", @@ -106,7 +95,7 @@ " self,\n", " x: Int[Array, \"seq_len\"], # noqa\n", " *,\n", - " key: PRNGKeyArray = None,\n", + " key: Optional[PRNGKeyArray] = None,\n", " ) -> Float[Array, \"seq_len n_dims\"]: # noqa\n", " embedding, linear = self.shared_emb_lm_head()\n", " x = jax.vmap(embedding)(x)\n", @@ -159,21 +148,71 @@ "Let's continue with the `ResidualBlock`." ] }, + { + "cell_type": "markdown", + "id": "6b4b5de1-8300-4fc4-934e-3b0194bf8372", + "metadata": {}, + "source": [ + "Here's an overview of what the components of the `ResidualBlock` will look like.\n", + "\n", + "
\n", + " \n", + "
\n", + "\n", + "As you can see, we keep diving further into the model. Let's implement this `ResidualBlock` now." + ] + }, { "cell_type": "code", - "execution_count": null, + "execution_count": 6, "id": "d1ed52a9-abea-43e8-822a-9d75ac8ae480", "metadata": {}, "outputs": [], "source": [ - "class ResidualBlock(eqx.Module):\n", + "class ResidualBlock(eqx.Module, strict=True):\n", + " mamba_block: MambaBlock\n", + " rns_norm: eqx.nn.RMSNorm\n", + "\n", + " def __init__(self, n_embd: int, *, key: PRNGKeyArray):\n", + " \n", + " self.mamba_block = MambaBlock(\n", + " key=key,\n", + " )\n", + " self.rns_norm = eqx.nn.RMSNorm(n_embd)\n", + "\n", + " def __call__(\n", + " self, x: Float[Array, \"seq_len n_embd\"], *, key: Optional[PRNGKeyArray] = None\n", + " ) -> Array:\n", + " return self.mamba_block(jax.vmap(self.rns_norm)(x)) + x" + ] + }, + { + "cell_type": "markdown", + "id": "d8562a8b-3bea-4997-8c53-be33d66893f1", + "metadata": {}, + "source": [ + "We're getting closer and closer to the heart of the Mamba model. Let's look at what the `MambaBlock` looks like. This time, I've included the shapes of the matrices as they traverse through all kinds of transformations. \n", + "\n", + "
\n", + " \n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "6b57d346-18d0-4d29-b35d-e1a6e99c6a75", + "metadata": {}, + "outputs": [], + "source": [ + "class MambaBlock(eqx.Module):\n", " pass" ] }, { "cell_type": "code", "execution_count": null, - "id": "8aa05c2c-d8d0-484f-9d29-430b41bf4181", + "id": "e84076b2-9808-4916-8580-d789cdc49c05", "metadata": {}, "outputs": [], "source": [] diff --git a/imgs/Mamba2.drawio.svg b/imgs/Mamba2.drawio.svg new file mode 100644 index 00000000..758b254e --- /dev/null +++ b/imgs/Mamba2.drawio.svg @@ -0,0 +1,4 @@ + + + +
(seq_len, n_embd)
Normalisation
MambaBlock
(seq_len, n_dims)
\ No newline at end of file diff --git a/imgs/Mamba3.drawio.svg b/imgs/Mamba3.drawio.svg new file mode 100644 index 00000000..1edc7487 --- /dev/null +++ b/imgs/Mamba3.drawio.svg @@ -0,0 +1,4 @@ + + + +
Input Projection
(n_embd -> 2 * d_inner)
(seq_len, n_embd)
(seq_len, 2 * d_inner)
Split
x
(seq_len, d_inner)
residual
(seq_len, d_inner)
Conv1d
x
(seq_len, d_inner)
Truncate to seq_len
Silu
SSM
x
(seq_len, d_inner)
Silu
x
(seq_len, d_inner)
Output Projection
(d_inner -> n_embd)
x * residual
(seq_len, d_inner)
x
(seq_len, n_embd)
\ No newline at end of file