Skip to content

Commit

Permalink
fix qk norm for paralleled VIT attention
Browse files Browse the repository at this point in the history
Co-authored-by: Isotr0py <mozf@mail2.sysu.edu.cn>
  • Loading branch information
ywang96 and Isotr0py committed Oct 7, 2024
1 parent b5ea51b commit 49e3dad
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 16 deletions.
31 changes: 23 additions & 8 deletions vllm/model_executor/models/intern_vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
# Copyright (c) 2023 OpenGVLab
# Licensed under The MIT License [see LICENSE for details]
# --------------------------------------------------------
from functools import partial
from typing import Iterable, Optional, Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
from vllm.model_executor.layers.activation import get_act_fn
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
Expand Down Expand Up @@ -104,6 +108,8 @@ def __init__(
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f'embed_dim must be divisible by num_heads '
Expand Down Expand Up @@ -134,22 +140,31 @@ def __init__(
self.tp_size = get_tensor_model_parallel_world_size()
self.num_heads_per_partition = divide(self.num_heads, self.tp_size)

def _apply_qk_norm(self, q, k):
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm.forward_native(q)
k = self.k_norm.forward_native(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
return q, k

def forward(self, x):
B, N, C = x.shape
qkv, _ = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)

if self.qk_normalization:
q, k = self._apply_qk_norm(q, k)

q = q.view(B, N, self.num_heads_per_partition, self.head_dim)
k = k.view(B, N, self.num_heads_per_partition, self.head_dim)
v = v.view(B, N, self.num_heads_per_partition, self.head_dim)

if self.qk_normalization:
B_, N_, H_, D_ = q.shape
q = self.q_norm.forward_native(q.flatten(-2,
-1)).view(B_, N_, H_, D_)
k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_)

x = xops.memory_efficient_attention_forward(q, k, v, scale=self.scale)
x = x.view(B, N, -1)

Expand Down
31 changes: 23 additions & 8 deletions vllm/model_executor/models/nvlm_d.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,18 @@
# Copyright (c) 2024 NVIDIA
# Licensed under Apache 2.0 License [see LICENSE for details]
# --------------------------------------------------------
from functools import partial
from typing import Optional

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import PretrainedConfig

from vllm.distributed import divide, get_tensor_model_parallel_world_size
from vllm.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
split_tensor_along_last_dim,
tensor_model_parallel_all_gather)
from vllm.inputs import INPUT_REGISTRY
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (QKVParallelLinear,
Expand Down Expand Up @@ -71,6 +75,8 @@ def __init__(
self.embed_dim = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.embed_dim // self.num_heads
self.tp_size = get_tensor_model_parallel_world_size()
self.tp_rank = get_tensor_model_parallel_rank()
if self.head_dim * self.num_heads != self.embed_dim:
raise ValueError(
f'embed_dim must be divisible by num_heads '
Expand Down Expand Up @@ -173,22 +179,31 @@ def __init__(self, config: PretrainedConfig, num_dummy_heads: int = 7):

self.proj = nn.Linear(self.dummy_dim, self.embed_dim)

def _apply_qk_norm(self, q, k):
if self.tp_size > 1:
q = tensor_model_parallel_all_gather(q.contiguous())
k = tensor_model_parallel_all_gather(k.contiguous())
q = self.q_norm.forward_native(q)
k = self.k_norm.forward_native(k)
if self.tp_size > 1:
splitter = partial(split_tensor_along_last_dim,
num_partitions=self.tp_size)
q = splitter(q)[self.tp_rank]
k = splitter(k)[self.tp_rank]
return q, k

def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x)
q, k, v = qkv.chunk(3, dim=-1)

if self.qk_normalization:
q, k = self._apply_qk_norm(q, k)

q = q.view(B, N, self.num_dummy_heads + self.num_heads, self.head_dim)
k = k.view(B, N, self.num_dummy_heads + self.num_heads, self.head_dim)
v = v.view(B, N, self.num_dummy_heads + self.num_heads, self.head_dim)

if self.qk_normalization:
B_, N_, H_, D_ = q.shape
q = self.q_norm.forward_native(q.flatten(-2,
-1)).view(B_, N_, H_, D_)
k = self.k_norm.forward_native(k.flatten(-2,
-1)).view(B_, N_, H_, D_)

q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
Expand Down

0 comments on commit 49e3dad

Please sign in to comment.