diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py index e6669b17e266..3808c6479061 100644 --- a/vllm/model_executor/models/intern_vit.py +++ b/vllm/model_executor/models/intern_vit.py @@ -4,6 +4,7 @@ # Copyright (c) 2023 OpenGVLab # Licensed under The MIT License [see LICENSE for details] # -------------------------------------------------------- +from functools import partial from typing import Iterable, Optional, Tuple import torch @@ -11,7 +12,10 @@ import torch.nn.functional as F from transformers import PretrainedConfig -from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) from vllm.model_executor.layers.activation import get_act_fn from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (ColumnParallelLinear, @@ -104,6 +108,8 @@ def __init__( self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f'embed_dim must be divisible by num_heads ' @@ -134,22 +140,31 @@ def __init__( self.tp_size = get_tensor_model_parallel_world_size() self.num_heads_per_partition = divide(self.num_heads, self.tp_size) + def _apply_qk_norm(self, q, k): + if self.tp_size > 1: + q = tensor_model_parallel_all_gather(q.contiguous()) + k = tensor_model_parallel_all_gather(k.contiguous()) + q = self.q_norm.forward_native(q) + k = self.k_norm.forward_native(k) + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + return q, k + def forward(self, x): B, N, C = x.shape qkv, _ = self.qkv(x) q, k, v = qkv.chunk(3, dim=-1) + if self.qk_normalization: + q, k = self._apply_qk_norm(q, k) + q = q.view(B, N, self.num_heads_per_partition, self.head_dim) k = k.view(B, N, self.num_heads_per_partition, self.head_dim) v = v.view(B, N, self.num_heads_per_partition, self.head_dim) - if self.qk_normalization: - B_, N_, H_, D_ = q.shape - q = self.q_norm.forward_native(q.flatten(-2, - -1)).view(B_, N_, H_, D_) - k = self.k_norm.forward_native(k.flatten(-2, - -1)).view(B_, N_, H_, D_) - x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale) x = x.view(B, N, -1) diff --git a/vllm/model_executor/models/nvlm_d.py b/vllm/model_executor/models/nvlm_d.py index c0b0b2600cc9..dd05c266b35c 100644 --- a/vllm/model_executor/models/nvlm_d.py +++ b/vllm/model_executor/models/nvlm_d.py @@ -4,6 +4,7 @@ # Copyright (c) 2024 NVIDIA # Licensed under Apache 2.0 License [see LICENSE for details] # -------------------------------------------------------- +from functools import partial from typing import Optional import torch @@ -11,7 +12,10 @@ import torch.nn.functional as F from transformers import PretrainedConfig -from vllm.distributed import divide, get_tensor_model_parallel_world_size +from vllm.distributed import (divide, get_tensor_model_parallel_rank, + get_tensor_model_parallel_world_size, + split_tensor_along_last_dim, + tensor_model_parallel_all_gather) from vllm.inputs import INPUT_REGISTRY from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (QKVParallelLinear, @@ -71,6 +75,8 @@ def __init__( self.embed_dim = config.hidden_size self.num_heads = config.num_attention_heads self.head_dim = self.embed_dim // self.num_heads + self.tp_size = get_tensor_model_parallel_world_size() + self.tp_rank = get_tensor_model_parallel_rank() if self.head_dim * self.num_heads != self.embed_dim: raise ValueError( f'embed_dim must be divisible by num_heads ' @@ -173,22 +179,31 @@ def __init__(self, config: PretrainedConfig, num_dummy_heads: int = 7): self.proj = nn.Linear(self.dummy_dim, self.embed_dim) + def _apply_qk_norm(self, q, k): + if self.tp_size > 1: + q = tensor_model_parallel_all_gather(q.contiguous()) + k = tensor_model_parallel_all_gather(k.contiguous()) + q = self.q_norm.forward_native(q) + k = self.k_norm.forward_native(k) + if self.tp_size > 1: + splitter = partial(split_tensor_along_last_dim, + num_partitions=self.tp_size) + q = splitter(q)[self.tp_rank] + k = splitter(k)[self.tp_rank] + return q, k + def forward(self, x): B, N, C = x.shape qkv = self.qkv(x) q, k, v = qkv.chunk(3, dim=-1) + if self.qk_normalization: + q, k = self._apply_qk_norm(q, k) + q = q.view(B, N, self.num_dummy_heads + self.num_heads, self.head_dim) k = k.view(B, N, self.num_dummy_heads + self.num_heads, self.head_dim) v = v.view(B, N, self.num_dummy_heads + self.num_heads, self.head_dim) - if self.qk_normalization: - B_, N_, H_, D_ = q.shape - q = self.q_norm.forward_native(q.flatten(-2, - -1)).view(B_, N_, H_, D_) - k = self.k_norm.forward_native(k.flatten(-2, - -1)).view(B_, N_, H_, D_) - q = q.transpose(1, 2) k = k.transpose(1, 2) v = v.transpose(1, 2)