From d37d3a561cdd9fc36d7d03fc64a7d9f29aceac71 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 18 Dec 2024 11:21:00 -0500 Subject: [PATCH] Add public, nonbatch option for benchmarking (#438) --- benchmarks/decoders/benchmark_decoders.py | 4 ++ .../decoders/benchmark_decoders_library.py | 59 +++++++++++++++++++ 2 files changed, 63 insertions(+) diff --git a/benchmarks/decoders/benchmark_decoders.py b/benchmarks/decoders/benchmark_decoders.py index 10539901..c01fd6fa 100644 --- a/benchmarks/decoders/benchmark_decoders.py +++ b/benchmarks/decoders/benchmark_decoders.py @@ -26,6 +26,7 @@ TorchCodecCoreCompiled, TorchCodecCoreNonBatch, TorchCodecPublic, + TorchCodecPublicNonBatch, TorchVision, ) @@ -49,6 +50,9 @@ class DecoderKind: "TorchCodecCoreCompiled", TorchCodecCoreCompiled ), "torchcodec_public": DecoderKind("TorchCodecPublic", TorchCodecPublic), + "torchcodec_public_nonbatch": DecoderKind( + "TorchCodecPublicNonBatch", TorchCodecPublicNonBatch + ), "torchvision": DecoderKind( # We don't compare against TorchVision's "pyav" backend because it doesn't support # accurate seeks. diff --git a/benchmarks/decoders/benchmark_decoders_library.py b/benchmarks/decoders/benchmark_decoders_library.py index ced55ab6..be60d8e8 100644 --- a/benchmarks/decoders/benchmark_decoders_library.py +++ b/benchmarks/decoders/benchmark_decoders_library.py @@ -342,6 +342,65 @@ def decode_and_resize(self, video_file, pts_list, height, width, device): return frames +class TorchCodecPublicNonBatch(AbstractDecoder): + def __init__(self, num_ffmpeg_threads=None, device="cpu"): + self._num_ffmpeg_threads = num_ffmpeg_threads + self._device = device + + from torchvision.transforms import v2 as transforms_v2 + + self.transforms_v2 = transforms_v2 + + def decode_frames(self, video_file, pts_list): + num_ffmpeg_threads = ( + int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0 + ) + decoder = VideoDecoder( + video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device + ) + + frames = [] + for pts in pts_list: + frame = decoder.get_frame_played_at(pts) + frames.append(frame) + return frames + + def decode_first_n_frames(self, video_file, n): + num_ffmpeg_threads = ( + int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0 + ) + decoder = VideoDecoder( + video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device + ) + frames = [] + count = 0 + for frame in decoder: + frames.append(frame) + count += 1 + if count == n: + break + return frames + + def decode_and_resize(self, video_file, pts_list, height, width, device): + num_ffmpeg_threads = ( + int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 1 + ) + decoder = VideoDecoder( + video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device + ) + + frames = [] + for pts in pts_list: + frame = decoder.get_frame_played_at(pts) + frames.append(frame) + + frames = [ + self.transforms_v2.functional.resize(frame.to(device), (height, width)) + for frame in frames + ] + return frames + + @torch.compile(fullgraph=True, backend="eager") def compiled_seek_and_next(decoder, pts): seek_to_pts(decoder, pts)