From cf73f0c95e09836efff876d5bfd9b9c6cc1ba06e Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 26 Nov 2024 02:14:33 +0800 Subject: [PATCH] [Model] Enable optional prefix when loading embedding models (#10639) Signed-off-by: DarkLight1337 --- vllm/model_executor/models/bert.py | 9 +++++---- vllm/model_executor/models/gemma2.py | 4 +++- vllm/model_executor/models/llama.py | 5 ++++- vllm/model_executor/models/qwen2.py | 12 ++++++------ vllm/model_executor/models/roberta.py | 3 ++- 5 files changed, 20 insertions(+), 13 deletions(-) diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py index f570d6d3c12b3..1fff72b3490e9 100644 --- a/vllm/model_executor/models/bert.py +++ b/vllm/model_executor/models/bert.py @@ -14,18 +14,17 @@ RowParallelLinear) from vllm.model_executor.layers.pooler import (CrossEncodingPooler, Pooler, PoolingType) -from vllm.model_executor.layers.quantization.base_config import ( - QuantizationConfig) +from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.vocab_parallel_embedding import ( VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader -from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.transformers_utils.config import ( get_cross_encoder_activation_function) -from .utils import maybe_prefix +from .interfaces import SupportsCrossEncoding +from .utils import WeightsMapper, maybe_prefix class BertEmbedding(nn.Module): @@ -442,6 +441,8 @@ def pooler( return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = hf_to_vllm_mapper.apply(weights) self.model.load_weights(weights) def _build_model(self, diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py index fd8223dd9be1b..d229eb74669ee 100644 --- a/vllm/model_executor/models/gemma2.py +++ b/vllm/model_executor/models/gemma2.py @@ -42,7 +42,7 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, extract_layer_index, +from .utils import (AutoWeightsLoader, WeightsMapper, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -511,4 +511,6 @@ def pooler( return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = hf_to_vllm_mapper.apply(weights) self.model.load_weights(weights) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 66b29e72cfa89..33d78d74129c8 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -53,7 +53,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -689,6 +690,8 @@ def pooler( return self._pooler(hidden_states, pooling_metadata) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = hf_to_vllm_mapper.apply(weights) self.model.load_weights(weights) def load_kv_cache_scales(self, quantization_param_path: str) -> None: diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py index 8da75c9935a13..46640226d4cf8 100644 --- a/vllm/model_executor/models/qwen2.py +++ b/vllm/model_executor/models/qwen2.py @@ -50,7 +50,8 @@ from vllm.sequence import IntermediateTensors, PoolerOutput from .interfaces import SupportsLoRA, SupportsPP -from .utils import (AutoWeightsLoader, PPMissingLayer, is_pp_missing_parameter, +from .utils import (AutoWeightsLoader, PPMissingLayer, WeightsMapper, + is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix) @@ -585,8 +586,7 @@ def pooler( ) -> Optional[PoolerOutput]: return self._pooler(hidden_states, pooling_metadata) - def load_weights(self, weights: Iterable[Tuple[str, - torch.Tensor]]) -> Set[str]: - loader = AutoWeightsLoader(self, - ignore_unexpected_prefixes=["lm_head."]) - return loader.load_weights(weights) + def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]): + hf_to_vllm_mapper = WeightsMapper(orig_to_new_prefix={"model.": ""}) + weights = hf_to_vllm_mapper.apply(weights) + self.model.load_weights(weights) diff --git a/vllm/model_executor/models/roberta.py b/vllm/model_executor/models/roberta.py index 5a296e311f079..ba1a78ac640fd 100644 --- a/vllm/model_executor/models/roberta.py +++ b/vllm/model_executor/models/roberta.py @@ -11,13 +11,14 @@ VocabParallelEmbedding) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.models.bert import BertEmbeddingModel, BertModel -from vllm.model_executor.models.interfaces import SupportsCrossEncoding from vllm.model_executor.models.utils import maybe_prefix from vllm.model_executor.pooling_metadata import PoolingMetadata from vllm.sequence import IntermediateTensors, PoolerOutput from vllm.transformers_utils.config import ( get_cross_encoder_activation_function) +from .interfaces import SupportsCrossEncoding + class RobertaEmbedding(nn.Module):