diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 6676dd16e005f..f3d66c2313198 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -45,8 +45,12 @@ except ImportError: USE_XFORMERS_OPS = False -PIXTRAL_IMAGE_BREAK_ID = 12 -PIXTRAL_IMAGE_END_ID = 13 +# These token ids cannot be retrieved from model config +# so we hardcode them here. +PIXTRAL_12B_IMAGE_BREAK_ID = 12 +PIXTRAL_12B_IMAGE_END_ID = 13 +PIXTRAL_LARGE_IMAGE_BREAK_ID = 14 +PIXTRAL_LARGE_IMAGE_END_ID = 15 def get_max_pixtral_image_tokens(ctx: InputContext): @@ -118,8 +122,7 @@ def input_mapper_for_pixtral(ctx: InputContext, for image_data in data_list: image = ImageChunk(image=image_data) encoding = tokenizer.instruct.mm_encoder(image) - image = torch.from_numpy(encoding.image).to(device="cuda", - dtype=torch.float16) + image = torch.from_numpy(encoding.image).to(dtype=torch.float16) images.append(image) image_tokens_list.append(encoding.tokens) @@ -237,8 +240,9 @@ def get_multimodal_embeddings(self, **kwargs) -> Optional[NestedTensors]: # NOTE: Image embeddings are split into separate tensors for each image # by the indices of `[IMG_END]` token. - split_indices = torch.where( - image_tokens == PIXTRAL_IMAGE_END_ID)[0] + 1 + image_end_condition = (image_tokens == PIXTRAL_12B_IMAGE_END_ID) | ( + image_tokens == PIXTRAL_LARGE_IMAGE_END_ID) + split_indices = torch.where(image_end_condition)[0] + 1 if len(split_indices) <= 1: # Do not split, return as tensor of shape [1, fs, hs] return image_embeds.unsqueeze(0) @@ -260,8 +264,11 @@ def get_input_embeddings( if multimodal_embeddings is not None: inputs_embeds = merge_multimodal_embeddings( input_ids, inputs_embeds, multimodal_embeddings, [ - self.vision_args.image_token_id, PIXTRAL_IMAGE_END_ID, - PIXTRAL_IMAGE_BREAK_ID + self.vision_args.image_token_id, + PIXTRAL_12B_IMAGE_END_ID, + PIXTRAL_12B_IMAGE_BREAK_ID, + PIXTRAL_LARGE_IMAGE_BREAK_ID, + PIXTRAL_LARGE_IMAGE_END_ID, ]) return inputs_embeds