diff --git a/benchmarks/samplers/benchmark_samplers.py b/benchmarks/samplers/benchmark_samplers.py index b9b27bc2..bf1a780f 100644 --- a/benchmarks/samplers/benchmark_samplers.py +++ b/benchmarks/samplers/benchmark_samplers.py @@ -11,8 +11,12 @@ clips_at_regular_timestamps, ) +DEFAULT_VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4" +DEFAULT_NUM_EXP = 30 +DEFAULT_TORCH_SEED = 0 -def bench(f, *args, num_exp=100, warmup=0, **kwargs): + +def bench(f, *args, num_exp, warmup=0, seed, **kwargs): for _ in range(warmup): f(*args, **kwargs) @@ -20,6 +24,7 @@ def bench(f, *args, num_exp=100, warmup=0, **kwargs): num_frames = None times = [] for _ in range(num_exp): + torch.manual_seed(seed) start = perf_counter_ns() clips = f(*args, **kwargs) end = perf_counter_ns() @@ -54,8 +59,7 @@ def sample(decoder, sampler, **kwargs): ) -def run_sampler_benchmarks(device, video): - NUM_EXP = 30 +def run_sampler_benchmarks(device, video, num_experiments, torch_seed): for num_clips in (1, 50): print("-" * 10) @@ -68,8 +72,9 @@ def run_sampler_benchmarks(device, video): decoder, clips_at_random_indices, num_clips=num_clips, - num_exp=NUM_EXP, + num_exp=num_experiments, warmup=2, + seed=torch_seed, ) report_stats(times, num_frames, unit="ms") @@ -79,8 +84,9 @@ def run_sampler_benchmarks(device, video): decoder, clips_at_regular_indices, num_clips=num_clips, - num_exp=NUM_EXP, + num_exp=num_experiments, warmup=2, + seed=torch_seed, ) report_stats(times, num_frames, unit="ms") @@ -90,8 +96,9 @@ def run_sampler_benchmarks(device, video): decoder, clips_at_random_timestamps, num_clips=num_clips, - num_exp=NUM_EXP, + num_exp=num_experiments, warmup=2, + seed=torch_seed, ) report_stats(times, num_frames, unit="ms") @@ -102,19 +109,23 @@ def run_sampler_benchmarks(device, video): decoder, clips_at_regular_timestamps, seconds_between_clip_starts=seconds_between_clip_starts, - num_exp=NUM_EXP, + num_exp=num_experiments, warmup=2, + seed=torch_seed, ) report_stats(times, num_frames, unit="ms") def main(): - DEFAULT_VIDEO_PATH = Path(__file__).parent / "../../test/resources/nasa_13013.mp4" parser = argparse.ArgumentParser() parser.add_argument("--device", type=str, default="cpu") parser.add_argument("--video", type=str, default=str(DEFAULT_VIDEO_PATH)) + parser.add_argument("--num_experiments", type=int, default=DEFAULT_NUM_EXP) + parser.add_argument("--torch_seed", type=int, default=DEFAULT_TORCH_SEED) args = parser.parse_args() - run_sampler_benchmarks(args.device, args.video) + run_sampler_benchmarks( + args.device, args.video, args.num_experiments, args.torch_seed + ) if __name__ == "__main__":