Skip to content

Commit

Permalink
Resolve sampler benchmark variability with setting random seed
Browse files Browse the repository at this point in the history
  • Loading branch information
scotts committed Nov 6, 2024
1 parent 8ac81b7 commit c66afd7
Showing 1 changed file with 20 additions and 9 deletions.
29 changes: 20 additions & 9 deletions benchmarks/samplers/benchmark_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,20 @@
clips_at_regular_timestamps,
)

DEFAULT_VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4"
DEFAULT_NUM_EXP = 30
DEFAULT_TORCH_SEED = 0

def bench(f, *args, num_exp=100, warmup=0, **kwargs):

def bench(f, *args, num_exp, warmup=0, seed, **kwargs):

for _ in range(warmup):
f(*args, **kwargs)

num_frames = None
times = []
for _ in range(num_exp):
torch.manual_seed(seed)
start = perf_counter_ns()
clips = f(*args, **kwargs)
end = perf_counter_ns()
Expand Down Expand Up @@ -54,8 +59,7 @@ def sample(decoder, sampler, **kwargs):
)


def run_sampler_benchmarks(device, video):
NUM_EXP = 30
def run_sampler_benchmarks(device, video, num_experiments, torch_seed):

for num_clips in (1, 50):
print("-" * 10)
Expand All @@ -68,8 +72,9 @@ def run_sampler_benchmarks(device, video):
decoder,
clips_at_random_indices,
num_clips=num_clips,
num_exp=NUM_EXP,
num_exp=num_experiments,
warmup=2,
seed=torch_seed,
)
report_stats(times, num_frames, unit="ms")

Expand All @@ -79,8 +84,9 @@ def run_sampler_benchmarks(device, video):
decoder,
clips_at_regular_indices,
num_clips=num_clips,
num_exp=NUM_EXP,
num_exp=num_experiments,
warmup=2,
seed=torch_seed,
)
report_stats(times, num_frames, unit="ms")

Expand All @@ -90,8 +96,9 @@ def run_sampler_benchmarks(device, video):
decoder,
clips_at_random_timestamps,
num_clips=num_clips,
num_exp=NUM_EXP,
num_exp=num_experiments,
warmup=2,
seed=torch_seed,
)
report_stats(times, num_frames, unit="ms")

Expand All @@ -102,19 +109,23 @@ def run_sampler_benchmarks(device, video):
decoder,
clips_at_regular_timestamps,
seconds_between_clip_starts=seconds_between_clip_starts,
num_exp=NUM_EXP,
num_exp=num_experiments,
warmup=2,
seed=torch_seed,
)
report_stats(times, num_frames, unit="ms")


def main():
DEFAULT_VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4"
parser = argparse.ArgumentParser()
parser.add_argument("--device", type=str, default="cpu")
parser.add_argument("--video", type=str, default=str(DEFAULT_VIDEO_PATH))
parser.add_argument("--num_experiments", type=int, default=DEFAULT_NUM_EXP)
parser.add_argument("--torch_seed", type=int, default=DEFAULT_TORCH_SEED)
args = parser.parse_args()
run_sampler_benchmarks(args.device, args.video)
run_sampler_benchmarks(
args.device, args.video, args.num_experiments, args.torch_seed
)


if __name__ == "__main__":
Expand Down

0 comments on commit c66afd7

Please sign in to comment.