From 5aa131d40912b66ff917570916b8587026bf2801 Mon Sep 17 00:00:00 2001
From: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com>
Date: Fri, 1 Mar 2024 15:52:42 -0500
Subject: [PATCH 1/8] Don't hardcode tied embed+unembed (#211)

* Add tie_embeddings config option

* only tie embeddings if config specifies it
---
 mamba_ssm/models/config_mamba.py     | 1 +
 mamba_ssm/models/mixer_seq_simple.py | 5 +++--
 2 files changed, 4 insertions(+), 2 deletions(-)

diff --git a/mamba_ssm/models/config_mamba.py b/mamba_ssm/models/config_mamba.py
index ffd31abc..2aa1e5a6 100644
--- a/mamba_ssm/models/config_mamba.py
+++ b/mamba_ssm/models/config_mamba.py
@@ -12,3 +12,4 @@ class MambaConfig:
     residual_in_fp32: bool = True
     fused_add_norm: bool = True
     pad_vocab_size_multiple: int = 8
+    tie_embeddings: bool = True
diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py
index 5b3ddfcf..2f1d97fd 100644
--- a/mamba_ssm/models/mixer_seq_simple.py
+++ b/mamba_ssm/models/mixer_seq_simple.py
@@ -220,8 +220,9 @@ def __init__(
         self.tie_weights()
 
     def tie_weights(self):
-        self.lm_head.weight = self.backbone.embedding.weight
-
+        if self.config.tie_embeddings:
+            self.lm_head.weight = self.backbone.embedding.weight
+    
     def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs):
         return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs)
 

From 2066c2115cd85fb95c7c8d5123f45802b165d7b3 Mon Sep 17 00:00:00 2001
From: Haobo Yuan <yuanhaobo.cs@outlook.com>
Date: Sat, 2 Mar 2024 08:44:17 +0800
Subject: [PATCH 2/8] Bugfix `causal_conv1d_fn` interface (#168)

---
 mamba_ssm/ops/selective_scan_interface.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py
index b8f14dd0..52e7eba0 100644
--- a/mamba_ssm/ops/selective_scan_interface.py
+++ b/mamba_ssm/ops/selective_scan_interface.py
@@ -318,7 +318,7 @@ def mamba_inner_ref(
     delta_rank = delta_proj_weight.shape[1]
     d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
     x, z = xz.chunk(2, dim=1)
-    x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu")
+    x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, None, "silu")
     # We're being very careful here about the layout, to avoid extra transposes.
     # We want delta to have d as the slowest moving dimension
     # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.

From 9583c56b741d28f691690e833cacce768938adce Mon Sep 17 00:00:00 2001
From: Wongboo <44860323+Wongboo@users.noreply.github.com>
Date: Sat, 2 Mar 2024 09:24:11 +0800
Subject: [PATCH 3/8] add support for python 3.12 (#153)

---
 .github/workflows/publish.yaml | 13 ++++++++++++-
 1 file changed, 12 insertions(+), 1 deletion(-)

diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml
index 64951b34..62904299 100644
--- a/.github/workflows/publish.yaml
+++ b/.github/workflows/publish.yaml
@@ -43,7 +43,7 @@ jobs:
           # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the
           # manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
           os: [ubuntu-20.04]
-          python-version: ['3.7', '3.8', '3.9', '3.10', '3.11']
+          python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12']
           torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0']
           cuda-version: ['11.8.0', '12.2.2']
           # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
@@ -52,6 +52,15 @@ jobs:
           # when building without C++11 ABI and using it on nvcr images.
           cxx11_abi: ['FALSE', 'TRUE']
           exclude:
+            # Pytorch < 2.2 does not support Python 3.12
+            - torch-version: '1.12.1'
+              python-version: '3.12'
+            - torch-version: '1.13.1'
+              python-version: '3.12'
+            - torch-version: '2.0.1'
+              python-version: '3.12'
+            - torch-version: '2.1.2'
+              python-version: '3.12'
             # Pytorch <= 1.12 does not support Python 3.11
             - torch-version: '1.12.1'
               python-version: '3.11'
@@ -119,6 +128,8 @@ jobs:
           # If we don't install before installing Pytorch, we get error for torch 2.0.1
           # ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none)
           pip install lit
+          # For some reason torch 2.2.0 on python 3.12 errors saying no setuptools
+          pip install setuptools
           # We want to figure out the CUDA version to download pytorch
           # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
           # This code is ugly, maybe there's a better way to do this.

From afd5fb5b1b89d2d353e9316809903267cbe70971 Mon Sep 17 00:00:00 2001
From: Tri Dao <tridpq@gmail.com>
Date: Fri, 1 Mar 2024 17:23:39 -0800
Subject: [PATCH 4/8] Update causal-conv1d to 1.2.0, make it optional

---
 README.md                                     |  2 +-
 mamba_ssm/modules/mamba_simple.py             |  4 +--
 mamba_ssm/ops/selective_scan_interface.py     | 26 ++++++++++++++-----
 setup.py                                      |  2 +-
 .../ops/triton/test_selective_state_update.py |  2 +-
 5 files changed, 24 insertions(+), 12 deletions(-)

diff --git a/README.md b/README.md
index ff825ef9..643be899 100644
--- a/README.md
+++ b/README.md
@@ -13,7 +13,7 @@ with an efficient hardware-aware design and implementation in the spirit of [Fla
 
 ## Installation
 
-- `pip install causal-conv1d>=1.1.0,<1.2.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
+- [Option] `pip install causal-conv1d>=1.2.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block.
 - `pip install mamba-ssm`: the core Mamba package.
 
 It can also be built from source with `pip install .` from this repository.
diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py
index 98d97a57..91cb9798 100644
--- a/mamba_ssm/modules/mamba_simple.py
+++ b/mamba_ssm/modules/mamba_simple.py
@@ -15,7 +15,7 @@
 try:
     from causal_conv1d import causal_conv1d_fn, causal_conv1d_update
 except ImportError:
-    causal_conv1d_fn, causal_conv1d_update = None
+    causal_conv1d_fn, causal_conv1d_update = None, None
 
 try:
     from mamba_ssm.ops.triton.selective_state_update import selective_state_update
@@ -142,7 +142,7 @@ def forward(self, hidden_states, inference_params=None):
 
         A = -torch.exp(self.A_log.float())  # (d_inner, d_state)
         # In the backward pass we write dx and dz next to each other to avoid torch.cat
-        if self.use_fast_path and inference_params is None:  # Doesn't support outputting the states
+        if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None:  # Doesn't support outputting the states
             out = mamba_inner_fn(
                 xz,
                 self.conv1d.weight,
diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py
index 52e7eba0..c3596bfe 100644
--- a/mamba_ssm/ops/selective_scan_interface.py
+++ b/mamba_ssm/ops/selective_scan_interface.py
@@ -6,8 +6,13 @@
 
 from einops import rearrange, repeat
 
-from causal_conv1d import causal_conv1d_fn
-import causal_conv1d_cuda
+try:
+    from causal_conv1d import causal_conv1d_fn
+    import causal_conv1d_cuda
+except ImportError:
+    causal_conv1d_fn = None
+    causal_conv1d_cuda = None
+
 import selective_scan_cuda
 
 
@@ -163,6 +168,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
         """
              xz: (batch, dim, seqlen)
         """
+        assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
         assert checkpoint_lvl in [0, 1]
         L = xz.shape[-1]
         delta_rank = delta_proj_weight.shape[1]
@@ -178,7 +184,9 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
         conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
         x, z = xz.chunk(2, dim=1)
         conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
-        conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
+        conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
+            x, conv1d_weight, conv1d_bias, None, None, None, True
+        )
         # We're being very careful here about the layout, to avoid extra transposes.
         # We want delta to have d as the slowest moving dimension
         # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
@@ -231,6 +239,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh
     @custom_bwd
     def backward(ctx, dout):
         # dout: (batch, seqlen, dim)
+        assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
         (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight,
          conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors
         L = xz.shape[-1]
@@ -240,7 +249,9 @@ def backward(ctx, dout):
         if dout.stride(-1) != 1:
             dout = dout.contiguous()
         if ctx.checkpoint_lvl == 1:
-            conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True)
+            conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(
+                x, conv1d_weight, conv1d_bias, None, None, None, True
+            )
             delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(),
                               "d (b l) -> b d l", l = L)
         # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
@@ -285,8 +296,8 @@ def backward(ctx, dout):
         dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
         # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
         # backward of conv1d with the backward of chunk).
-        dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd(
-            x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True
+        dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
+            x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
         )
         dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
         dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
@@ -314,11 +325,12 @@ def mamba_inner_ref(
     A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None,
     C_proj_bias=None, delta_softplus=True
 ):
+    assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d."
     L = xz.shape[-1]
     delta_rank = delta_proj_weight.shape[1]
     d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
     x, z = xz.chunk(2, dim=1)
-    x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, None, "silu")
+    x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu")
     # We're being very careful here about the layout, to avoid extra transposes.
     # We want delta to have d as the slowest moving dimension
     # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
diff --git a/setup.py b/setup.py
index 9e1023a9..d794a9ce 100644
--- a/setup.py
+++ b/setup.py
@@ -271,6 +271,6 @@ def run(self):
         "einops",
         "triton",
         "transformers",
-        "causal_conv1d>=1.1.0,<1.2.0",
+        # "causal_conv1d>=1.2.0",
     ],
 )
diff --git a/tests/ops/triton/test_selective_state_update.py b/tests/ops/triton/test_selective_state_update.py
index 70a8d79d..3e4cc6ba 100644
--- a/tests/ops/triton/test_selective_state_update.py
+++ b/tests/ops/triton/test_selective_state_update.py
@@ -19,7 +19,7 @@
 # @pytest.mark.parametrize("dstate", [16])
 @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
 # @pytest.mark.parametrize("dim", [2048])
-def test_causal_conv1d_update(dim, dstate, has_z, itype):
+def test_selective_state_update(dim, dstate, has_z, itype):
     device = "cuda"
     rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2)
     if itype == torch.bfloat16:

From a1c397b5837f59c44fdefae065149a7af3938fa8 Mon Sep 17 00:00:00 2001
From: Tri Dao <tridpq@gmail.com>
Date: Fri, 1 Mar 2024 17:25:36 -0800
Subject: [PATCH 5/8] [CI] Compile for torch 2.3.0.dev20240126

---
 .github/workflows/publish.yaml | 18 +++++++++++++++---
 1 file changed, 15 insertions(+), 3 deletions(-)

diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml
index 62904299..163847b7 100644
--- a/.github/workflows/publish.yaml
+++ b/.github/workflows/publish.yaml
@@ -44,7 +44,7 @@ jobs:
           # manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
           os: [ubuntu-20.04]
           python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12']
-          torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0']
+          torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240126']
           cuda-version: ['11.8.0', '12.2.2']
           # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
           # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
@@ -71,6 +71,8 @@ jobs:
               python-version: '3.7'
             - torch-version: '2.2.0'
               python-version: '3.7'
+            - torch-version: '2.3.0.dev20240126'
+              python-version: '3.7'
             # Pytorch <= 2.0 only supports CUDA <= 11.8
             - torch-version: '1.12.1'
               cuda-version: '12.2.2'
@@ -133,9 +135,19 @@ jobs:
           # We want to figure out the CUDA version to download pytorch
           # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116
           # This code is ugly, maybe there's a better way to do this.
-          export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))")
+          export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \
+            minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118}[env['MATRIX_TORCH_VERSION']]; \
+            maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121}[env['MATRIX_TORCH_VERSION']]; \
+            print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \
+          )
           if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then
-            pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
+            if [[ ${MATRIX_TORCH_VERSION} == "2.2" ]]; then
+              # --no-deps because we can't install old versions of pytorch-triton
+              pip install typing-extensions jinja2
+              pip install --no-cache-dir --no-deps --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl
+            else
+              pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}
+            fi
           else
             pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION}
           fi

From 2845255a95f7e4a76fe649ec0532178fe3015dd0 Mon Sep 17 00:00:00 2001
From: Tri Dao <tridpq@gmail.com>
Date: Fri, 1 Mar 2024 17:26:15 -0800
Subject: [PATCH 6/8] Bump to v1.2.0

---
 mamba_ssm/__init__.py | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/mamba_ssm/__init__.py b/mamba_ssm/__init__.py
index 63694144..c585a996 100644
--- a/mamba_ssm/__init__.py
+++ b/mamba_ssm/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "1.1.4"
+__version__ = "1.2.0"
 
 from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
 from mamba_ssm.modules.mamba_simple import Mamba

From 34076d664838588a3c97727b263478ab9f621a07 Mon Sep 17 00:00:00 2001
From: Tri Dao <tridpq@gmail.com>
Date: Fri, 1 Mar 2024 22:06:49 -0800
Subject: [PATCH 7/8] [CI] Change torch 2.3.0.dev20240126 to 20240102 for nvcr
 24.02

---
 .github/workflows/publish.yaml | 4 ++--
 mamba_ssm/__init__.py          | 2 +-
 2 files changed, 3 insertions(+), 3 deletions(-)

diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml
index 163847b7..facb5e07 100644
--- a/.github/workflows/publish.yaml
+++ b/.github/workflows/publish.yaml
@@ -44,7 +44,7 @@ jobs:
           # manylinux docker image, but I haven't figured out how to install CUDA on manylinux.
           os: [ubuntu-20.04]
           python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12']
-          torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240126']
+          torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240105']
           cuda-version: ['11.8.0', '12.2.2']
           # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not.
           # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI.
@@ -71,7 +71,7 @@ jobs:
               python-version: '3.7'
             - torch-version: '2.2.0'
               python-version: '3.7'
-            - torch-version: '2.3.0.dev20240126'
+            - torch-version: '2.3.0.dev20240105'
               python-version: '3.7'
             # Pytorch <= 2.0 only supports CUDA <= 11.8
             - torch-version: '1.12.1'
diff --git a/mamba_ssm/__init__.py b/mamba_ssm/__init__.py
index c585a996..9a28b68d 100644
--- a/mamba_ssm/__init__.py
+++ b/mamba_ssm/__init__.py
@@ -1,4 +1,4 @@
-__version__ = "1.2.0"
+__version__ = "1.2.0.post1"
 
 from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn
 from mamba_ssm.modules.mamba_simple import Mamba

From 9127d1f47f367f5c9cc49c73ad73557089d02cb8 Mon Sep 17 00:00:00 2001
From: deroholic <105595360+deroholic@users.noreply.github.com>
Date: Mon, 11 Mar 2024 09:58:59 +0100
Subject: [PATCH 8/8] Fix directory creating for saving with multiprocessing
 (#185)

---
 mamba_ssm/models/mixer_seq_simple.py | 3 +--
 1 file changed, 1 insertion(+), 2 deletions(-)

diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py
index 2f1d97fd..cd224738 100644
--- a/mamba_ssm/models/mixer_seq_simple.py
+++ b/mamba_ssm/models/mixer_seq_simple.py
@@ -252,8 +252,7 @@ def save_pretrained(self, save_directory):
         Save the model and its configuration file to a directory.
         """
         # Ensure save_directory exists
-        if not os.path.exists(save_directory):
-            os.makedirs(save_directory)
+        os.makedirs(save_directory, exist_ok=True)
 
         # Save the model's state_dict
         model_path = os.path.join(save_directory, 'pytorch_model.bin')