diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 45d6ad3c0efa..98fcc127b78d 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -18,6 +18,7 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) +import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -31,6 +32,7 @@ import vllm.distributed.parallel_state as ps from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs @@ -72,6 +74,16 @@ class MllamaImagePixelInputs(TypedDict): # TODO: support LlamaImageEmbeddingInputs +def _get_num_images_for_decode(prompt_token_ids: List[int]) -> int: + num_images = 0 + for token_id in prompt_token_ids[::-1]: + if token_id == MLLAMA_IMAGE_TOKEN_ID: + num_images += 1 + elif num_images > 0: + break + return num_images + + def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): # move encoder_prompt to prompt if llm_inputs.get("prompt") is None: @@ -91,12 +103,16 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): llm_inputs["encoder_multi_modal_data"] = {} return llm_inputs - # get num_tiles if isinstance(multi_modal_data['image'], Image.Image): multi_modal_data['image'] = [multi_modal_data['image']] + # Since only the last group of consectutive images + # are attended by the decoded tokens, we only need to + # get the number of tiles for those images. + num_decode_images = _get_num_images_for_decode( + llm_inputs["prompt_token_ids"]) hf_config = ctx.model_config.hf_config num_tiles = 0 - for image in multi_modal_data["image"]: + for image in multi_modal_data["image"][::-1]: width, height = image.size tile_size = hf_config.vision_config.image_size canvas_height, canvas_width = get_optimal_tiled_canvas( @@ -108,8 +124,13 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): num_tiles_height = canvas_height // tile_size num_tiles_width = canvas_width // tile_size num_tiles += num_tiles_height * num_tiles_width + num_decode_images -= 1 + if num_decode_images == 0: + break - # set encoder prompt based on num_tiles + # Set encoder prompt length based on the number of tiles. + # This tells the block manager to allocate correct number + # of slots for encoder tokens. assert hf_config.vision_config.image_size % 14 == 0, \ "chunk size should be multiple of 14" token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 @@ -675,6 +696,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], + kv_range_for_decode: Optional[List[List[int]]], cross_attention_states: Optional[torch.Tensor], kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, @@ -697,12 +719,45 @@ def forward( q = q.view(-1, self.num_local_heads, self.head_dim) q = self.q_norm(q) - output = self.attn(q, - k, - v, - kv_cache, - attn_metadata, - attn_type=AttentionType.ENCODER_DECODER) + if attention_mask is not None: + if len(kv_cache.shape) == 3: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_local_key_value_heads, self.head_dim) + cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) + cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) + PagedAttention.write_to_paged_cache( + cached_k, cached_v, key_cache, value_cache, + attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) + # We have to call torch.sdpa for prefill when using a + # custom cross-attention mask. Because the mask is not a + # standard causal mask, neither a block diagonal mask which + # can be optimized by xformers.BlockDiagonalMask. + # The mask is specially calculated for supporting multi + # images and interleaved images. + q_len = q.shape[0] + kv_len = k.shape[0] + q = q.transpose(0, 1).view(self.num_key_value_groups, + self.num_local_key_value_heads, q_len, + self.head_dim) + k = k.transpose(0, 1).expand(self.num_key_value_groups, + self.num_local_key_value_heads, + kv_len, self.head_dim) + v = v.transpose(0, 1).expand(self.num_key_value_groups, + self.num_local_key_value_heads, + kv_len, self.head_dim) + attention_mask = attention_mask.view(1, 1, q_len, kv_len) + with torch.backends.cuda.sdp_kernel(enable_flash=True): + output = F.scaled_dot_product_attention( + q, k, v, attn_mask=attention_mask) + output = output.permute(2, 0, 1, 3).reshape( + q_len, self.num_local_heads * self.head_dim) + else: + output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.ENCODER_DECODER) out, _ = self.o_proj(output) return out @@ -741,6 +796,7 @@ def forward( hidden_states: torch.Tensor, cross_attention_states: torch.Tensor, cross_attention_mask: torch.Tensor, + kv_range_for_decode: Optional[List[List[int]]], full_text_row_masked_out_mask: torch.Tensor, kv_cache: List[torch.Tensor], attn_metadata: AttentionMetadata, @@ -751,6 +807,7 @@ def forward( hidden_states = self.cross_attn( hidden_states=hidden_states, attention_mask=cross_attention_mask, + kv_range_for_decode=kv_range_for_decode, cross_attention_states=cross_attention_states, kv_cache=kv_cache, attn_metadata=attn_metadata, @@ -804,6 +861,7 @@ def forward( positions: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor], + kv_range_for_decode: Optional[List[List[int]]], full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[torch.Tensor], @@ -820,6 +878,7 @@ def forward( hidden_states=hidden_states, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, + kv_range_for_decode=kv_range_for_decode, full_text_row_masked_out_mask= full_text_row_masked_out_mask, kv_cache=kv_caches[idx], @@ -868,6 +927,7 @@ def forward( positions: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor], + kv_range_for_decode: Optional[List[List[int]]], full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[torch.Tensor], @@ -879,6 +939,7 @@ def forward( positions=positions, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, + kv_range_for_decode=kv_range_for_decode, full_text_row_masked_out_mask=full_text_row_masked_out_mask, kv_caches=kv_caches, attn_metadata=attn_metadata, @@ -1026,16 +1087,17 @@ def _parse_and_validate_image_input(self, **kwargs: object): raise AssertionError("This line should be unreachable.") def flat_encoder_result(self, cross_attention_states: torch.Tensor, - attn_metadata: AttentionMetadata): + attn_metadata: AttentionMetadata, + real_encoder_seq_lens: List[int]): cross_attention_states_flat = torch.zeros( - sum(attn_metadata.encoder_seq_lens), + sum(real_encoder_seq_lens), cross_attention_states.shape[-1], device=cross_attention_states.device, dtype=cross_attention_states.dtype) start_pos = 0 - for seq_len, vision_token_in_batch in zip( - attn_metadata.encoder_seq_lens, cross_attention_states): + for seq_len, vision_token_in_batch in zip(real_encoder_seq_lens, + cross_attention_states): end_pos = start_pos + seq_len cross_attention_states_flat[ start_pos:end_pos] = vision_token_in_batch[:seq_len] @@ -1045,9 +1107,8 @@ def flat_encoder_result(self, cross_attention_states: torch.Tensor, full_text_row_masked_out_mask = torch.ones( (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool) start_pos = 0 - for seq_len, encoder_seq_len in zip( - attn_metadata.seq_lens_tensor.cpu(), - attn_metadata.encoder_seq_lens): + for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens, + attn_metadata.encoder_seq_lens): if encoder_seq_len == 0: full_text_row_masked_out_mask[start_pos:start_pos + seq_len] = False @@ -1069,13 +1130,17 @@ def forward( attn_metadata.num_decode_tokens > 0: raise ValueError("Chunk prefill not supported") image_inputs = self._parse_and_validate_image_input(**kwargs) + # text-only prefill and decode + # image decode + cross_attention_states = None + cross_attention_mask = None + kv_range_for_decode = None if image_inputs is None: - cross_attention_mask = None full_text_row_masked_out_mask = ( attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to( input_ids.device) - cross_attention_states = None skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0 + # image prefill else: # NOTE: llama's reference implementation runs vision model on CPU pixel_values = image_inputs['data'] @@ -1091,17 +1156,47 @@ def forward( cross_attention_states = cross_attention_states.view( bsz, -1, image_token_dim) + num_tiles_tensor = kwargs.pop("num_tiles") + if isinstance(num_tiles_tensor, list): + num_tiles = [t.tolist()[0] for t in num_tiles_tensor] + else: + num_tiles = [t[0] for t in num_tiles_tensor.tolist()] + num_tokens_per_tile = (self.image_size // 14)**2 + 1 + real_encoder_seq_lens = [ + sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles + ] + cross_attention_states, full_text_row_masked_out_mask = \ - self.flat_encoder_result(cross_attention_states, attn_metadata) + self.flat_encoder_result(cross_attention_states, attn_metadata, real_encoder_seq_lens) skip_cross_attention = False - # TODO: support multi-image by this mask - cross_attention_mask = None + + token_ids = input_ids.tolist() + start = 0 + batch_token_ids = [] + for seq_len in attn_metadata.seq_lens: + batch_token_ids.append(token_ids[start:start + seq_len]) + start += seq_len + sparse_mask = [ + create_sparse_cross_attention_mask(t, MLLAMA_IMAGE_TOKEN_ID) + for t in batch_token_ids + ] + + if not all_single_leading_image(sparse_mask): + dense_mask, tile_range_for_decode = convert_sparse_cross_attention_mask_to_dense( + sparse_mask, num_tiles, attn_metadata.seq_lens) + cross_attention_mask = convert_dense_cross_attention_mask_to_tensor( + dense_mask, num_tokens_per_tile, input_ids.device, + cross_attention_states.dtype) + kv_range_for_decode = [[ + t[0] * num_tokens_per_tile, t[1] * num_tokens_per_tile + ] for t in tile_range_for_decode] outputs = self.language_model( input_ids=input_ids, positions=positions, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, + kv_range_for_decode=kv_range_for_decode, full_text_row_masked_out_mask=full_text_row_masked_out_mask, kv_caches=kv_caches, attn_metadata=attn_metadata, @@ -1140,3 +1235,101 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + +def create_sparse_cross_attention_mask( + tokens: List[int], + vision_token: int, +) -> List[List[int]]: + vision_token_locations = [ + i for i, token in enumerate(tokens) if token == vision_token + ] + # No vision token present + if len(vision_token_locations) == 0: + return [] + # only one image present, unmask until end of sequence + if len(vision_token_locations) == 1: + return [[vision_token_locations[0], -1]] + + vision_masks = [[loc1, loc2] for loc1, loc2 in zip( + vision_token_locations[:-1], vision_token_locations[1:])] + # last image will attend to all subsequent text + vision_masks.append([vision_token_locations[-1], len(tokens)]) + + # if there are two or more consecutive vision tokens, + # they should all attend to all subsequent + # text present + last_mask_end = vision_masks[-1][1] + for vision_mask in vision_masks[::-1]: + if vision_mask[0] == vision_mask[1] - 1: + vision_mask[1] = last_mask_end + last_mask_end = vision_mask[1] + + # A list of ranges, range[i] = [start:end] means the + # i-th image will attend to the tokens in the range [start:end) + return vision_masks + + +def all_single_leading_image(sparse_mask: List[List[int]]) -> bool: + for mask in sparse_mask: + if len(mask) != 1: + return False + if mask[0][1] != -1: + return False + return True + + +def convert_sparse_cross_attention_mask_to_dense( + sparse_mask: List[List[List[int]]], + num_tiles: List[List[int]], + lengths: List[int], +) -> Tuple[np.ndarray, List[List[int]]]: + total_length = sum(lengths) + total_tiles = sum([sum(tiles) for tiles in num_tiles]) + dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64) + # A list of ranges, range[i] = [start, end] means + # if the i-th sample has N tiles in total, the tiles[start, end] + # will be used for cross-attention decoding. + tile_range_for_decode = [] + + seq_start = 0 + tile_start = 0 + for masks, tiles, length in zip(sparse_mask, num_tiles, lengths): + ts, td = 0, 0 + for mask, tile in zip(masks, tiles): + if len(mask) != 2: + continue + start, end = mask + end = min(end, length) + if end == -1: + end = length + if end == length: + if ts == 0: + ts = tile_start + td += tile + dense_mask[seq_start + start:seq_start + end, + tile_start:tile_start + tile] = 1 + tile_start += tile + tile_range_for_decode.append([ts, ts + td]) + seq_start += length + + return dense_mask, tile_range_for_decode + + +def convert_dense_cross_attention_mask_to_tensor( + cross_attention_token_mask: np.ndarray, + num_tokens_per_tile: int, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + mask = torch.tensor(cross_attention_token_mask, dtype=dtype, device=device) + mask = mask.repeat_interleave(num_tokens_per_tile, dim=1) + + mask = 1.0 - mask + mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(dtype).min) + + ninf = torch.finfo(dtype).min + full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None]) + mask *= full_text_mask + # (num_decoder_length, num_encoder_length) + return mask