Skip to content

Commit

Permalink
Make llama3.2 support multiple and interleaved iamges
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangxu-google committed Oct 5, 2024
1 parent 5df1834 commit 64acbc5
Showing 1 changed file with 218 additions and 21 deletions.
239 changes: 218 additions & 21 deletions vllm/model_executor/models/mllama.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
TypedDict, Union)

import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
Expand All @@ -31,6 +32,7 @@

import vllm.distributed.parallel_state as ps
from vllm.attention import Attention, AttentionMetadata, AttentionType
from vllm.attention.ops.paged_attn import PagedAttention
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.inputs import INPUT_REGISTRY, InputContext, LLMInputs
Expand Down Expand Up @@ -72,6 +74,16 @@ class MllamaImagePixelInputs(TypedDict):
# TODO: support LlamaImageEmbeddingInputs


def _get_num_images_for_decode(prompt_token_ids: List[int]) -> int:
num_images = 0
for token_id in prompt_token_ids[::-1]:
if token_id == MLLAMA_IMAGE_TOKEN_ID:
num_images += 1
elif num_images > 0:
break
return num_images


def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
# move encoder_prompt to prompt
if llm_inputs.get("prompt") is None:
Expand All @@ -91,12 +103,16 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
llm_inputs["encoder_multi_modal_data"] = {}
return llm_inputs

# get num_tiles
if isinstance(multi_modal_data['image'], Image.Image):
multi_modal_data['image'] = [multi_modal_data['image']]
# Since only the last group of consecutive images
# are attended by the decoded tokens, we only need to
# get the number of tiles for those images.
num_decode_images = _get_num_images_for_decode(
llm_inputs["prompt_token_ids"])
hf_config = ctx.model_config.hf_config
num_tiles = 0
for image in multi_modal_data["image"]:
for image in multi_modal_data["image"][::-1]:
width, height = image.size
tile_size = hf_config.vision_config.image_size
canvas_height, canvas_width = get_optimal_tiled_canvas(
Expand All @@ -108,8 +124,13 @@ def input_processor_for_mllama(ctx: InputContext, llm_inputs: LLMInputs):
num_tiles_height = canvas_height // tile_size
num_tiles_width = canvas_width // tile_size
num_tiles += num_tiles_height * num_tiles_width
num_decode_images -= 1
if num_decode_images == 0:
break

# set encoder prompt based on num_tiles
# Set encoder prompt length based on the number of tiles.
# This tells the block manager to allocate correct number
# of slots for encoder tokens.
assert hf_config.vision_config.image_size % 14 == 0, \
"chunk size should be multiple of 14"
token_per_chunk = (hf_config.vision_config.image_size // 14)**2 + 1
Expand Down Expand Up @@ -675,6 +696,7 @@ def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor],
kv_range_for_decode: Optional[List[List[int]]],
cross_attention_states: Optional[torch.Tensor],
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
Expand All @@ -697,12 +719,45 @@ def forward(
q = q.view(-1, self.num_local_heads, self.head_dim)
q = self.q_norm(q)

output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER_DECODER)
if attention_mask is not None:
if len(kv_cache.shape) == 3:
key_cache, value_cache = PagedAttention.split_kv_cache(
kv_cache, self.num_local_key_value_heads, self.head_dim)
cached_k = torch.cat([k[s:e] for s, e in kv_range_for_decode])
cached_v = torch.cat([v[s:e] for s, e in kv_range_for_decode])
PagedAttention.write_to_paged_cache(
cached_k, cached_v, key_cache, value_cache,
attn_metadata.cross_slot_mapping, "auto", 1.0, 1.0)
# We have to call torch.sdpa for prefill when using a
# custom cross-attention mask. Because the mask is not a
# standard causal mask, neither a block diagonal mask which
# can be optimized by xformers.BlockDiagonalMask.
# The mask is specially calculated for supporting multi
# images and interleaved images.
q_len = q.shape[0]
kv_len = k.shape[0]
q = q.transpose(0, 1).view(self.num_key_value_groups,
self.num_local_key_value_heads, q_len,
self.head_dim)
k = k.transpose(0, 1).expand(self.num_key_value_groups,
self.num_local_key_value_heads,
kv_len, self.head_dim)
v = v.transpose(0, 1).expand(self.num_key_value_groups,
self.num_local_key_value_heads,
kv_len, self.head_dim)
attention_mask = attention_mask.view(1, 1, q_len, kv_len)
with torch.backends.cuda.sdp_kernel(enable_flash=True):
output = F.scaled_dot_product_attention(
q, k, v, attn_mask=attention_mask)
output = output.permute(2, 0, 1, 3).reshape(
q_len, self.num_local_heads * self.head_dim)
else:
output = self.attn(q,
k,
v,
kv_cache,
attn_metadata,
attn_type=AttentionType.ENCODER_DECODER)
out, _ = self.o_proj(output)
return out

Expand Down Expand Up @@ -741,6 +796,7 @@ def forward(
hidden_states: torch.Tensor,
cross_attention_states: torch.Tensor,
cross_attention_mask: torch.Tensor,
kv_range_for_decode: Optional[List[List[int]]],
full_text_row_masked_out_mask: torch.Tensor,
kv_cache: List[torch.Tensor],
attn_metadata: AttentionMetadata,
Expand All @@ -751,6 +807,7 @@ def forward(
hidden_states = self.cross_attn(
hidden_states=hidden_states,
attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
cross_attention_states=cross_attention_states,
kv_cache=kv_cache,
attn_metadata=attn_metadata,
Expand Down Expand Up @@ -804,6 +861,7 @@ def forward(
positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor],
kv_range_for_decode: Optional[List[List[int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]],
kv_caches: List[torch.Tensor],
Expand All @@ -820,6 +878,7 @@ def forward(
hidden_states=hidden_states,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=
full_text_row_masked_out_mask,
kv_cache=kv_caches[idx],
Expand Down Expand Up @@ -868,6 +927,7 @@ def forward(
positions: Optional[torch.LongTensor],
cross_attention_states: Optional[torch.LongTensor],
cross_attention_mask: Optional[torch.LongTensor],
kv_range_for_decode: Optional[List[List[int]]],
full_text_row_masked_out_mask: Optional[Tuple[torch.Tensor,
torch.Tensor]],
kv_caches: List[torch.Tensor],
Expand All @@ -879,6 +939,7 @@ def forward(
positions=positions,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
Expand Down Expand Up @@ -1026,16 +1087,17 @@ def _parse_and_validate_image_input(self, **kwargs: object):
raise AssertionError("This line should be unreachable.")

def flat_encoder_result(self, cross_attention_states: torch.Tensor,
attn_metadata: AttentionMetadata):
attn_metadata: AttentionMetadata,
real_encoder_seq_lens: List[int]):

cross_attention_states_flat = torch.zeros(
sum(attn_metadata.encoder_seq_lens),
sum(real_encoder_seq_lens),
cross_attention_states.shape[-1],
device=cross_attention_states.device,
dtype=cross_attention_states.dtype)
start_pos = 0
for seq_len, vision_token_in_batch in zip(
attn_metadata.encoder_seq_lens, cross_attention_states):
for seq_len, vision_token_in_batch in zip(real_encoder_seq_lens,
cross_attention_states):
end_pos = start_pos + seq_len
cross_attention_states_flat[
start_pos:end_pos] = vision_token_in_batch[:seq_len]
Expand All @@ -1045,9 +1107,8 @@ def flat_encoder_result(self, cross_attention_states: torch.Tensor,
full_text_row_masked_out_mask = torch.ones(
(attn_metadata.num_prefill_tokens, 1), dtype=torch.bool)
start_pos = 0
for seq_len, encoder_seq_len in zip(
attn_metadata.seq_lens_tensor.cpu(),
attn_metadata.encoder_seq_lens):
for seq_len, encoder_seq_len in zip(attn_metadata.seq_lens,
attn_metadata.encoder_seq_lens):
if encoder_seq_len == 0:
full_text_row_masked_out_mask[start_pos:start_pos +
seq_len] = False
Expand All @@ -1069,13 +1130,17 @@ def forward(
attn_metadata.num_decode_tokens > 0:
raise ValueError("Chunk prefill not supported")
image_inputs = self._parse_and_validate_image_input(**kwargs)
# text-only prefill and decode
# image decode
cross_attention_states = None
cross_attention_mask = None
kv_range_for_decode = None
if image_inputs is None:
cross_attention_mask = None
full_text_row_masked_out_mask = (
attn_metadata.encoder_seq_lens_tensor != 0).reshape(-1, 1).to(
input_ids.device)
cross_attention_states = None
skip_cross_attention = max(attn_metadata.encoder_seq_lens) == 0
# image prefill
else:
# NOTE: llama's reference implementation runs vision model on CPU
pixel_values = image_inputs['data']
Expand All @@ -1091,17 +1156,51 @@ def forward(
cross_attention_states = cross_attention_states.view(
bsz, -1, image_token_dim)

num_tiles_tensor = kwargs.pop("num_tiles")
if isinstance(num_tiles_tensor, list):
num_tiles = [t.tolist()[0] for t in num_tiles_tensor]
else:
num_tiles = [t[0] for t in num_tiles_tensor.tolist()]
num_tokens_per_tile = (self.image_size // 14)**2 + 1
real_encoder_seq_lens = [
sum(num_tile) * num_tokens_per_tile for num_tile in num_tiles
]

cross_attention_states, full_text_row_masked_out_mask = \
self.flat_encoder_result(cross_attention_states, attn_metadata)
self.flat_encoder_result(
cross_attention_states, attn_metadata,
real_encoder_seq_lens)
skip_cross_attention = False
# TODO: support multi-image by this mask
cross_attention_mask = None

token_ids = input_ids.tolist()
start = 0
batch_token_ids = []
for seq_len in attn_metadata.seq_lens:
batch_token_ids.append(token_ids[start:start + seq_len])
start += seq_len
sparse_mask = [
create_sparse_cross_attention_mask(t, MLLAMA_IMAGE_TOKEN_ID)
for t in batch_token_ids
]

if not all_single_leading_image(sparse_mask):
dense_mask, tile_range_for_decode = \
convert_sparse_cross_attention_mask_to_dense(
sparse_mask, num_tiles, attn_metadata.seq_lens)
cross_attention_mask = \
convert_dense_cross_attention_mask_to_tensor(
dense_mask, num_tokens_per_tile, input_ids.device,
cross_attention_states.dtype)
kv_range_for_decode = [[
t[0] * num_tokens_per_tile, t[1] * num_tokens_per_tile
] for t in tile_range_for_decode]

outputs = self.language_model(
input_ids=input_ids,
positions=positions,
cross_attention_states=cross_attention_states,
cross_attention_mask=cross_attention_mask,
kv_range_for_decode=kv_range_for_decode,
full_text_row_masked_out_mask=full_text_row_masked_out_mask,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
Expand Down Expand Up @@ -1140,3 +1239,101 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
weight_loader = getattr(param, "weight_loader",
default_weight_loader)
weight_loader(param, loaded_weight)


def create_sparse_cross_attention_mask(
tokens: List[int],
vision_token: int,
) -> List[List[int]]:
vision_token_locations = [
i for i, token in enumerate(tokens) if token == vision_token
]
# No vision token present
if len(vision_token_locations) == 0:
return []
# only one image present, unmask until end of sequence
if len(vision_token_locations) == 1:
return [[vision_token_locations[0], -1]]

vision_masks = [[loc1, loc2] for loc1, loc2 in zip(
vision_token_locations[:-1], vision_token_locations[1:])]
# last image will attend to all subsequent text
vision_masks.append([vision_token_locations[-1], len(tokens)])

# if there are two or more consecutive vision tokens,
# they should all attend to all subsequent
# text present
last_mask_end = vision_masks[-1][1]
for vision_mask in vision_masks[::-1]:
if vision_mask[0] == vision_mask[1] - 1:
vision_mask[1] = last_mask_end
last_mask_end = vision_mask[1]

# A list of ranges, range[i] = [start:end] means the
# i-th image will attend to the tokens in the range [start:end)
return vision_masks


def all_single_leading_image(sparse_mask: List[List[int]]) -> bool:
for mask in sparse_mask:
if len(mask) != 1:
return False
if mask[0][1] != -1:
return False
return True


def convert_sparse_cross_attention_mask_to_dense(
sparse_mask: List[List[List[int]]],
num_tiles: List[List[int]],
lengths: List[int],
) -> Tuple[np.ndarray, List[List[int]]]:
total_length = sum(lengths)
total_tiles = sum([sum(tiles) for tiles in num_tiles])
dense_mask = np.zeros(shape=(total_length, total_tiles), dtype=np.int64)
# A list of ranges, range[i] = [start, end] means
# if the i-th sample has N tiles in total, the tiles[start, end]
# will be used for cross-attention decoding.
tile_range_for_decode = []

seq_start = 0
tile_start = 0
for masks, tiles, length in zip(sparse_mask, num_tiles, lengths):
ts, td = 0, 0
for mask, tile in zip(masks, tiles):
if len(mask) != 2:
continue
start, end = mask
end = min(end, length)
if end == -1:
end = length
if end == length:
if ts == 0:
ts = tile_start
td += tile
dense_mask[seq_start + start:seq_start + end,
tile_start:tile_start + tile] = 1
tile_start += tile
tile_range_for_decode.append([ts, ts + td])
seq_start += length

return dense_mask, tile_range_for_decode


def convert_dense_cross_attention_mask_to_tensor(
cross_attention_token_mask: np.ndarray,
num_tokens_per_tile: int,
device: torch.device,
dtype: torch.dtype,
) -> torch.Tensor:
mask = torch.tensor(cross_attention_token_mask, dtype=dtype, device=device)
mask = mask.repeat_interleave(num_tokens_per_tile, dim=1)

mask = 1.0 - mask
mask = mask.masked_fill(mask.to(torch.bool), torch.finfo(dtype).min)

ninf = torch.finfo(dtype).min
full_text_mask = ((mask != ninf).any(dim=-1).type_as(mask)[..., None])
mask *= full_text_mask
# (num_decoder_length, num_encoder_length)
return mask

0 comments on commit 64acbc5

Please sign in to comment.