From 76a5c3d15046de7c0e274c7b920c6edf1759c282 Mon Sep 17 00:00:00 2001 From: Lucas Wilkinson Date: Wed, 4 Sep 2024 17:07:01 +0000 Subject: [PATCH] add profile wrapper --- benchmarks/kernels/benchmark_machete.py | 26 ++++++++++++++++++------- 1 file changed, 19 insertions(+), 7 deletions(-) diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 0b9eefc3f66d..590ed5ad82fb 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -1,17 +1,17 @@ import argparse import copy +import os import itertools import math import pickle as pkl import time -from itertools import product -from typing import Callable, Iterable, List, Optional, Tuple, Dict -from functools import partial -from dataclasses import dataclass - import pandas as pd import torch import torch.utils.benchmark as TBenchmark +from itertools import product +from dataclasses import dataclass +from typing import Callable, Iterable, List, Optional, Tuple + from torch.utils.benchmark import Measurement as TMeasurement from weight_shapes import WEIGHT_SHAPES @@ -30,6 +30,10 @@ DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024] DEFAULT_TP_SIZES = [1] +NVTX_PROFILE = os.environ.get("NVTX_PROFILE", False) + +if NVTX_PROFILE: + import nvtx def terse_type_name(dt): return { @@ -268,8 +272,9 @@ def machete_create_bench_fn(bt: BenchmarkTensors, def bench_fns(label: str, sub_label: str, description: str, fns: List[Callable]): - min_run_time = 1 - return TBenchmark.Timer( + + min_run_time = 1 if not NVTX_PROFILE else 0.1 + res = TBenchmark.Timer( stmt=""" for fn in fns: fn() @@ -282,6 +287,13 @@ def bench_fns(label: str, sub_label: str, description: str, description=description, ).blocked_autorange(min_run_time=min_run_time) + if NVTX_PROFILE: + with nvtx.annotate("mm-bench"): + with nvtx.annotate(f"{label}|{sub_label}|{description}"): + fns[0]() + + return res + _SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None _SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None