Skip to content

Commit

Permalink
add profile wrapper
Browse files Browse the repository at this point in the history
  • Loading branch information
LucasWilkinson committed Oct 2, 2024
1 parent ad54e8f commit 76a5c3d
Showing 1 changed file with 19 additions and 7 deletions.
26 changes: 19 additions & 7 deletions benchmarks/kernels/benchmark_machete.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import argparse
import copy
import os
import itertools
import math
import pickle as pkl
import time
from itertools import product
from typing import Callable, Iterable, List, Optional, Tuple, Dict
from functools import partial
from dataclasses import dataclass

import pandas as pd
import torch
import torch.utils.benchmark as TBenchmark
from itertools import product
from dataclasses import dataclass
from typing import Callable, Iterable, List, Optional, Tuple

from torch.utils.benchmark import Measurement as TMeasurement
from weight_shapes import WEIGHT_SHAPES

Expand All @@ -30,6 +30,10 @@
DEFAULT_BATCH_SIZES = [1, 16, 32, 64, 128, 256, 512, 1024]
DEFAULT_TP_SIZES = [1]

NVTX_PROFILE = os.environ.get("NVTX_PROFILE", False)

if NVTX_PROFILE:
import nvtx

def terse_type_name(dt):
return {
Expand Down Expand Up @@ -268,8 +272,9 @@ def machete_create_bench_fn(bt: BenchmarkTensors,

def bench_fns(label: str, sub_label: str, description: str,
fns: List[Callable]):
min_run_time = 1
return TBenchmark.Timer(

min_run_time = 1 if not NVTX_PROFILE else 0.1
res = TBenchmark.Timer(
stmt="""
for fn in fns:
fn()
Expand All @@ -282,6 +287,13 @@ def bench_fns(label: str, sub_label: str, description: str,
description=description,
).blocked_autorange(min_run_time=min_run_time)

if NVTX_PROFILE:
with nvtx.annotate("mm-bench"):
with nvtx.annotate(f"{label}|{sub_label}|{description}"):
fns[0]()

return res


_SWEEP_SCHEDULES_RESULTS: Optional[pd.DataFrame] = None
_SWEEP_SCHEDULES_RESULTS_CSV: Optional[str] = None
Expand Down

0 comments on commit 76a5c3d

Please sign in to comment.