From 9f635ea50de920aa507f486daafba26a5b837574 Mon Sep 17 00:00:00 2001 From: Mick Date: Tue, 28 Jan 2025 16:22:13 +0800 Subject: [PATCH] [Fix] Address remaining issues of supporting MiniCPMV (#2977) --- docs/references/supported_models.md | 1 + .../attention/triton_ops/prefill_attention.py | 6 + python/sglang/srt/layers/attention/vision.py | 283 +++++++++++++++--- python/sglang/srt/managers/image_processor.py | 115 ++++--- python/sglang/srt/models/minicpmv.py | 205 ++++++++----- python/sglang/srt/models/mllama.py | 72 +---- python/sglang/srt/models/qwen2.py | 5 +- python/sglang/srt/models/qwen2_vl.py | 26 +- python/sglang/srt/utils.py | 2 - test/srt/run_suite.py | 2 +- test/srt/test_vision_llm.py | 210 +++++++++++++ test/srt/test_vision_openai_server.py | 4 +- 12 files changed, 708 insertions(+), 223 deletions(-) create mode 100644 test/srt/test_vision_llm.py diff --git a/docs/references/supported_models.md b/docs/references/supported_models.md index 0a00ad0c8a1..93c4273765d 100644 --- a/docs/references/supported_models.md +++ b/docs/references/supported_models.md @@ -78,6 +78,7 @@ Another valuable resource is the [vLLM Models Directory](https://github.com/vllm To port a model from vLLM to SGLang, you can compare these two files [SGLang Llama Implementation](https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/models/llama.py) and [vLLM Llama Implementation](https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/llama.py). This comparison will help you understand how to convert a model implementation from vLLM to SGLang. The major difference is the replacement of Attention with RadixAttention. The other parts are almost identical. Specifically, - Replace vllm's `Attention` with `RadixAttention`. Note that you need to pass `layer_id` all the way to `RadixAttention`. - Replace vllm's `LogitsProcessor` with SGLang's `LogitsProcessor`. + - Replace Multi-headed `Attention` of ViT with SGLang's `VisionAttention`. - Replace other vLLM layers with SGLang layers (e.g., `RMSNorm`, `SiluAndMul`). - Remove `Sample`. - Change `forward()` functions, and add `forward_batch`. diff --git a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py index 9163eba68de..d022b972147 100644 --- a/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py +++ b/python/sglang/srt/layers/attention/triton_ops/prefill_attention.py @@ -166,6 +166,12 @@ def _fwd_kernel( def context_attention_fwd( q, k, v, o, b_start_loc, b_seq_len, max_input_len, is_causal=True ): + """ + q, k, v: [b * s, head, head_dim] + b_start_loc: [b] + b_seq_len: [b] + out: [b * s, head, head_dim] + """ if is_cuda_available and CUDA_CAPABILITY[0] > 8: BLOCK = 128 else: diff --git a/python/sglang/srt/layers/attention/vision.py b/python/sglang/srt/layers/attention/vision.py index 4fcfaad5625..03c4cfb46a8 100644 --- a/python/sglang/srt/layers/attention/vision.py +++ b/python/sglang/srt/layers/attention/vision.py @@ -4,6 +4,7 @@ import torch import torch.nn as nn +import torch.nn.functional as F from einops import rearrange, repeat from sglang.srt.distributed import parallel_state @@ -63,7 +64,20 @@ def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.T class VisionAttention(nn.Module): - """Multi-headed attention without any cache, mostly used for ViT.""" + r""" + Multi-headed attention without any cache, mostly used for ViT. + + + Args: + use_qkv_parallel (bool, optional): If True, use QKV-parallel attention. + use_context_forward (bool, default to True): + if ``True``, a flash_attn style attention will be applied + Otherwise, a full-sequence attention will be applied. + use_full_precision_softmax (bool, default to False): + if ``True``, the softmax will be performed in full-precision + Otherwise, it will be performed in half-precision + + """ def __init__( self, @@ -72,25 +86,39 @@ def __init__( projection_size: int, use_qkv_parallel: bool, quant_config: Optional[QuantizationConfig] = None, + dropout: float = 0.0, + use_context_forward: bool = True, + use_full_precision_softmax: bool = False, + flatten_batch: bool = False, prefix: str = "", ): super().__init__() + self.use_context_forward = use_context_forward world_size = parallel_state.get_tensor_model_parallel_world_size() - + self.dropout = dropout + self.head_size = embed_dim // num_heads self.hidden_size_per_attention_head = dist_utils.divide( projection_size, num_heads ) self.num_attention_heads_per_partition = dist_utils.divide( num_heads, world_size ) - # self.tp_size = get_tensor_model_parallel_world_size() - # num_heads = self.num_heads_per_partition + + if self.use_context_forward: + self.qkv_backend = VisionTritonAttention() + else: + self.qkv_backend = VisionSdpaAttention( + head_size=self.head_size, + dropout=dropout, + flatten_batch=flatten_batch, + use_full_precision_softmax=use_full_precision_softmax, + ) + self.use_qkv_parallel = use_qkv_parallel if use_qkv_parallel: - self.head_dim = embed_dim // num_heads self.qkv_proj = QKVParallelLinear( hidden_size=embed_dim, - head_size=self.head_dim, + head_size=self.head_size, total_num_heads=num_heads, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", @@ -114,12 +142,15 @@ def forward( x: torch.Tensor, cu_seqlens: Optional[torch.Tensor] = None, rotary_pos_emb: torch.Tensor = None, + attention_mask: Optional[torch.Tensor] = None, ) -> torch.Tensor: + r""" + Args: + x: [b, s, embed_dim] + cu_seqlens: [b] + Returns: + [s, b, num_heads * head] """ - Input shape: [b, s, embed_dim] - Output shape: [s, b, num_heads * head_size] - """ - bsz, s, _ = x.shape if self.use_qkv_parallel: # [b, s, embed_dim] --> [b, s, embed_dim] @@ -136,19 +167,19 @@ def forward( else: # [b, s, embed_dim] --> [s, b, embed_dim] x = rearrange(x, "b s ... -> s b ...") - # [s, b, embed_dim] --> [s, b, head * 3 * head_dim] + # [s, b, embed_dim] --> [s, b, head * 3 * head_size] qkv, _ = self.qkv_proj(x) - # [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim] + # [s, b, head * 3 * head_size] --> [s, b, head, 3 * head_size] new_x_shape = qkv.size()[:-1] + ( self.num_attention_heads_per_partition, 3 * self.hidden_size_per_attention_head, ) qkv = qkv.view(*new_x_shape) - # [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim] + # [s, b, head, 3 * head_size] --> 3 [s, b, head, head_size] q, k, v = dist_utils.split_tensor_along_last_dim(qkv, 3) - # [s, b, head, head_dim] --> [b, s, head, head_dim] + # [s, b, head, head_size] --> [b, s, head, head_size] q, k, v = [ rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v) ] @@ -160,45 +191,217 @@ def forward( if self.use_qkv_parallel: pass else: - # [b, s, head, head_dim] --> [b * s, head, head_dim] + # [b, s, head, head_size] --> [b * s, head, head_size] q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]] - # [b * s, num_heads, head_size] - output = torch.empty_like(q) - - seq_lens = (cu_seqlens[1:] - cu_seqlens[:-1]).cuda() - max_seqlen = seq_lens.max().item() - - context_attention_fwd( - q, - k, - v, - output, - cu_seqlens.cuda(), - seq_lens, - max_seqlen, - is_causal=False, - ) + output = self.qkv_backend.forward(q, k, v, bsz, cu_seqlens, attention_mask) if self.use_qkv_parallel: - - # [b * s, head, head_dim] --> [b, s, head * head_dim] + # [b * s, h, head_size] --> [b, s, h * head_size] output = rearrange(output, "(b s) ... h d -> b s ... (h d)", b=bsz) - # [b, s, head, head_dim] --> [b, s, head, head_dim] + # [b, s, h * head_size] --> [b, s, h * head_size] output, _ = self.proj(output) else: - # [b * s, head, head_dim] --> [b, s, head, head_dim] - context_layer = rearrange(output, "(b s) ... -> b s ...", b=bsz) - - # [s, b, num_heads * head_size] + # [b * s, h, head_size] --> [s, b, h * head_size] context_layer = rearrange( - context_layer, "b s h d -> s b (h d)" + output, "(b s) h d -> s b (h d)", b=bsz, s=s ).contiguous() - # [s, b, num_heads * head_size] --> [s, b, num_heads * head_size] + # [s, b, h * head_size] --> [s, b, h * head_size] output, _ = self.proj(context_layer) + # [s, b, h * head_size] --> [b, s, h * head_size] output = output.view(bsz, s, -1) return output + + +class VisionSdpaAttention(nn.Module): + r""" + Scaled Dot Product Attention inner product + + """ + + # TODO: Should it be released after used? + _mask_cache = {} + + def __init__( + self, + head_size: int, + dropout: float = 0.0, + flatten_batch: bool = False, + use_full_precision_softmax: bool = False, + ): + super().__init__() + self.head_size = head_size + self.flatten_batch = flatten_batch + self.use_full_precision_softmax = use_full_precision_softmax + self.dropout = dropout + + def generate_patch_attention_mask( + self, + s: int, + bsz: int, + device, + cu_seqlens: Optional[torch.Tensor], + flatten_batch: bool = False, + dtype=torch.bfloat16, + ) -> torch.Tensor: + r""" + Creates a non-causal 4D mask of shape `(b, 1, s, s)` or `(1, 1, s, s)`. + + When `flatten_batch` is True: + - All sequences in the batch are flattened into a single dimension + - `s` represents the total number of tokens across all sequences in the batch + - Returns a unified mask of shape `(1, 1, s, s)` + + When `flatten_batch` is False: + - Each sequence has its own attention mask + - `s` represents the maximum sequence length in the batch + - Returns separate masks of shape `(b, 1, s, s)` + + Args: + flatten_batch: (bool): + If True, treats all sequences in the batch as a single flattened sequence + If False, generates separate masks for each sequence + + Returns: + Tensor of shape `(b, 1, s, s)` or `(1, 1, s, s)`. + """ + + cache_key = (s, bsz, flatten_batch, tuple(cu_seqlens.cpu().tolist())) + + if cache_key in VisionSdpaAttention._mask_cache: + cached_mask = VisionSdpaAttention._mask_cache[cache_key] + # print(f"cache hit for key: {cache_key}") + return cached_mask.to(device=device, dtype=dtype) + + if cu_seqlens is None: + raise ValueError("Internal Error: cu_seqlens cannot be None") + + if flatten_batch: + mask = torch.zeros([1, s, s], device=device, dtype=torch.bool) + for i in range(1, len(cu_seqlens)): + start = cu_seqlens[i - 1] + end = cu_seqlens[i] + mask[ + ..., + start:end, + start:end, + ] = True + else: + # [1, 1, 1, s] + row_indices = torch.arange(s, device=device).view(1, 1, 1, s) + # [1, 1, s, 1] + col_indices = torch.arange(s, device=device).view(1, 1, s, 1) + # [b, 1, 1, 1] + seq_lens = ( + (cu_seqlens[1:] - cu_seqlens[:-1]).to(device=device).view(-1, 1, 1, 1) + ) + + mask = (row_indices < seq_lens) & (col_indices < seq_lens) + + # Convert to attention mask format (False -> 0, True -> -inf) + mask = (~mask).to(dtype) * torch.finfo(dtype).min + + VisionSdpaAttention._mask_cache[cache_key] = mask + + return mask + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + bsz: int, + cu_seqlens: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + r""" + Args: + cu_seqlens: [b] + Returns: + [b * s, h, head_size] + """ + + s = q.shape[0] // bsz + + # [b, 1, s, s] + if attention_mask is None: + attention_mask = self.generate_patch_attention_mask( + s, bsz, q.device, cu_seqlens, self.flatten_batch, q.dtype + ) + q, k, v = [rearrange(x, "(b s) h d -> b h s d", b=bsz) for x in [q, k, v]] + # [b, 1, s] + if self.use_full_precision_softmax: + scale = self.head_size**-0.5 + k_transposed = rearrange(k, "b h s d -> b h d s") + attn_weights = torch.matmul(q, k_transposed) * scale + del k, k_transposed + attn_weights = attn_weights + attention_mask + del attention_mask + # full-precision + attn_weights = nn.functional.softmax( + attn_weights, dim=-1, dtype=torch.float32 + ).to(q.dtype) + attn_weights = nn.functional.dropout( + attn_weights, p=self.dropout, training=False + ) + output = torch.matmul(attn_weights, v) + del attn_weights, v + else: + # SDPA + # [b, h, s, head_size] + output = F.scaled_dot_product_attention( + q, k, v, attention_mask, dropout_p=self.dropout + ) + + # [b, h, s, head_size] --> [b * s, h, head_size] + output = rearrange(output, "b h s d -> (b s) h d") + + return output + + +class VisionTritonAttention(nn.Module): + """ + Triton-implemented attention without a causal mask + """ + + def __init__( + self, + ): + super().__init__() + + def forward( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + _bsz: int, + cu_seqlens: Optional[torch.Tensor], + **kwargs, + ) -> torch.Tensor: + r""" + Args: + cu_seqlens: [b] + Returns: + [b * s, h, head_size] + """ + + # [b * s, head, head_size] + output = torch.empty_like(q) + seq_lens = cu_seqlens[1:] - cu_seqlens[:-1] + max_seqlen = seq_lens.max().item() + context_attention_fwd( + q, + k, + v, + output, + cu_seqlens.cuda(), + seq_lens.cuda(), + max_seqlen, + is_causal=False, + ) + + return output diff --git a/python/sglang/srt/managers/image_processor.py b/python/sglang/srt/managers/image_processor.py index c8ebbed783a..f43ecb18c16 100644 --- a/python/sglang/srt/managers/image_processor.py +++ b/python/sglang/srt/managers/image_processor.py @@ -240,6 +240,7 @@ async def process_images_async( class MiniCPMVImageProcessor(BaseImageProcessor): def __init__(self, hf_config, server_args, _processor): super().__init__(hf_config, server_args, _processor) + self.IMAGE_TOKEN = "(./)" @staticmethod def _process_images_task(images, input_text): @@ -271,7 +272,7 @@ async def _process_images(self, images, input_text): async def process_images_async( self, image_data: List[Union[str, bytes]], - input_text, + input_ids, request_obj, max_req_input_len, ): @@ -282,28 +283,49 @@ async def process_images_async( image_data = [image_data] image_hashes, image_sizes = [], [] - raw_images = [] - IMAGE_TOKEN = "(./)" + all_frames = [] - # roughly calculate the max number of frames - # TODO: the process should be applied to all the visual inputs + # roughly calculate the max number of frames under the max_req_input_len limit def calculate_max_num_frames() -> int: # Model-specific NUM_TOKEN_PER_FRAME = 330 - ret = (max_req_input_len - len(input_text)) // NUM_TOKEN_PER_FRAME + ret = (max_req_input_len - len(input_ids)) // NUM_TOKEN_PER_FRAME return min(ret, 100) - # if cuda OOM set a smaller number MAX_NUM_FRAMES = calculate_max_num_frames() - print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}") - def encode_video(video_path): + # print(f"MAX_NUM_FRAMES: {MAX_NUM_FRAMES}") + + def get_estimated_frames_list(): + """ + estimate the total frame count from all visual input + """ + # Before processing inputs + estimated_frames_list = [] + for image in image_data: + if isinstance(image, str) and image.startswith("video:"): + path = image[len("video:") :] + # Estimate frames for the video + vr = VideoReader(path, ctx=cpu(0)) + num_frames = len(vr) + else: + # For images, each contributes one frame + num_frames = 1 + estimated_frames_list.append(num_frames) + + return estimated_frames_list + + estimated_frames_list = get_estimated_frames_list() + total_frame_count = sum(estimated_frames_list) + scaling_factor = min(1.0, MAX_NUM_FRAMES / total_frame_count) + + def encode_video(video_path, frame_count_limit=None): if not os.path.exists(video_path): logger.error(f"Video {video_path} does not exist") return [] - if MAX_NUM_FRAMES == 0: + if frame_count_limit == 0: return [] def uniform_sample(l, n): @@ -314,45 +336,63 @@ def uniform_sample(l, n): vr = VideoReader(video_path, ctx=cpu(0)) sample_fps = round(vr.get_avg_fps() / 1) # FPS frame_idx = [i for i in range(0, len(vr), sample_fps)] - if len(frame_idx) > MAX_NUM_FRAMES: - frame_idx = uniform_sample(frame_idx, MAX_NUM_FRAMES) + if frame_count_limit is not None and len(frame_idx) > frame_count_limit: + frame_idx = uniform_sample(frame_idx, frame_count_limit) frames = vr.get_batch(frame_idx).asnumpy() frames = [Image.fromarray(v.astype("uint8")) for v in frames] return frames - if isinstance(input_text, list): - assert len(input_text) and isinstance(input_text[0], int) - input_text = self._processor.tokenizer.decode(input_text) - + if isinstance(input_ids, list): + assert len(input_ids) and isinstance(input_ids[0], int) + input_text = self._processor.tokenizer.decode(input_ids) + else: + input_text = input_ids # MiniCPMV requires each frame of video as a single image token - text_parts = input_text.split(IMAGE_TOKEN) + text_parts = input_text.split(self.IMAGE_TOKEN) new_text_parts = [] - for image_index, image in enumerate(image_data): - try: - if isinstance(image, str) and image.startswith("video:"): - path = image[len("video:") :] - frames = encode_video(path) - else: - raw_image, size = load_image(image) - frames = [raw_image] - if len(frames) == 0: - continue - except FileNotFoundError as e: - print(e) - return None - - image_sizes += frames[0].size * len(frames) - image_hashes += [hash(image)] * len(frames) - raw_images += frames + # Process each input with allocated frames + for image_index, (image, estimated_frames) in enumerate( + zip(image_data, estimated_frames_list) + ): + if len(all_frames) >= MAX_NUM_FRAMES: + frames_to_process = 0 + else: + frames_to_process = max(1, int(estimated_frames * scaling_factor)) + + if frames_to_process == 0: + frames = [] + else: + try: + if isinstance(image, str) and image.startswith("video:"): + path = image[len("video:") :] + frames = encode_video(path, frame_count_limit=frames_to_process) + else: + raw_image, _size = load_image(image) + frames = [raw_image] + if len(frames) == 0: + continue + except FileNotFoundError as e: + print(e) + return None + image_sizes += frames[0].size * len(frames) + image_hashes += [hash(image)] * len(frames) + all_frames += frames + + assert frames_to_process == len(frames) + new_text_parts.append(text_parts[image_index]) - new_text_parts.append(IMAGE_TOKEN * len(frames)) + + if frames_to_process != 0: + new_text_parts.append(self.IMAGE_TOKEN * len(frames)) new_text_parts.append(text_parts[-1]) + input_text = "".join(new_text_parts) - if len(raw_images) == 0: + + if len(all_frames) == 0: return None - res = await self._process_images(images=raw_images, input_text=input_text) + res = await self._process_images(images=all_frames, input_text=input_text) pixel_values = res["pixel_values"] tgt_sizes = res["tgt_sizes"] input_ids = res["input_ids"] @@ -364,7 +404,6 @@ def uniform_sample(l, n): if tokenizer.slice_start_id: slice_start_id = [tokenizer.slice_start_id] slice_end_id = [tokenizer.slice_end_id] - return { "input_ids": input_ids.flatten().tolist(), "pixel_values": pixel_values, diff --git a/python/sglang/srt/models/minicpmv.py b/python/sglang/srt/models/minicpmv.py index 23147529a64..7b02b4cedbb 100644 --- a/python/sglang/srt/models/minicpmv.py +++ b/python/sglang/srt/models/minicpmv.py @@ -1,6 +1,6 @@ # Adapted from # https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/llama/modeling_llama.py -# Copyright 2023 The vLLM team. +# Copyright 2023 The SGLang team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX @@ -20,7 +20,7 @@ # See the License for the specific language governing permissions and # limitations under the License. """Inference-only MiniCPM-V model compatible with HuggingFace weights.""" -from functools import cached_property, partial +from functools import partial from typing import ( Any, Callable, @@ -33,16 +33,13 @@ Union, ) +import numpy as np import torch import torch.types from PIL import Image from torch import nn from torch.nn.init import trunc_normal_ from transformers import PretrainedConfig -from vllm.model_executor.layers.resampler import get_2d_sincos_pos_embed -from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler -from vllm.model_executor.models.module_mapping import MultiModelKeys -from vllm.model_executor.sampling_metadata import SamplingMetadata from sglang.srt.distributed import divide, get_tensor_model_parallel_world_size from sglang.srt.layers.activation import get_act_fn @@ -63,6 +60,88 @@ RawImageType = Union[Image.Image, torch.Tensor] +# sin/cos positional embedding helpers are adapted from: +# https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 +def get_1d_sincos_pos_embed_from_grid( + embed_dim: int, pos: np.ndarray, version: Tuple[int, int] = (2, 0) +) -> torch.Tensor: + """ + embed_dim: output dimension for each position + pos: a list of positions to be encoded: size (M,) / (H, W) + out: (M, D) / (H, W, D) + """ + assert embed_dim % 2 == 0 + omega = np.arange(embed_dim // 2, dtype=np.float32) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + if version == (2, 0): + pos = pos.reshape(-1) # (M,) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) + else: + out = np.einsum("hw,d->hwd", pos, omega) # (H, W, D/2), outer product + emb_sin = np.sin(out) # (H, W, D/2) + emb_cos = np.cos(out) # (H, W, D/2) + emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed_from_grid( + embed_dim: int, grid: np.ndarray, version: Tuple[int, int] = (2, 0) +) -> torch.Tensor: + assert embed_dim % 2 == 0 + + # use half of dimensions to encode grid_h + emb_h = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[0], version + ) # (H*W, D/2) or (H, W, D/2) + emb_w = get_1d_sincos_pos_embed_from_grid( + embed_dim // 2, grid[1], version + ) # (H*W, D/2) or (H, W, D/2) + + if version == (2, 0): + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + else: + emb = np.concatenate([emb_h, emb_w], axis=-1) # (H, W, D) + return emb + + +def get_2d_sincos_pos_embed( + embed_dim: int, + grid_size: Union[int, Tuple[int, int]], + cls_token: bool = False, + version: Tuple[int, int] = (2, 0), +) -> torch.Tensor: + """ + grid_size: int of the grid height and width + return: + pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + if isinstance(grid_size, int): + grid_h_size, grid_w_size = grid_size, grid_size + else: + grid_h_size, grid_w_size = grid_size[0], grid_size[1] + + grid_h = np.arange(grid_h_size, dtype=np.float32) + grid_w = np.arange(grid_w_size, dtype=np.float32) + grid = np.meshgrid(grid_w, grid_h) # here w goes first + grid = np.stack(grid, axis=0) + assert isinstance(grid, np.ndarray) and grid.shape == (2, grid_h_size, grid_w_size) + + if version == (2, 0): + grid = grid.reshape([2, 1, grid_h_size, grid_w_size]) + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + if cls_token: + pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0) + else: + pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid, version) + return pos_embed + + class Idefics2VisionMLP(nn.Module): def __init__( @@ -116,6 +195,10 @@ def __init__( projection_size=config.intermediate_size, use_qkv_parallel=True, quant_config=quant_config, + dropout=config.attention_dropout, + use_context_forward=False, + use_full_precision_softmax=True, + flatten_batch=False, prefix=f"{prefix}.self_attn", ) self.layer_norm1 = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps) @@ -126,7 +209,6 @@ def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, - forward_batch: ForwardBatch, ) -> torch.Tensor: """ Args: @@ -136,11 +218,8 @@ def forward( """ residual = hidden_states hidden_states = self.layer_norm1(hidden_states) - hidden_states = self.self_attn( - hidden_states, - cu_seqlens=cu_seqlens, - # , forward_batch=forward_batch - ) + hidden_states = self.self_attn(hidden_states, cu_seqlens=cu_seqlens) + hidden_states = residual + hidden_states residual = hidden_states hidden_states = self.layer_norm2(hidden_states) @@ -181,7 +260,6 @@ def forward( self, inputs_embeds: torch.Tensor, cu_seqlens: torch.Tensor, - forward_batch: ForwardBatch, ) -> torch.Tensor: r""" Args: @@ -195,7 +273,8 @@ def forward( hidden_states = inputs_embeds for encoder_layer in self.layers: layer_outputs = encoder_layer( - hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch + hidden_states, + cu_seqlens=cu_seqlens, ) hidden_states = layer_outputs return hidden_states @@ -232,19 +311,14 @@ def __init__(self, config: PretrainedConfig): self.num_positions = self.num_patches self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim) - def forward( + def get_position_ids( self, pixel_values: torch.FloatTensor, patch_attention_mask: torch.BoolTensor, tgt_sizes: Optional[torch.IntTensor] = None, - ) -> torch.Tensor: + ): batch_size, _, max_im_h, max_im_w = pixel_values.shape - target_dtype = self.patch_embedding.weight.dtype - pixel_values = pixel_values.to( - device=self.patch_embedding.weight.device, dtype=target_dtype - ) - patch_embeds = self.patch_embedding(pixel_values) - embeddings = patch_embeds.flatten(2).transpose(1, 2) + max_nb_patches_h, max_nb_patches_w = ( max_im_h // self.patch_size, max_im_w // self.patch_size, @@ -277,6 +351,24 @@ def forward( ).flatten() position_ids[batch_idx][p_attn_mask.view(-1).cpu()] = pos_ids position_ids = position_ids.to(self.position_embedding.weight.device) + return position_ids + + def forward( + self, + pixel_values: torch.FloatTensor, + patch_attention_mask: torch.BoolTensor, + tgt_sizes: Optional[torch.IntTensor] = None, + ) -> torch.Tensor: + target_dtype = self.patch_embedding.weight.dtype + pixel_values = pixel_values.to( + device=self.patch_embedding.weight.device, dtype=target_dtype + ) + patch_embeds = self.patch_embedding(pixel_values) + embeddings = patch_embeds.flatten(2).transpose(1, 2) + position_ids = self.get_position_ids( + pixel_values, patch_attention_mask, tgt_sizes + ) + embeddings = embeddings + self.position_embedding(position_ids) return embeddings @@ -287,7 +379,6 @@ def __init__( self, config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", ) -> None: super().__init__() @@ -302,8 +393,6 @@ def get_input_embeddings(self): def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor: patch_len = tgt_sizes[:, 0] * tgt_sizes[:, 1] # shape: (batch_size,) - - # 做 prefix sum 来得到 cu_seqlens,注意在最前面插一个 0 作为 offset cu_seqlens = torch.cat( [ torch.tensor([0], device=patch_len.device, dtype=torch.int32), @@ -316,19 +405,18 @@ def compute_cu_seqlens(self, tgt_sizes: torch.Tensor) -> torch.Tensor: def forward( self, pixel_values, - forward_batch: ForwardBatch, patch_attention_mask: Optional[torch.BoolTensor] = None, tgt_sizes: Optional[torch.IntTensor] = None, ) -> torch.Tensor: hidden_states = self.embeddings( pixel_values=pixel_values, patch_attention_mask=patch_attention_mask, - # forward_batch=forward_batch, tgt_sizes=tgt_sizes, ) cu_seqlens = self.compute_cu_seqlens(tgt_sizes) encoder_outputs = self.encoder( - hidden_states, cu_seqlens=cu_seqlens, forward_batch=forward_batch + hidden_states, + cu_seqlens=cu_seqlens, ) last_hidden_state = self.post_layernorm(encoder_outputs) return last_hidden_state @@ -573,14 +661,12 @@ def __init__( config: PretrainedConfig, quant_config: Optional[QuantizationConfig] = None, ): - # multimodal_config = config.model_config.multimodal_config super().__init__() # All MiniCPM-V models disable `tie_word_embeddings` but # `PretrainedConfig.tie_word_embeddings` defaults to True; we cannot - # check `tie_word_embeddings` until vLLM integrate MiniCPM-V model + # check `tie_word_embeddings` until SGLang integrate MiniCPM-V model # and config class self.config = config - # self.multimodal_config = multimodal_config self.version = get_version_by_config(self.config) self.llm = self.init_llm(config=config, quant_config=quant_config) @@ -598,13 +684,6 @@ def __init__( self.logits_processor = LogitsProcessor(config) - @cached_property - def sampler(self): - if hasattr(self.llm, "sampler"): - return self.llm.sampler - - return get_sampler() - def _get_image_bounds( self, input_ids: torch.Tensor, @@ -666,7 +745,6 @@ def get_embedding( self, input_ids: torch.Tensor, image_inputs: Optional[MiniCPMVImageInputs], - forward_batch: ForwardBatch, ) -> Tuple[torch.Tensor, torch.Tensor]: vlm_embedding: torch.Tensor = self.llm.get_input_embeddings(input_ids) @@ -680,10 +758,7 @@ def get_embedding( .to(vlm_embedding.device) ) else: - vision_hidden_states = self.get_vision_hidden_states( - forward_batch, image_inputs - ) - + vision_hidden_states = self.get_vision_hidden_states(image_inputs) # See NOTE in _parse_and_validate_inputs image_bounds = image_inputs["image_bounds"] if len(image_bounds) > 0: @@ -693,6 +768,7 @@ def get_embedding( for start, end in image_bounds.tolist() ] ).to(vlm_embedding.device) + vlm_embedding.scatter_( 0, image_indices.view(-1, 1).repeat(1, vlm_embedding.shape[-1]), @@ -839,7 +915,7 @@ def forward( # There values are useless because their embeddings will be replaced by vision embeddings anyway. input_ids.clamp_(min=0, max=self.config.vocab_size - 1) - vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs, forward_batch) + vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs) # always pass the input via `inputs_embeds` # to make sure the computation graph is consistent @@ -857,29 +933,6 @@ def forward( input_ids, hidden_states, self.llm.lm_head, forward_batch ) - def compute_logits( - self, - hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[torch.Tensor]: - return self.llm.compute_logits(hidden_states, sampling_metadata) - - def sample( - self, - logits: torch.Tensor, - sampling_metadata: SamplingMetadata, - ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(logits, sampling_metadata) - return next_tokens - - def get_mm_mapping(self) -> MultiModelKeys: - """ - Get the module prefix in multimodal models - """ - return MultiModelKeys.from_string_field( - language_model="llm", connector="resampler", tower_model="vpm" - ) - def init_llm( self, config: Qwen2Config, @@ -910,9 +963,7 @@ def get_vision_embedding( ) -> torch.Tensor: raise NotImplementedError - def get_vision_hidden_states( - self, forward_batch: ForwardBatch, data: MiniCPMVImageInputs - ) -> torch.Tensor: + def get_vision_hidden_states(self, data: MiniCPMVImageInputs) -> torch.Tensor: raise NotImplementedError @@ -1019,7 +1070,6 @@ def get_vision_embedding( def get_vision_hidden_states( self, - forward_batch: ForwardBatch, data: MiniCPMVImageInputs, ) -> torch.Tensor: pixel_values = data["data"] @@ -1042,15 +1092,18 @@ def get_vision_hidden_states( patch_attn_mask = torch.zeros( (B, 1, max_patches), dtype=torch.bool, device=device ) - for i in range(B): - patch_attn_mask[i, 0, : tgt_sizes[i][0] * tgt_sizes[i][1]] = True + + tgt_sizes_tensor = tgt_sizes.clone().to(device=patch_attn_mask.device) + mask_shapes = tgt_sizes_tensor[:, 0] * tgt_sizes_tensor[:, 1] + patch_attn_mask[:, 0, :] = torch.arange( + patch_attn_mask.size(2), device=patch_attn_mask.device + ).unsqueeze(0) < mask_shapes.unsqueeze(1) + vision_embedding = self.vpm( all_pixel_values.type(dtype), - forward_batch=forward_batch, patch_attention_mask=patch_attn_mask, tgt_sizes=tgt_sizes, ) - return self.resampler(vision_embedding, tgt_sizes) def pad_input_ids(self, input_ids: List[int], image_inputs: ImageInputs): @@ -1138,7 +1191,7 @@ class MiniCPMV: """ Different versions of MiniCPMV use different visual encoders and LLMs, which is not conducive to the current integration logic of LoRA and - bitsandbytes in vLLM. Therefore, it is necessary to separate them. + bitsandbytes in SGLang. Therefore, it is necessary to separate them. """ # Ensure that the LoRA support check passes when the class is not diff --git a/python/sglang/srt/models/mllama.py b/python/sglang/srt/models/mllama.py index 43f6793e4ef..05069edb69b 100644 --- a/python/sglang/srt/models/mllama.py +++ b/python/sglang/srt/models/mllama.py @@ -17,6 +17,7 @@ import sglang.srt.distributed.parallel_state as ps from sglang.srt.distributed import get_tensor_model_parallel_world_size from sglang.srt.layers.activation import get_act_fn +from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.layernorm import RMSNorm from sglang.srt.layers.linear import ( ColumnParallelLinear, @@ -145,61 +146,6 @@ def forward( return hidden_state -class MllamaVisionSdpaAttention(nn.Module): - def __init__(self, config: config_mllama.MllamaVisionConfig): - super().__init__() - - model_parallel_size = get_tensor_model_parallel_world_size() - self.embed_dim = config.hidden_size - self.num_heads = config.attention_heads - self.head_dim = config.hidden_size // config.attention_heads - self.num_local_heads = self.num_heads // model_parallel_size - self.q_size = self.num_local_heads * self.head_dim - self.kv_size = self.num_local_heads * self.head_dim - - self.qkv_proj = QKVParallelLinear( - self.embed_dim, - self.head_dim, - self.num_heads, - bias=False, - ) - self.o_proj = RowParallelLinear( - self.num_heads * self.head_dim, - self.embed_dim, - bias=False, - input_is_parallel=True, - ) - - def forward( - self, - hidden_state: torch.Tensor, - attention_mask: Optional[torch.Tensor] = None, - ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_state) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) - q = q.view( - q.shape[0], q.shape[1], self.num_local_heads, self.head_dim - ).transpose(1, 2) - k = k.view( - k.shape[0], k.shape[1], self.num_local_heads, self.head_dim - ).transpose(1, 2) - v = v.view( - v.shape[0], v.shape[1], self.num_local_heads, self.head_dim - ).transpose(1, 2) - - # TODO: remove padding in image encoder - attn_output = F.scaled_dot_product_attention( - q, k, v, attn_mask=attention_mask, dropout_p=0.0 - ) - - attn_output = attn_output.transpose(1, 2).contiguous() - attn_output = attn_output.reshape( - attn_output.shape[0], attn_output.shape[1], -1 - ) - output, _ = self.o_proj(attn_output) - return output - - class MllamaVisionMLP(nn.Module): def __init__(self, config, quant_config: Optional[QuantizationConfig] = None): super().__init__() @@ -237,7 +183,17 @@ def __init__( self.is_gated = is_gated self.intermediate_size = config.intermediate_size - self.self_attn = MllamaVisionSdpaAttention(config) + self.self_attn = VisionAttention( + self.hidden_size, + self.num_attention_heads, + self.hidden_size, + use_qkv_parallel=True, + quant_config=None, + dropout=0.0, + use_context_forward=False, + use_full_precision_softmax=False, + flatten_batch=False, + ) self.mlp = MllamaVisionMLP(config) self.input_layernorm = nn.LayerNorm(self.hidden_size, eps=config.norm_eps) @@ -992,6 +948,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader(param, loaded_weight, shard_id) break else: + if "vision_model" in name: + # adapt to VisionAttention + name = name.replace("self_attn.o_proj", "self_attn.proj") + param = params_dict.pop(name) weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/python/sglang/srt/models/qwen2.py b/python/sglang/srt/models/qwen2.py index 0c01ab9e5b4..46b62f837f6 100644 --- a/python/sglang/srt/models/qwen2.py +++ b/python/sglang/srt/models/qwen2.py @@ -249,7 +249,10 @@ def __init__( self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: - return self.embed_tokens(input_ids) + if hasattr(self.config, "scale_emb"): + return self.embed_tokens(input_ids) * self.config.scale_emb + else: + return self.embed_tokens(input_ids) def forward( self, diff --git a/python/sglang/srt/models/qwen2_vl.py b/python/sglang/srt/models/qwen2_vl.py index 0fb85679f7a..365891544e0 100644 --- a/python/sglang/srt/models/qwen2_vl.py +++ b/python/sglang/srt/models/qwen2_vl.py @@ -30,12 +30,10 @@ import torch import torch.nn as nn import torch.nn.functional as F -from einops import rearrange, repeat +from einops import rearrange from vllm.model_executor.layers.activation import QuickGELU from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig -from sglang.srt.distributed import parallel_state -from sglang.srt.distributed import utils as dist_utils from sglang.srt.hf_transformers_utils import get_processor from sglang.srt.layers.attention.vision import VisionAttention from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear @@ -118,6 +116,7 @@ def __init__( mlp_ratio: float, act_layer: Type[nn.Module] = QuickGELU, norm_layer: Type[nn.Module] = None, + attn_implementation: Optional[str] = "sdpa", quant_config: Optional[QuantizationConfig] = None, ) -> None: super().__init__() @@ -126,12 +125,24 @@ def __init__( self.norm1 = norm_layer(dim) self.norm2 = norm_layer(dim) mlp_hidden_dim = int(dim * mlp_ratio) + if attn_implementation == "sdpa": + use_context_forward = False + use_full_precision_softmax = False + elif attn_implementation == "flash_attention_2": + use_full_precision_softmax = False + use_context_forward = True + elif attn_implementation == "eager": + use_full_precision_softmax = True + use_context_forward = False self.attn = VisionAttention( embed_dim=dim, num_heads=num_heads, projection_size=dim, use_qkv_parallel=False, + use_context_forward=use_context_forward, + use_full_precision_softmax=use_full_precision_softmax, + flatten_batch=True, quant_config=quant_config, ) self.mlp = Qwen2VisionMLP( @@ -286,7 +297,6 @@ def __init__( norm_layer = partial(nn.LayerNorm, eps=norm_eps) head_dim = embed_dim // num_heads self.rotary_pos_emb = Qwen2VisionRotaryEmbedding(head_dim // 2) - self.blocks = nn.ModuleList( [ Qwen2VisionBlock( @@ -294,6 +304,7 @@ def __init__( num_heads=num_heads, mlp_ratio=mlp_ratio, norm_layer=norm_layer, + attn_implementation="sdpa", quant_config=quant_config, ) for _ in range(depth) @@ -482,10 +493,6 @@ def forward( opensource models), the shape will be `(3, seq_len)`, otherwise it will be `(seq_len,). (Use input_metadata.mrope_positions to replace it) - pixel_values: Pixel values to be fed to a model. - `None` if no images are passed. - image_grid_thw: Tensor `(n_images, 3)` of image 3D grid in LLM. - `None` if no images are passed. """ if getattr(self.config, "rope_scaling", {}).get("type", None) == "mrope": positions = forward_batch.mrope_positions @@ -540,15 +547,18 @@ def forward( num_image_tokens = self.calculate_num_image_tokens( image_grid_thws[idx] ) + left_idx = start_idx + (image_offset - prefix_len) right_idx = ( start_idx + (image_offset - prefix_len) + num_image_tokens ) + inputs_embeds[left_idx:right_idx] = image_embeds[ image_embeds_offset : image_embeds_offset + num_image_tokens ] image_embeds_offset += num_image_tokens + input_ids = None hidden_states = self.model( input_ids=input_ids, positions=positions, diff --git a/python/sglang/srt/utils.py b/python/sglang/srt/utils.py index d8d935437b2..ebb346bbc63 100644 --- a/python/sglang/srt/utils.py +++ b/python/sglang/srt/utils.py @@ -444,8 +444,6 @@ def load_image(image_file: Union[str, bytes]): else: raise ValueError(f"Invalid image: {image}") - # if image_size is None: - # image_size = image.size return image, image_size diff --git a/test/srt/run_suite.py b/test/srt/run_suite.py index f6aa356826d..603bab957bd 100644 --- a/test/srt/run_suite.py +++ b/test/srt/run_suite.py @@ -48,6 +48,7 @@ "test_update_weights_from_disk.py", "test_update_weights_from_tensor.py", "test_vision_chunked_prefill.py", + "test_vision_llm.py", "test_vision_openai_server.py", "test_w8a8_quantization.py", "test_fp8_kvcache.py", @@ -72,7 +73,6 @@ tests.remove(target_suite_name) tests.extend(target_tests) - if __name__ == "__main__": arg_parser = argparse.ArgumentParser() arg_parser.add_argument( diff --git a/test/srt/test_vision_llm.py b/test/srt/test_vision_llm.py new file mode 100644 index 00000000000..7cda64fc0c7 --- /dev/null +++ b/test/srt/test_vision_llm.py @@ -0,0 +1,210 @@ +""" +""" + +import unittest +from io import BytesIO + +import numpy as np +import requests +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import AutoModel, AutoProcessor, AutoTokenizer + +from sglang.srt.configs.model_config import ModelConfig +from sglang.srt.conversation import generate_chat_conv +from sglang.srt.model_executor.model_runner import ModelRunner +from sglang.srt.openai_api.protocol import ChatCompletionRequest +from sglang.srt.server_args import ServerArgs + +MiniCPMV = "openbmb/MiniCPM-V-2_6" + + +# Test the logits output between HF and SGLang +class VisionLLMLogitsBase(unittest.IsolatedAsyncioTestCase): + @classmethod + def setUpClass(cls): + cls.image_url = "https://github.com/sgl-project/sglang/blob/main/test/lang/example_image.png?raw=true" + cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + cls.model_path = "" + cls.chat_template = "" + cls.processor = "" + response = requests.get(cls.image_url) + cls.main_image = Image.open(BytesIO(response.content)) + + def compare_outputs(self, sglang_output: torch.Tensor, hf_output: torch.Tensor): + # Convert to float32 for numerical stability if needed + hf = hf_output.float() + sg = sglang_output.float() + + # Basic shape and dtype comparison + print("\n=== Basic Properties ===") + print(f"Shapes match: {hf.shape == sg.shape}") + print(f"HF shape: {hf.shape}, SGLang shape: {sg.shape}") + print(f"HF dtype: {hf.dtype}, SGLang dtype: {sg.dtype}") + + # Move tensors to CPU for numpy operations + hf_np = hf.cpu().numpy() + sg_np = sg.cpu().numpy() + + # Statistical metrics + print("\n=== Statistical Metrics ===") + print(f"Mean absolute difference: {torch.mean(torch.abs(hf - sg)).item():.6f}") + print(f"Max absolute difference: {torch.max(torch.abs(hf - sg)).item():.6f}") + print(f"Mean squared error: {torch.mean((hf - sg) ** 2).item():.6f}") + print( + f"Root mean squared error: {torch.sqrt(torch.mean((hf - sg) ** 2)).item():.6f}" + ) + + # Cosine similarity (across feature dimension) + cos_sim = F.cosine_similarity(hf, sg) + print(f"Mean cosine similarity: {torch.mean(cos_sim).item():.6f}") + print(f"Min cosine similarity: {torch.min(cos_sim).item():.6f}") + + # Find largest absolute differences + print("\n=== Largest Absolute Differences ===") + diffs = torch.abs(hf - sg) + flat_diffs = diffs.flatten() + + # Get indices of top 10 differences + top_k = 10 + top_values, top_flat_indices = torch.topk(flat_diffs, top_k) + + # Convert flat indices to multidimensional indices + top_indices = np.unravel_index(top_flat_indices.cpu().numpy(), diffs.shape) + + print(f"\nTop {top_k} largest absolute differences:") + print( + "Index".ljust(30) + + "Difference".ljust(15) + + "HF Value".ljust(15) + + "SGLang Value" + ) + print("-" * 75) + + for i in range(top_k): + # Get the index tuple for this difference + idx = tuple(dim[i] for dim in top_indices) + diff_val = top_values[i].item() + hf_val = hf[idx].item() + sg_val = sg[idx].item() + + # Format the index tuple and values + idx_str = str(idx) + print(f"{idx_str:<30}{diff_val:<15.6f}{hf_val:<15.6f}{sg_val:.6f}") + + np.testing.assert_allclose(hf_np, sg_np) + + def get_processor_output(self): + json_str = f""" + {{ + "model": "{self.model_path}", + "messages": [ + {{ + "role": "user", + "content": [ + {{ + "type": "image_url", + "image_url": {{ + "url": "{self.image_url}" + }} + }}, + {{ + "type": "text", + "text": "Whats in this picture?" + }} + ] + }} + ] +}} + """ + + req = ChatCompletionRequest.model_validate_json(json_str) + + conv = generate_chat_conv(req, template_name=self.chat_template) + + text = conv.get_prompt() + + # Process inputs using processor + # FIXME: the formal arguments may differ + inputs = self.processor( + text=[text], + images=[self.main_image], + return_tensors="pt", + ).to(self.device) + + return inputs + + def get_sglang_model(self): + model_runner = ModelRunner( + model_config=ModelConfig(self.model_path, model_override_args="{}"), + mem_fraction_static=0.8, + gpu_id=0, + tp_rank=0, + tp_size=1, + nccl_port=12435, + server_args=ServerArgs( + model_path=self.model_path, + disable_cuda_graph=True, + ), + ) + return model_runner.model + + +class TestMiniCPMVLogits(VisionLLMLogitsBase): + @classmethod + def setUpClass(cls): + super().setUpClass() + cls.model_path = MiniCPMV + cls.tokenizer = AutoTokenizer.from_pretrained( + cls.model_path, trust_remote_code=True + ) + cls.processor = AutoProcessor.from_pretrained( + cls.model_path, trust_remote_code=True + ) + cls.chat_template = "minicpmv" + + cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + cls.model = AutoModel.from_pretrained( + cls.model_path, torch_dtype=torch.bfloat16, trust_remote_code=True + ).eval() + cls.model.to(cls.device) + + async def test_encode_output(self): + inputs = self.get_processor_output() + + with torch.no_grad(): + model_inputs = { + "input_ids": inputs.input_ids, + "image_bound": inputs.image_bound, + "pixel_values": inputs.pixel_values, + "tgt_sizes": inputs.tgt_sizes, + } + (hf_output, _) = self.model.get_vllm_embedding( + model_inputs, + ) + hf_output = hf_output.squeeze(0) + + with torch.no_grad(): + model = self.get_sglang_model() + input_ids = inputs["input_ids"].to(self.device).flatten() + image_inputs = model._parse_and_validate_inputs( + input_ids=input_ids, + **{ + "pixel_values": [inputs["pixel_values"]], + "tgt_sizes": [inputs["tgt_sizes"]], + "im_start_id": [self.tokenizer.im_start_id], + "im_end_id": [self.tokenizer.im_end_id], + "slice_start_id": [self.tokenizer.slice_start_id], + "slice_end_id": [self.tokenizer.slice_end_id], + }, + ) + (sglang_output, _) = model.get_embedding( + input_ids=input_ids, image_inputs=image_inputs + ) + + self.compare_outputs(sglang_output, hf_output) + + +if __name__ == "__main__": + unittest.main() diff --git a/test/srt/test_vision_openai_server.py b/test/srt/test_vision_openai_server.py index 5be911ab84a..01762202882 100644 --- a/test/srt/test_vision_openai_server.py +++ b/test/srt/test_vision_openai_server.py @@ -180,7 +180,9 @@ def test_multi_images_chat_completion(self): assert response.usage.total_tokens > 0 def prepare_video_messages(self, video_path): - max_frames_num = 32 + # the memory consumed by the Vision Attention varies a lot, e.g. blocked qkv vs full-sequence sdpa + # the size of the video embeds differs from the `modality` argument when preprocessed + max_frames_num = 12 vr = VideoReader(video_path, ctx=cpu(0)) total_frame_num = len(vr) uniform_sampled_frames = np.linspace(