diff --git a/tests/conftest.py b/tests/conftest.py index d565da5a1019c..ba764223a29e4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -4,7 +4,8 @@ import sys from collections import UserList from enum import Enum -from typing import Any, Dict, List, Optional, Tuple, TypedDict, TypeVar, Union +from typing import (Any, Callable, Dict, List, Optional, Tuple, TypedDict, + TypeVar, Union) import pytest import torch @@ -27,7 +28,7 @@ from vllm.outputs import RequestOutput from vllm.sequence import SampleLogprobs from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, cuda_device_count_stateless, - is_cpu) + identity, is_cpu) logger = init_logger(__name__) @@ -197,6 +198,8 @@ def __init__( is_embedding_model: bool = False, is_vision_model: bool = False, is_encoder_decoder_model: bool = False, + postprocess_inputs: Callable[[BatchEncoding], + BatchEncoding] = identity, ) -> None: torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] @@ -242,12 +245,14 @@ def __init__( torch_dtype=torch_dtype, trust_remote_code=True, ) - except Exception: + except Exception as exc: logger.warning( - "Unable to auto-load processor from HuggingFace for " - "model %s. Using tokenizer instead.", model_name) + "Unable to auto-load HuggingFace processor for model (%s). " + "Using tokenizer instead. Reason: %s", model_name, exc) self.processor = self.tokenizer + self.postprocess_inputs = postprocess_inputs + def generate( self, prompts: List[str], @@ -267,6 +272,7 @@ def generate( processor_kwargs["images"] = images[i] inputs = self.processor(**processor_kwargs) + inputs = self.postprocess_inputs(inputs) output_ids = self.model.generate( **self.wrap_device(inputs), @@ -336,6 +342,7 @@ def generate_greedy_logprobs( processor_kwargs["images"] = images[i] inputs = self.processor(**processor_kwargs) + inputs = self.postprocess_inputs(inputs) output = self.model.generate( **self.wrap_device(inputs), @@ -420,6 +427,7 @@ def generate_greedy_logprobs_limit( processor_kwargs["images"] = images[i] inputs = self.processor(**processor_kwargs) + inputs = self.postprocess_inputs(inputs) output = self.model.generate( **self.wrap_device(inputs), @@ -552,7 +560,8 @@ def generate( self, prompts: List[str], sampling_params: SamplingParams, - images: Optional[List[Image.Image]] = None, + images: Optional[Union[List[Image.Image], + List[List[Image.Image]]]] = None, ) -> List[Tuple[List[List[int]], List[str]]]: if images is not None: assert len(prompts) == len(images) @@ -587,7 +596,7 @@ def _final_steps_generate_w_logprobs( for req_output in req_outputs: for sample in req_output.outputs: output_str = sample.text - output_ids = sample.token_ids + output_ids = list(sample.token_ids) output_logprobs = sample.logprobs outputs.append((output_ids, output_str, output_logprobs)) return outputs @@ -596,7 +605,8 @@ def generate_w_logprobs( self, prompts: List[str], sampling_params: SamplingParams, - images: Optional[List[Image.Image]] = None, + images: Optional[Union[List[Image.Image], + List[List[Image.Image]]]] = None, ) -> List[Tuple[List[int], str, Optional[SampleLogprobs]]]: assert sampling_params.logprobs is not None diff --git a/tests/distributed/test_multimodal_broadcast.py b/tests/distributed/test_multimodal_broadcast.py index 2c96358e2e6f2..e7723a7ae2480 100644 --- a/tests/distributed/test_multimodal_broadcast.py +++ b/tests/distributed/test_multimodal_broadcast.py @@ -18,8 +18,10 @@ @pytest.mark.parametrize("model, distributed_executor_backend", [ ("llava-hf/llava-1.5-7b-hf", "ray"), ("llava-hf/llava-v1.6-mistral-7b-hf", "ray"), + ("facebook/chameleon-7b", "ray"), ("llava-hf/llava-1.5-7b-hf", "mp"), ("llava-hf/llava-v1.6-mistral-7b-hf", "mp"), + ("facebook/chameleon-7b", "mp"), ]) @fork_new_process_for_each_test def test_models(hf_runner, vllm_runner, image_assets, model: str, @@ -34,6 +36,8 @@ def test_models(hf_runner, vllm_runner, image_assets, model: str, from ..models.test_llava import models, run_test elif model.startswith("llava-hf/llava-v1.6"): from ..models.test_llava_next import models, run_test + elif model.startswith("facebook/chameleon"): + from ..models.test_chameleon import models, run_test else: raise NotImplementedError(f"Unsupported model: {model}") diff --git a/tests/entrypoints/openai/test_oot_registration.py b/tests/entrypoints/openai/test_oot_registration.py index 9f9a4cd972c51..3e1c7a1456697 100644 --- a/tests/entrypoints/openai/test_oot_registration.py +++ b/tests/entrypoints/openai/test_oot_registration.py @@ -1,5 +1,6 @@ import sys import time +from typing import Optional import torch from openai import OpenAI, OpenAIError @@ -17,8 +18,11 @@ class MyOPTForCausalLM(OPTForCausalLM): - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: # this dummy model always predicts the first token logits = super().compute_logits(hidden_states, sampling_metadata) logits.zero_() diff --git a/tests/models/test_chameleon.py b/tests/models/test_chameleon.py index 6e775da24d14e..5e7e0e6258f8a 100644 --- a/tests/models/test_chameleon.py +++ b/tests/models/test_chameleon.py @@ -1,11 +1,13 @@ -import re from typing import List, Optional, Type import pytest +from transformers import BatchEncoding from vllm.multimodal.utils import rescale_image_size +from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE -from ..conftest import IMAGE_ASSETS, VllmRunner, _ImageAssets +from ..conftest import IMAGE_ASSETS, HfRunner, VllmRunner, _ImageAssets +from .utils import check_outputs_equal pytestmark = pytest.mark.vlm @@ -19,9 +21,8 @@ models = ["facebook/chameleon-7b"] -#TODO (ywang96): Add correctness test when chameleon is -# available on transformers. def run_test( + hf_runner: Type[HfRunner], vllm_runner: Type[VllmRunner], image_assets: _ImageAssets, model: str, @@ -29,13 +30,20 @@ def run_test( size_factors: List[float], dtype: str, max_tokens: int, + num_logprobs: int, tensor_parallel_size: int, distributed_executor_backend: Optional[str] = None, ): - """Test if the model can generate text given - a batch of images and prompts. - + """Inference result should be the same between hf and vllm. + + All the image fixtures for the test is under tests/images. + For huggingface runner, we provide the PIL images as input. + For vllm runner, we provide MultiModalDataDict objects + and corresponding vision language config as input. + Note, the text input is also adjusted to abide by vllm contract. + The text output is sanitized to be able to compare with hf. """ + torch_dtype = STR_DTYPE_TO_TORCH_DTYPE[dtype] images = [asset.pil_image for asset in image_assets] inputs_per_image = [( @@ -50,35 +58,49 @@ def run_test( distributed_executor_backend=distributed_executor_backend, enforce_eager=True) as vllm_model: - for prompts, images in inputs_per_image: - vllm_outputs = vllm_model.generate_greedy(prompts, - max_tokens, - images=images) - for i in range(len(vllm_outputs)): - - # format prompt back to original - replacements = { - "": "", - "": "", - "": "" - } - pattern = '|'.join(replacements.keys()) - vllm_result = re.sub( - pattern, - lambda match: replacements[match.group(0)], #noqa B023 - vllm_outputs[i][1]) - vllm_result = vllm_result.replace("", "", 1023) - assert vllm_result[:len(prompts[i])] == prompts[i] - - # assert at least 10 new characters are generated - # (to take stop token into account) - assert len(vllm_outputs[i][1]) - len(prompts[i]) > 10 + vllm_outputs_per_image = [ + vllm_model.generate_greedy_logprobs(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image + ] + + def process(hf_inputs: BatchEncoding): + hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \ + .to(torch_dtype) # type: ignore + return hf_inputs + + with hf_runner(model, + dtype=dtype, + postprocess_inputs=process, + is_vision_model=True) as hf_model: + hf_outputs_per_image = [ + hf_model.generate_greedy_logprobs_limit(prompts, + max_tokens, + num_logprobs=num_logprobs, + images=images) + for prompts, images in inputs_per_image + ] + + for hf_outputs, vllm_outputs in zip(hf_outputs_per_image, + vllm_outputs_per_image): + # HF Logprobs include image tokens, unlike vLLM, so we don't directly + # compare them + check_outputs_equal( + outputs_0_lst=[outputs[:2] for outputs in hf_outputs], + outputs_1_lst=[outputs[:2] for outputs in vllm_outputs], + name_0="hf", + name_1="vllm", + ) @pytest.mark.parametrize("model", models) @pytest.mark.parametrize( "size_factors", [ + # No image + [], # Single-scale [1.0], # Single-scale, batched @@ -88,15 +110,18 @@ def run_test( ], ) @pytest.mark.parametrize("dtype", ["bfloat16"]) -@pytest.mark.parametrize("max_tokens", [128]) -def test_models(vllm_runner, image_assets, model, size_factors, dtype: str, - max_tokens: int) -> None: +@pytest.mark.parametrize("max_tokens", [8]) +@pytest.mark.parametrize("num_logprobs", [5]) +def test_models(hf_runner, vllm_runner, image_assets, model, size_factors, + dtype, max_tokens, num_logprobs) -> None: run_test( + hf_runner, vllm_runner, image_assets, model, size_factors=size_factors, dtype=dtype, max_tokens=max_tokens, + num_logprobs=num_logprobs, tensor_parallel_size=1, ) diff --git a/tests/models/test_llava.py b/tests/models/test_llava.py index 749d3353717e2..2724a0855117e 100644 --- a/tests/models/test_llava.py +++ b/tests/models/test_llava.py @@ -1,7 +1,7 @@ from typing import List, Optional, Tuple, Type import pytest -from transformers import AutoConfig, AutoTokenizer +from transformers import AutoConfig, AutoTokenizer, BatchEncoding from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -110,16 +110,21 @@ def run_test( for prompts, images in inputs_per_image ] - with hf_runner(model, dtype=dtype, is_vision_model=True) as hf_model: - if mantis_processor is not None: + if mantis_processor is not None: - def process(*args, **kwargs): - output = mantis_processor(*args, **kwargs) - output["pixel_values"] = output["pixel_values"].to(torch_dtype) - return output + def process(hf_inputs: BatchEncoding): + hf_inputs["pixel_values"] = hf_inputs["pixel_values"] \ + .to(torch_dtype) # type: ignore + return hf_inputs + else: - hf_model.processor = process + def process(hf_inputs: BatchEncoding): + return hf_inputs + with hf_runner(model, + dtype=dtype, + postprocess_inputs=process, + is_vision_model=True) as hf_model: hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, diff --git a/tests/models/test_minicpmv.py b/tests/models/test_minicpmv.py index c3b2a7bcbaafd..32f1cb2c2ed33 100644 --- a/tests/models/test_minicpmv.py +++ b/tests/models/test_minicpmv.py @@ -1,10 +1,9 @@ -from collections import UserDict from typing import List, Optional, Tuple, Type import pytest import torch import torch.types -from transformers import BatchFeature +from transformers import BatchEncoding from vllm.multimodal.utils import rescale_image_size from vllm.sequence import SampleLogprobs @@ -14,18 +13,6 @@ pytestmark = pytest.mark.vlm - -class NestedInputs(UserDict): - - def __init__(self, model_inputs: BatchFeature): - super().__init__({"model_inputs": model_inputs}) - - self.model_inputs = model_inputs - - def to(self, device: torch.types.Device): - return NestedInputs(self.model_inputs.to(device)) - - # The image token is placed before "user" on purpose so that the test can pass HF_IMAGE_PROMPTS = IMAGE_ASSETS.prompts({ "stop_sign": @@ -41,6 +28,10 @@ def to(self, device: torch.types.Device): models = ["openbmb/MiniCPM-Llama3-V-2_5"] +def _wrap_inputs(hf_inputs: BatchEncoding) -> BatchEncoding: + return BatchEncoding({"model_inputs": hf_inputs}) + + def trunc_hf_output(hf_output: Tuple[List[int], str, Optional[SampleLogprobs]]): output_ids, output_str, out_logprobs = hf_output @@ -105,11 +96,8 @@ def run_test( for prompts, images in inputs_per_image ] - with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad(): - hf_processor = hf_model.processor - hf_model.processor = lambda **kw: NestedInputs( - hf_processor(**kw) # type: ignore - ) + hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs) + with hf_model, torch.no_grad(): hf_outputs_per_image = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, @@ -224,11 +212,8 @@ def run_multi_image_test( for prompts, images in inputs_per_case ] - with hf_runner(model, dtype=dtype) as hf_model, torch.no_grad(): - hf_processor = hf_model.processor - hf_model.processor = lambda **kw: NestedInputs( - hf_processor(**kw) # type: ignore - ) + hf_model = hf_runner(model, dtype=dtype, postprocess_inputs=_wrap_inputs) + with hf_model, torch.no_grad(): hf_outputs_per_case = [ hf_model.generate_greedy_logprobs_limit(prompts, max_tokens, diff --git a/tests/models/test_oot_registration.py b/tests/models/test_oot_registration.py index 50ab06631500b..4918593ff0f98 100644 --- a/tests/models/test_oot_registration.py +++ b/tests/models/test_oot_registration.py @@ -1,3 +1,5 @@ +from typing import Optional + import torch from vllm import LLM, ModelRegistry, SamplingParams @@ -7,8 +9,11 @@ class MyOPTForCausalLM(OPTForCausalLM): - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: # this dummy model always predicts the first token logits = super().compute_logits(hidden_states, sampling_metadata) logits.zero_() diff --git a/vllm/distributed/communication_op.py b/vllm/distributed/communication_op.py index 32394a07b00b9..e13505dc37bb0 100644 --- a/vllm/distributed/communication_op.py +++ b/vllm/distributed/communication_op.py @@ -19,7 +19,7 @@ def tensor_model_parallel_all_gather(input_: torch.Tensor, def tensor_model_parallel_gather(input_: torch.Tensor, dst: int = 0, - dim: int = -1) -> torch.Tensor: + dim: int = -1) -> Optional[torch.Tensor]: """Gather the input tensor across model parallel group.""" return get_tp_group().gather(input_, dst, dim) diff --git a/vllm/distributed/parallel_state.py b/vllm/distributed/parallel_state.py index a20b92de81cda..6755b20eec9bb 100644 --- a/vllm/distributed/parallel_state.py +++ b/vllm/distributed/parallel_state.py @@ -329,7 +329,7 @@ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor: def gather(self, input_: torch.Tensor, dst: int = 0, - dim: int = -1) -> torch.Tensor: + dim: int = -1) -> Optional[torch.Tensor]: """ NOTE: We assume that the input tensor is on the same device across all the ranks. diff --git a/vllm/model_executor/layers/logits_processor.py b/vllm/model_executor/layers/logits_processor.py index 80534acdc1a6a..1d5b6fad2e160 100644 --- a/vllm/model_executor/layers/logits_processor.py +++ b/vllm/model_executor/layers/logits_processor.py @@ -50,7 +50,7 @@ def forward( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, embedding_bias: Optional[torch.Tensor] = None, - ) -> torch.Tensor: + ) -> Optional[torch.Tensor]: if self.logits_as_input: logits = hidden_states else: @@ -73,14 +73,18 @@ def forward( return logits - def _get_logits(self, hidden_states: torch.Tensor, - lm_head: VocabParallelEmbedding, - embedding_bias: Optional[torch.Tensor]) -> torch.Tensor: + def _get_logits( + self, + hidden_states: torch.Tensor, + lm_head: VocabParallelEmbedding, + embedding_bias: Optional[torch.Tensor], + ) -> Optional[torch.Tensor]: # Get the logits for the next tokens. logits = lm_head.linear_method.apply(lm_head, hidden_states, bias=embedding_bias) if self.use_gather: + # None may be returned for rank > 0 logits = tensor_model_parallel_gather(logits) else: # Gather is not supported for some devices such as TPUs. diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index a9a04b4263ae2..bd8ff8ba8d8c0 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -19,6 +19,7 @@ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME from vllm.config import LoadConfig, ModelConfig +from vllm.distributed import get_tensor_model_parallel_rank from vllm.logger import init_logger from vllm.model_executor.layers.quantization import (QuantizationConfig, get_quantization_config) @@ -514,8 +515,30 @@ def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: def default_weight_loader(param: torch.Tensor, loaded_weight: torch.Tensor) -> None: """Default weight loader.""" - assert param.size() == loaded_weight.size() - param.data.copy_(loaded_weight) + try: + assert param.size() == loaded_weight.size(), ( + f"Attempted to load weight ({loaded_weight.size()}) " + f"into parameter ({param.size()})") + + param.data.copy_(loaded_weight) + except Exception: + # NOTE: This exception is added for the purpose of setting breakpoint to + # debug weight loading issues. + raise + + +def row_parallel_weight_loader(param: torch.Tensor, + loaded_weight: torch.Tensor) -> None: + """Load weights that are row-parallelized.""" + tp_rank = get_tensor_model_parallel_rank() + shard_dim = 0 if param.dim() != 1 else None + + if shard_dim is not None: + shard_size = param.data.shape[shard_dim] + start_idx = tp_rank * shard_size + loaded_weight = loaded_weight.narrow(shard_dim, start_idx, shard_size) + + return default_weight_loader(param, loaded_weight) def initialize_dummy_weights( diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py index 49e57a847e847..74e534aa76a9d 100644 --- a/vllm/model_executor/models/arctic.py +++ b/vllm/model_executor/models/arctic.py @@ -433,8 +433,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py index e1ea8bfcac655..a11c7663263c6 100644 --- a/vllm/model_executor/models/baichuan.py +++ b/vllm/model_executor/models/baichuan.py @@ -346,8 +346,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/bart.py b/vllm/model_executor/models/bart.py index 5066e991f9003..ef988532ce126 100644 --- a/vllm/model_executor/models/bart.py +++ b/vllm/model_executor/models/bart.py @@ -872,8 +872,11 @@ def forward( return self.model(input_ids, positions, encoder_input_ids, encoder_positions, kv_caches, attn_metadata) - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py index 084cbf35533bc..4968d6d900ac2 100644 --- a/vllm/model_executor/models/blip2.py +++ b/vllm/model_executor/models/blip2.py @@ -637,8 +637,11 @@ def forward( return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.get_lm_head(), hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py index 86ae32e0cb01f..282a0f84eacb1 100644 --- a/vllm/model_executor/models/bloom.py +++ b/vllm/model_executor/models/bloom.py @@ -292,8 +292,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py index 10a82207d90ef..2b6e5ee975172 100644 --- a/vllm/model_executor/models/chameleon.py +++ b/vllm/model_executor/models/chameleon.py @@ -25,8 +25,10 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, row_parallel_weight_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata +from vllm.model_executor.utils import set_weight_attrs from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.image import (cached_get_tokenizer, repeat_and_pad_image_tokens) @@ -141,6 +143,11 @@ def __init__(self, hidden_size, *args, **kwargs): super().__init__(hidden_size, *args, **kwargs) self.normalized_shape = (hidden_size[-1], ) + set_weight_attrs(self.weight, + {"weight_loader": row_parallel_weight_loader}) + set_weight_attrs(self.bias, + {"weight_loader": row_parallel_weight_loader}) + def forward(self, hidden_states): hidden_states = F.layer_norm(hidden_states, self.normalized_shape, @@ -697,6 +704,8 @@ def __init__(self, config: ChameleonVQVAEConfig): ) def forward(self, pixel_values: torch.Tensor): + pixel_values = pixel_values.to(self.conv_in.weight.dtype) + # downsampling hidden_states = [self.conv_in(pixel_values)] for i_level in range(self.num_resolutions): @@ -959,15 +968,19 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) # Disallow image tokens which does not include special # begin-image and end-image tokens - image_tokens = self.model.vocabulary_mapping.image_tokens - logits[:, image_tokens] = torch.finfo(logits.dtype).min + if logits is not None: + image_tokens = self.model.vocabulary_mapping.image_tokens + logits[:, image_tokens] = torch.finfo(logits.dtype).min return logits diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 553ddf90475b4..b29ebe2f59e7b 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -372,8 +372,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py index 5f6e3a134f408..0894f750e5fbf 100644 --- a/vllm/model_executor/models/commandr.py +++ b/vllm/model_executor/models/commandr.py @@ -25,13 +25,11 @@ import torch import torch.utils.checkpoint from torch import nn -from torch.nn.parameter import Parameter from transformers import CohereConfig from vllm.attention import Attention, AttentionMetadata from vllm.config import CacheConfig, LoRAConfig -from vllm.distributed import (get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) +from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, @@ -43,7 +41,8 @@ from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) -from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.model_loader.weight_utils import ( + default_weight_loader, row_parallel_weight_loader) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.model_executor.utils import set_weight_attrs from vllm.sequence import IntermediateTensors, SamplerOutput @@ -67,25 +66,14 @@ def __init__(self, param_shape=None, eps=1e-5): super().__init__() self.weight = nn.Parameter(torch.ones(param_shape)) self.variance_epsilon = eps - set_weight_attrs(self.weight, {"weight_loader": self.weight_loader}) + set_weight_attrs(self.weight, + {"weight_loader": row_parallel_weight_loader}) def forward(self, hidden_states, residuals=None): hidden_states = layer_norm_func(hidden_states, self.weight, self.variance_epsilon) return hidden_states, residuals - def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor): - tp_rank = get_tensor_model_parallel_rank() - shard_dim = 0 if param.dim() != 1 else None - param_data = param.data - if shard_dim is not None: - shard_size = param_data.shape[shard_dim] - start_idx = tp_rank * shard_size - loaded_weight = loaded_weight.narrow(shard_dim, start_idx, - shard_size) - assert param_data.shape == loaded_weight.shape - param_data.copy_(loaded_weight) - # Copied from transformers.models.llama.modeling_llama.LlamaMLP Llama->Cohere class CohereMLP(nn.Module): @@ -359,8 +347,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: is_not_lora = hasattr(self.model.embed_tokens, 'weight') if is_not_lora: logits = self.logits_processor(self.model.embed_tokens, diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py index d758333b22388..7ebeca1a359ef 100644 --- a/vllm/model_executor/models/dbrx.py +++ b/vllm/model_executor/models/dbrx.py @@ -388,8 +388,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py index 3fd6f2218f3eb..f10977ed2c90d 100644 --- a/vllm/model_executor/models/deepseek.py +++ b/vllm/model_executor/models/deepseek.py @@ -395,8 +395,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py index 2e3e9b6f2792e..1ac15cefb5e29 100644 --- a/vllm/model_executor/models/deepseek_v2.py +++ b/vllm/model_executor/models/deepseek_v2.py @@ -505,8 +505,11 @@ def forward( attn_metadata, intermediate_tensors) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py index 93f07327eaa26..7b97b3d255dfa 100644 --- a/vllm/model_executor/models/falcon.py +++ b/vllm/model_executor/models/falcon.py @@ -420,8 +420,11 @@ def forward( ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py index bb49349e7954d..41e8b13990e81 100644 --- a/vllm/model_executor/models/fuyu.py +++ b/vllm/model_executor/models/fuyu.py @@ -287,8 +287,11 @@ def forward( ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.language_model.logits_processor( self.language_model.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py index 64aef1024a1a5..14d1578863e5e 100644 --- a/vllm/model_executor/models/gemma.py +++ b/vllm/model_executor/models/gemma.py @@ -352,8 +352,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.model.embed_tokens, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index 7bad2626fec6a..aa9cff02283c0 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -343,8 +343,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.model.embed_tokens, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py index 94cd67e75336a..4f2fe0c42a3ff 100644 --- a/vllm/model_executor/models/gpt2.py +++ b/vllm/model_executor/models/gpt2.py @@ -265,8 +265,11 @@ def forward( attn_metadata, intermediate_tensors) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py index fc4e13bbb0e68..b30af3599aa4d 100644 --- a/vllm/model_executor/models/gpt_bigcode.py +++ b/vllm/model_executor/models/gpt_bigcode.py @@ -279,8 +279,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py index 4bb9debe7ae81..4d52b448049b4 100644 --- a/vllm/model_executor/models/gpt_j.py +++ b/vllm/model_executor/models/gpt_j.py @@ -246,8 +246,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata, self.lm_head.bias) return logits diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py index b306574b2ed92..e61b4448981e8 100644 --- a/vllm/model_executor/models/gpt_neox.py +++ b/vllm/model_executor/models/gpt_neox.py @@ -258,8 +258,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.embed_out, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py index 745fbf99a902d..216458465513a 100644 --- a/vllm/model_executor/models/internlm2.py +++ b/vllm/model_executor/models/internlm2.py @@ -279,8 +279,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.output, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py index 3a8e4baccc6fa..e34a486f56e38 100644 --- a/vllm/model_executor/models/internvl.py +++ b/vllm/model_executor/models/internvl.py @@ -466,8 +466,11 @@ def forward( inputs_embeds=inputs_embeds) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states, sampling_metadata) diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py index 0030c761d34db..8c606916dfb5a 100644 --- a/vllm/model_executor/models/jais.py +++ b/vllm/model_executor/models/jais.py @@ -295,8 +295,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index ededf9c533f01..6296cd502b1e1 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -861,8 +861,11 @@ def _prepare_mamba_cache(self): dtype=dtype, device="cuda")) - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 023ae2a18d41c..0c67a9b8e198b 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -430,8 +430,11 @@ def forward( attn_metadata, intermediate_tensors) return model_output - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index 0ff68943b5103..71a46256040c6 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -355,8 +355,11 @@ def forward( return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states, sampling_metadata) diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py index d94af966162f7..8331cbe8bcd1e 100644 --- a/vllm/model_executor/models/llava_next.py +++ b/vllm/model_executor/models/llava_next.py @@ -588,8 +588,11 @@ def forward( return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: return self.language_model.compute_logits(hidden_states, sampling_metadata) diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py index 6453d0cb25c91..c2a61ca52011e 100644 --- a/vllm/model_executor/models/medusa.py +++ b/vllm/model_executor/models/medusa.py @@ -65,22 +65,28 @@ def forward(self, hidden_states: torch.Tensor) -> List[torch.Tensor]: def compute_logits( self, hidden_states: List[torch.Tensor], sampling_metadata: SamplingMetadata) -> List[torch.Tensor]: - logits = [] + logits_lst: List[torch.Tensor] = [] for hs, lm_head in zip(hidden_states, self.lm_heads): _logits = self.logits_processor(lm_head, hs, sampling_metadata) + if _logits is None: + # _logits should only be None on rank > 0, in which case + # it should remain true for every lm_head + assert len(logits_lst) == 0 + continue + if self.token_map is None: - logits.append(_logits) + logits_lst.append(_logits) else: - logits.append(-torch.inf * torch.ones( + logits_lst.append(-torch.inf * torch.ones( size=(*_logits.shape[:-1], self.orig_vocab_size), device=_logits.device, dtype=_logits.dtype)) - logits[-1][..., self.token_map] = _logits + logits_lst[-1][..., self.token_map] = _logits - return logits + return logits_lst def sample( self, diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 7f8f38fe8439a..ff42bdefe0269 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -470,8 +470,11 @@ def forward( attn_metadata, intermediate_tensors) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: hidden_states = hidden_states / self.scale_width if self.config.tie_word_embeddings: lm_head = self.model.embed_tokens diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py index 85522beb0f204..ab2b2c81ef4db 100644 --- a/vllm/model_executor/models/minicpmv.py +++ b/vllm/model_executor/models/minicpmv.py @@ -630,8 +630,11 @@ def forward( ) return output - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py index 8fbd537a2c031..34c21350dbc60 100644 --- a/vllm/model_executor/models/mixtral.py +++ b/vllm/model_executor/models/mixtral.py @@ -375,8 +375,11 @@ def forward( attn_metadata, intermediate_tensors) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py index 10faa5cc6b6cc..812dce5d04771 100644 --- a/vllm/model_executor/models/mixtral_quant.py +++ b/vllm/model_executor/models/mixtral_quant.py @@ -362,8 +362,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py index 7d658b39e6794..1a8e514a7ae83 100644 --- a/vllm/model_executor/models/mpt.py +++ b/vllm/model_executor/models/mpt.py @@ -279,8 +279,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py index bb85f20ab9802..57598b49bcca9 100644 --- a/vllm/model_executor/models/nemotron.py +++ b/vllm/model_executor/models/nemotron.py @@ -453,8 +453,11 @@ def forward( attn_metadata, intermediate_tensors) return model_output - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py index 1a0a3774dc8fb..8de124cd034dc 100644 --- a/vllm/model_executor/models/olmo.py +++ b/vllm/model_executor/models/olmo.py @@ -311,8 +311,11 @@ def forward( ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py index a05090cd46648..b05f799e4dd2b 100644 --- a/vllm/model_executor/models/opt.py +++ b/vllm/model_executor/models/opt.py @@ -323,8 +323,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py index 8159cc13fba0b..6923e11e288be 100644 --- a/vllm/model_executor/models/orion.py +++ b/vllm/model_executor/models/orion.py @@ -277,8 +277,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py index c6d59db643bbd..3b9470774f843 100644 --- a/vllm/model_executor/models/paligemma.py +++ b/vllm/model_executor/models/paligemma.py @@ -262,8 +262,11 @@ def forward(self, return hidden_states # Copied from vllm/model_executor/models/gemma.py - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.language_model.embed_tokens, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py index bc38d4421b79e..3300939c7b102 100644 --- a/vllm/model_executor/models/persimmon.py +++ b/vllm/model_executor/models/persimmon.py @@ -285,8 +285,11 @@ def forward( ) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py index ac7496f68fd99..54f4dd2fcde0a 100644 --- a/vllm/model_executor/models/phi.py +++ b/vllm/model_executor/models/phi.py @@ -286,8 +286,11 @@ def forward( return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata, self.lm_head.bias) return logits diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py index cc06929fefab4..98e344d483e29 100644 --- a/vllm/model_executor/models/phi3_small.py +++ b/vllm/model_executor/models/phi3_small.py @@ -399,8 +399,11 @@ def set_decoder(self, decoder): def get_decoder(self): return self.model - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) if self.dummy_token_indices is not None and logits is not None: diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py index e0e427218bdd4..dd921c6af0538 100644 --- a/vllm/model_executor/models/phi3v.py +++ b/vllm/model_executor/models/phi3v.py @@ -584,8 +584,11 @@ def forward(self, return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py index eb61adf34e9a7..a7485bcb489a0 100644 --- a/vllm/model_executor/models/qwen.py +++ b/vllm/model_executor/models/qwen.py @@ -281,8 +281,11 @@ def make_empty_intermediate_tensors( device=device), }) - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index a66a1eee7c160..b95987c16ebca 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -362,8 +362,11 @@ def forward( attn_metadata, intermediate_tensors) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py index b895788206d10..b85512095622f 100644 --- a/vllm/model_executor/models/qwen2_moe.py +++ b/vllm/model_executor/models/qwen2_moe.py @@ -400,8 +400,11 @@ def forward( attn_metadata, intermediate_tensors) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py index 5451b56ed05f7..c98226d61a8a0 100644 --- a/vllm/model_executor/models/stablelm.py +++ b/vllm/model_executor/models/stablelm.py @@ -258,8 +258,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py index 1752bfd473b88..d1b1d210b727c 100644 --- a/vllm/model_executor/models/starcoder2.py +++ b/vllm/model_executor/models/starcoder2.py @@ -268,8 +268,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py index 84f0ffc376d65..e9bf67d314d0a 100644 --- a/vllm/model_executor/models/xverse.py +++ b/vllm/model_executor/models/xverse.py @@ -328,8 +328,11 @@ def forward( attn_metadata) return hidden_states - def compute_logits(self, hidden_states: torch.Tensor, - sampling_metadata: SamplingMetadata) -> torch.Tensor: + def compute_logits( + self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> Optional[torch.Tensor]: logits = self.logits_processor(self.lm_head, hidden_states, sampling_metadata) return logits diff --git a/vllm/outputs.py b/vllm/outputs.py index 6e11ff841c62e..e091b576f5972 100644 --- a/vllm/outputs.py +++ b/vllm/outputs.py @@ -1,6 +1,8 @@ import time from dataclasses import dataclass -from typing import List, Optional, Tuple, Union +from typing import List, Optional +from typing import Sequence as GenericSequence +from typing import Union from vllm.lora.request import LoRARequest from vllm.sequence import (PromptLogprobs, RequestMetrics, SampleLogprobs, @@ -28,7 +30,7 @@ class CompletionOutput: index: int text: str - token_ids: Tuple[int, ...] + token_ids: GenericSequence[int] cumulative_logprob: Optional[float] logprobs: Optional[SampleLogprobs] finish_reason: Optional[str] = None @@ -139,7 +141,7 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": CompletionOutput( seqs.index(seq), seq.get_output_text_to_return(text_buffer_length), - seq.data._output_token_ids, # type: ignore + seq.data._output_token_ids, seq.get_cumulative_logprob() if include_logprobs else None, seq.output_logprobs if include_logprobs else None, SequenceStatus.get_finished_reason(seq.status),