Skip to content

Commit

Permalink
add reduce kernel grouping
Browse files Browse the repository at this point in the history
  • Loading branch information
Varun Sundar Rabindranath authored and LucasWilkinson committed Sep 17, 2024
1 parent a1271fa commit 52aafcf
Showing 1 changed file with 91 additions and 4 deletions.
95 changes: 91 additions & 4 deletions vllm/profiler/visualize_layerwise_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,36 @@ def is_mem_op(op_name: str):
def is_vocab_embedding_op(op_name: str):
return "vocabparallelembed" in op_name.lower()

# nccl ops
def is_nccl_op(op_name: str):
return "nccl" in op_name.lower()

def is_nccl_all_reduce(op_name: str):
return is_nccl_op(op_name) and \
("all_reduce" in op_name.lower() or \
"allreduce" in op_name.lower())

def is_nccl_gather(op_name: str):
return is_nccl_op(op_name) and \
"gather" in op_name.lower()

def is_nccl_broadcast(op_name: str):
return is_nccl_op(op_name) and \
"broadcast" in op_name.lower()

# Reduce ops types
def is_cross_device_reduce_1stage(op_name: str):
return "cross_device_reduce_1stage" in op_name

def is_cross_device_reduce_2stage(op_name: str):
return "cross_device_reduce_2stage" in op_name

def is_custom_ar_all_reduce_unreg(op_name: str):
return "_C_custom_ar::all_reduce_unreg" in op_name

def is_reduce_kernel(op_name: str):
return "reduce_kernel" in op_name

headers = list(trace_df)
ops = copy.deepcopy(headers)

Expand All @@ -196,6 +226,33 @@ def is_vocab_embedding_op(op_name: str):
elementwise_ops = list(filter(lambda x: is_elementwise_op(x), ops))
ops = list(filter(lambda x: x not in elementwise_ops, ops))

nccl_all_reduce_ops = list(filter(lambda x: is_nccl_all_reduce(x), ops))
ops = list(filter(lambda x: x not in nccl_all_reduce_ops, ops))

nccl_gather_ops = list(filter(lambda x: is_nccl_gather(x), ops))
ops = list(filter(lambda x: x not in nccl_gather_ops, ops))

nccl_broadcast_ops = list(filter(lambda x: is_nccl_broadcast(x), ops))
ops = list(filter(lambda x: x not in nccl_broadcast_ops, ops))

nccl_other_ops = list(filter(lambda x: is_nccl_op(x), ops))
ops = list(filter(lambda x: x not in nccl_other_ops, ops))

cross_device_reduce_1stage_ops = list(
filter(lambda x: is_cross_device_reduce_1stage(x), ops))
ops = list(filter(lambda x: x not in cross_device_reduce_1stage_ops, ops))

cross_device_reduce_2stage_ops = list(
filter(lambda x: is_cross_device_reduce_2stage(x), ops))
ops = list(filter(lambda x: x not in cross_device_reduce_2stage_ops, ops))

custom_ar_all_reduce_unreg_ops = list(
filter(lambda x: is_custom_ar_all_reduce_unreg(x), ops))
ops = list(filter(lambda x: x not in custom_ar_all_reduce_unreg_ops, ops))

reduce_kernel_ops = list(filter(lambda x: is_reduce_kernel(x), ops))
ops = list(filter(lambda x: x not in reduce_kernel_ops, ops))

if len(attention_ops):
trace_df['attention'] = trace_df[attention_ops].agg("sum", axis=1)
if len(quant_ops):
Expand All @@ -213,10 +270,40 @@ def is_vocab_embedding_op(op_name: str):
trace_df['elementwise_ops'] = trace_df[elementwise_ops].agg("sum",
axis=1)

trace_df.drop(attention_ops + quant_ops + gemm_ops + rms_norm_ops +
vocab_embed_ops + mem_ops + elementwise_ops,
axis=1,
inplace=True)
if len(nccl_all_reduce_ops):
trace_df['nccl_all_reduce_ops'] = trace_df[nccl_all_reduce_ops].agg(
"sum", axis=1)
if len(nccl_gather_ops):
trace_df['nccl_gather_ops'] = trace_df[nccl_gather_ops].agg("sum",
axis=1)
if len(nccl_broadcast_ops):
trace_df['nccl_broadcast_ops'] = trace_df[nccl_broadcast_ops].agg(
"sum", axis=1)
if len(nccl_other_ops):
trace_df['nccl_other_ops'] = trace_df[nccl_other_ops].agg("sum",
axis=1)

if len(cross_device_reduce_1stage_ops):
trace_df['cross_device_reduce_1stage_ops'] = trace_df[
cross_device_reduce_1stage_ops].agg("sum", axis=1)
if len(cross_device_reduce_2stage_ops):
trace_df['cross_device_reduce_2stage_ops'] = trace_df[
cross_device_reduce_2stage_ops].agg("sum", axis=1)
if len(custom_ar_all_reduce_unreg_ops):
trace_df['custom_ar_all_reduce_unreg_ops'] = trace_df[
custom_ar_all_reduce_unreg_ops].agg("sum", axis=1)
if len(reduce_kernel_ops):
trace_df['reduce_kernel_ops'] = trace_df[reduce_kernel_ops].agg("sum",
axis=1)

trace_df.drop(
attention_ops + quant_ops + gemm_ops + rms_norm_ops + vocab_embed_ops +
mem_ops + elementwise_ops + nccl_all_reduce_ops + nccl_gather_ops +
nccl_broadcast_ops + nccl_other_ops + cross_device_reduce_1stage_ops +
cross_device_reduce_2stage_ops + custom_ar_all_reduce_unreg_ops +
reduce_kernel_ops,
axis=1,
inplace=True)
return trace_df


Expand Down

0 comments on commit 52aafcf

Please sign in to comment.