Skip to content

Commit

Permalink
Add real BS & seq_len to profiling (#601)
Browse files Browse the repository at this point in the history
Add real batch size and sequence length to high-level profiling.
  • Loading branch information
kamil-kaczor authored Feb 11, 2025
1 parent 669706e commit e8f66d5
Showing 1 changed file with 12 additions and 3 deletions.
15 changes: 12 additions & 3 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2291,7 +2291,14 @@ def try_revert_dummy_output_tokens():
self.trim_attn_metadata(
broadcast_data["attn_metadata"])
})
with self.profiler.record_event('internal', model_event_name):
profiler_args = {
'real_seq_len': model_input.seq_lens,
'real_batch_size': real_batch_size
}

with self.profiler.record_event('internal',
model_event_name,
args=profiler_args):
hidden_states = self.model.forward(
**execute_model_kwargs,
selected_token_indices=sampling_metadata.
Expand All @@ -2308,7 +2315,8 @@ def try_revert_dummy_output_tokens():
('compute_logits_'
f'{"prompt" if is_prompt else "decode"}_bs'
f'{batch_size}_'
f'seq{seq_len}')):
f'seq{seq_len}'),
args=profiler_args):
if num_steps == 1:
sampling_metadata.selected_token_indices = None
logits = self.model.compute_logits(hidden_states,
Expand All @@ -2325,7 +2333,8 @@ def try_revert_dummy_output_tokens():
'internal', ('sample_'
f'{"prompt" if is_prompt else "decode"}_'
f'bs{batch_size}_'
f'seq{seq_len}')):
f'seq{seq_len}'),
args=profiler_args):
output = self.model.sample(
logits=logits,
sampling_metadata=sampling_metadata,
Expand Down

0 comments on commit e8f66d5

Please sign in to comment.