From 5aa131d40912b66ff917570916b8587026bf2801 Mon Sep 17 00:00:00 2001 From: Hailey Schoelkopf <65563625+haileyschoelkopf@users.noreply.github.com> Date: Fri, 1 Mar 2024 15:52:42 -0500 Subject: [PATCH 1/8] Don't hardcode tied embed+unembed (#211) * Add tie_embeddings config option * only tie embeddings if config specifies it --- mamba_ssm/models/config_mamba.py | 1 + mamba_ssm/models/mixer_seq_simple.py | 5 +++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/mamba_ssm/models/config_mamba.py b/mamba_ssm/models/config_mamba.py index ffd31abc..2aa1e5a6 100644 --- a/mamba_ssm/models/config_mamba.py +++ b/mamba_ssm/models/config_mamba.py @@ -12,3 +12,4 @@ class MambaConfig: residual_in_fp32: bool = True fused_add_norm: bool = True pad_vocab_size_multiple: int = 8 + tie_embeddings: bool = True diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index 5b3ddfcf..2f1d97fd 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -220,8 +220,9 @@ def __init__( self.tie_weights() def tie_weights(self): - self.lm_head.weight = self.backbone.embedding.weight - + if self.config.tie_embeddings: + self.lm_head.weight = self.backbone.embedding.weight + def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): return self.backbone.allocate_inference_cache(batch_size, max_seqlen, dtype=dtype, **kwargs) From 2066c2115cd85fb95c7c8d5123f45802b165d7b3 Mon Sep 17 00:00:00 2001 From: Haobo Yuan <yuanhaobo.cs@outlook.com> Date: Sat, 2 Mar 2024 08:44:17 +0800 Subject: [PATCH 2/8] Bugfix `causal_conv1d_fn` interface (#168) --- mamba_ssm/ops/selective_scan_interface.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index b8f14dd0..52e7eba0 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -318,7 +318,7 @@ def mamba_inner_ref( delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) - x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, "silu") + x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, None, "silu") # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. From 9583c56b741d28f691690e833cacce768938adce Mon Sep 17 00:00:00 2001 From: Wongboo <44860323+Wongboo@users.noreply.github.com> Date: Sat, 2 Mar 2024 09:24:11 +0800 Subject: [PATCH 3/8] add support for python 3.12 (#153) --- .github/workflows/publish.yaml | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 64951b34..62904299 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -43,7 +43,7 @@ jobs: # Using ubuntu-20.04 instead of 22.04 for more compatibility (glibc). Ideally we'd use the # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] - python-version: ['3.7', '3.8', '3.9', '3.10', '3.11'] + python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0'] cuda-version: ['11.8.0', '12.2.2'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. @@ -52,6 +52,15 @@ jobs: # when building without C++11 ABI and using it on nvcr images. cxx11_abi: ['FALSE', 'TRUE'] exclude: + # Pytorch < 2.2 does not support Python 3.12 + - torch-version: '1.12.1' + python-version: '3.12' + - torch-version: '1.13.1' + python-version: '3.12' + - torch-version: '2.0.1' + python-version: '3.12' + - torch-version: '2.1.2' + python-version: '3.12' # Pytorch <= 1.12 does not support Python 3.11 - torch-version: '1.12.1' python-version: '3.11' @@ -119,6 +128,8 @@ jobs: # If we don't install before installing Pytorch, we get error for torch 2.0.1 # ERROR: Could not find a version that satisfies the requirement setuptools>=40.8.0 (from versions: none) pip install lit + # For some reason torch 2.2.0 on python 3.12 errors saying no setuptools + pip install setuptools # We want to figure out the CUDA version to download pytorch # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 # This code is ugly, maybe there's a better way to do this. From afd5fb5b1b89d2d353e9316809903267cbe70971 Mon Sep 17 00:00:00 2001 From: Tri Dao <tridpq@gmail.com> Date: Fri, 1 Mar 2024 17:23:39 -0800 Subject: [PATCH 4/8] Update causal-conv1d to 1.2.0, make it optional --- README.md | 2 +- mamba_ssm/modules/mamba_simple.py | 4 +-- mamba_ssm/ops/selective_scan_interface.py | 26 ++++++++++++++----- setup.py | 2 +- .../ops/triton/test_selective_state_update.py | 2 +- 5 files changed, 24 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index ff825ef9..643be899 100644 --- a/README.md +++ b/README.md @@ -13,7 +13,7 @@ with an efficient hardware-aware design and implementation in the spirit of [Fla ## Installation -- `pip install causal-conv1d>=1.1.0,<1.2.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block. +- [Option] `pip install causal-conv1d>=1.2.0`: an efficient implementation of a simple causal Conv1d layer used inside the Mamba block. - `pip install mamba-ssm`: the core Mamba package. It can also be built from source with `pip install .` from this repository. diff --git a/mamba_ssm/modules/mamba_simple.py b/mamba_ssm/modules/mamba_simple.py index 98d97a57..91cb9798 100644 --- a/mamba_ssm/modules/mamba_simple.py +++ b/mamba_ssm/modules/mamba_simple.py @@ -15,7 +15,7 @@ try: from causal_conv1d import causal_conv1d_fn, causal_conv1d_update except ImportError: - causal_conv1d_fn, causal_conv1d_update = None + causal_conv1d_fn, causal_conv1d_update = None, None try: from mamba_ssm.ops.triton.selective_state_update import selective_state_update @@ -142,7 +142,7 @@ def forward(self, hidden_states, inference_params=None): A = -torch.exp(self.A_log.float()) # (d_inner, d_state) # In the backward pass we write dx and dz next to each other to avoid torch.cat - if self.use_fast_path and inference_params is None: # Doesn't support outputting the states + if self.use_fast_path and causal_conv1d_fn is not None and inference_params is None: # Doesn't support outputting the states out = mamba_inner_fn( xz, self.conv1d.weight, diff --git a/mamba_ssm/ops/selective_scan_interface.py b/mamba_ssm/ops/selective_scan_interface.py index 52e7eba0..c3596bfe 100644 --- a/mamba_ssm/ops/selective_scan_interface.py +++ b/mamba_ssm/ops/selective_scan_interface.py @@ -6,8 +6,13 @@ from einops import rearrange, repeat -from causal_conv1d import causal_conv1d_fn -import causal_conv1d_cuda +try: + from causal_conv1d import causal_conv1d_fn + import causal_conv1d_cuda +except ImportError: + causal_conv1d_fn = None + causal_conv1d_cuda = None + import selective_scan_cuda @@ -163,6 +168,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh """ xz: (batch, dim, seqlen) """ + assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." assert checkpoint_lvl in [0, 1] L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] @@ -178,7 +184,9 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w") x, z = xz.chunk(2, dim=1) conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( + x, conv1d_weight, conv1d_bias, None, None, None, True + ) # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. @@ -231,6 +239,7 @@ def forward(ctx, xz, conv1d_weight, conv1d_bias, x_proj_weight, delta_proj_weigh @custom_bwd def backward(ctx, dout): # dout: (batch, seqlen, dim) + assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d." (xz, conv1d_weight, conv1d_bias, x_dbl, x_proj_weight, delta_proj_weight, out_proj_weight, conv1d_out, delta, A, B, C, D, delta_bias, scan_intermediates, out) = ctx.saved_tensors L = xz.shape[-1] @@ -240,7 +249,9 @@ def backward(ctx, dout): if dout.stride(-1) != 1: dout = dout.contiguous() if ctx.checkpoint_lvl == 1: - conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, True) + conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd( + x, conv1d_weight, conv1d_bias, None, None, None, True + ) delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l = L) # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the @@ -285,8 +296,8 @@ def backward(ctx, dout): dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1]) # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the # backward of conv1d with the backward of chunk). - dx, dconv1d_weight, dconv1d_bias = causal_conv1d_cuda.causal_conv1d_bwd( - x, conv1d_weight, conv1d_bias, dconv1d_out, None, dx, True + dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd( + x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True ) dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w") @@ -314,11 +325,12 @@ def mamba_inner_ref( A, B=None, C=None, D=None, delta_bias=None, B_proj_bias=None, C_proj_bias=None, delta_softplus=True ): + assert causal_conv1d_fn is not None, "causal_conv1d_fn is not available. Please install causal-conv1d." L = xz.shape[-1] delta_rank = delta_proj_weight.shape[1] d_state = A.shape[-1] * (1 if not A.is_complex() else 2) x, z = xz.chunk(2, dim=1) - x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, None, "silu") + x = causal_conv1d_fn(x, rearrange(conv1d_weight, "d 1 w -> d w"), conv1d_bias, activation="silu") # We're being very careful here about the layout, to avoid extra transposes. # We want delta to have d as the slowest moving dimension # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects. diff --git a/setup.py b/setup.py index 9e1023a9..d794a9ce 100644 --- a/setup.py +++ b/setup.py @@ -271,6 +271,6 @@ def run(self): "einops", "triton", "transformers", - "causal_conv1d>=1.1.0,<1.2.0", + # "causal_conv1d>=1.2.0", ], ) diff --git a/tests/ops/triton/test_selective_state_update.py b/tests/ops/triton/test_selective_state_update.py index 70a8d79d..3e4cc6ba 100644 --- a/tests/ops/triton/test_selective_state_update.py +++ b/tests/ops/triton/test_selective_state_update.py @@ -19,7 +19,7 @@ # @pytest.mark.parametrize("dstate", [16]) @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096]) # @pytest.mark.parametrize("dim", [2048]) -def test_causal_conv1d_update(dim, dstate, has_z, itype): +def test_selective_state_update(dim, dstate, has_z, itype): device = "cuda" rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (5e-3, 1e-2) if itype == torch.bfloat16: From a1c397b5837f59c44fdefae065149a7af3938fa8 Mon Sep 17 00:00:00 2001 From: Tri Dao <tridpq@gmail.com> Date: Fri, 1 Mar 2024 17:25:36 -0800 Subject: [PATCH 5/8] [CI] Compile for torch 2.3.0.dev20240126 --- .github/workflows/publish.yaml | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 62904299..163847b7 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -44,7 +44,7 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] - torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0'] + torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240126'] cuda-version: ['11.8.0', '12.2.2'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. @@ -71,6 +71,8 @@ jobs: python-version: '3.7' - torch-version: '2.2.0' python-version: '3.7' + - torch-version: '2.3.0.dev20240126' + python-version: '3.7' # Pytorch <= 2.0 only supports CUDA <= 11.8 - torch-version: '1.12.1' cuda-version: '12.2.2' @@ -133,9 +135,19 @@ jobs: # We want to figure out the CUDA version to download pytorch # e.g. we can have system CUDA version being 11.7 but if torch==1.12 then we need to download the wheel from cu116 # This code is ugly, maybe there's a better way to do this. - export TORCH_CUDA_VERSION=$(python -c "import os; minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118}[os.environ['MATRIX_TORCH_VERSION']]; maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121}[os.environ['MATRIX_TORCH_VERSION']]; print(max(min(int(os.environ['MATRIX_CUDA_VERSION']), maxv), minv))") + export TORCH_CUDA_VERSION=$(python -c "from os import environ as env; \ + minv = {'1.12': 113, '1.13': 116, '2.0': 117, '2.1': 118, '2.2': 118, '2.3': 118}[env['MATRIX_TORCH_VERSION']]; \ + maxv = {'1.12': 116, '1.13': 117, '2.0': 118, '2.1': 121, '2.2': 121, '2.3': 121}[env['MATRIX_TORCH_VERSION']]; \ + print(max(min(int(env['MATRIX_CUDA_VERSION']), maxv), minv))" \ + ) if [[ ${{ matrix.torch-version }} == *"dev"* ]]; then - pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} + if [[ ${MATRIX_TORCH_VERSION} == "2.2" ]]; then + # --no-deps because we can't install old versions of pytorch-triton + pip install typing-extensions jinja2 + pip install --no-cache-dir --no-deps --pre https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION}/torch-${{ matrix.torch-version }}%2Bcu${TORCH_CUDA_VERSION}-cp${MATRIX_PYTHON_VERSION}-cp${MATRIX_PYTHON_VERSION}-linux_x86_64.whl + else + pip install --no-cache-dir --pre torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/nightly/cu${TORCH_CUDA_VERSION} + fi else pip install --no-cache-dir torch==${{ matrix.torch-version }} --index-url https://download.pytorch.org/whl/cu${TORCH_CUDA_VERSION} fi From 2845255a95f7e4a76fe649ec0532178fe3015dd0 Mon Sep 17 00:00:00 2001 From: Tri Dao <tridpq@gmail.com> Date: Fri, 1 Mar 2024 17:26:15 -0800 Subject: [PATCH 6/8] Bump to v1.2.0 --- mamba_ssm/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mamba_ssm/__init__.py b/mamba_ssm/__init__.py index 63694144..c585a996 100644 --- a/mamba_ssm/__init__.py +++ b/mamba_ssm/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.1.4" +__version__ = "1.2.0" from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn from mamba_ssm.modules.mamba_simple import Mamba From 34076d664838588a3c97727b263478ab9f621a07 Mon Sep 17 00:00:00 2001 From: Tri Dao <tridpq@gmail.com> Date: Fri, 1 Mar 2024 22:06:49 -0800 Subject: [PATCH 7/8] [CI] Change torch 2.3.0.dev20240126 to 20240102 for nvcr 24.02 --- .github/workflows/publish.yaml | 4 ++-- mamba_ssm/__init__.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index 163847b7..facb5e07 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -44,7 +44,7 @@ jobs: # manylinux docker image, but I haven't figured out how to install CUDA on manylinux. os: [ubuntu-20.04] python-version: ['3.7', '3.8', '3.9', '3.10', '3.11', '3.12'] - torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240126'] + torch-version: ['1.12.1', '1.13.1', '2.0.1', '2.1.2', '2.2.0', '2.3.0.dev20240105'] cuda-version: ['11.8.0', '12.2.2'] # We need separate wheels that either uses C++11 ABI (-D_GLIBCXX_USE_CXX11_ABI) or not. # Pytorch wheels currently don't use it, but nvcr images have Pytorch compiled with C++11 ABI. @@ -71,7 +71,7 @@ jobs: python-version: '3.7' - torch-version: '2.2.0' python-version: '3.7' - - torch-version: '2.3.0.dev20240126' + - torch-version: '2.3.0.dev20240105' python-version: '3.7' # Pytorch <= 2.0 only supports CUDA <= 11.8 - torch-version: '1.12.1' diff --git a/mamba_ssm/__init__.py b/mamba_ssm/__init__.py index c585a996..9a28b68d 100644 --- a/mamba_ssm/__init__.py +++ b/mamba_ssm/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.2.0" +__version__ = "1.2.0.post1" from mamba_ssm.ops.selective_scan_interface import selective_scan_fn, mamba_inner_fn from mamba_ssm.modules.mamba_simple import Mamba From 9127d1f47f367f5c9cc49c73ad73557089d02cb8 Mon Sep 17 00:00:00 2001 From: deroholic <105595360+deroholic@users.noreply.github.com> Date: Mon, 11 Mar 2024 09:58:59 +0100 Subject: [PATCH 8/8] Fix directory creating for saving with multiprocessing (#185) --- mamba_ssm/models/mixer_seq_simple.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mamba_ssm/models/mixer_seq_simple.py b/mamba_ssm/models/mixer_seq_simple.py index 2f1d97fd..cd224738 100644 --- a/mamba_ssm/models/mixer_seq_simple.py +++ b/mamba_ssm/models/mixer_seq_simple.py @@ -252,8 +252,7 @@ def save_pretrained(self, save_directory): Save the model and its configuration file to a directory. """ # Ensure save_directory exists - if not os.path.exists(save_directory): - os.makedirs(save_directory) + os.makedirs(save_directory, exist_ok=True) # Save the model's state_dict model_path = os.path.join(save_directory, 'pytorch_model.bin')