From 911a3bcff650dadcbbee69fcfe7c8c14e4814bf3 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 20 Dec 2024 13:43:51 -0800 Subject: [PATCH] Consolidated logic for timestamp batch. Big perf win. --- .../decoders/benchmark_decoders_library.py | 2 +- .../decoders/_core/VideoDecoder.cpp | 95 +++++++++---------- src/torchcodec/decoders/_core/VideoDecoder.h | 3 + 3 files changed, 50 insertions(+), 50 deletions(-) diff --git a/benchmarks/decoders/benchmark_decoders_library.py b/benchmarks/decoders/benchmark_decoders_library.py index e8fcd7d3..b0b69dfd 100644 --- a/benchmarks/decoders/benchmark_decoders_library.py +++ b/benchmarks/decoders/benchmark_decoders_library.py @@ -349,7 +349,7 @@ def decode_and_resize(self, video_file, pts_list, height, width, device): class TorchCodecPublicNonBatch(AbstractDecoder): - def __init__(self, num_ffmpeg_threads=None, device="cpu", seek_mode="exact"): + def __init__(self, num_ffmpeg_threads=None, device="cpu", seek_mode="approximate"): self._num_ffmpeg_threads = num_ffmpeg_threads self._device = device self._seek_mode = seek_mode diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 9cc1dd46..07e25739 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1157,6 +1157,28 @@ int64_t VideoDecoder::getFramesSize( } } +double VideoDecoder::getMinSeconds(const StreamMetadata& streamMetadata) { + switch (seekMode_) { + case SeekMode::exact: + return streamMetadata.minPtsSecondsFromScan.value(); + case SeekMode::approximate: + return 0; + default: + throw std::runtime_error("Unknown SeekMode"); + } +} + +double VideoDecoder::getMaxSeconds(const StreamMetadata& streamMetadata) { + switch (seekMode_) { + case SeekMode::exact: + return streamMetadata.maxPtsSecondsFromScan.value(); + case SeekMode::approximate: + return streamMetadata.durationSeconds.value(); + default: + throw std::runtime_error("Unknown SeekMode"); + } +} + VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( int streamIndex, int64_t frameIndex, @@ -1238,24 +1260,25 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( const auto& streamMetadata = containerMetadata_.streams[streamIndex]; const auto& stream = streams_[streamIndex]; - if (seekMode_ == SeekMode::exact) { - double minSeconds = streamMetadata.minPtsSecondsFromScan.value(); - double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value(); + double minSeconds = getMinSeconds(streamMetadata); + double maxSeconds = getMaxSeconds(streamMetadata); - // The frame played at timestamp t and the one played at timestamp `t + - // eps` are probably the same frame, with the same index. The easiest way to - // avoid decoding that unique frame twice is to convert the input timestamps - // to indices, and leverage the de-duplication logic of getFramesAtIndices. + // The frame played at timestamp t and the one played at timestamp `t + + // eps` are probably the same frame, with the same index. The easiest way to + // avoid decoding that unique frame twice is to convert the input timestamps + // to indices, and leverage the de-duplication logic of getFramesAtIndices. - std::vector frameIndices(timestamps.size()); - for (auto i = 0; i < timestamps.size(); ++i) { - auto frameSeconds = timestamps[i]; - TORCH_CHECK( - frameSeconds >= minSeconds && frameSeconds < maxSeconds, - "frame pts is " + std::to_string(frameSeconds) + - "; must be in range [" + std::to_string(minSeconds) + ", " + - std::to_string(maxSeconds) + ")."); + std::vector frameIndices(timestamps.size()); + for (auto i = 0; i < timestamps.size(); ++i) { + auto frameSeconds = timestamps[i]; + TORCH_CHECK( + frameSeconds >= minSeconds && frameSeconds < maxSeconds, + "frame pts is " + std::to_string(frameSeconds) + + "; must be in range [" + std::to_string(minSeconds) + ", " + + std::to_string(maxSeconds) + ")."); + int64_t frameIndex = -1; + if (seekMode_ == SeekMode::exact) { auto it = std::lower_bound( stream.allFrames.begin(), stream.allFrames.end(), @@ -1263,43 +1286,17 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( [&stream](const FrameInfo& info, double frameSeconds) { return ptsToSeconds(info.nextPts, stream.timeBase) <= frameSeconds; }); - int64_t frameIndex = it - stream.allFrames.begin(); - frameIndices[i] = frameIndex; - } - - return getFramesAtIndices(streamIndex, frameIndices); - - } else if (seekMode_ == SeekMode::approximate) { - double minSeconds = 0; - double maxSeconds = streamMetadata.durationSeconds.value(); - - // TODO: Figure out if we can be smarter than just iterating over the - // timestamps one-by-one. - - const auto& streamMetadata = containerMetadata_.streams[streamIndex]; - const auto& options = stream.options; - BatchDecodedOutput output(timestamps.size(), options, streamMetadata); - - for (auto i = 0; i < timestamps.size(); ++i) { - auto frameSeconds = timestamps[i]; - TORCH_CHECK( - frameSeconds >= minSeconds && frameSeconds < maxSeconds, - "frame pts is " + std::to_string(frameSeconds) + - "; must be in range [" + std::to_string(minSeconds) + ", " + - std::to_string(maxSeconds) + ")."); - - DecodedOutput singleOut = getFramePlayedAtTimestampNoDemuxInternal( - frameSeconds, output.frames[i]); - output.ptsSeconds[i] = singleOut.ptsSeconds; - output.durationSeconds[i] = singleOut.durationSeconds; + frameIndex = it - stream.allFrames.begin(); + } else if (seekMode_ == SeekMode::approximate) { + frameIndex = std::floor(frameSeconds * streamMetadata.averageFps.value()); + } else { + throw std::runtime_error("Unknown SeekMode"); } - output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); - return output; - - } else { - throw std::runtime_error("Unknown SeekMode"); + frameIndices[i] = frameIndex; } + + return getFramesAtIndices(streamIndex, frameIndices); } VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 5726b8c5..68bb4305 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -398,6 +398,9 @@ class VideoDecoder { const StreamMetadata& streamMetadata, int64_t frameIndex); + double getMinSeconds(const StreamMetadata& streamMetadata); + double getMaxSeconds(const StreamMetadata& streamMetadata); + void createSwsContext( StreamInfo& streamInfo, const DecodedFrameContext& frameContext,