From 1446292659cc059dbb738371e74d4791339a4811 Mon Sep 17 00:00:00 2001 From: Xiang Xu Date: Sat, 5 Oct 2024 11:42:14 -0700 Subject: [PATCH 1/2] Make llama3.2 support multiple and interleaved iamges --- vllm/model_executor/models/mllama.py | 288 +++++++++++++++++++++++---- 1 file changed, 253 insertions(+), 35 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index 45d6ad3c0efa..c369096f3655 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -18,6 +18,7 @@ from typing import (Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict, Union) +import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -28,9 +29,12 @@ CausalLMOutputWithPast) from transformers.models.mllama.image_processing_mllama import ( get_optimal_tiled_canvas) +from transformers.models.mllama.processing_mllama import ( + get_cross_attention_token_mask) import vllm.distributed.parallel_state as ps from vllm.attention import Attention, AttentionMetadata, AttentionType +from vllm.attention.ops.paged_attn import PagedAttention from vllm.config import CacheConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs @@ -72,6 +76,16 @@ class MllamaImagePixelInputs(TypedDict): # TODO: support LlamaImageEmbeddingInputs +def _get_num_image_in_last_group(prompt_token_ids: List[int]) -> int: + num_images = 0 + for token_id in prompt_token_ids[::-1]: + if token_id == MLLAMA_IMAGE_TOKEN_ID: + num_images += 1 + elif num_images > 0: + break + return num_images + + def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): # move encoder_prompt to prompt if llm_inputs.get("prompt") is None: @@ -91,12 +105,16 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): llm_inputs["encoder_multi_modal_data"] = {} return llm_inputs - # get num_tiles if isinstance(multi_modal_data['image'], Image.Image): multi_modal_data['image'] = [multi_modal_data['image']] + # Since only the last group of consecutive images + # are attended by the decoded tokens, we only need to + # get the number of tiles for those images. + num_decode_images = _get_num_image_in_last_group( + llm_inputs["prompt_token_ids"]) hf_config = ctx.model_config.hf_config num_tiles = 0 - for image in multi_modal_data["image"]: + for image in multi_modal_data["image"][::-1]: width, height = image.size tile_size = hf_config.vision_config.image_size canvas_height, canvas_width = get_optimal_tiled_canvas( @@ -108,8 +126,13 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): num_tiles_height = canvas_height // tile_size num_tiles_width = canvas_width // tile_size num_tiles += num_tiles_height * num_tiles_width + num_decode_images -= 1 + if num_decode_images == 0: + break - # set encoder prompt based on num_tiles + # Set encoder prompt length based on the number of tiles. + # This tells the block manager to allocate correct number + # of slots for encoder tokens. assert hf_config.vision_config.image_size % 14 == 0, \ "chunk size should be multiple of 14" token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 @@ -675,6 +698,7 @@ def forward( self, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor], + kv_range_for_decode: Optional[List[Tuple[int, int]]], cross_attention_states: Optional[torch.Tensor], kv_cache: torch.Tensor, attn_metadata: AttentionMetadata, @@ -697,15 +721,65 @@ def forward( q = q.view(-1, self.num_local_heads, self.head_dim) q = self.q_norm(q) - output = self.attn(q, - k, - v, - kv_cache, - attn_metadata, - attn_type=AttentionType.ENCODER_DECODER) + if attention_mask is not None: + output = self.attention_with_mask(q, k, v, kv_cache, + attention_mask, + kv_range_for_decode, + attn_metadata) + else: + output = self.attn(q, + k, + v, + kv_cache, + attn_metadata, + attn_type=AttentionType.ENCODER_DECODER) out, _ = self.o_proj(output) return out + def attention_with_mask( + self, + q: torch.Tensor, + k: torch.Tensor, + v: torch.Tensor, + kv_cache: torch.Tensor, + attention_mask: torch.Tensor, + kv_range_for_decode: List[Tuple[int, int]], + attn_metadata: AttentionMetadata, + ) -> torch.Tensor: + if len(kv_cache.shape) == 3: + key_cache, value_cache = PagedAttention.split_kv_cache( + kv_cache, self.num_local_key_value_heads, self.head_dim) + cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode]) + cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode]) + PagedAttention.write_to_paged_cache( + cached_k, cached_v, key_cache, value_cache, + attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0) + # We have to call torch.sdpa for prefill when using a + # custom cross-attention mask. Because the mask is not a + # standard causal mask, neither a block diagonal mask which + # can be optimized by xformers.BlockDiagonalMask. + # The mask is specially calculated for supporting multi + # images and interleaved images. + q_len = q.shape[0] + kv_len = k.shape[0] + q = q.transpose(0, 1).view(self.num_key_value_groups, + self.num_local_key_value_heads, q_len, + self.head_dim) + k = k.transpose(0, 1).expand(self.num_key_value_groups, + self.num_local_key_value_heads, kv_len, + self.head_dim) + v = v.transpose(0, 1).expand(self.num_key_value_groups, + self.num_local_key_value_heads, kv_len, + self.head_dim) + attention_mask = attention_mask.view(1, 1, q_len, kv_len) + output = F.scaled_dot_product_attention(q, + k, + v, + attn_mask=attention_mask) + output = output.permute(2, 0, 1, 3).reshape( + q_len, self.num_local_heads * self.head_dim) + return output + class MllamaCrossAttentionDecoderLayer(torch.nn.Module): """Cross-attention transformer block with tanh-gated attention @@ -741,6 +815,7 @@ def forward( hidden_states: torch.Tensor, cross_attention_states: torch.Tensor, cross_attention_mask: torch.Tensor, + kv_range_for_decode: Optional[List[Tuple[int, int]]], full_text_row_masked_out_mask: torch.Tensor, kv_cache: List[torch.Tensor], attn_metadata: AttentionMetadata, @@ -751,6 +826,7 @@ def forward( hidden_states = self.cross_attn( hidden_states=hidden_states, attention_mask=cross_attention_mask, + kv_range_for_decode=kv_range_for_decode, cross_attention_states=cross_attention_states, kv_cache=kv_cache, attn_metadata=attn_metadata, @@ -804,6 +880,7 @@ def forward( positions: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor], + kv_range_for_decode: Optional[List[Tuple[int, int]]], full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[torch.Tensor], @@ -820,6 +897,7 @@ def forward( hidden_states=hidden_states, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, + kv_range_for_decode=kv_range_for_decode, full_text_row_masked_out_mask= full_text_row_masked_out_mask, kv_cache=kv_caches[idx], @@ -868,6 +946,7 @@ def forward( positions: Optional[torch.LongTensor], cross_attention_states: Optional[torch.LongTensor], cross_attention_mask: Optional[torch.LongTensor], + kv_range_for_decode: Optional[List[Tuple[int, int]]], full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor, torch.Tensor]], kv_caches: List[torch.Tensor], @@ -879,6 +958,7 @@ def forward( positions=positions, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, + kv_range_for_decode=kv_range_for_decode, full_text_row_masked_out_mask=full_text_row_masked_out_mask, kv_caches=kv_caches, attn_metadata=attn_metadata, @@ -1026,16 +1106,17 @@ def _parse_and_validate_image_input(self, **kwargs: object): raise AssertionError("This line should be unreachable.") def flat_encoder_result(self, cross_attention_states: torch.Tensor, - attn_metadata: AttentionMetadata): + attn_metadata: AttentionMetadata, + actual_encoder_seq_lens: List[int]): cross_attention_states_flat = torch.zeros( - sum(attn_metadata.encoder_seq_lens), + sum(actual_encoder_seq_lens), cross_attention_states.shape[-1], device=cross_attention_states.device, dtype=cross_attention_states.dtype) start_pos = 0 - for seq_len, vision_token_in_batch in zip( - attn_metadata.encoder_seq_lens, cross_attention_states): + for seq_len, vision_token_in_batch in zip(actual_encoder_seq_lens, + cross_attention_states): end_pos = start_pos + seq_len cross_attention_states_flat[ start_pos:end_pos] = vision_token_in_batch[:seq_len] @@ -1045,9 +1126,8 @@ def flat_encoder_result(self, cross_attention_states: torch.Tensor, full_text_row_masked_out_mask = torch.ones( (attn_metadata.num_prefill_tokens, 1), dtype=torch.bool) start_pos = 0 - for seq_len, encoder_seq_len in zip( - attn_metadata.seq_lens_tensor.cpu(), - attn_metadata.encoder_seq_lens): + for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens, + attn_metadata.encoder_seq_lens): if encoder_seq_len == 0: full_text_row_masked_out_mask[start_pos:start_pos + seq_len] = False @@ -1057,6 +1137,67 @@ def flat_encoder_result(self, cross_attention_states: torch.Tensor, return cross_attention_states, full_text_row_masked_out_mask + def get_cross_attention_states( + self, + image_inputs: MllamaImagePixelInputs, + attn_metadata: AttentionMetadata, + actual_encoder_seq_lens: List[int], + ) -> Tuple[torch.Tensor, torch.Tensor]: + # NOTE: llama's reference implementation runs vision model on CPU + pixel_values = image_inputs['data'] + aspect_ratio_ids = image_inputs['aspect_ratio_ids'] + aspect_ratio_mask = image_inputs['aspect_ratio_mask'] + cross_attention_states = self.vision_model(pixel_values, + aspect_ratio_ids, + aspect_ratio_mask) + cross_attention_states = self.multi_modal_projector( + cross_attention_states) + + bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape) + cross_attention_states = cross_attention_states.view( + bsz, -1, image_token_dim) + + cross_attention_states, full_text_row_masked_out_mask = \ + self.flat_encoder_result( + cross_attention_states, attn_metadata, + actual_encoder_seq_lens) + + return cross_attention_states, full_text_row_masked_out_mask + + def get_cross_attention_mask( + self, + input_ids: torch.Tensor, + attn_metadata: AttentionMetadata, + num_tiles: List[List[int]], + num_tokens_per_tile: int, + dtype: torch.dtype, + ) -> Tuple[torch.Tensor, torch.Tensor]: + token_ids = input_ids.tolist() + start = 0 + batch_token_ids = [] + for seq_len in attn_metadata.seq_lens: + batch_token_ids.append(token_ids[start:start + seq_len]) + start += seq_len + sparse_mask = [ + get_cross_attention_token_mask(t, MLLAMA_IMAGE_TOKEN_ID) + for t in batch_token_ids + ] + + if skip_attention_mask(sparse_mask): + return None, None + + dense_mask, tile_range_for_decode = \ + convert_sparse_cross_attention_mask_to_dense( + sparse_mask, num_tiles, attn_metadata.seq_lens) + cross_attention_mask = \ + convert_dense_cross_attention_mask_to_tensor( + dense_mask, num_tokens_per_tile, input_ids.device, dtype) + kv_range_for_decode = [[ + t[0] * num_tokens_per_tile, t[1] * num_tokens_per_tile + ] for t in tile_range_for_decode] + + return cross_attention_mask, kv_range_for_decode + def forward( self, input_ids: torch.Tensor, @@ -1069,39 +1210,51 @@ def forward( attn_metadata.num_decode_tokens > 0: raise ValueError("Chunk prefill not supported") image_inputs = self._parse_and_validate_image_input(**kwargs) + cross_attention_states = None + cross_attention_mask = None + kv_range_for_decode = None + + # For 1) text-only prefill and decode, 2) image-present decode. if image_inputs is None: - cross_attention_mask = None full_text_row_masked_out_mask = ( attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to( input_ids.device) - cross_attention_states = None skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0 + + # For image-present prefill. else: - # NOTE: llama's reference implementation runs vision model on CPU - pixel_values = image_inputs['data'] - aspect_ratio_ids = image_inputs['aspect_ratio_ids'] - aspect_ratio_mask = image_inputs['aspect_ratio_mask'] - cross_attention_states = self.vision_model(pixel_values, - aspect_ratio_ids, - aspect_ratio_mask) - cross_attention_states = self.multi_modal_projector( - cross_attention_states) - - bsz, _, _, _, image_token_dim = tuple(cross_attention_states.shape) - cross_attention_states = cross_attention_states.view( - bsz, -1, image_token_dim) + skip_cross_attention = False + + # Get the actual number of encoder tokens for each sample. + # Because attn_metadata.encoder_seq_lens only counts the last + # group of images for each sample, which is used to cheat the + # block manager to allocate blocks for those images only. + # See input_processor_for_mllama() for more details. + num_tiles_tensor = kwargs.pop("num_tiles") + num_tiles = [t[0].tolist() for t in num_tiles_tensor] + num_tokens_per_tile = (self.image_size // 14)**2 + 1 + actual_encoder_seq_lens = [ + sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles + ] + for actual_len, last_group_len in zip( + actual_encoder_seq_lens, attn_metadata.encoder_seq_lens): + assert actual_len >= last_group_len cross_attention_states, full_text_row_masked_out_mask = \ - self.flat_encoder_result(cross_attention_states, attn_metadata) - skip_cross_attention = False - # TODO: support multi-image by this mask - cross_attention_mask = None + self.get_cross_attention_states( + image_inputs, attn_metadata, actual_encoder_seq_lens) + + cross_attention_mask, kv_range_for_decode = \ + self.get_cross_attention_mask( + input_ids, attn_metadata, num_tiles, + num_tokens_per_tile, cross_attention_states.dtype) outputs = self.language_model( input_ids=input_ids, positions=positions, cross_attention_states=cross_attention_states, cross_attention_mask=cross_attention_mask, + kv_range_for_decode=kv_range_for_decode, full_text_row_masked_out_mask=full_text_row_masked_out_mask, kv_caches=kv_caches, attn_metadata=attn_metadata, @@ -1140,3 +1293,68 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + + +def skip_attention_mask(sparse_mask: List[List[int]]) -> bool: + for mask in sparse_mask: + if len(mask) != 1: + return False + if mask[0][1] != -1: + return False + return True + + +def convert_sparse_cross_attention_mask_to_dense( + sparse_mask: List[List[List[int]]], + num_tiles: List[List[int]], + lengths: List[int], +) -> Tuple[np.ndarray, List[Tuple[int, int]]]: + total_length = sum(lengths) + total_tiles = sum([sum(tiles) for tiles in num_tiles]) + dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64) + # A list of ranges, range[i] = [start, end] means + # if the i-th sample has N tiles in total, the tiles[start, end] + # will be used for cross-attention decoding. + tile_range_for_decode = [] + + seq_start = 0 + tile_start = 0 + for masks, tiles, length in zip(sparse_mask, num_tiles, lengths): + ts, td = 0, 0 + for mask, tile in zip(masks, tiles): + if len(mask) != 2: + continue + start, end = mask + end = min(end, length) + if end == -1: + end = length + if end == length: + if ts == 0: + ts = tile_start + td += tile + dense_mask[seq_start + start:seq_start + end, + tile_start:tile_start + tile] = 1 + tile_start += tile + tile_range_for_decode.append((ts, ts + td)) + seq_start += length + + return dense_mask, tile_range_for_decode + + +def convert_dense_cross_attention_mask_to_tensor( + cross_attention_token_mask: np.ndarray, + num_tokens_per_tile: int, + device: torch.device, + dtype: torch.dtype, +) -> torch.Tensor: + mask = torch.tensor(cross_attention_token_mask, dtype=dtype, device=device) + mask = mask.repeat_interleave(num_tokens_per_tile, dim=1) + + mask = 1.0 - mask + mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(dtype).min) + + ninf = torch.finfo(dtype).min + full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None]) + mask *= full_text_mask + # (num_prompt_tokens, num_encoder_tokens) + return mask From 61d6a25f03a29265a16e798c1040fe27a02c0531 Mon Sep 17 00:00:00 2001 From: Xiang Xu Date: Mon, 7 Oct 2024 22:59:38 -0700 Subject: [PATCH 2/2] Fix attention and mask --- vllm/model_executor/models/mllama.py | 42 ++++++++++++++++++++-------- 1 file changed, 30 insertions(+), 12 deletions(-) diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py index c369096f3655..c68bbfd79cdc 100644 --- a/vllm/model_executor/models/mllama.py +++ b/vllm/model_executor/models/mllama.py @@ -136,6 +136,7 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs): assert hf_config.vision_config.image_size % 14 == 0, \ "chunk size should be multiple of 14" token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1 + print(f"vllm num_tiles: {num_tiles}") num_tokens = num_tiles * token_per_chunk llm_inputs["encoder_prompt"] = MLLAMA_IMAGE_TOKEN * num_tokens llm_inputs["encoder_prompt_token_ids"] = [MLLAMA_IMAGE_TOKEN_ID @@ -762,20 +763,25 @@ def attention_with_mask( # images and interleaved images. q_len = q.shape[0] kv_len = k.shape[0] - q = q.transpose(0, 1).view(self.num_key_value_groups, - self.num_local_key_value_heads, q_len, + q = q.transpose(0, 1).view(self.num_local_key_value_heads, + self.num_key_value_groups, q_len, self.head_dim) - k = k.transpose(0, 1).expand(self.num_key_value_groups, - self.num_local_key_value_heads, kv_len, - self.head_dim) - v = v.transpose(0, 1).expand(self.num_key_value_groups, - self.num_local_key_value_heads, kv_len, - self.head_dim) + k = k.transpose(0, + 1)[:, + None, :, :].expand(self.num_local_key_value_heads, + self.num_key_value_groups, + kv_len, self.head_dim) + v = v.transpose(0, + 1)[:, + None, :, :].expand(self.num_local_key_value_heads, + self.num_key_value_groups, + kv_len, self.head_dim) attention_mask = attention_mask.view(1, 1, q_len, kv_len) output = F.scaled_dot_product_attention(q, k, v, - attn_mask=attention_mask) + attn_mask=attention_mask, + is_causal=False) output = output.permute(2, 0, 1, 3).reshape( q_len, self.num_local_heads * self.head_dim) return output @@ -1145,6 +1151,7 @@ def get_cross_attention_states( ) -> Tuple[torch.Tensor, torch.Tensor]: # NOTE: llama's reference implementation runs vision model on CPU pixel_values = image_inputs['data'] + print(f"pixel_values={pixel_values.shape}") aspect_ratio_ids = image_inputs['aspect_ratio_ids'] aspect_ratio_mask = image_inputs['aspect_ratio_mask'] cross_attention_states = self.vision_model(pixel_values, @@ -1182,7 +1189,10 @@ def get_cross_attention_mask( get_cross_attention_token_mask(t, MLLAMA_IMAGE_TOKEN_ID) for t in batch_token_ids ] + print(f"sparse_mask={sparse_mask}") + # Skip generating cross-attention mask if all samples + # are text-only or have only 1 leading image. if skip_attention_mask(sparse_mask): return None, None @@ -1297,9 +1307,17 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): def skip_attention_mask(sparse_mask: List[List[int]]) -> bool: for mask in sparse_mask: + # Skip text-only samples. + if len(mask) == 0: + continue + # If the sample contains more than 1 images, + # we can't skip mask. if len(mask) != 1: return False - if mask[0][1] != -1: + # If the sample contains only 1 image, + # but the image is not the leading one, + # we can't skip mask. + if mask[0][0] != 0 or mask[0][1] != -1: return False return True @@ -1320,7 +1338,7 @@ def convert_sparse_cross_attention_mask_to_dense( seq_start = 0 tile_start = 0 for masks, tiles, length in zip(sparse_mask, num_tiles, lengths): - ts, td = 0, 0 + ts, td = -1, 0 for mask, tile in zip(masks, tiles): if len(mask) != 2: continue @@ -1329,7 +1347,7 @@ def convert_sparse_cross_attention_mask_to_dense( if end == -1: end = length if end == length: - if ts == 0: + if ts == -1: ts = tile_start td += tile dense_mask[seq_start + start:seq_start + end,