Skip to content

Commit

Permalink
add profile wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Sep 4, 2024
1 parent c1ee58b commit 358a67a
Showing 1 changed file with 18 additions and 5 deletions.
23 changes: 18 additions & 5 deletions benchmarks/kernels/benchmark_machete.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,15 @@
import argparse
import copy
import os
import itertools
import math
import pickle as pkl
import time
from dataclasses import dataclass
from typing import Callable, Dict, Iterable, List, Optional, Tuple

import torch
import torch.utils.benchmark as TBenchmark
from dataclasses import dataclass
from typing import Callable, Iterable, List, Optional, Tuple

from torch.utils.benchmark import Measurement as TMeasurement
from weight_shapes import WEIGHT_SHAPES

Expand All @@ -27,6 +28,10 @@
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
DEFAULT_TP_SIZES = [1]

NVTX_PROFILE = os.environ.get("NVTX_PROFILE", False)

if NVTX_PROFILE:
import nvtx

def terse_type_name(dt):
return {
Expand Down Expand Up @@ -265,8 +270,9 @@ def machete_create_bench_fn(bt: BenchmarkTensors,

def bench_fns(label: str, sub_label: str, description: str,
fns: List[Callable]):
min_run_time = 1
return TBenchmark.Timer(

min_run_time = 1 if not NVTX_PROFILE else 0.1
res = TBenchmark.Timer(
stmt="""
for fn in fns:
fn()
Expand All @@ -279,6 +285,13 @@ def bench_fns(label: str, sub_label: str, description: str,
description=description,
).blocked_autorange(min_run_time=min_run_time)

if NVTX_PROFILE:
with nvtx.annotate("mm-bench"):
with nvtx.annotate(f"{label}|{sub_label}|{description}"):
fns[0]()

return res


def bench(types: TypeConfig,
group_size: int,
Expand Down

0 comments on commit 358a67a

Please sign in to comment.