diff --git a/examples/mamba.ipynb b/examples/mamba.ipynb index d3bf3826..20e57057 100644 --- a/examples/mamba.ipynb +++ b/examples/mamba.ipynb @@ -52,6 +52,8 @@ "metadata": {}, "outputs": [], "source": [ + "from typing import Optional\n", + "\n", "import equinox as eqx\n", "import jax\n", "from jaxtyping import Array, Float, Int, PRNGKeyArray" @@ -59,23 +61,10 @@ }, { "cell_type": "code", - "execution_count": 4, - "id": "dae7f1eb-ad21-4e6f-a6a1-352eea03d414", + "execution_count": 3, + "id": "91463596-114f-45f7-a2f0-b1ae8d16f25f", "metadata": {}, - "outputs": [ - { - "ename": "AttributeError", - "evalue": "module 'equinox.nn' has no attribute 'RMSNorm'", - "output_type": "error", - "traceback": [ - "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", - "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", - "Cell \u001b[0;32mIn[4], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[38;5;28;43;01mclass\u001b[39;49;00m\u001b[43m \u001b[49m\u001b[38;5;21;43;01mMamba\u001b[39;49;00m\u001b[43m(\u001b[49m\u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mModule\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mstrict\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\u001b[43m:\u001b[49m\n\u001b[1;32m 2\u001b[0m \u001b[43m \u001b[49m\u001b[43mlayers\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mSequential\u001b[49m\n\u001b[1;32m 3\u001b[0m \u001b[43m \u001b[49m\u001b[43mnormalization\u001b[49m\u001b[43m:\u001b[49m\u001b[43m \u001b[49m\u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mRMSNorm\u001b[49m\n", - "Cell \u001b[0;32mIn[4], line 3\u001b[0m, in \u001b[0;36mMamba\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mclass\u001b[39;00m \u001b[38;5;21;01mMamba\u001b[39;00m(eqx\u001b[38;5;241m.\u001b[39mModule, strict\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m):\n\u001b[1;32m 2\u001b[0m layers: eqx\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mSequential\n\u001b[0;32m----> 3\u001b[0m normalization: \u001b[43meqx\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mnn\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mRMSNorm\u001b[49m\n\u001b[1;32m 4\u001b[0m shared_emb_lm_head: eqx\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mShared\n\u001b[1;32m 6\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__init__\u001b[39m(\u001b[38;5;28mself\u001b[39m, n_layers: \u001b[38;5;28mint\u001b[39m, n_dims: \u001b[38;5;28mint\u001b[39m, n_embd: \u001b[38;5;28mint\u001b[39m, \u001b[38;5;241m*\u001b[39m, key: PRNGKeyArray):\n", - "\u001b[0;31mAttributeError\u001b[0m: module 'equinox.nn' has no attribute 'RMSNorm'" - ] - } - ], + "outputs": [], "source": [ "class Mamba(eqx.Module, strict=True):\n", " layers: eqx.nn.Sequential\n", @@ -106,7 +95,7 @@ " self,\n", " x: Int[Array, \"seq_len\"], # noqa\n", " *,\n", - " key: PRNGKeyArray = None,\n", + " key: Optional[PRNGKeyArray] = None,\n", " ) -> Float[Array, \"seq_len n_dims\"]: # noqa\n", " embedding, linear = self.shared_emb_lm_head()\n", " x = jax.vmap(embedding)(x)\n", @@ -159,21 +148,71 @@ "Let's continue with the `ResidualBlock`." ] }, + { + "cell_type": "markdown", + "id": "6b4b5de1-8300-4fc4-934e-3b0194bf8372", + "metadata": {}, + "source": [ + "Here's an overview of what the components of the `ResidualBlock` will look like.\n", + "\n", + "