Skip to content

Commit

Permalink
[torch.compile] Adding torch compile to vision-language models (vllm-…
Browse files Browse the repository at this point in the history
  • Loading branch information
CRZbulabula authored Nov 2, 2024
1 parent 1b73ab2 commit ae5279a
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 8 deletions.
10 changes: 7 additions & 3 deletions vllm/model_executor/models/llava_next.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,6 @@ def forward(
:class:`LlavaNextImageInputs`
"""
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand All @@ -618,9 +617,14 @@ def forward(
self.language_model.model.get_input_embeddings,
lambda _: self._process_image_input(image_input),
)
input_ids = None
else:
inputs_embeds = None
inputs_embeds = self.language_model.model.get_input_embeddings(
input_ids)

# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None

hidden_states = self.language_model.model(input_ids,
positions,
Expand Down
7 changes: 6 additions & 1 deletion vllm/model_executor/models/minicpmv.py
Original file line number Diff line number Diff line change
Expand Up @@ -564,8 +564,13 @@ def forward(

vlm_embeddings, _ = self.get_embedding(input_ids, image_inputs)

# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None

output = self.llm(
input_ids=None,
input_ids=input_ids,
positions=positions,
kv_caches=kv_caches,
attn_metadata=attn_metadata,
Expand Down
12 changes: 8 additions & 4 deletions vllm/model_executor/models/molmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from vllm.attention import Attention, AttentionMetadata
from vllm.attention.selector import _Backend
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, MultiModalConfig
from vllm.distributed import (get_pp_group, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
Expand Down Expand Up @@ -713,6 +714,7 @@ def forward(
return image_features


@support_torch_compile
class MolmoModel(nn.Module):

def __init__(
Expand Down Expand Up @@ -1141,7 +1143,6 @@ def forward(
**kwargs: object,
) -> SamplerOutput:
if intermediate_tensors is not None:
input_ids = None
inputs_embeds = None
else:
image_input = self._parse_and_validate_image_input(**kwargs)
Expand All @@ -1156,10 +1157,13 @@ def forward(
image_input["image_input_idx"],
image_input["seq_len"],
)

input_ids = None
else:
inputs_embeds = None
inputs_embeds = self.model.embed_tokens(input_ids)

# always pass the input via `inputs_embeds`
# to make sure the computation graph is consistent
# for `torch.compile` integration
input_ids = None

hidden_states = self.model(
input_ids=input_ids,
Expand Down

0 comments on commit ae5279a

Please sign in to comment.