Skip to content

Commit

Permalink
[misc] Layerwise profile updates (vllm-project#10242)
Browse files Browse the repository at this point in the history
Signed-off-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
  • Loading branch information
varun-sundar-rabindranath and Varun Sundar Rabindranath authored Dec 16, 2024
1 parent 2ca830d commit efbce85
Show file tree
Hide file tree
Showing 5 changed files with 314 additions and 47 deletions.
2 changes: 1 addition & 1 deletion .buildkite/test-pipeline.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ steps:
- python3 offline_inference_classification.py
- python3 offline_inference_embedding.py
- python3 offline_inference_scoring.py
- python3 offline_profile.py --model facebook/opt-125m
- python3 offline_profile.py --model facebook/opt-125m run_num_steps --num-steps 2

- label: Prefix Caching Test # 9min
mirror_hardwares: [amd]
Expand Down
236 changes: 206 additions & 30 deletions examples/offline_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
import sys
from argparse import RawTextHelpFormatter
from dataclasses import asdict, dataclass
from typing import Optional
from typing import Any, Dict, Generator, List, Optional, TypeAlias

import torch
import tqdm

from vllm import LLM, SamplingParams
from vllm.engine.arg_utils import EngineArgs
Expand All @@ -15,16 +16,21 @@

BATCH_SIZE_DEFAULT = 1
PROMPT_LEN_DEFAULT = 256
OUTPUT_LEN_DEFAULT = 2


@dataclass
class ProfileContext:
engine_args: EngineArgs
prompt_len: int
output_len: int
batch_size: int
save_chrome_traces_folder: Optional[str]

# The profiler can run in 2 modes,
# 1. Run profiler for user specified num_steps
num_steps: Optional[int] = None
# 2. Run profiler until all requests complete
complete_num_requests_per_step: Optional[int] = None

save_chrome_traces_folder: Optional[str] = None


def get_dtype(dtype: str):
Expand All @@ -34,23 +40,155 @@ def get_dtype(dtype: str):
return dtype


OutputLen_NumReqs_Map: TypeAlias = Dict[int, int]
def compute_request_output_lengths(batch_size: int, step_requests: List[int]) \
-> OutputLen_NumReqs_Map:
"""
Given the number of requests, batch_size, and the number of requests
that each engine-step should process, step_requests, determine the
output lengths of the requests such that step_request is honoured.
Example:
if batch size = 128 and step_request = [128, 128, 96, 64, 32, 1]
then return,
{2 : 32, 3 : 32, 4 : 32, 5 : 31, 6 : 1}, meaning,
32 requests should have output length 2,
32 requests should have output length 3,
32 requests should have output length 4,
31 requests should have output length 5,
1 request should have output length 6.
Args:
batch_size (int): Number of requests submitted for profile. This is
args.batch_size.
step_requests (List[int]): step_requests[i] is the number of requests
that the ith engine step should process.
Returns:
OutputLen_NumReqs_Map : A dictionary with output-length as keys and the
number of requests required to have that output-length as values.
"""
ol_nr: OutputLen_NumReqs_Map = {}

# Number of request that are assigned an output-length
num_reqs_assigned: int = 0
num_steps: int = len(step_requests)

# sanity check. The first step (prefill-step), must process all requests.
assert step_requests[0] == batch_size

# Begin assignments from the last step.
output_length: int = num_steps
for num_requests_at_step in reversed(step_requests):
if num_reqs_assigned == batch_size:
break

assert num_reqs_assigned < batch_size

# Remove the number of requests that have been determined
# to participate in this step and beyond.
num_reqs_unassigned_at_step = num_requests_at_step - num_reqs_assigned
assert num_reqs_unassigned_at_step >= 0

if num_reqs_unassigned_at_step > 0:
ol_nr[output_length] = num_reqs_unassigned_at_step
num_reqs_assigned += num_reqs_unassigned_at_step

output_length -= 1

# sanity checks.
assert sum(ol_nr.values()) == batch_size, \
("Number of requests in output-length assignment does not match "
f"batch-size.\n batch size {batch_size} - "
f"step requests {step_requests} - assignments {ol_nr}")

# Check that the output-length is in [1, num-steps]. Output length must be
# at least 1 as all requests must participate in the prefill-step.
assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), \
("Output lengths of requests should be in range "
f"[1, num-engine-steps].\n batch size {batch_size} - "
f"step requests {step_requests} - assignments {ol_nr}")

return ol_nr


def determine_requests_per_step(context: ProfileContext) -> List[int]:
"""
Determine number of requests each engine step should process.
If context.num_steps is set, then all engine steps process the
same number of requests and the output list is of length
context.num_steps.
If context.complete_num_requests_per_step is set, then each decode step
processes fewer and fewer requests until there are no requests to process.
In this case, the output list is as big as the number of steps
required to process all requests.
Args:
context: ProfileContext object.
Returns:
List[int]: Number of requests to process for all engine-steps.
output[i], contains the number of requests that the ith step
should process.
"""
if context.num_steps:
# All requests must run until num_engine_steps. This implies
# that their output lengths must be equal to num_engine_steps.
return [context.batch_size] * context.num_steps

assert context.complete_num_requests_per_step and \
context.complete_num_requests_per_step > 0, \
(f"Expected a positive complete_num_requests_per_step argument."
f"Instead got {context.complete_num_requests_per_step}")

# We start dropping after the first decode step.
step_requests = [
context.batch_size, # prefill
context.batch_size, # decode
]

num_running_requests = context.batch_size
num_running_requests -= context.complete_num_requests_per_step
while num_running_requests > 0:
step_requests.append(num_running_requests)
num_running_requests -= context.complete_num_requests_per_step

if step_requests[-1] != 1:
# have 1 request running at the last step. This is often
# useful
step_requests.append(1)

return step_requests


def run_profile(context: ProfileContext, csv_output: Optional[str],
json_output: Optional[str]):
print("Run profile with:")
for key, value in asdict(context).items():
print(f" {key} = {value}")

requests_per_step: List[int] = determine_requests_per_step(context)

ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths(
context.batch_size, requests_per_step)

num_steps_to_profile: int = len(requests_per_step)
max_output_len: int = max(ol_nr.keys())
assert max_output_len >= 1

# Create sampling params
sampling_params = SamplingParams(temperature=0.8,
top_p=0.95,
max_tokens=args.output_len,
ignore_eos=True)
sampling_params = SamplingParams(
temperature=0.8,
top_p=0.95,
# max_tokens is set on a per-request basis.
max_tokens=None,
ignore_eos=True)

# Create LLM
llm = LLM(**asdict(context.engine_args))
batch_size = context.batch_size
prompt_len = context.prompt_len
output_len = context.output_len

scheduler_config = llm.llm_engine.scheduler_config
max_model_len = llm.llm_engine.model_config.max_model_len
Expand All @@ -65,24 +203,34 @@ def run_profile(context: ProfileContext, csv_output: Optional[str],
f"choose a smaller batch size or prompt length, or increase "
f"--max-num-batched-tokens")
sys.exit(-1)
if batch_size >= max_num_seqs:
if batch_size > max_num_seqs:
print(
f"ERROR: chosen batch_size ({batch_size}) is larger than "
f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a "
f"single profile step, please choose a smaller batch size")
sys.exit(-1)
print("llm.llm_engine.model_config.max_model_len: ",
llm.llm_engine.model_config.max_model_len)
if prompt_len + output_len > llm.llm_engine.model_config.max_model_len:
print(
f"ERROR: chosen prompt_len + output_len ({prompt_len} + "
f"{output_len} = {prompt_len + output_len}) is larger than the "
f"model's max_model_len ({max_model_len}), please choose a smaller "
f"prompt_len or output_len, or increase --max-model-len")
if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len:
print(f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + "
f"{max_output_len} = {prompt_len + max_output_len}) is larger "
f"than the model's max_model_len ({max_model_len}), please "
f"choose a smaller prompt_len or max_output_len, or increase "
f"--max-model-len")
sys.exit(-1)

def add_requests():

def get_output_len_generator() -> Generator[int, Any, Any]:
for output_len, num_reqs in ol_nr.items():
for _ in range(num_reqs):
yield output_len

output_len_generator = get_output_len_generator()
for i in range(batch_size):
sampling_params.max_tokens = next(output_len_generator)
assert isinstance(sampling_params.max_tokens, int)

prompt_token_ids = torch.randint(
llm.llm_engine.model_config.get_vocab_size(),
size=(prompt_len, )).tolist()
Expand Down Expand Up @@ -110,8 +258,11 @@ def abort_requests():
llm.llm_engine.step() # First step is prefill

decode_profs = []
for x in range(args.output_len - 1):
with layerwise_profile() as decode_prof:
for _ in tqdm.tqdm(range(num_steps_to_profile - 1)):
num_running_seqs = llm.llm_engine.scheduler[
0].get_num_unfinished_seq_groups()
with layerwise_profile(
num_running_seqs=num_running_seqs) as decode_prof:
llm.llm_engine.step()
decode_profs.append(decode_prof)

Expand Down Expand Up @@ -154,7 +305,8 @@ def abort_requests():
decode_results_list[0].print_summary_table()

if csv_output:
csv_filename_base = csv_output.rstrip(".csv")
csv_filename_base = csv_output[:-4] \
if csv_output.endswith('.csv') else csv_output
prefill_results.export_model_stats_table_csv(
csv_filename_base + "_prefill_model_table.csv")
prefill_results.export_summary_stats_table_csv(
Expand Down Expand Up @@ -187,10 +339,10 @@ def abort_requests():
for idx, dr in enumerate(decode_results_list):
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()

for idx, dr in enumerate(decode_results_list[1:]):
json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict()

with open(json_output.rstrip(".json") + ".json", "w+") as f:
# Add .json to json_output filename if it doesn't exist already.
json_output_file = json_output if json_output.endswith(
'.json') else json_output + '.json'
with open(json_output_file, "w+") as f:
json.dump(json_dict, f, indent=2)
pass

Expand All @@ -214,7 +366,7 @@ def abort_requests():
python examples/offline_profile.py \\
--model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\
--prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\
--enforce-eager
--enforce-eager run_num_steps -n 2
```
then you can use various tools to analyze the json output
Expand Down Expand Up @@ -261,17 +413,41 @@ def abort_requests():
default=BATCH_SIZE_DEFAULT,
help=f"Number of requests to run as a single batch, "
f"default={BATCH_SIZE_DEFAULT}")
parser.add_argument(
"--output-len",

subparsers = parser.add_subparsers(dest="cmd")

run_num_steps_parser = subparsers.add_parser(
"run_num_steps",
help="This variation profiles n engine.step() invocations.")
run_num_steps_parser.add_argument(
'-n',
'--num-steps',
type=int,
default=OUTPUT_LEN_DEFAULT,
help="Number of llm steps to run (includes prefill and decode) "
"- default={OUTPUT_LEN_DEFAULT}")
help="Number of engine steps to profile.\n"
"Setting it to 1, profiles only the prefill step.\n"
"Setting it to 2, profiles the prefill and first decode step\n"
"Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n"
"and so on ...")

run_to_completion_parser = subparsers.add_parser(
"run_to_completion",
help="This variation profiles all the engine.step() invocations"
"until the engine exhausts all submitted requests.")
run_to_completion_parser.add_argument(
'-n',
'--complete-num-requests-per-step',
type=int,
help=
"Complete complete_num_requests_per_step requests every decode step."
"For e.g., with batch_size 128 and complete_num_requests_per_step 32,"
"the profiler is run for 6 engine steps, with the steps processing, "
"128, 128, 96, 64, 32, 1 requests respectively.\n"
"Note that we tack-on a one-request step at the end as it is often "
"useful.")

EngineArgs.add_cli_args(parser)

args = parser.parse_args()

context = ProfileContext(
engine_args=EngineArgs.from_cli_args(args),
**{
Expand Down
9 changes: 7 additions & 2 deletions tools/profiler/print_layerwise_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ def get_entries(node, curr_depth=0):
"examples/offline_profile.py")
parser.add_argument("--phase",
type=str,
choices=["prefill", "decode_1"],
required=True,
help="The phase to print the table for.")
help="The phase to print the table for. This is either"
"prefill or decode_n, where n is the decode step "
"number")
parser.add_argument("--table",
type=str,
choices=["summary", "model"],
Expand All @@ -49,6 +50,10 @@ def get_entries(node, curr_depth=0):
with open(args.json_trace) as f:
profile_data = json.load(f)

assert args.phase in profile_data, \
(f"Cannot find phase {args.phase} in profile data. Choose one among"
f'{[x for x in profile_data.keys() if "prefill" in x or "decode" in x]}') #noqa

if args.table == "summary":
entries_and_depths = flatten_entries(
SummaryStatsEntry, profile_data[args.phase]["summary_stats"])
Expand Down
Loading

0 comments on commit efbce85

Please sign in to comment.