diff --git a/python/sglang/srt/models/olmo2.py b/python/sglang/srt/models/olmo2.py index f3e1979f849..a8af7bc1a4e 100644 --- a/python/sglang/srt/models/olmo2.py +++ b/python/sglang/srt/models/olmo2.py @@ -64,24 +64,24 @@ def __init__( super().__init__() self.config = config self.hidden_size = config.hidden_size - tp_size = get_tensor_model_parallel_world_size() + self.tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = config.num_attention_heads assert self.hidden_size % self.total_num_heads == 0 - assert self.total_num_heads % tp_size == 0 + assert self.total_num_heads % self.tp_size == 0 - self.num_heads = self.total_num_heads // tp_size + self.num_heads = self.total_num_heads // self.tp_size self.total_num_kv_heads = self.config.num_key_value_heads - if self.total_num_kv_heads >= tp_size: + if self.total_num_kv_heads >= self.tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. - assert self.total_num_kv_heads % tp_size == 0 + assert self.total_num_kv_heads % self.tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. - assert tp_size % self.total_num_kv_heads == 0 - self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) + assert self.tp_size % self.total_num_kv_heads == 0 + self.num_kv_heads = max(1, self.total_num_kv_heads // self.tp_size) self.head_dim = self.hidden_size // self.total_num_heads self.max_position_embeddings = config.max_position_embeddings @@ -343,7 +343,7 @@ def forward( input_embeds=input_embeds, ) return self.logits_processor( - input_ids, hidden_states, self.lm_head.weight, forward_batch + input_ids, hidden_states, self.lm_head, forward_batch ) def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):