Skip to content

Commit

Permalink
[Fix] Address remaining issues of supporting MiniCPMV (sgl-project#2977)
Browse files Browse the repository at this point in the history
  • Loading branch information
mickqian authored Jan 28, 2025
1 parent 76285fd commit 9f635ea
Show file tree
Hide file tree
Showing 12 changed files with 708 additions and 223 deletions.
1 change: 1 addition & 0 deletions docs/references/supported_models.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ Another valuable resource is the [vLLM Models Directory](https://github.com/vllm
To port a model from vLLM to SGLang, you can compare these two files [SGLang Llama Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py) and [vLLM Llama Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of Attention with RadixAttention. The other parts are almost identical. Specifically,
- Replace vllm's `Attention` with `RadixAttention`. Note that you need to pass `layer_id` all the way to `RadixAttention`.
- Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`.
- Replace Multi-headed `Attention` of ViT with SGLang's `VisionAttention`.
- Replace other vLLM layers with SGLang layers (e.g., `RMSNorm`, `SiluAndMul`).
- Remove `Sample`.
- Change `forward()` functions, and add `forward_batch`.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,12 @@ def _fwd_kernel(
def context_attention_fwd(
q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True
):
"""
q, k, v: [b * s, head, head_dim]
b_start_loc: [b]
b_seq_len: [b]
out: [b * s, head, head_dim]
"""
if is_cuda_available and CUDA_CAPABILITY[0] > 8:
BLOCK = 128
else:
Expand Down
283 changes: 243 additions & 40 deletions python/sglang/srt/layers/attention/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat

from sglang.srt.distributed import parallel_state
Expand Down Expand Up @@ -63,7 +64,20 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.T


class VisionAttention(nn.Module):
"""Multi-headed attention without any cache, mostly used for ViT."""
r"""
Multi-headed attention without any cache, mostly used for ViT.
Args:
use_qkv_parallel (bool, optional): If True, use QKV-parallel attention.
use_context_forward (bool, default to True):
if ``True``, a flash_attn style attention will be applied
Otherwise, a full-sequence attention will be applied.
use_full_precision_softmax (bool, default to False):
if ``True``, the softmax will be performed in full-precision
Otherwise, it will be performed in half-precision
"""

def __init__(
self,
Expand All @@ -72,25 +86,39 @@ def __init__(
projection_size: int,
use_qkv_parallel: bool,
quant_config: Optional[QuantizationConfig] = None,
dropout: float = 0.0,
use_context_forward: bool = True,
use_full_precision_softmax: bool = False,
flatten_batch: bool = False,
prefix: str = "",
):
super().__init__()
self.use_context_forward = use_context_forward
world_size = parallel_state.get_tensor_model_parallel_world_size()

self.dropout = dropout
self.head_size = embed_dim // num_heads
self.hidden_size_per_attention_head = dist_utils.divide(
projection_size, num_heads
)
self.num_attention_heads_per_partition = dist_utils.divide(
num_heads, world_size
)
# self.tp_size = get_tensor_model_parallel_world_size()
# num_heads = self.num_heads_per_partition

if self.use_context_forward:
self.qkv_backend = VisionTritonAttention()
else:
self.qkv_backend = VisionSdpaAttention(
head_size=self.head_size,
dropout=dropout,
flatten_batch=flatten_batch,
use_full_precision_softmax=use_full_precision_softmax,
)

self.use_qkv_parallel = use_qkv_parallel
if use_qkv_parallel:
self.head_dim = embed_dim // num_heads
self.qkv_proj = QKVParallelLinear(
hidden_size=embed_dim,
head_size=self.head_dim,
head_size=self.head_size,
total_num_heads=num_heads,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
Expand All @@ -114,12 +142,15 @@ def forward(
x: torch.Tensor,
cu_seqlens: Optional[torch.Tensor] = None,
rotary_pos_emb: torch.Tensor = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
x: [b, s, embed_dim]
cu_seqlens: [b]
Returns:
[s, b, num_heads * head]
"""
Input shape: [b, s, embed_dim]
Output shape: [s, b, num_heads * head_size]
"""

bsz, s, _ = x.shape
if self.use_qkv_parallel:
# [b, s, embed_dim] --> [b, s, embed_dim]
Expand All @@ -136,19 +167,19 @@ def forward(
else:
# [b, s, embed_dim] --> [s, b, embed_dim]
x = rearrange(x, "b s ... -> s b ...")
# [s, b, embed_dim] --> [s, b, head * 3 * head_dim]
# [s, b, embed_dim] --> [s, b, head * 3 * head_size]
qkv, _ = self.qkv_proj(x)
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
# [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size]
new_x_shape = qkv.size()[:-1] + (
self.num_attention_heads_per_partition,
3 * self.hidden_size_per_attention_head,
)
qkv = qkv.view(*new_x_shape)

# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
# [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size]
q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3)

# [s, b, head, head_dim] --> [b, s, head, head_dim]
# [s, b, head, head_size] --> [b, s, head, head_size]
q, k, v = [
rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)
]
Expand All @@ -160,45 +191,217 @@ def forward(
if self.use_qkv_parallel:
pass
else:
# [b, s, head, head_dim] --> [b * s, head, head_dim]
# [b, s, head, head_size] --> [b * s, head, head_size]
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]

# [b * s, num_heads, head_size]
output = torch.empty_like(q)

seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda()
max_seqlen = seq_lens.max().item()

context_attention_fwd(
q,
k,
v,
output,
cu_seqlens.cuda(),
seq_lens,
max_seqlen,
is_causal=False,
)
output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask)

if self.use_qkv_parallel:

# [b * s, head, head_dim] --> [b, s, head * head_dim]
# [b * s, h, head_size] --> [b, s, h * head_size]
output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz)

# [b, s, head, head_dim] --> [b, s, head, head_dim]
# [b, s, h * head_size] --> [b, s, h * head_size]
output, _ = self.proj(output)
else:
# [b * s, head, head_dim] --> [b, s, head, head_dim]
context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz)

# [s, b, num_heads * head_size]
# [b * s, h, head_size] --> [s, b, h * head_size]
context_layer = rearrange(
context_layer, "b s h d -> s b (h d)"
output, "(b s) h d -> s b (h d)", b=bsz, s=s
).contiguous()

# [s, b, num_heads * head_size] --> [s, b, num_heads * head_size]
# [s, b, h * head_size] --> [s, b, h * head_size]
output, _ = self.proj(context_layer)

# [s, b, h * head_size] --> [b, s, h * head_size]
output = output.view(bsz, s, -1)

return output


class VisionSdpaAttention(nn.Module):
r"""
Scaled Dot Product Attention inner product
"""

# TODO: Should it be released after used?
_mask_cache = {}

def __init__(
self,
head_size: int,
dropout: float = 0.0,
flatten_batch: bool = False,
use_full_precision_softmax: bool = False,
):
super().__init__()
self.head_size = head_size
self.flatten_batch = flatten_batch
self.use_full_precision_softmax = use_full_precision_softmax
self.dropout = dropout

def generate_patch_attention_mask(
self,
s: int,
bsz: int,
device,
cu_seqlens: Optional[torch.Tensor],
flatten_batch: bool = False,
dtype=torch.bfloat16,
) -> torch.Tensor:
r"""
Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
When `flatten_batch` is True:
- All sequences in the batch are flattened into a single dimension
- `s` represents the total number of tokens across all sequences in the batch
- Returns a unified mask of shape `(1, 1, s, s)`
When `flatten_batch` is False:
- Each sequence has its own attention mask
- `s` represents the maximum sequence length in the batch
- Returns separate masks of shape `(b, 1, s, s)`
Args:
flatten_batch: (bool):
If True, treats all sequences in the batch as a single flattened sequence
If False, generates separate masks for each sequence
Returns:
Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`.
"""

cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist()))

if cache_key in VisionSdpaAttention._mask_cache:
cached_mask = VisionSdpaAttention._mask_cache[cache_key]
# print(f"cache hit for key: {cache_key}")
return cached_mask.to(device=device, dtype=dtype)

if cu_seqlens is None:
raise ValueError("Internal Error: cu_seqlens cannot be None")

if flatten_batch:
mask = torch.zeros([1, s, s], device=device, dtype=torch.bool)
for i in range(1, len(cu_seqlens)):
start = cu_seqlens[i - 1]
end = cu_seqlens[i]
mask[
...,
start:end,
start:end,
] = True
else:
# [1, 1, 1, s]
row_indices = torch.arange(s, device=device).view(1, 1, 1, s)
# [1, 1, s, 1]
col_indices = torch.arange(s, device=device).view(1, 1, s, 1)
# [b, 1, 1, 1]
seq_lens = (
(cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1)
)

mask = (row_indices < seq_lens) & (col_indices < seq_lens)

# Convert to attention mask format (False -> 0, True -> -inf)
mask = (~mask).to(dtype) * torch.finfo(dtype).min

VisionSdpaAttention._mask_cache[cache_key] = mask

return mask

def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
bsz: int,
cu_seqlens: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
r"""
Args:
cu_seqlens: [b]
Returns:
[b * s, h, head_size]
"""

s = q.shape[0] // bsz

# [b, 1, s, s]
if attention_mask is None:
attention_mask = self.generate_patch_attention_mask(
s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype
)
q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]]
# [b, 1, s]
if self.use_full_precision_softmax:
scale = self.head_size**-0.5
k_transposed = rearrange(k, "b h s d -> b h d s")
attn_weights = torch.matmul(q, k_transposed) * scale
del k, k_transposed
attn_weights = attn_weights + attention_mask
del attention_mask
# full-precision
attn_weights = nn.functional.softmax(
attn_weights, dim=-1, dtype=torch.float32
).to(q.dtype)
attn_weights = nn.functional.dropout(
attn_weights, p=self.dropout, training=False
)
output = torch.matmul(attn_weights, v)
del attn_weights, v
else:
# SDPA
# [b, h, s, head_size]
output = F.scaled_dot_product_attention(
q, k, v, attention_mask, dropout_p=self.dropout
)

# [b, h, s, head_size] --> [b * s, h, head_size]
output = rearrange(output, "b h s d -> (b s) h d")

return output


class VisionTritonAttention(nn.Module):
"""
Triton-implemented attention without a causal mask
"""

def __init__(
self,
):
super().__init__()

def forward(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
_bsz: int,
cu_seqlens: Optional[torch.Tensor],
**kwargs,
) -> torch.Tensor:
r"""
Args:
cu_seqlens: [b]
Returns:
[b * s, h, head_size]
"""

# [b * s, head, head_size]
output = torch.empty_like(q)
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
max_seqlen = seq_lens.max().item()
context_attention_fwd(
q,
k,
v,
output,
cu_seqlens.cuda(),
seq_lens.cuda(),
max_seqlen,
is_causal=False,
)

return output
Loading

0 comments on commit 9f635ea

Please sign in to comment.