Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[misc] CUDA Time Layerwise Profiler #8337

Open
wants to merge 10 commits into
base: main
Choose a base branch
from

Conversation

LucasWilkinson
Copy link
Contributor

@LucasWilkinson LucasWilkinson commented Sep 10, 2024

Layerwise profiler for see how much time is spent on CUDA (GPU kernels) for each module/layer

Example of how to run a profile

python examples/offline_profile.py --model neuralmagic/Meta-Llama-3.1-8B-Instruct-FP8 --batch-size 4 --prompt-len 512 --json Llama31-8b-FP8 --max-num-batched-tokens 8196

Then there are some utilities for looking at the profile breakdown, e.g. to get a summary table of the prefill phase you can run:

$ python tools/profiler/print_layerwise_table.py --json-trace Llama31-8b-FP8.json --phase prefill --table summary
name                                                                             | cuda_time_us | pct_cuda_... | invocations    
================================================================================================================================
LlamaForCausalLM                                                                 |     31788.89 |        97.33 |            1.00
|- LlamaModel                                                                    |     31788.89 |        97.33 |            1.00
|-- VocabParallelEmbedding(weight=bfloat16[128256, 4096])                        |        59.20 |         0.18 |            1.00
|--- void at::native::(anonymous namespace)::indexSelectLargeIndex<c10::BFloa... |        59.20 |         0.18 |            1.00
|-- LlamaDecoderLayer                                                            |     31709.72 |        97.09 |           32.00
|--- RMSNorm(weight=bfloat16[4096])                                              |      1336.33 |         4.09 |           64.00
|---- void vllm::rms_norm_kernel<c10::BFloat16>(c10::BFloat16*, c10::BFloat16... |        26.56 |         0.08 |            1.00
|---- std::enable_if<(((8)>(0)))&&vllm::_typeConvert<c10::BFloat16>::exists, ... |      1309.77 |         4.01 |           63.00
|--- LlamaAttention                                                              |      8511.74 |        26.06 |           32.00
|---- QKVParallelLinear(weight=float8_e4m3fn[4096, 6144], weight_scale=float3... |      3014.17 |         9.23 |           32.00
|----- void vllm::scaled_fp8_quant_kernel<c10::BFloat16>(c10::Float8_e4m3fn*,... |       472.03 |         1.45 |           32.00
|----- void cutlass::device_kernel<(anonymous namespace)::cutlass_3x_gemm<cut... |      2542.14 |         7.78 |           32.00
|---- Llama3RotaryEmbedding                                                      |       892.66 |         2.73 |           32.00
|----- void vllm::rotary_embedding_kernel<c10::BFloat16, true>(long const*, c... |       892.66 |         2.73 |           32.00
|---- Attention                                                                  |      2454.87 |         7.52 |           32.00
|----- void vllm::reshape_and_cache_flash_kernel<__nv_bfloat16, __nv_bfloat16... |       343.77 |         1.05 |           32.00
|----- void flash_fwd_splitkv_kernel<Flash_fwd_kernel_traits<128, 64, 128, 4,... |      1756.16 |         5.38 |           32.00
|----- Memcpy DtoD (Device -> Device)                                            |       354.94 |         1.09 |           32.00
|---- RowParallelLinear(weight=float8_e4m3fn[4096, 4096], weight_scale=float3... |      2150.05 |         6.58 |           32.00
|----- void vllm::scaled_fp8_quant_kernel<c10::BFloat16>(c10::Float8_e4m3fn*,... |       474.97 |         1.45 |           32.00
|----- void cutlass::device_kernel<(anonymous namespace)::cutlass_3x_gemm<cut... |      1675.08 |         5.13 |           32.00
|--- LlamaMLP                                                                    |     21861.65 |        66.94 |           32.00
|---- MergedColumnParallelLinear(weight=float8_e4m3fn[4096, 28672], weight_sc... |     12077.86 |        36.98 |           32.00
|----- void vllm::scaled_fp8_quant_kernel<c10::BFloat16>(c10::Float8_e4m3fn*,... |       483.19 |         1.48 |           32.00
|----- void cutlass::device_kernel<(anonymous namespace)::cutlass_3x_gemm<cut... |     11594.67 |        35.50 |           32.00
|---- SiluAndMul                                                                 |      2882.36 |         8.83 |           32.00
|----- void vllm::act_and_mul_kernel<c10::BFloat16, &(c10::BFloat16 vllm::sil... |      2882.36 |         8.83 |           32.00
|---- RowParallelLinear(weight=float8_e4m3fn[14336, 4096], weight_scale=float... |      6901.43 |        21.13 |           32.00
|----- void vllm::scaled_fp8_quant_kernel<c10::BFloat16>(c10::Float8_e4m3fn*,... |      1268.66 |         3.88 |           32.00
|----- void cutlass::device_kernel<(anonymous namespace)::cutlass_3x_gemm<cut... |      5632.78 |        17.25 |           32.00
|-- RMSNorm(weight=bfloat16[4096])                                               |        19.97 |         0.06 |            1.00
|--- std::enable_if<(((8)>(0)))&&vllm::_typeConvert<c10::BFloat16>::exists, v... |        19.97 |         0.06 |            1.00
LogitsProcessor                                                                  |       360.95 |         1.11 |            1.00
|- void at::native::(anonymous namespace)::indexSelectSmallIndex<c10::BFloat1... |         3.81 |         0.01 |            1.00
|- Memset (Device)                                                               |         1.12 |         0.00 |            1.00
|- sm90_xmma_gemm_bf16bf16_bf16f32_f32_tn_n_tilesize64x128x64_warpgroupsize1x... |       356.03 |         1.09 |            1.00
Sampler                                                                          |       510.01 |         1.56 |            1.00
|- Memcpy HtoD (Pinned -> Device)                                                |        16.67 |         0.05 |            7.00
|- void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kerne... |         4.48 |         0.01 |            1.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl<at... |         4.86 |         0.01 |            1.00
|- at::native::(anonymous namespace)::fill_index_and_segment_kernel(int2*, in... |         3.33 |         0.01 |            1.00
|- Memset (Device)                                                               |        11.94 |         0.04 |            9.00
|- void at_cuda_detail::cub::DeviceRadixSortHistogramKernel<at_cuda_detail::c... |         6.14 |         0.02 |            1.00
|- void at_cuda_detail::cub::DeviceRadixSortExclusiveSumKernel<at_cuda_detail... |         1.89 |         0.01 |            1.00
|- void at_cuda_detail::cub::DeviceRadixSortOnesweepKernel<at_cuda_detail::cu... |        61.31 |         0.19 |            4.00
|- void at_cuda_detail::cub::DeviceRadixSortHistogramKernel<at_cuda_detail::c... |         3.20 |         0.01 |            1.00
|- void at_cuda_detail::cub::DeviceRadixSortExclusiveSumKernel<at_cuda_detail... |         1.54 |         0.00 |            1.00
|- void at_cuda_detail::cub::DeviceRadixSortOnesweepKernel<at_cuda_detail::cu... |        11.10 |         0.03 |            1.00
|- void at::native::(anonymous namespace)::sort_postprocess_kernel<float>(flo... |         6.79 |         0.02 |            1.00
|- Memcpy DtoD (Device -> Device)                                                |         2.75 |         0.01 |            1.00
|- void at::native::unrolled_elementwise_kernel<at::native::direct_copy_kerne... |         1.82 |         0.01 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctorO... |         1.50 |         0.00 |            1.00
|- void at::native::_scatter_gather_elementwise_kernel<128, 4, at::native::_c... |         7.68 |         0.02 |            2.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         3.78 |         0.01 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::(anonymous n... |         4.22 |         0.01 |            2.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<4, float, floa... |        89.31 |         0.27 |            2.00
|- void at::native::tensor_kernel_scan_innermost_dim<float, std::plus<float> ... |       169.34 |         0.52 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::CUDAFunctorO... |         1.54 |         0.00 |            1.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl<at... |         4.90 |         0.01 |            1.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl_no... |         1.82 |         0.01 |            1.00
|- void (anonymous namespace)::elementwise_kernel_with_index<int, at::native:... |         2.46 |         0.01 |            1.00
|- void at::native::_scatter_gather_elementwise_kernel<128, 4, at::native::_c... |        10.53 |         0.03 |            1.00
|- void at::native::(anonymous namespace)::cunn_SoftMaxForward<4, float, floa... |        28.00 |         0.09 |            1.00
|- void at::native::elementwise_kernel<128, 4, at::native::gpu_kernel_impl<at... |         2.30 |         0.01 |            1.00
|- void at::native::index_elementwise_kernel<128, 4, at::native::gpu_index_ke... |         5.02 |         0.02 |            1.00
|- void at::native::(anonymous namespace)::distribution_elementwise_grid_stri... |         4.54 |         0.01 |            1.00
|- void at::native::vectorized_elementwise_kernel<4, at::native::BinaryFuncto... |         3.23 |         0.01 |            1.00
|- void at::native::reduce_kernel<512, 1, at::native::ReduceOp<float, at::nat... |        28.96 |         0.09 |            1.00
|- Memcpy DtoH (Device -> Pageable)                                              |         3.04 |         0.01 |            1.00

Or to view it as a graph you can run:

python tools/profiler/visualize_layerwise_profile.py  --json-trace Llama31-8b-FP8.json --output-directory profile_breakdown --plot-metric pct_cuda_time

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

@LucasWilkinson LucasWilkinson changed the title [WIP, misc] CUDA Time Layerwise Profiler [misc] CUDA Time Layerwise Profiler Sep 16, 2024
@LucasWilkinson LucasWilkinson marked this pull request as ready for review September 16, 2024 16:15
@LucasWilkinson LucasWilkinson force-pushed the varun/main-with-profiler branch 2 times, most recently from 1a0844e to 52aafcf Compare September 17, 2024 15:03
Copy link
Sponsor Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great! Since the script is pretty technically involved and relies on exact attributes to exist, could you add a simple e2e test to run in CI so we can know if torch updates break it?

examples/offline_profile.py Outdated Show resolved Hide resolved
examples/offline_profile.py Outdated Show resolved Hide resolved
LucasWilkinson and others added 2 commits October 7, 2024 13:04
Co-authored-by: Michael Goin <michael@neuralmagic.com>
@LucasWilkinson
Copy link
Contributor Author

LucasWilkinson commented Oct 7, 2024

could you add a simple e2e test to run in CI so we can know if torch updates break it?

what's the easiest way to do this? just add a pytest test or just invoke offline_profile somehow? is there instructions on how to register something with buildkite or all pytest folders already automatically run?

@mgoin added examples test

Copy link
Sponsor Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM and works well!

@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 7, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants