From c4e464333eac5a46e1cc2701e095a44057c82927 Mon Sep 17 00:00:00 2001
From: Isotr0py <2037008807@qq.com>
Date: Mon, 18 Nov 2024 09:07:46 +0800
Subject: [PATCH] [Misc] Add uninitialized params tracking for
 `AutoWeightsLoader` (#10327)

Signed-off-by: Isotr0py <2037008807@qq.com>
---
 vllm/model_executor/model_loader/loader.py     | 12 +++++++++++-
 vllm/model_executor/models/arctic.py           |  8 ++++++--
 vllm/model_executor/models/baichuan.py         |  8 ++++++--
 vllm/model_executor/models/bert.py             |  8 ++++++--
 vllm/model_executor/models/blip.py             | 12 ++++++++----
 vllm/model_executor/models/blip2.py            |  7 ++++---
 vllm/model_executor/models/bloom.py            |  8 ++++++--
 vllm/model_executor/models/chameleon.py        |  8 ++++++--
 vllm/model_executor/models/chatglm.py          | 10 ++++++++--
 vllm/model_executor/models/clip.py             | 11 ++++++++---
 vllm/model_executor/models/commandr.py         |  4 +++-
 vllm/model_executor/models/dbrx.py             |  8 ++++++--
 vllm/model_executor/models/decilm.py           |  8 ++++++--
 vllm/model_executor/models/deepseek.py         |  8 ++++++--
 vllm/model_executor/models/deepseek_v2.py      |  8 ++++++--
 vllm/model_executor/models/exaone.py           |  9 +++++++--
 vllm/model_executor/models/falcon.py           |  8 ++++++--
 vllm/model_executor/models/florence2.py        | 17 +++++++++++------
 vllm/model_executor/models/fuyu.py             |  8 +++++---
 vllm/model_executor/models/gemma.py            |  4 +++-
 vllm/model_executor/models/gemma2.py           |  9 ++++++---
 vllm/model_executor/models/gpt2.py             |  8 ++++++--
 vllm/model_executor/models/gpt_bigcode.py      |  8 ++++++--
 vllm/model_executor/models/gpt_j.py            |  8 ++++++--
 vllm/model_executor/models/gpt_neox.py         |  8 ++++++--
 vllm/model_executor/models/granite.py          |  9 +++++++--
 vllm/model_executor/models/granitemoe.py       |  8 +++++---
 .../models/idefics2_vision_model.py            | 11 ++++++++---
 vllm/model_executor/models/idefics3.py         |  7 ++++---
 vllm/model_executor/models/intern_vit.py       |  8 ++++++--
 vllm/model_executor/models/internlm2.py        |  8 ++++++--
 vllm/model_executor/models/internvl.py         |  7 ++++---
 vllm/model_executor/models/jais.py             |  8 ++++++--
 vllm/model_executor/models/jamba.py            |  8 ++++++--
 vllm/model_executor/models/llama.py            | 15 ++++++++++-----
 vllm/model_executor/models/llava.py            |  7 ++++---
 vllm/model_executor/models/llava_next.py       |  7 ++++---
 vllm/model_executor/models/llava_next_video.py |  7 ++++---
 vllm/model_executor/models/llava_onevision.py  |  7 ++++---
 vllm/model_executor/models/mamba.py            |  8 ++++++--
 vllm/model_executor/models/medusa.py           |  9 +++++++--
 vllm/model_executor/models/minicpm.py          |  8 ++++++--
 vllm/model_executor/models/minicpmv.py         | 14 +++++++++-----
 vllm/model_executor/models/mixtral.py          |  8 ++++++--
 vllm/model_executor/models/mixtral_quant.py    |  8 ++++++--
 vllm/model_executor/models/mllama.py           |  9 ++++++---
 vllm/model_executor/models/mlp_speculator.py   |  8 ++++++--
 vllm/model_executor/models/mpt.py              |  8 ++++++--
 vllm/model_executor/models/nemotron.py         |  8 ++++++--
 vllm/model_executor/models/olmo.py             |  8 ++++++--
 vllm/model_executor/models/olmoe.py            |  8 ++++++--
 vllm/model_executor/models/opt.py              |  8 ++++++--
 vllm/model_executor/models/orion.py            |  8 ++++++--
 vllm/model_executor/models/paligemma.py        |  7 ++++---
 vllm/model_executor/models/persimmon.py        |  8 ++++++--
 vllm/model_executor/models/phi.py              |  8 ++++++--
 vllm/model_executor/models/phi3_small.py       |  8 ++++++--
 vllm/model_executor/models/phi3v.py            |  9 ++++++---
 vllm/model_executor/models/phimoe.py           |  8 ++++++--
 vllm/model_executor/models/pixtral.py          | 12 ++++++++----
 vllm/model_executor/models/qwen.py             |  8 ++++++--
 vllm/model_executor/models/qwen2.py            | 18 ++++++++++++------
 vllm/model_executor/models/qwen2_audio.py      |  9 +++++++--
 vllm/model_executor/models/qwen2_cls.py        |  7 ++++---
 vllm/model_executor/models/qwen2_moe.py        |  8 ++++++--
 vllm/model_executor/models/qwen2_rm.py         |  7 ++++---
 vllm/model_executor/models/qwen2_vl.py         |  8 ++++++--
 vllm/model_executor/models/siglip.py           | 11 ++++++++---
 vllm/model_executor/models/solar.py            |  9 +++++++--
 vllm/model_executor/models/stablelm.py         |  8 ++++++--
 vllm/model_executor/models/starcoder2.py       |  8 ++++++--
 vllm/model_executor/models/ultravox.py         |  7 ++++---
 vllm/model_executor/models/utils.py            | 11 ++++++-----
 vllm/model_executor/models/xverse.py           |  8 ++++++--
 74 files changed, 454 insertions(+), 185 deletions(-)

diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py
index 0f8b81c3ef40c..d9ce85949e4ee 100644
--- a/vllm/model_executor/model_loader/loader.py
+++ b/vllm/model_executor/model_loader/loader.py
@@ -334,7 +334,17 @@ def load_model(self, vllm_config: VllmConfig) -> nn.Module:
             with target_device:
                 model = _initialize_model(vllm_config=vllm_config)
 
-            model.load_weights(self._get_all_weights(model_config, model))
+            weights_to_load = {name for name, _ in model.named_parameters()}
+            loaded_weights = model.load_weights(
+                self._get_all_weights(model_config, model))
+            # We only enable strict check for non-quantiized models
+            # that have loaded weights tracking currently.
+            if model_config.quantization is None and loaded_weights is not None:
+                weights_not_loaded = weights_to_load - loaded_weights
+                if weights_not_loaded:
+                    raise ValueError(
+                        "Following weights were not initialized from "
+                        f"checkpoint: {weights_not_loaded}")
 
             for _, module in model.named_modules():
                 quant_method = getattr(module, "quant_method", None)
diff --git a/vllm/model_executor/models/arctic.py b/vllm/model_executor/models/arctic.py
index d52418ee0f6f1..e58ad19cab54c 100644
--- a/vllm/model_executor/models/arctic.py
+++ b/vllm/model_executor/models/arctic.py
@@ -1,5 +1,5 @@
 """Inference-only Snowflake Arctic model."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -480,7 +480,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -518,6 +519,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                         ("ws", f"experts.{expert_id}.w3.weight", expert_id))
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
 
         logger.info(
             "It will take ~10 minutes loading from the 16-bit weights. "
@@ -573,3 +575,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                         weight_loader = getattr(param, "weight_loader",
                                                 default_weight_loader)
                         weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/baichuan.py b/vllm/model_executor/models/baichuan.py
index 01ce7c42cd391..3749a16a38994 100644
--- a/vllm/model_executor/models/baichuan.py
+++ b/vllm/model_executor/models/baichuan.py
@@ -18,7 +18,7 @@
 # limitations under the License.
 """Inference-only BaiChuan model compatible with HuggingFace weights."""
 import math
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -404,13 +404,15 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("gate_up_proj", "gate_proj", 0),
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -449,6 +451,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
 
 class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
diff --git a/vllm/model_executor/models/bert.py b/vllm/model_executor/models/bert.py
index 42dd6119e76f1..d8301a36acb01 100644
--- a/vllm/model_executor/models/bert.py
+++ b/vllm/model_executor/models/bert.py
@@ -1,4 +1,4 @@
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Set, Tuple
 
 import torch
 from torch import nn
@@ -337,7 +337,8 @@ def forward(
 
         return self.encoder(hidden_states, kv_caches, attn_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "query", "q"),
@@ -346,6 +347,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         ]
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "pooler" in name:
                 continue
@@ -368,6 +370,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
 
 class BertEmbeddingModel(nn.Module):
diff --git a/vllm/model_executor/models/blip.py b/vllm/model_executor/models/blip.py
index e612010677364..6db6462e97f3f 100644
--- a/vllm/model_executor/models/blip.py
+++ b/vllm/model_executor/models/blip.py
@@ -1,6 +1,6 @@
 """Minimal implementation of BlipVisionModel intended to be only used 
 within a vision language model."""
-from typing import Iterable, Optional, Tuple, Union
+from typing import Iterable, Optional, Set, Tuple, Union
 
 import torch
 import torch.nn as nn
@@ -415,7 +415,8 @@ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
 
         return self.post_layernorm(hidden_states)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -423,6 +424,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("qkv_proj", "v_proj", "v"),
         ] if self.shard_weight else []
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         layer_count = len(self.encoder.layers)
 
         for name, loaded_weight in weights:
@@ -440,8 +442,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
-
-                param = params_dict[name.replace(weight_name, param_name)]
+                name = name.replace(weight_name, param_name)
+                param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
@@ -450,3 +452,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/blip2.py b/vllm/model_executor/models/blip2.py
index 03dc1d15ab697..7d7639b4a92ce 100644
--- a/vllm/model_executor/models/blip2.py
+++ b/vllm/model_executor/models/blip2.py
@@ -1,5 +1,5 @@
 from functools import cached_property
-from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
                     TypedDict, Union)
 
 import torch
@@ -692,6 +692,7 @@ def sample(
     ) -> Optional[SamplerOutput]:
         return self.language_model.sample(logits, sampling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self)
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/bloom.py b/vllm/model_executor/models/bloom.py
index cf2eee8172769..1060d418474ef 100644
--- a/vllm/model_executor/models/bloom.py
+++ b/vllm/model_executor/models/bloom.py
@@ -16,7 +16,7 @@
 # limitations under the License.
 """Inference-only BLOOM model compatible with HuggingFace weights."""
 import math
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -341,8 +341,10 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if name == "lm_head.weight":
                 continue
@@ -371,3 +373,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/chameleon.py b/vllm/model_executor/models/chameleon.py
index 7b59c818e0b60..8f91abffaea90 100644
--- a/vllm/model_executor/models/chameleon.py
+++ b/vllm/model_executor/models/chameleon.py
@@ -1,5 +1,5 @@
 from functools import cached_property
-from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
+from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
                     Tuple, TypedDict, Union)
 
 import torch
@@ -1034,7 +1034,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             (".qkv_proj", ".q_proj", "q"),
@@ -1044,6 +1045,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             (".gate_up_proj", ".up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -1111,3 +1113,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py
index 70e9b607b0642..81e56381eabd8 100644
--- a/vllm/model_executor/models/chatglm.py
+++ b/vllm/model_executor/models/chatglm.py
@@ -3,7 +3,8 @@
 """Inference-only ChatGLM model compatible with THUDM weights."""
 from argparse import Namespace
 from array import array
-from typing import Dict, Iterable, List, Mapping, Optional, Tuple, TypedDict
+from typing import (Dict, Iterable, List, Mapping, Optional, Set, Tuple,
+                    TypedDict)
 
 import torch
 from PIL import Image
@@ -645,7 +646,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         # Merge two ColumnParallelLinear into one MergedColumnParallelLinear
         merged_weights_dict: Dict[str, Dict[str, Optional[torch.Tensor]]] = {
             "transformer.vision.linear_proj.merged_proj.weight": {
@@ -655,6 +657,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         }
 
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             is_weight_to_be_merge = False
             for _, merged_weight_dict in merged_weights_dict.items():
@@ -677,6 +680,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
 
         for combined_name, merged_weight_dict in merged_weights_dict.items():
             if combined_name in params_dict:
@@ -686,3 +690,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, combined_weight)
+                loaded_params.add(combined_name)
+        return loaded_params
diff --git a/vllm/model_executor/models/clip.py b/vllm/model_executor/models/clip.py
index 2d81b9266826b..184758f4a8a45 100644
--- a/vllm/model_executor/models/clip.py
+++ b/vllm/model_executor/models/clip.py
@@ -1,6 +1,6 @@
 """Minimal implementation of CLIPVisionModel intended to be only used
 within a vision language model."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import numpy as np
 import torch
@@ -483,7 +483,8 @@ def device(self):
 
     # (TODO) Add prefix argument for filtering out weights to be loaded
     #        ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -491,6 +492,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("qkv_proj", "v_proj", "v"),
         ] if self.shard_weight else []
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         layer_count = len(self.vision_model.encoder.layers)
 
         for name, loaded_weight in weights:
@@ -508,8 +510,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
+                name = name.replace(weight_name, param_name)
 
-                param = params_dict[name.replace(weight_name, param_name)]
+                param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
@@ -518,3 +521,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/commandr.py b/vllm/model_executor/models/commandr.py
index fbb09a64cde9b..9fd083e5a02a9 100644
--- a/vllm/model_executor/models/commandr.py
+++ b/vllm/model_executor/models/commandr.py
@@ -402,7 +402,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -447,3 +448,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
             loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/dbrx.py b/vllm/model_executor/models/dbrx.py
index 3952ff31e5cec..eab338800249e 100644
--- a/vllm/model_executor/models/dbrx.py
+++ b/vllm/model_executor/models/dbrx.py
@@ -1,4 +1,4 @@
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 import torch.nn as nn
@@ -417,13 +417,15 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
 
         expert_params_mapping = [(
             "w13_weight" if weight_name in ["w1", "v1"] else "w2_weight",
             f"mlp.{weight_name}",
         ) for weight_name in ["w1", "v1", "w2"]]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             for param_name, weight_name in expert_params_mapping:
                 if weight_name not in name:
@@ -447,3 +449,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/decilm.py b/vllm/model_executor/models/decilm.py
index b38fd9fa49c21..c551853956b92 100644
--- a/vllm/model_executor/models/decilm.py
+++ b/vllm/model_executor/models/decilm.py
@@ -22,7 +22,7 @@
 # limitations under the License.
 """Inference-only DeciLM model compatible with HuggingFace weights."""
 
-from typing import Iterable, Tuple
+from typing import Iterable, Set, Tuple
 
 import torch
 
@@ -57,7 +57,8 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
         delattr(config, "num_key_value_heads_per_layer")
         super().__init__(vllm_config=vllm_config)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -67,6 +68,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -97,6 +99,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
     def _degroup_weight(self, loaded_weight: torch.Tensor) -> torch.Tensor:
         hidden_size = self.config.hidden_size
diff --git a/vllm/model_executor/models/deepseek.py b/vllm/model_executor/models/deepseek.py
index 36dfea5a65656..8c5ad9904e925 100644
--- a/vllm/model_executor/models/deepseek.py
+++ b/vllm/model_executor/models/deepseek.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Deepseek model."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -442,7 +442,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -453,6 +454,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         ]
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -487,3 +489,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/deepseek_v2.py b/vllm/model_executor/models/deepseek_v2.py
index 1e32fe60c7a5b..d2c4ca0bf85e9 100644
--- a/vllm/model_executor/models/deepseek_v2.py
+++ b/vllm/model_executor/models/deepseek_v2.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only DeepseekV2 model."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -550,7 +550,8 @@ def make_empty_intermediate_tensors(
                         device=device),
         })
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("gate_up_proj", "gate_proj", 0),
@@ -566,6 +567,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             num_experts=self.config.n_routed_experts)
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -623,3 +625,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)
                     weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/exaone.py b/vllm/model_executor/models/exaone.py
index 52dd603ca558d..9d739d0479548 100644
--- a/vllm/model_executor/models/exaone.py
+++ b/vllm/model_executor/models/exaone.py
@@ -22,7 +22,7 @@
 # limitations under the License.
 """Inference-only Exaone model compatible with HuggingFace weights."""
 
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -513,7 +513,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             (".qkv_proj", ".q_proj", "q"),
@@ -523,6 +524,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             (".gate_up_proj", ".c_fc_1", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -543,6 +545,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                                         default_weight_loader)
                 loaded_weight = loaded_weight[0]
                 weight_loader(param, loaded_weight)
+                loaded_params.add(scale_name)
                 continue
             for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
@@ -576,6 +579,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
     # If this function is called, it should always initialize KV cache scale
     # factors (or else raise an exception). Thus, handled exceptions should
diff --git a/vllm/model_executor/models/falcon.py b/vllm/model_executor/models/falcon.py
index e97abe949ccdb..2aa4b67d99894 100644
--- a/vllm/model_executor/models/falcon.py
+++ b/vllm/model_executor/models/falcon.py
@@ -18,7 +18,7 @@
 """PyTorch Falcon model."""
 
 import math
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -473,7 +473,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         total_num_heads = self.config.num_attention_heads
         if self.config.new_decoder_architecture:
             total_num_kv_heads = self.config.num_kv_heads
@@ -483,6 +484,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             total_num_kv_heads = total_num_heads
         num_query_heads_per_kv_head = total_num_heads // total_num_kv_heads
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if name == "lm_head.weight" and self.tie_word_embeddings:
                 # Falcon uses tied embeddings except Falcon-11b.
@@ -519,3 +521,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/florence2.py b/vllm/model_executor/models/florence2.py
index 971a71180164b..d3a9ff6915b84 100644
--- a/vllm/model_executor/models/florence2.py
+++ b/vllm/model_executor/models/florence2.py
@@ -1,5 +1,5 @@
 import math
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Set, Tuple
 
 import torch
 import torch.nn as nn
@@ -156,7 +156,8 @@ def sample(self, logits: torch.Tensor,
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -165,12 +166,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         ]
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
-
-                param = params_dict[name.replace(weight_name, param_name)]
+                name = name.replace(weight_name, param_name)
+                param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
@@ -183,6 +185,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
 
 class Florence2ForConditionalGeneration(nn.Module):
@@ -248,10 +252,11 @@ def sample(
     ) -> SamplerOutput:
         return self.language_model.sample(logits, sampling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         skip_prefixes = [
             'image_projection', "vision_tower", "image_proj_norm",
             "image_pos_embed", "visual_temporal_embed"
         ]
         loader = AutoWeightsLoader(self, skip_prefixes=skip_prefixes)
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/fuyu.py b/vllm/model_executor/models/fuyu.py
index 31fc098a8bb3f..7b46907ac83ab 100644
--- a/vllm/model_executor/models/fuyu.py
+++ b/vllm/model_executor/models/fuyu.py
@@ -16,7 +16,8 @@
 """ PyTorch Fuyu model."""
 import math
 from array import array
-from typing import Iterable, List, Literal, Mapping, Optional, Tuple, TypedDict
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
+                    TypedDict)
 
 import torch
 import torch.nn as nn
@@ -354,6 +355,7 @@ def sample(
         next_tokens = self.language_model.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self)
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/gemma.py b/vllm/model_executor/models/gemma.py
index ace13664c6ea6..64e03b30bf2f1 100644
--- a/vllm/model_executor/models/gemma.py
+++ b/vllm/model_executor/models/gemma.py
@@ -424,7 +424,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -469,3 +470,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             logger.warning(
                 "Some weights are not initialized from checkpoints: %s",
                 unloaded_params)
+        return loaded_params
diff --git a/vllm/model_executor/models/gemma2.py b/vllm/model_executor/models/gemma2.py
index a60b4e73a76d4..4ba39223cc07f 100644
--- a/vllm/model_executor/models/gemma2.py
+++ b/vllm/model_executor/models/gemma2.py
@@ -312,7 +312,8 @@ def forward(
         hidden_states, _ = self.norm(hidden_states, residual)
         return hidden_states
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -354,6 +355,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             logger.warning(
                 "Some weights are not initialized from checkpoints: %s",
                 unloaded_params)
+        return loaded_params
 
 
 class Gemma2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
@@ -451,13 +453,14 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(
             self,
             skip_prefixes=(["lm_head."]
                            if self.config.tie_word_embeddings else None),
         )
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
 
 
 class Gemma2EmbeddingModel(nn.Module, SupportsPP):
diff --git a/vllm/model_executor/models/gpt2.py b/vllm/model_executor/models/gpt2.py
index fa0fdad28d161..1c61408ae1dd9 100644
--- a/vllm/model_executor/models/gpt2.py
+++ b/vllm/model_executor/models/gpt2.py
@@ -16,7 +16,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only GPT-2 model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -298,8 +298,10 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "lm_head.weight" in name:
                 # GPT-2 ties the weights of the embedding layer and the final
@@ -328,3 +330,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/gpt_bigcode.py b/vllm/model_executor/models/gpt_bigcode.py
index b2fc79d0d36dc..50a143cb1b600 100644
--- a/vllm/model_executor/models/gpt_bigcode.py
+++ b/vllm/model_executor/models/gpt_bigcode.py
@@ -17,7 +17,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only GPTBigCode model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -323,8 +323,10 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "lm_head.weight" in name:
                 continue
@@ -344,3 +346,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader(param, loaded_weight, 'v')
             else:
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/gpt_j.py b/vllm/model_executor/models/gpt_j.py
index cec3fd12a67d6..d5defc60764e6 100644
--- a/vllm/model_executor/models/gpt_j.py
+++ b/vllm/model_executor/models/gpt_j.py
@@ -15,7 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only GPT-J model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -291,7 +291,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -301,6 +302,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "attn.bias" in name or "attn.masked_bias" in name:
                 continue
@@ -330,3 +332,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/gpt_neox.py b/vllm/model_executor/models/gpt_neox.py
index 11f286d6bcba0..0bb5e2f9b95f9 100644
--- a/vllm/model_executor/models/gpt_neox.py
+++ b/vllm/model_executor/models/gpt_neox.py
@@ -15,7 +15,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only GPT-NeoX model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -303,8 +303,10 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if ("attention.bias" in name or "attention.masked_bias" in name
                     or "rotary_emb.inv_freq" in name):
@@ -337,3 +339,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/granite.py b/vllm/model_executor/models/granite.py
index cb2583e69d88d..c1e2e87f08ec3 100644
--- a/vllm/model_executor/models/granite.py
+++ b/vllm/model_executor/models/granite.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only IBM Granite model compatible with HuggingFace weights."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -455,7 +455,8 @@ def make_empty_intermediate_tensors(
                         device=device),
         })
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             (".qkv_proj", ".q_proj", "q"),
@@ -465,6 +466,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             (".gate_up_proj", ".up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -485,6 +487,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                                         default_weight_loader)
                 loaded_weight = loaded_weight[0]
                 weight_loader(param, loaded_weight)
+                loaded_params.add(scale_name)
                 continue
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
@@ -518,6 +521,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
     # If this function is called, it should always initialize KV cache scale
     # factors (or else raise an exception). Thus, handled exceptions should
diff --git a/vllm/model_executor/models/granitemoe.py b/vllm/model_executor/models/granitemoe.py
index f437dd521a7d5..a91a18816995f 100644
--- a/vllm/model_executor/models/granitemoe.py
+++ b/vllm/model_executor/models/granitemoe.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only GraniteMoe model."""
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Set, Tuple
 
 import torch
 from torch import nn
@@ -419,7 +419,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         new_weights = {}
         for n, p in weights:
             if n.endswith('.block_sparse_moe.input_linear.weight'):
@@ -452,4 +453,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 pass
             else:
                 new_weights[n] = p
-        mixtral.MixtralForCausalLM.load_weights(self, new_weights.items())
+        return mixtral.MixtralForCausalLM.load_weights(self,
+                                                       new_weights.items())
diff --git a/vllm/model_executor/models/idefics2_vision_model.py b/vllm/model_executor/models/idefics2_vision_model.py
index b21bc2a3f9ce1..16192928beb1f 100644
--- a/vllm/model_executor/models/idefics2_vision_model.py
+++ b/vllm/model_executor/models/idefics2_vision_model.py
@@ -15,7 +15,7 @@
 # limitations under the License.
 """PyTorch Idefics2 model."""
 
-from typing import Iterable, Optional, Tuple
+from typing import Iterable, Optional, Set, Tuple
 
 import torch
 from torch import nn
@@ -331,7 +331,8 @@ def forward(
         last_hidden_state = self.post_layernorm(encoder_outputs)
         return last_hidden_state
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -339,11 +340,13 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("qkv_proj", "v_proj", "v"),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
                     continue
-                param = params_dict[name.replace(weight_name, param_name)]
+                name = name.replace(weight_name, param_name)
+                param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
@@ -352,3 +355,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/idefics3.py b/vllm/model_executor/models/idefics3.py
index 0cecc754e916f..5d176b2a4e416 100644
--- a/vllm/model_executor/models/idefics3.py
+++ b/vllm/model_executor/models/idefics3.py
@@ -15,7 +15,7 @@
 
 import math
 from typing import (Dict, Iterable, List, Literal, Mapping, NamedTuple,
-                    Optional, Tuple, TypedDict, Union)
+                    Optional, Set, Tuple, TypedDict, Union)
 
 import torch
 import torch.utils.checkpoint
@@ -751,9 +751,10 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self)
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
 
     def get_mm_mapping(self) -> MultiModelKeys:
         """
diff --git a/vllm/model_executor/models/intern_vit.py b/vllm/model_executor/models/intern_vit.py
index 9761635d2a6c2..bd91a0806ae5c 100644
--- a/vllm/model_executor/models/intern_vit.py
+++ b/vllm/model_executor/models/intern_vit.py
@@ -5,7 +5,7 @@
 # Licensed under The MIT License [see LICENSE for details]
 # --------------------------------------------------------
 from functools import partial
-from typing import Iterable, Optional, Tuple
+from typing import Iterable, Optional, Set, Tuple
 
 import torch
 import torch.nn as nn
@@ -469,10 +469,14 @@ def forward(
 
         return encoder_outputs
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             param = params_dict[name]
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/internlm2.py b/vllm/model_executor/models/internlm2.py
index 19bfe16e4d5fc..94b819b5d9366 100644
--- a/vllm/model_executor/models/internlm2.py
+++ b/vllm/model_executor/models/internlm2.py
@@ -1,5 +1,5 @@
 from functools import partial
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -369,13 +369,15 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("gate_up_proj", "w1", 0),
             ("gate_up_proj", "w3", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -402,3 +404,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/internvl.py b/vllm/model_executor/models/internvl.py
index 92579e3aae949..7ea2f9be2191d 100644
--- a/vllm/model_executor/models/internvl.py
+++ b/vllm/model_executor/models/internvl.py
@@ -6,7 +6,7 @@
 # --------------------------------------------------------
 import re
 from functools import cached_property, partial
-from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
                     TypedDict, Union)
 
 import torch
@@ -663,6 +663,7 @@ def sample(
     ) -> Optional[SamplerOutput]:
         return self.language_model.sample(logits, sampling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self)
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/jais.py b/vllm/model_executor/models/jais.py
index ee49ffb3cd87f..41db85b678456 100644
--- a/vllm/model_executor/models/jais.py
+++ b/vllm/model_executor/models/jais.py
@@ -19,7 +19,7 @@
 """Inference-only Jais model compatible with HuggingFace weights."""
 
 import math
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -350,8 +350,10 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "lm_head.weight" in name:
                 # GPT-2 ties the weights of the embedding layer and the final
@@ -382,3 +384,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py
index 5612dd6886385..f83f0fce7275f 100644
--- a/vllm/model_executor/models/jamba.py
+++ b/vllm/model_executor/models/jamba.py
@@ -1,5 +1,5 @@
 """Inference-only Jamba model."""
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Set, Tuple
 
 import torch
 from torch import nn
@@ -462,7 +462,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -479,6 +480,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             num_experts=self.config.num_experts)
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -534,6 +536,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)
                     weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
 
 def _is_moe_layer(name: str):
diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py
index e53631ef19f31..2b40e9ec73fad 100644
--- a/vllm/model_executor/models/llama.py
+++ b/vllm/model_executor/models/llama.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only LLaMA model compatible with HuggingFace weights."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -350,7 +350,8 @@ def forward(
         hidden_states, _ = self.norm(hidden_states, residual)
         return hidden_states
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             (".qkv_proj", ".q_proj", "q"),
@@ -360,6 +361,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             (".gate_up_proj", ".up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -375,6 +377,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                                         default_weight_loader)
                 loaded_weight = loaded_weight[0]
                 weight_loader(param, loaded_weight)
+                loaded_params.add(scale_name)
                 continue
             for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
@@ -390,7 +393,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
-
                 break
             else:
                 # Skip loading extra bias for GPTQ models.
@@ -408,6 +410,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
     # If this function is called, it should always initialize KV cache scale
     # factors (or else raise an exception). Thus, handled exceptions should
@@ -577,13 +581,14 @@ def sample(self, logits: torch.Tensor,
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(
             self,
             skip_prefixes=(["lm_head."]
                            if self.config.tie_word_embeddings else None),
         )
-        loader.load_weights(
+        return loader.load_weights(
             self.maybe_remap_mistral(name, loaded_weight)
             for name, loaded_weight in weights)
 
diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py
index b13bcfa676811..e7d3161a7cb2d 100644
--- a/vllm/model_executor/models/llava.py
+++ b/vllm/model_executor/models/llava.py
@@ -1,5 +1,5 @@
 from functools import cached_property
-from typing import (Iterable, List, Literal, Mapping, Optional, Protocol,
+from typing import (Iterable, List, Literal, Mapping, Optional, Protocol, Set,
                     Tuple, TypedDict, Union)
 
 import torch
@@ -547,6 +547,7 @@ def sample(
     ) -> Optional[SamplerOutput]:
         return self.language_model.sample(logits, sampling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self)
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/llava_next.py b/vllm/model_executor/models/llava_next.py
index dd2fa6cac969f..37e2227a52dcd 100644
--- a/vllm/model_executor/models/llava_next.py
+++ b/vllm/model_executor/models/llava_next.py
@@ -1,5 +1,5 @@
 from functools import cached_property
-from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
                     TypedDict, Union)
 
 import torch
@@ -654,6 +654,7 @@ def pooler(
     ) -> Optional[PoolerOutput]:
         return self._pooler(hidden_states, pooling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self)
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/llava_next_video.py b/vllm/model_executor/models/llava_next_video.py
index 5d5598d07bfde..e2880c76cf43d 100644
--- a/vllm/model_executor/models/llava_next_video.py
+++ b/vllm/model_executor/models/llava_next_video.py
@@ -1,6 +1,6 @@
 import math
 from functools import cached_property
-from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
                     TypedDict, Union)
 
 import numpy as np
@@ -445,10 +445,11 @@ def sample(
     ) -> Optional[SamplerOutput]:
         return self.language_model.sample(logits, sampling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(
             self,
             # This model doesn't support images for now
             ignore_unexpected_prefixes=["image_newline"],
         )
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/llava_onevision.py b/vllm/model_executor/models/llava_onevision.py
index a5b2108177830..705ca1e4ab6e6 100644
--- a/vllm/model_executor/models/llava_onevision.py
+++ b/vllm/model_executor/models/llava_onevision.py
@@ -1,6 +1,6 @@
 import math
 from functools import cached_property
-from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
                     TypedDict, Union)
 
 import numpy as np
@@ -887,6 +887,7 @@ def sample(
     ) -> Optional[SamplerOutput]:
         return self.language_model.sample(logits, sampling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self)
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py
index ac0d265a961f0..405b8f7787ba8 100644
--- a/vllm/model_executor/models/mamba.py
+++ b/vllm/model_executor/models/mamba.py
@@ -1,5 +1,5 @@
 """PyTorch MAMBA model."""
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Set, Tuple
 
 import torch
 from torch import nn
@@ -243,8 +243,10 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "A_log" in name:
                 name = name.replace("A_log", "A")
@@ -256,3 +258,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/medusa.py b/vllm/model_executor/models/medusa.py
index b05360b55466b..b4ed6538bddac 100644
--- a/vllm/model_executor/models/medusa.py
+++ b/vllm/model_executor/models/medusa.py
@@ -1,4 +1,4 @@
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Set, Tuple
 
 import torch
 import torch.nn as nn
@@ -156,8 +156,10 @@ def generate_proposals(
             sampling_metadata=sampling_metadata,
         )
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
 
         weights_map = {}
 
@@ -181,9 +183,12 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
 
         if self.token_map is not None:
             self.token_map.to(device=self.lm_heads[0].weight.device)
 
         assert (self.truncated_vocab_size
                 == self.orig_vocab_size) or (self.token_map is not None)
+
+        return loaded_params
diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py
index 6b67266c53362..b92bff4d7c28c 100644
--- a/vllm/model_executor/models/minicpm.py
+++ b/vllm/model_executor/models/minicpm.py
@@ -21,7 +21,7 @@
 # limitations under the License.
 """Inference-only MiniCPM model compatible with HuggingFace weights."""
 import math
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -539,7 +539,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -556,6 +557,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             for weight_name in ["w1", "w2", "w3"]
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -606,3 +608,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)
                     weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/minicpmv.py b/vllm/model_executor/models/minicpmv.py
index fd8eda997f76f..99bf1d42d0355 100644
--- a/vllm/model_executor/models/minicpmv.py
+++ b/vllm/model_executor/models/minicpmv.py
@@ -24,7 +24,7 @@
 import re
 from functools import partial
 from typing import (Any, Callable, Iterable, List, Literal, Mapping, Optional,
-                    Tuple, TypedDict, Union)
+                    Set, Tuple, TypedDict, Union)
 
 import torch
 import torch.types
@@ -602,7 +602,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -612,6 +613,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             for key_to_modify, new_key in _KEYS_TO_MODIFY_MAPPING.items():
                 if key_to_modify in name:
@@ -630,10 +632,10 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 for param_name, weight_name, shard_id in stacked_params_mapping:
                     if weight_name not in name:
                         continue
-                    if is_pp_missing_parameter(
-                            name.replace(weight_name, param_name), self):
+                    name = name.replace(weight_name, param_name)
+                    if is_pp_missing_parameter(name, self):
                         continue
-                    param = params_dict[name.replace(weight_name, param_name)]
+                    param = params_dict[name]
                     weight_loader = param.weight_loader
                     weight_loader(param, loaded_weight, shard_id)
                     break
@@ -646,6 +648,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
     def get_mm_mapping(self) -> MultiModelKeys:
         """
diff --git a/vllm/model_executor/models/mixtral.py b/vllm/model_executor/models/mixtral.py
index eebf5bab5a288..0faffb4f1b00c 100644
--- a/vllm/model_executor/models/mixtral.py
+++ b/vllm/model_executor/models/mixtral.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Mixtral model."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -404,7 +404,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -421,6 +422,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             num_experts=self.config.num_local_experts)
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -478,3 +480,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)
                     weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/mixtral_quant.py b/vllm/model_executor/models/mixtral_quant.py
index af2e9586988df..ddd6afcf6a1b6 100644
--- a/vllm/model_executor/models/mixtral_quant.py
+++ b/vllm/model_executor/models/mixtral_quant.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Mixtral model."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import numpy as np
 import torch
@@ -409,7 +409,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -418,6 +419,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         ]
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -448,3 +450,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/mllama.py b/vllm/model_executor/models/mllama.py
index db7ee7b2d8537..41f62b37f3bd9 100644
--- a/vllm/model_executor/models/mllama.py
+++ b/vllm/model_executor/models/mllama.py
@@ -13,7 +13,7 @@
 # limitations under the License.
 """PyTorch Mllama model."""
 import math
-from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
                     TypedDict, Union)
 
 import numpy as np
@@ -1427,7 +1427,8 @@ def forward(
 
         return outputs
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             (".qkv_proj", ".q_proj", "q"),
@@ -1437,7 +1438,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             (".gate_up_proj", ".up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
-        updated_params = set()
+        updated_params: Set[str] = set()
         for name, loaded_weight in weights:
             if 'patch_embedding.weight' in name:
                 name = name.replace('patch_embedding.weight',
@@ -1457,6 +1458,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+                updated_params.add(name)
+        return updated_params
 
 
 def skip_attention_mask(sparse_mask: List[List[int]]) -> bool:
diff --git a/vllm/model_executor/models/mlp_speculator.py b/vllm/model_executor/models/mlp_speculator.py
index 4d7e82880041d..f2aa2653c4f5c 100644
--- a/vllm/model_executor/models/mlp_speculator.py
+++ b/vllm/model_executor/models/mlp_speculator.py
@@ -1,5 +1,5 @@
 import math
-from typing import Iterable, List, Tuple
+from typing import Iterable, List, Set, Tuple
 
 import torch
 import torch.nn as nn
@@ -188,11 +188,15 @@ def generate_proposals(
 
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             param = params_dict.get(name.replace("speculator.", ""))
             if param is not None:
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+                loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/mpt.py b/vllm/model_executor/models/mpt.py
index 3c74ef2448abb..8716e92b0f1c2 100644
--- a/vllm/model_executor/models/mpt.py
+++ b/vllm/model_executor/models/mpt.py
@@ -1,6 +1,6 @@
 # Adapted from https://huggingface.co/mosaicml/mpt-7b/tree/main
 import math
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 import torch.nn as nn
@@ -324,8 +324,10 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             # Skip loading extra bias for GPTQ models.
             if name.endswith(".bias") and name not in params_dict:
@@ -336,3 +338,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/nemotron.py b/vllm/model_executor/models/nemotron.py
index eb45beae7d21a..ceab299a7950a 100644
--- a/vllm/model_executor/models/nemotron.py
+++ b/vllm/model_executor/models/nemotron.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Nemotron model compatible with HuggingFace weights."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -474,7 +474,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             (".qkv_proj", ".q_proj", "q"),
@@ -482,6 +483,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             (".qkv_proj", ".v_proj", "v"),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -522,3 +524,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/olmo.py b/vllm/model_executor/models/olmo.py
index 98d4e1ec320a4..dc138e2e636ad 100644
--- a/vllm/model_executor/models/olmo.py
+++ b/vllm/model_executor/models/olmo.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only OLMo model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -356,7 +356,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -366,6 +367,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -402,3 +404,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/olmoe.py b/vllm/model_executor/models/olmoe.py
index f4eebab8c98dd..ab87695d8e650 100644
--- a/vllm/model_executor/models/olmoe.py
+++ b/vllm/model_executor/models/olmoe.py
@@ -10,7 +10,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only OLMoE model compatible with HuggingFace weights."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -364,7 +364,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -383,6 +384,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             num_experts=self.config.num_experts)
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -455,3 +457,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)
                     weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/opt.py b/vllm/model_executor/models/opt.py
index 997fe642439e6..db85a494980a7 100644
--- a/vllm/model_executor/models/opt.py
+++ b/vllm/model_executor/models/opt.py
@@ -16,7 +16,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only OPT model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -394,7 +394,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -402,6 +403,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("qkv_proj", "v_proj", "v"),
         ]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "lm_head.weight" in name and self.config.tie_word_embeddings:
                 continue
@@ -431,3 +433,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/orion.py b/vllm/model_executor/models/orion.py
index 39d659c49cbcf..b01734af8ddd8 100644
--- a/vllm/model_executor/models/orion.py
+++ b/vllm/model_executor/models/orion.py
@@ -3,7 +3,7 @@
 # Copyright (c) OrionStar Inc.
 # LICENSE: https://huggingface.co/OrionStarAI/Orion-14B-Base/blob/main/LICENSE
 """Inference-only Orion-14B model compatible with HuggingFace weights."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -327,7 +327,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -337,6 +338,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -368,3 +370,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/paligemma.py b/vllm/model_executor/models/paligemma.py
index eea229359255e..dd5256eb87ab3 100644
--- a/vllm/model_executor/models/paligemma.py
+++ b/vllm/model_executor/models/paligemma.py
@@ -1,4 +1,4 @@
-from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
                     TypedDict, Union)
 
 import torch
@@ -295,6 +295,7 @@ def sample(
     ) -> Optional[SamplerOutput]:
         return self.language_model.sample(logits, sampling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self)
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/persimmon.py b/vllm/model_executor/models/persimmon.py
index 62c509153a111..3b8199f4f1661 100644
--- a/vllm/model_executor/models/persimmon.py
+++ b/vllm/model_executor/models/persimmon.py
@@ -19,7 +19,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only persimmon model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -324,8 +324,10 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -358,3 +360,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/phi.py b/vllm/model_executor/models/phi.py
index a2ab0d74c48db..0a117bf16c9b3 100644
--- a/vllm/model_executor/models/phi.py
+++ b/vllm/model_executor/models/phi.py
@@ -34,7 +34,7 @@
 # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
 # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 """Inference-only Phi-1.5 model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -345,7 +345,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -353,6 +354,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("qkv_proj", "v_proj", "v")
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
 
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
@@ -383,3 +385,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/phi3_small.py b/vllm/model_executor/models/phi3_small.py
index 2139cec441807..a78e4d355a314 100644
--- a/vllm/model_executor/models/phi3_small.py
+++ b/vllm/model_executor/models/phi3_small.py
@@ -1,5 +1,5 @@
 import math
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -457,9 +457,11 @@ def sample(
                                    sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -471,3 +473,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             weight_loader = getattr(param, "weight_loader",
                                     default_weight_loader)
             weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/phi3v.py b/vllm/model_executor/models/phi3v.py
index 4db65edc174f1..2e583bb08e87a 100644
--- a/vllm/model_executor/models/phi3v.py
+++ b/vllm/model_executor/models/phi3v.py
@@ -15,7 +15,7 @@
 import itertools
 import re
 from functools import cached_property, lru_cache
-from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional,
+from typing import (Any, Dict, Iterable, List, Literal, Mapping, Optional, Set,
                     Tuple, TypedDict, Union)
 
 import numpy as np
@@ -744,7 +744,8 @@ def pooler(
     ) -> Optional[PoolerOutput]:
         return self._pooler(hidden_states, pooling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         hf_to_vllm_mapper = WeightsMapper(
             orig_to_new_prefix={
                 "model.vision_embed_tokens.wte": "embed_tokens",
@@ -759,5 +760,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
 
         # The HF config doesn't specify whether these are tied,
         # so we detect it this way
-        if "embed_tokens" not in autoloaded_weights:
+        if "embed_tokens.weight" not in autoloaded_weights:
             self.embed_tokens = self.language_model.model.embed_tokens
+            autoloaded_weights.add("embed_tokens.weight")
+        return autoloaded_weights
diff --git a/vllm/model_executor/models/phimoe.py b/vllm/model_executor/models/phimoe.py
index b7e70f8fa2c6d..e475d286bd7ea 100644
--- a/vllm/model_executor/models/phimoe.py
+++ b/vllm/model_executor/models/phimoe.py
@@ -20,7 +20,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only PhiMoE model."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -598,7 +598,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -613,6 +614,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             num_experts=self.config.num_local_experts)
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -666,3 +668,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)
                     weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py
index a3e30ea2dd299..307febde7eef0 100644
--- a/vllm/model_executor/models/pixtral.py
+++ b/vllm/model_executor/models/pixtral.py
@@ -1,7 +1,7 @@
 from dataclasses import dataclass, fields
 from functools import cached_property
 from itertools import tee
-from typing import Iterable, List, Mapping, Optional, Tuple, Union
+from typing import Iterable, List, Mapping, Optional, Set, Tuple, Union
 
 import numpy
 import torch
@@ -1053,7 +1053,8 @@ def forward(
 
     # (TODO) Add prefix argument for filtering out weights to be loaded
     #        ref: https://github.com/vllm-project/vllm/pull/7186#discussion_r1734163986
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             (".qkv_proj", ".q_proj", "q"),
@@ -1063,6 +1064,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             (".gate_up_proj", ".up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         layer_count = len(self.transformer.layers)
 
         for name, loaded_weight in weights:
@@ -1075,8 +1077,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
-
-                param = params_dict[name.replace(weight_name, param_name)]
+                name = name.replace(weight_name, param_name)
+                param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
@@ -1085,3 +1087,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/qwen.py b/vllm/model_executor/models/qwen.py
index 447632cefcd9a..3978c176a2144 100644
--- a/vllm/model_executor/models/qwen.py
+++ b/vllm/model_executor/models/qwen.py
@@ -8,7 +8,7 @@
 import re
 from functools import partial
 from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
-                    Optional, Tuple, TypedDict, Union)
+                    Optional, Set, Tuple, TypedDict, Union)
 
 import numpy as np
 import torch
@@ -964,13 +964,15 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("gate_up_proj", "w2", 0),
             ("gate_up_proj", "w1", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -999,6 +1001,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
 
 class QWenLLM(QWenBaseModel):
diff --git a/vllm/model_executor/models/qwen2.py b/vllm/model_executor/models/qwen2.py
index 8f10df808c216..370cff5fa153f 100644
--- a/vllm/model_executor/models/qwen2.py
+++ b/vllm/model_executor/models/qwen2.py
@@ -21,7 +21,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Qwen2 model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -332,7 +332,8 @@ def forward(
         hidden_states, _ = self.norm(hidden_states, residual)
         return hidden_states
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -342,6 +343,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -372,6 +374,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
 
 class Qwen2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
@@ -494,13 +498,14 @@ def pooler(
     ) -> Optional[PoolerOutput]:
         return self._pooler(hidden_states, pooling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(
             self,
             skip_prefixes=(["lm_head."]
                            if self.config.tie_word_embeddings else None),
         )
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
 
 
 class Qwen2EmbeddingModel(nn.Module, SupportsLoRA, SupportsPP):
@@ -564,7 +569,8 @@ def pooler(
     ) -> Optional[PoolerOutput]:
         return self._pooler(hidden_states, pooling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self,
                                    ignore_unexpected_prefixes=["lm_head."])
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/qwen2_audio.py b/vllm/model_executor/models/qwen2_audio.py
index d30950361ad89..a4965f34b1ca8 100644
--- a/vllm/model_executor/models/qwen2_audio.py
+++ b/vllm/model_executor/models/qwen2_audio.py
@@ -20,7 +20,8 @@
 # limitations under the License.
 """Inference-only Qwen2-Audio model compatible with HuggingFace weights."""
 from functools import lru_cache
-from typing import Iterable, List, Mapping, Optional, Tuple, TypedDict, Union
+from typing import (Iterable, List, Mapping, Optional, Set, Tuple, TypedDict,
+                    Union)
 
 import librosa
 import numpy as np
@@ -420,7 +421,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -430,6 +432,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -463,3 +466,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/qwen2_cls.py b/vllm/model_executor/models/qwen2_cls.py
index 07eb330620a43..dc5dabf6fc38b 100644
--- a/vllm/model_executor/models/qwen2_cls.py
+++ b/vllm/model_executor/models/qwen2_cls.py
@@ -4,7 +4,7 @@
 # Copyright 2024 The Qwen team.
 # Copyright 2023 The vLLM team.
 """Inference-only Qwen2-Classification model compatible with HF weights."""
-from typing import Iterable, List, Optional, Tuple
+from typing import Iterable, List, Optional, Set, Tuple
 
 import torch
 from torch import nn
@@ -97,7 +97,8 @@ def pooler(
     ) -> Optional[PoolerOutput]:
         return self._pooler(hidden_states, pooling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self,
                                    ignore_unexpected_prefixes=["lm_head."])
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/qwen2_moe.py b/vllm/model_executor/models/qwen2_moe.py
index 249d94b5d95e9..96a9bc451f4df 100644
--- a/vllm/model_executor/models/qwen2_moe.py
+++ b/vllm/model_executor/models/qwen2_moe.py
@@ -21,7 +21,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Qwen2MoE model compatible with HuggingFace weights."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 import torch.nn.functional as F
@@ -436,7 +436,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -455,6 +456,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             num_experts=self.config.num_experts)
 
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -532,3 +534,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                     weight_loader = getattr(param, "weight_loader",
                                             default_weight_loader)
                     weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/qwen2_rm.py b/vllm/model_executor/models/qwen2_rm.py
index 6db467af334f5..988d682d36be3 100644
--- a/vllm/model_executor/models/qwen2_rm.py
+++ b/vllm/model_executor/models/qwen2_rm.py
@@ -3,7 +3,7 @@
 # Copyright 2024 The Qwen team.
 # Copyright 2023 The vLLM team.
 """Inference-only Qwen2-RM model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -110,7 +110,8 @@ def pooler(
     ) -> Optional[PoolerOutput]:
         return self._pooler(hidden_states, pooling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         loader = AutoWeightsLoader(self,
                                    ignore_unexpected_prefixes=["lm_head."])
-        loader.load_weights(weights)
+        return loader.load_weights(weights)
diff --git a/vllm/model_executor/models/qwen2_vl.py b/vllm/model_executor/models/qwen2_vl.py
index 2335baf459771..ef6b52db6e17d 100644
--- a/vllm/model_executor/models/qwen2_vl.py
+++ b/vllm/model_executor/models/qwen2_vl.py
@@ -23,7 +23,7 @@
 """Inference-only Qwen2-VL model compatible with HuggingFace weights."""
 from functools import partial
 from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
-                    Optional, Tuple, Type, TypedDict, Union)
+                    Optional, Set, Tuple, Type, TypedDict, Union)
 
 import torch
 import torch.nn as nn
@@ -1333,7 +1333,8 @@ def pooler(
     ) -> Optional[PoolerOutput]:
         return self._pooler(hidden_states, pooling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -1343,6 +1344,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "gate_proj", 0),
         ]
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -1392,3 +1394,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/siglip.py b/vllm/model_executor/models/siglip.py
index acaf4afdecfe5..c9e09b879843a 100644
--- a/vllm/model_executor/models/siglip.py
+++ b/vllm/model_executor/models/siglip.py
@@ -2,7 +2,7 @@
 within a vision language model."""
 
 import math
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import numpy as np
 import torch
@@ -594,7 +594,8 @@ def forward(
             interpolate_pos_encoding=interpolate_pos_encoding,
         )
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -602,6 +603,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("qkv_proj", "v_proj", "v"),
         ] if self.shard_weight else []
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         layer_count = len(self.vision_model.encoder.layers)
 
         for name, loaded_weight in weights:
@@ -619,8 +621,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             for (param_name, weight_name, shard_id) in stacked_params_mapping:
                 if weight_name not in name:
                     continue
+                name = name.replace(weight_name, param_name)
 
-                param = params_dict[name.replace(weight_name, param_name)]
+                param = params_dict[name]
                 weight_loader = param.weight_loader
                 weight_loader(param, loaded_weight, shard_id)
                 break
@@ -629,3 +632,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/solar.py b/vllm/model_executor/models/solar.py
index affb2c975ce4a..6d6fafc5ab0eb 100644
--- a/vllm/model_executor/models/solar.py
+++ b/vllm/model_executor/models/solar.py
@@ -21,7 +21,7 @@
 # limitations under the License.
 """Inference-only Solar model compatible with HuggingFace weights."""
 
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -477,7 +477,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             (".qkv_proj", ".q_proj", "q"),
@@ -487,6 +488,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             (".gate_up_proj", ".up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -502,6 +504,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                                         default_weight_loader)
                 loaded_weight = loaded_weight[0]
                 weight_loader(param, loaded_weight)
+                loaded_params.add(scale_name)
                 continue
             for param_name, weight_name, shard_id in stacked_params_mapping:
                 if weight_name not in name:
@@ -535,6 +538,8 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
 
     # If this function is called, it should always initialize KV cache scale
     # factors (or else raise an exception). Thus, handled exceptions should
diff --git a/vllm/model_executor/models/stablelm.py b/vllm/model_executor/models/stablelm.py
index 99acce596602e..e11d2e916730a 100644
--- a/vllm/model_executor/models/stablelm.py
+++ b/vllm/model_executor/models/stablelm.py
@@ -18,7 +18,7 @@
 # https://huggingface.co/stabilityai/stablelm-3b-4e1t/blob/main/config.json
 """Inference-only StabeLM (https://github.com/Stability-AI/StableLM)
 model compatible with HuggingFace weights."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -306,7 +306,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -316,6 +317,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -347,3 +349,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/starcoder2.py b/vllm/model_executor/models/starcoder2.py
index 0ef940acebb93..74c66042226de 100644
--- a/vllm/model_executor/models/starcoder2.py
+++ b/vllm/model_executor/models/starcoder2.py
@@ -17,7 +17,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """ PyTorch Starcoder2 model."""
-from typing import Iterable, List, Optional, Tuple, Union
+from typing import Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -314,7 +314,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             # (param_name, shard_name, shard_id)
             ("qkv_proj", "q_proj", "q"),
@@ -323,6 +324,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
         ]
 
         params_dict = dict(self.named_parameters(remove_duplicate=False))
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if "rotary_emb.inv_freq" in name:
                 continue
@@ -346,3 +348,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params
diff --git a/vllm/model_executor/models/ultravox.py b/vllm/model_executor/models/ultravox.py
index 9fde22c016de0..512adbc7db35e 100644
--- a/vllm/model_executor/models/ultravox.py
+++ b/vllm/model_executor/models/ultravox.py
@@ -3,7 +3,7 @@
 
 import math
 from functools import cached_property, lru_cache
-from typing import (Iterable, List, Literal, Mapping, Optional, Tuple,
+from typing import (Iterable, List, Literal, Mapping, Optional, Set, Tuple,
                     TypedDict, Union, cast)
 
 import numpy as np
@@ -504,10 +504,11 @@ def sample(
     ) -> Optional[SamplerOutput]:
         return self.language_model.sample(logits, sampling_metadata)
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         hf_to_vllm_mapper = WeightsMapper(
             orig_to_new_prefix={"audio_tower.model.encoder.": "audio_tower."})
 
         loader = AutoWeightsLoader(self,
                                    ignore_unexpected_prefixes=["audio_tower."])
-        loader.load_weights(weights, mapper=hf_to_vllm_mapper)
+        return loader.load_weights(weights, mapper=hf_to_vllm_mapper)
diff --git a/vllm/model_executor/models/utils.py b/vllm/model_executor/models/utils.py
index 1d51885f9094a..7a4fcce95603d 100644
--- a/vllm/model_executor/models/utils.py
+++ b/vllm/model_executor/models/utils.py
@@ -1,7 +1,7 @@
 import itertools
 from dataclasses import dataclass, field
 from typing import (Any, Callable, Dict, Iterable, List, Literal, Mapping,
-                    Optional, Protocol, Tuple, Union, overload)
+                    Optional, Protocol, Set, Tuple, Union, overload)
 
 import torch
 import torch.nn as nn
@@ -172,8 +172,9 @@ def _load_module(
         if module != self.module:
             module_load_weights = getattr(module, "load_weights", None)
             if callable(module_load_weights):
-                module_load_weights(weights)
-                return
+                loaded_params = module_load_weights(weights)
+                yield from map(lambda x: self._get_qualname(base_prefix, x),
+                               loaded_params)
 
         child_modules = dict(module.named_children())
         child_params = dict(module.named_parameters(recurse=False))
@@ -222,11 +223,11 @@ def load_weights(
         weights: Iterable[Tuple[str, torch.Tensor]],
         *,
         mapper: Optional[WeightsMapper] = None,
-    ) -> List[str]:
+    ) -> Set[str]:
         if mapper is not None:
             weights = mapper.apply(weights)
 
-        autoloaded_weights = list(self._load_module("", self.module, weights))
+        autoloaded_weights = set(self._load_module("", self.module, weights))
         return autoloaded_weights
 
 
diff --git a/vllm/model_executor/models/xverse.py b/vllm/model_executor/models/xverse.py
index 51172d8782a70..bc37a997eabb5 100644
--- a/vllm/model_executor/models/xverse.py
+++ b/vllm/model_executor/models/xverse.py
@@ -19,7 +19,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 """Inference-only Xverse model compatible with HuggingFace weights."""
-from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
+from typing import Any, Dict, Iterable, List, Optional, Set, Tuple, Union
 
 import torch
 from torch import nn
@@ -376,7 +376,8 @@ def sample(
         next_tokens = self.sampler(logits, sampling_metadata)
         return next_tokens
 
-    def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
+    def load_weights(self, weights: Iterable[Tuple[str,
+                                                   torch.Tensor]]) -> Set[str]:
         stacked_params_mapping = [
             ("qkv_proj", "q_proj", "q"),
             ("qkv_proj", "k_proj", "k"),
@@ -385,6 +386,7 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
             ("gate_up_proj", "up_proj", 1),
         ]
         params_dict = dict(self.named_parameters())
+        loaded_params: Set[str] = set()
         for name, loaded_weight in weights:
             if ("rotary_emb.inv_freq" in name
                     or "rotary_emb.cos_cached" in name
@@ -413,3 +415,5 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
                 weight_loader = getattr(param, "weight_loader",
                                         default_weight_loader)
                 weight_loader(param, loaded_weight)
+            loaded_params.add(name)
+        return loaded_params