diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index 97aae233db105..44f47fac1c1b3 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -201,7 +201,7 @@ steps: - python3 offline_inference_classification.py - python3 offline_inference_embedding.py - python3 offline_inference_scoring.py - - python3 offline_profile.py --model facebook/opt-125m + - python3 offline_profile.py --model facebook/opt-125m run_num_steps --num-steps 2 - label: Prefix Caching Test # 9min mirror_hardwares: [amd] diff --git a/examples/offline_profile.py b/examples/offline_profile.py index 1d415b82cddb6..46afe8aa2604b 100644 --- a/examples/offline_profile.py +++ b/examples/offline_profile.py @@ -4,9 +4,10 @@ import sys from argparse import RawTextHelpFormatter from dataclasses import asdict, dataclass -from typing import Optional +from typing import Any, Dict, Generator, List, Optional, TypeAlias import torch +import tqdm from vllm import LLM, SamplingParams from vllm.engine.arg_utils import EngineArgs @@ -15,16 +16,21 @@ BATCH_SIZE_DEFAULT = 1 PROMPT_LEN_DEFAULT = 256 -OUTPUT_LEN_DEFAULT = 2 @dataclass class ProfileContext: engine_args: EngineArgs prompt_len: int - output_len: int batch_size: int - save_chrome_traces_folder: Optional[str] + + # The profiler can run in 2 modes, + # 1. Run profiler for user specified num_steps + num_steps: Optional[int] = None + # 2. Run profiler until all requests complete + complete_num_requests_per_step: Optional[int] = None + + save_chrome_traces_folder: Optional[str] = None def get_dtype(dtype: str): @@ -34,23 +40,155 @@ def get_dtype(dtype: str): return dtype +OutputLen_NumReqs_Map: TypeAlias = Dict[int, int] +def compute_request_output_lengths(batch_size: int, step_requests: List[int]) \ + -> OutputLen_NumReqs_Map: + """ + Given the number of requests, batch_size, and the number of requests + that each engine-step should process, step_requests, determine the + output lengths of the requests such that step_request is honoured. + + Example: + if batch size = 128 and step_request = [128, 128, 96, 64, 32, 1] + then return, + {2 : 32, 3 : 32, 4 : 32, 5 : 31, 6 : 1}, meaning, + 32 requests should have output length 2, + 32 requests should have output length 3, + 32 requests should have output length 4, + 31 requests should have output length 5, + 1 request should have output length 6. + + Args: + batch_size (int): Number of requests submitted for profile. This is + args.batch_size. + step_requests (List[int]): step_requests[i] is the number of requests + that the ith engine step should process. + + Returns: + OutputLen_NumReqs_Map : A dictionary with output-length as keys and the + number of requests required to have that output-length as values. + """ + ol_nr: OutputLen_NumReqs_Map = {} + + # Number of request that are assigned an output-length + num_reqs_assigned: int = 0 + num_steps: int = len(step_requests) + + # sanity check. The first step (prefill-step), must process all requests. + assert step_requests[0] == batch_size + + # Begin assignments from the last step. + output_length: int = num_steps + for num_requests_at_step in reversed(step_requests): + if num_reqs_assigned == batch_size: + break + + assert num_reqs_assigned < batch_size + + # Remove the number of requests that have been determined + # to participate in this step and beyond. + num_reqs_unassigned_at_step = num_requests_at_step - num_reqs_assigned + assert num_reqs_unassigned_at_step >= 0 + + if num_reqs_unassigned_at_step > 0: + ol_nr[output_length] = num_reqs_unassigned_at_step + num_reqs_assigned += num_reqs_unassigned_at_step + + output_length -= 1 + + # sanity checks. + assert sum(ol_nr.values()) == batch_size, \ + ("Number of requests in output-length assignment does not match " + f"batch-size.\n batch size {batch_size} - " + f"step requests {step_requests} - assignments {ol_nr}") + + # Check that the output-length is in [1, num-steps]. Output length must be + # at least 1 as all requests must participate in the prefill-step. + assert all(ol >= 1 and ol <= num_steps for ol in ol_nr), \ + ("Output lengths of requests should be in range " + f"[1, num-engine-steps].\n batch size {batch_size} - " + f"step requests {step_requests} - assignments {ol_nr}") + + return ol_nr + + +def determine_requests_per_step(context: ProfileContext) -> List[int]: + """ + Determine number of requests each engine step should process. + If context.num_steps is set, then all engine steps process the + same number of requests and the output list is of length + context.num_steps. + + If context.complete_num_requests_per_step is set, then each decode step + processes fewer and fewer requests until there are no requests to process. + In this case, the output list is as big as the number of steps + required to process all requests. + + Args: + context: ProfileContext object. + + Returns: + List[int]: Number of requests to process for all engine-steps. + output[i], contains the number of requests that the ith step + should process. + """ + if context.num_steps: + # All requests must run until num_engine_steps. This implies + # that their output lengths must be equal to num_engine_steps. + return [context.batch_size] * context.num_steps + + assert context.complete_num_requests_per_step and \ + context.complete_num_requests_per_step > 0, \ + (f"Expected a positive complete_num_requests_per_step argument." + f"Instead got {context.complete_num_requests_per_step}") + + # We start dropping after the first decode step. + step_requests = [ + context.batch_size, # prefill + context.batch_size, # decode + ] + + num_running_requests = context.batch_size + num_running_requests -= context.complete_num_requests_per_step + while num_running_requests > 0: + step_requests.append(num_running_requests) + num_running_requests -= context.complete_num_requests_per_step + + if step_requests[-1] != 1: + # have 1 request running at the last step. This is often + # useful + step_requests.append(1) + + return step_requests + + def run_profile(context: ProfileContext, csv_output: Optional[str], json_output: Optional[str]): print("Run profile with:") for key, value in asdict(context).items(): print(f" {key} = {value}") + requests_per_step: List[int] = determine_requests_per_step(context) + + ol_nr: OutputLen_NumReqs_Map = compute_request_output_lengths( + context.batch_size, requests_per_step) + + num_steps_to_profile: int = len(requests_per_step) + max_output_len: int = max(ol_nr.keys()) + assert max_output_len >= 1 + # Create sampling params - sampling_params = SamplingParams(temperature=0.8, - top_p=0.95, - max_tokens=args.output_len, - ignore_eos=True) + sampling_params = SamplingParams( + temperature=0.8, + top_p=0.95, + # max_tokens is set on a per-request basis. + max_tokens=None, + ignore_eos=True) # Create LLM llm = LLM(**asdict(context.engine_args)) batch_size = context.batch_size prompt_len = context.prompt_len - output_len = context.output_len scheduler_config = llm.llm_engine.scheduler_config max_model_len = llm.llm_engine.model_config.max_model_len @@ -65,7 +203,7 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], f"choose a smaller batch size or prompt length, or increase " f"--max-num-batched-tokens") sys.exit(-1) - if batch_size >= max_num_seqs: + if batch_size > max_num_seqs: print( f"ERROR: chosen batch_size ({batch_size}) is larger than " f"max_num_seqs ({max_num_seqs}) and therefore cannot be run in a " @@ -73,16 +211,26 @@ def run_profile(context: ProfileContext, csv_output: Optional[str], sys.exit(-1) print("llm.llm_engine.model_config.max_model_len: ", llm.llm_engine.model_config.max_model_len) - if prompt_len + output_len > llm.llm_engine.model_config.max_model_len: - print( - f"ERROR: chosen prompt_len + output_len ({prompt_len} + " - f"{output_len} = {prompt_len + output_len}) is larger than the " - f"model's max_model_len ({max_model_len}), please choose a smaller " - f"prompt_len or output_len, or increase --max-model-len") + if prompt_len + max_output_len > llm.llm_engine.model_config.max_model_len: + print(f"ERROR: chosen prompt_len + max_output_len ({prompt_len} + " + f"{max_output_len} = {prompt_len + max_output_len}) is larger " + f"than the model's max_model_len ({max_model_len}), please " + f"choose a smaller prompt_len or max_output_len, or increase " + f"--max-model-len") sys.exit(-1) def add_requests(): + + def get_output_len_generator() -> Generator[int, Any, Any]: + for output_len, num_reqs in ol_nr.items(): + for _ in range(num_reqs): + yield output_len + + output_len_generator = get_output_len_generator() for i in range(batch_size): + sampling_params.max_tokens = next(output_len_generator) + assert isinstance(sampling_params.max_tokens, int) + prompt_token_ids = torch.randint( llm.llm_engine.model_config.get_vocab_size(), size=(prompt_len, )).tolist() @@ -110,8 +258,11 @@ def abort_requests(): llm.llm_engine.step() # First step is prefill decode_profs = [] - for x in range(args.output_len - 1): - with layerwise_profile() as decode_prof: + for _ in tqdm.tqdm(range(num_steps_to_profile - 1)): + num_running_seqs = llm.llm_engine.scheduler[ + 0].get_num_unfinished_seq_groups() + with layerwise_profile( + num_running_seqs=num_running_seqs) as decode_prof: llm.llm_engine.step() decode_profs.append(decode_prof) @@ -154,7 +305,8 @@ def abort_requests(): decode_results_list[0].print_summary_table() if csv_output: - csv_filename_base = csv_output.rstrip(".csv") + csv_filename_base = csv_output[:-4] \ + if csv_output.endswith('.csv') else csv_output prefill_results.export_model_stats_table_csv( csv_filename_base + "_prefill_model_table.csv") prefill_results.export_summary_stats_table_csv( @@ -187,10 +339,10 @@ def abort_requests(): for idx, dr in enumerate(decode_results_list): json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() - for idx, dr in enumerate(decode_results_list[1:]): - json_dict[f"decode_{idx + 1}"] = dr.convert_stats_to_dict() - - with open(json_output.rstrip(".json") + ".json", "w+") as f: + # Add .json to json_output filename if it doesn't exist already. + json_output_file = json_output if json_output.endswith( + '.json') else json_output + '.json' + with open(json_output_file, "w+") as f: json.dump(json_dict, f, indent=2) pass @@ -214,7 +366,7 @@ def abort_requests(): python examples/offline_profile.py \\ --model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 \\ --prompt-len 512 --max-num-batched-tokens 8196 --json Llama31-8b-FP8 \\ - --enforce-eager + --enforce-eager run_num_steps -n 2 ``` then you can use various tools to analyze the json output @@ -261,17 +413,41 @@ def abort_requests(): default=BATCH_SIZE_DEFAULT, help=f"Number of requests to run as a single batch, " f"default={BATCH_SIZE_DEFAULT}") - parser.add_argument( - "--output-len", + + subparsers = parser.add_subparsers(dest="cmd") + + run_num_steps_parser = subparsers.add_parser( + "run_num_steps", + help="This variation profiles n engine.step() invocations.") + run_num_steps_parser.add_argument( + '-n', + '--num-steps', type=int, - default=OUTPUT_LEN_DEFAULT, - help="Number of llm steps to run (includes prefill and decode) " - "- default={OUTPUT_LEN_DEFAULT}") + help="Number of engine steps to profile.\n" + "Setting it to 1, profiles only the prefill step.\n" + "Setting it to 2, profiles the prefill and first decode step\n" + "Setting it to 3, profiles the prefill, 1st and 2nd decode steps\n" + "and so on ...") + + run_to_completion_parser = subparsers.add_parser( + "run_to_completion", + help="This variation profiles all the engine.step() invocations" + "until the engine exhausts all submitted requests.") + run_to_completion_parser.add_argument( + '-n', + '--complete-num-requests-per-step', + type=int, + help= + "Complete complete_num_requests_per_step requests every decode step." + "For e.g., with batch_size 128 and complete_num_requests_per_step 32," + "the profiler is run for 6 engine steps, with the steps processing, " + "128, 128, 96, 64, 32, 1 requests respectively.\n" + "Note that we tack-on a one-request step at the end as it is often " + "useful.") EngineArgs.add_cli_args(parser) args = parser.parse_args() - context = ProfileContext( engine_args=EngineArgs.from_cli_args(args), **{ diff --git a/tools/profiler/print_layerwise_table.py b/tools/profiler/print_layerwise_table.py index 081076ad7dbdc..394ca8663e189 100644 --- a/tools/profiler/print_layerwise_table.py +++ b/tools/profiler/print_layerwise_table.py @@ -34,9 +34,10 @@ def get_entries(node, curr_depth=0): "examples/offline_profile.py") parser.add_argument("--phase", type=str, - choices=["prefill", "decode_1"], required=True, - help="The phase to print the table for.") + help="The phase to print the table for. This is either" + "prefill or decode_n, where n is the decode step " + "number") parser.add_argument("--table", type=str, choices=["summary", "model"], @@ -49,6 +50,10 @@ def get_entries(node, curr_depth=0): with open(args.json_trace) as f: profile_data = json.load(f) + assert args.phase in profile_data, \ + (f"Cannot find phase {args.phase} in profile data. Choose one among" + f'{[x for x in profile_data.keys() if "prefill" in x or "decode" in x]}') #noqa + if args.table == "summary": entries_and_depths = flatten_entries( SummaryStatsEntry, profile_data[args.phase]["summary_stats"]) diff --git a/tools/profiler/visualize_layerwise_profile.py b/tools/profiler/visualize_layerwise_profile.py index adc44474aa4c1..da7a28da15c19 100644 --- a/tools/profiler/visualize_layerwise_profile.py +++ b/tools/profiler/visualize_layerwise_profile.py @@ -151,16 +151,31 @@ def is_quant(op_name: str): "scaled_int8_quant" in op_name: return True + # LoRA ops + def is_sgmv_shrink(op_name: str): + return "sgmv_shrink" in op_name + + def is_sgmv_expand(op_name: str): + return "sgmv_expand" in op_name + + def is_bgmv_shrink(op_name: str): + return "bgmv_shrink" in op_name + + def is_bgmv_expand(op_name: str): + return "bgmv_expand" in op_name + + def is_cutlass_gemm_op(op_name: str): + return "void cutlass::Kernel" in op_name or \ + "void cutlass::device_kernel" in op_name + def is_gemm_op(op_name: str): if is_quant(op_name): return False - if "xmma_gemm" in op_name or \ + return is_cutlass_gemm_op(op_name) or \ + "xmma_gemm" in op_name or \ "gemv2T_kernel" in op_name or \ "splitKreduce" in op_name or \ - "void cutlass::Kernel" in op_name or \ - "void cutlass::device_kernel" in op_name or \ - "s16816gemm" in op_name: - return True + "s16816gemm" in op_name def is_elementwise_op(op_name: str): return "elementwise_kernel" in op_name @@ -211,6 +226,18 @@ def is_reduce_kernel(op_name: str): quant_ops = list(filter(lambda x: is_quant(x), ops)) ops = list(filter(lambda x: x not in quant_ops, ops)) + sgmv_shrink_ops = list(filter(lambda x: is_sgmv_shrink(x), ops)) + ops = list(filter(lambda x: x not in sgmv_shrink_ops, ops)) + sgmv_expand_ops = list(filter(lambda x: is_sgmv_expand(x), ops)) + ops = list(filter(lambda x: x not in sgmv_expand_ops, ops)) + bgmv_shrink_ops = list(filter(lambda x: is_bgmv_shrink(x), ops)) + ops = list(filter(lambda x: x not in bgmv_shrink_ops, ops)) + bgmv_expand_ops = list(filter(lambda x: is_bgmv_expand(x), ops)) + ops = list(filter(lambda x: x not in bgmv_expand_ops, ops)) + + cutlass_gemm_ops = list(filter(lambda x: is_cutlass_gemm_op(x), ops)) + ops = list(filter(lambda x: x not in cutlass_gemm_ops, ops)) + gemm_ops = list(filter(lambda x: is_gemm_op(x), ops)) ops = list(filter(lambda x: x not in gemm_ops, ops)) @@ -257,6 +284,24 @@ def is_reduce_kernel(op_name: str): trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1) if len(quant_ops): trace_df['quant_ops'] = trace_df[quant_ops].agg("sum", axis=1) + + if len(sgmv_shrink_ops): + trace_df['sgmv_shrink_ops'] = trace_df[sgmv_shrink_ops].agg("sum", + axis=1) + if len(sgmv_expand_ops): + trace_df['sgmv_expand_ops'] = trace_df[sgmv_expand_ops].agg("sum", + axis=1) + if len(bgmv_shrink_ops): + trace_df['bgmv_shrink_ops'] = trace_df[bgmv_shrink_ops].agg("sum", + axis=1) + if len(bgmv_expand_ops): + trace_df['bgmv_expand_ops'] = trace_df[bgmv_expand_ops].agg("sum", + axis=1) + + if len(cutlass_gemm_ops): + trace_df['cutlass_gemm_ops'] = trace_df[cutlass_gemm_ops].agg("sum", + axis=1) + if len(gemm_ops): trace_df['gemm_ops'] = trace_df[gemm_ops].agg("sum", axis=1) if len(rms_norm_ops): @@ -296,7 +341,9 @@ def is_reduce_kernel(op_name: str): trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum", axis=1) - trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops + + trace_df.drop(attention_ops + quant_ops + sgmv_shrink_ops + + sgmv_expand_ops + bgmv_shrink_ops + bgmv_expand_ops + + cutlass_gemm_ops + gemm_ops + rms_norm_ops + vocab_embed_ops + mem_ops + elementwise_ops + nccl_all_reduce_ops + nccl_gather_ops + nccl_broadcast_ops + nccl_other_ops + cross_device_reduce_1stage_ops + @@ -315,7 +362,14 @@ def plot_trace_df(traces_df: pd.DataFrame, plot_title: str, output: Optional[Path] = None): + def get_phase_description(traces_df: pd.DataFrame, phase: str) -> str: + phase_df = traces_df.query(f'phase == "{phase}"') + descs = phase_df['phase_desc'].to_list() + assert all([desc == descs[0] for desc in descs]) + return descs[0] + phases = traces_df['phase'].unique() + phase_descs = [get_phase_description(traces_df, p) for p in phases] traces_df = traces_df.pivot_table(index="phase", columns="name", values=plot_metric, @@ -324,7 +378,8 @@ def plot_trace_df(traces_df: pd.DataFrame, traces_df = group_trace_by_operations(traces_df) # Make the figure - fig, ax = plt.subplots(1, figsize=(5, 8), sharex=True) + fig_size_x = max(5, len(phases)) + fig, ax = plt.subplots(1, figsize=(fig_size_x, 8), sharex=True) # Draw the stacked bars ops = list(traces_df) @@ -332,7 +387,7 @@ def plot_trace_df(traces_df: pd.DataFrame, for op in ops: values = [traces_df[op][phase] for phase in phases] values = list(map(lambda x: 0.0 if math.isnan(x) else x, values)) - ax.bar(phases, values, label=op, bottom=bottom) + ax.bar(phase_descs, values, label=op, bottom=bottom) bottom = [bottom[j] + values[j] for j in range(len(phases))] # Write the values as text on the bars @@ -390,6 +445,14 @@ def keep_only_top_entries(df: pd.DataFrame, ["name"]] = "others" return df + def get_phase_description(key: str) -> str: + num_running_seqs = profile_json[key]['metadata'][ + 'num_running_seqs'] + if num_running_seqs is not None: + return f"{key}-seqs-{num_running_seqs}" + else: + return key + # Get data for each key traces = list(map(lambda x: get_entries_and_traces(x), step_keys)) @@ -413,6 +476,7 @@ def keep_only_top_entries(df: pd.DataFrame, # Fill in information about the step-keys for trace_df, step_key in zip(trace_dfs, step_keys): trace_df['phase'] = step_key + trace_df['phase_desc'] = get_phase_description(step_key) # Combine all data frames so they can be put in a single plot traces_df = pd.concat(trace_dfs) @@ -426,12 +490,16 @@ def keep_only_top_entries(df: pd.DataFrame, def make_plot_title_suffix(profile_json: dict) -> str: context = profile_json["context"] sparsity = context.get('sparsity', None) - return (f"{context['model']}\n" + run_type = \ + f'Run {context["num_steps"]} steps' if context['num_steps'] else \ + (f'Complete {context["complete_num_requests_per_step"]} per ' + f'step; Run till completion') + return (f"{context['engine_args']['model']}\n" f"Batch={context['batch_size']}, " f"PromptLen={context['prompt_len']}, " - f"OutputLen={context['output_len']}," - f"NumGpus={context['tensor_parallel_size']}" - f"{', Sparsity ' + sparsity if sparsity else ''}") + f"NumGpus={context['engine_args']['tensor_parallel_size']}" + f"{', Sparsity ' + sparsity if sparsity else ''}\n" + f"Run Type: {run_type}") profile_json = None with open(json_trace) as f: diff --git a/vllm/profiler/layerwise_profile.py b/vllm/profiler/layerwise_profile.py index 9d9f427e807f6..33babfebdca1e 100644 --- a/vllm/profiler/layerwise_profile.py +++ b/vllm/profiler/layerwise_profile.py @@ -72,6 +72,9 @@ class LayerwiseProfileResults(profile): _model_stats_tree: List[_StatsTreeNode] = field(init=False) _summary_stats_tree: List[_StatsTreeNode] = field(init=False) + # profile metadata + num_running_seqs: Optional[int] = None + def __post_init__(self): self._build_correlation_map() self._build_module_tree() @@ -127,6 +130,9 @@ def export_summary_stats_table_csv(self, filename: str): def convert_stats_to_dict(self) -> str: return { + "metadata": { + "num_running_seqs": self.num_running_seqs + }, "summary_stats": self._convert_stats_tree_to_dict(self._summary_stats_tree), "model_stats": @@ -338,7 +344,15 @@ def df_traversal(node: _StatsTreeNode, curr_json_list: List[Dict]): class layerwise_profile(profile): - def __init__(self): + def __init__(self, num_running_seqs: Optional[int] = None): + """ + layerwise profile constructor. + + Args: + num_running_seqs (Optional[int], optional): When given, + num_running_seqs will be passed to LayerProfileResults for metadata + update. Defaults to None. + """ super().__init__( activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA], record_shapes=True, @@ -346,9 +360,13 @@ def __init__(self): with_modules=True, experimental_config=_ExperimentalConfig(verbose=True)) + self.num_running_seqs = num_running_seqs + def __enter__(self): return super().__enter__() def __exit__(self, exc_type, exc_val, exc_tb): super().__exit__(exc_type, exc_val, exc_tb) - self.results = LayerwiseProfileResults(self.profiler.kineto_results) + self.results = LayerwiseProfileResults( + self.profiler.kineto_results, + num_running_seqs=self.num_running_seqs)