From 0ad3f010df90115aadf38138867cdd80ad0c2e20 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 22 Jan 2025 11:41:44 -0500 Subject: [PATCH] Implement approximate mode (#440) Co-authored-by: Nicolas Hug --- .../decoders/benchmark_decoders_library.py | 54 ++- .../decoders/_core/VideoDecoder.cpp | 324 ++++++++++++------ src/torchcodec/decoders/_core/VideoDecoder.h | 46 ++- .../decoders/_core/VideoDecoderOps.cpp | 53 ++- .../decoders/_core/VideoDecoderOps.h | 13 +- src/torchcodec/decoders/_core/_metadata.py | 57 ++- .../decoders/_core/video_decoder_ops.py | 12 +- src/torchcodec/decoders/_video_decoder.py | 25 +- src/torchcodec/samplers/_time_based.py | 11 +- test/decoders/VideoDecoderTest.cpp | 6 +- test/decoders/test_metadata.py | 17 +- test/decoders/test_video_decoder.py | 112 +++--- test/decoders/test_video_decoder_ops.py | 15 - test/samplers/test_samplers.py | 15 +- 14 files changed, 514 insertions(+), 246 deletions(-) diff --git a/benchmarks/decoders/benchmark_decoders_library.py b/benchmarks/decoders/benchmark_decoders_library.py index be60d8e8..b0b69dfd 100644 --- a/benchmarks/decoders/benchmark_decoders_library.py +++ b/benchmarks/decoders/benchmark_decoders_library.py @@ -22,7 +22,6 @@ get_frames_by_pts, get_json_metadata, get_next_frame, - scan_all_streams_to_update_metadata, seek_to_pts, ) @@ -154,8 +153,7 @@ def __init__(self, num_threads=None, color_conversion_library=None, device="cpu" self._device = device def decode_frames(self, video_file, pts_list): - decoder = create_from_file(video_file) - scan_all_streams_to_update_metadata(decoder) + decoder = create_from_file(video_file, seek_mode="exact") _add_video_stream( decoder, num_threads=self._num_threads, @@ -170,7 +168,7 @@ def decode_frames(self, video_file, pts_list): return frames def decode_first_n_frames(self, video_file, n): - decoder = create_from_file(video_file) + decoder = create_from_file(video_file, seek_mode="approximate") _add_video_stream( decoder, num_threads=self._num_threads, @@ -197,7 +195,7 @@ def __init__(self, num_threads=None, color_conversion_library=None, device="cpu" self.transforms_v2 = transforms_v2 def decode_frames(self, video_file, pts_list): - decoder = create_from_file(video_file) + decoder = create_from_file(video_file, seek_mode="approximate") num_threads = int(self._num_threads) if self._num_threads else 0 _add_video_stream( decoder, @@ -216,7 +214,7 @@ def decode_frames(self, video_file, pts_list): def decode_first_n_frames(self, video_file, n): num_threads = int(self._num_threads) if self._num_threads else 0 - decoder = create_from_file(video_file) + decoder = create_from_file(video_file, seek_mode="approximate") _add_video_stream( decoder, num_threads=num_threads, @@ -233,7 +231,7 @@ def decode_first_n_frames(self, video_file, n): def decode_and_resize(self, video_file, pts_list, height, width, device): num_threads = int(self._num_threads) if self._num_threads else 1 - decoder = create_from_file(video_file) + decoder = create_from_file(video_file, seek_mode="approximate") _add_video_stream( decoder, num_threads=num_threads, @@ -263,8 +261,7 @@ def __init__(self, num_threads=None, color_conversion_library=None, device="cpu" self._device = device def decode_frames(self, video_file, pts_list): - decoder = create_from_file(video_file) - scan_all_streams_to_update_metadata(decoder) + decoder = create_from_file(video_file, seek_mode="exact") _add_video_stream( decoder, num_threads=self._num_threads, @@ -279,8 +276,7 @@ def decode_frames(self, video_file, pts_list): return frames def decode_first_n_frames(self, video_file, n): - decoder = create_from_file(video_file) - scan_all_streams_to_update_metadata(decoder) + decoder = create_from_file(video_file, seek_mode="exact") _add_video_stream( decoder, num_threads=self._num_threads, @@ -297,9 +293,10 @@ def decode_first_n_frames(self, video_file, n): class TorchCodecPublic(AbstractDecoder): - def __init__(self, num_ffmpeg_threads=None, device="cpu"): + def __init__(self, num_ffmpeg_threads=None, device="cpu", seek_mode="exact"): self._num_ffmpeg_threads = num_ffmpeg_threads self._device = device + self._seek_mode = seek_mode from torchvision.transforms import v2 as transforms_v2 @@ -310,7 +307,10 @@ def decode_frames(self, video_file, pts_list): int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0 ) decoder = VideoDecoder( - video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device + video_file, + num_ffmpeg_threads=num_ffmpeg_threads, + device=self._device, + seek_mode=self._seek_mode, ) return decoder.get_frames_played_at(pts_list) @@ -319,7 +319,10 @@ def decode_first_n_frames(self, video_file, n): int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0 ) decoder = VideoDecoder( - video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device + video_file, + num_ffmpeg_threads=num_ffmpeg_threads, + device=self._device, + seek_mode=self._seek_mode, ) frames = [] count = 0 @@ -335,7 +338,10 @@ def decode_and_resize(self, video_file, pts_list, height, width, device): int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 1 ) decoder = VideoDecoder( - video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device + video_file, + num_ffmpeg_threads=num_ffmpeg_threads, + device=self._device, + seek_mode=self._seek_mode, ) frames = decoder.get_frames_played_at(pts_list) frames = self.transforms_v2.functional.resize(frames.data, (height, width)) @@ -343,9 +349,10 @@ def decode_and_resize(self, video_file, pts_list, height, width, device): class TorchCodecPublicNonBatch(AbstractDecoder): - def __init__(self, num_ffmpeg_threads=None, device="cpu"): + def __init__(self, num_ffmpeg_threads=None, device="cpu", seek_mode="approximate"): self._num_ffmpeg_threads = num_ffmpeg_threads self._device = device + self._seek_mode = seek_mode from torchvision.transforms import v2 as transforms_v2 @@ -356,7 +363,10 @@ def decode_frames(self, video_file, pts_list): int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0 ) decoder = VideoDecoder( - video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device + video_file, + num_ffmpeg_threads=num_ffmpeg_threads, + device=self._device, + seek_mode=self._seek_mode, ) frames = [] @@ -370,7 +380,10 @@ def decode_first_n_frames(self, video_file, n): int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 0 ) decoder = VideoDecoder( - video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device + video_file, + num_ffmpeg_threads=num_ffmpeg_threads, + device=self._device, + seek_mode=self._seek_mode, ) frames = [] count = 0 @@ -386,7 +399,10 @@ def decode_and_resize(self, video_file, pts_list, height, width, device): int(self._num_ffmpeg_threads) if self._num_ffmpeg_threads else 1 ) decoder = VideoDecoder( - video_file, num_ffmpeg_threads=num_ffmpeg_threads, device=self._device + video_file, + num_ffmpeg_threads=num_ffmpeg_threads, + device=self._device, + seek_mode=self._seek_mode, ) frames = [] diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 528ef8cd..3899bc62 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -218,14 +218,16 @@ bool VideoDecoder::DecodedFrameContext::operator!=( return !(*this == other); } -VideoDecoder::VideoDecoder(const std::string& videoFilePath) { +VideoDecoder::VideoDecoder(const std::string& videoFilePath, SeekMode seekMode) + : seekMode_(seekMode) { AVInput input = createAVFormatContextFromFilePath(videoFilePath); formatContext_ = std::move(input.formatContext); initializeDecoder(); } -VideoDecoder::VideoDecoder(const void* buffer, size_t length) { +VideoDecoder::VideoDecoder(const void* buffer, size_t length, SeekMode seekMode) + : seekMode_(seekMode) { TORCH_CHECK(buffer != nullptr, "Video buffer cannot be nullptr!"); AVInput input = createAVFormatContextFromBuffer(buffer, length); @@ -306,18 +308,26 @@ void VideoDecoder::initializeDecoder() { containerMetadata_.bestAudioStreamIndex = bestAudioStream; } + if (seekMode_ == SeekMode::exact) { + scanFileAndUpdateMetadataAndIndex(); + } + initialized_ = true; } std::unique_ptr VideoDecoder::createFromFilePath( - const std::string& videoFilePath) { - return std::unique_ptr(new VideoDecoder(videoFilePath)); + const std::string& videoFilePath, + SeekMode seekMode) { + return std::unique_ptr( + new VideoDecoder(videoFilePath, seekMode)); } std::unique_ptr VideoDecoder::createFromBuffer( const void* buffer, - size_t length) { - return std::unique_ptr(new VideoDecoder(buffer, length)); + size_t length, + SeekMode seekMode) { + return std::unique_ptr( + new VideoDecoder(buffer, length, seekMode)); } void VideoDecoder::createFilterGraph( @@ -441,6 +451,7 @@ void VideoDecoder::addVideoStreamDecoder( " is already active."); } TORCH_CHECK(formatContext_.get() != nullptr); + AVCodecOnlyUseForCallingAVFindBestStream codec = nullptr; int streamNumber = av_find_best_stream( formatContext_.get(), @@ -453,10 +464,12 @@ void VideoDecoder::addVideoStreamDecoder( throw std::invalid_argument("No valid stream found in input file."); } TORCH_CHECK(codec != nullptr); + StreamInfo& streamInfo = streams_[streamNumber]; streamInfo.streamIndex = streamNumber; streamInfo.timeBase = formatContext_->streams[streamNumber]->time_base; streamInfo.stream = formatContext_->streams[streamNumber]; + if (streamInfo.stream->codecpar->codec_type != AVMEDIA_TYPE_VIDEO) { throw std::invalid_argument( "Stream with index " + std::to_string(streamNumber) + @@ -469,12 +482,23 @@ void VideoDecoder::addVideoStreamDecoder( .value_or(codec)); } + StreamMetadata& streamMetadata = containerMetadata_.streams[streamNumber]; + if (seekMode_ == SeekMode::approximate && + !streamMetadata.averageFps.has_value()) { + throw std::runtime_error( + "Seek mode is approximate, but stream " + std::to_string(streamNumber) + + " does not have an average fps in its metadata."); + } + AVCodecContext* codecContext = avcodec_alloc_context3(codec); - codecContext->thread_count = options.ffmpegThreadCount.value_or(0); TORCH_CHECK(codecContext != nullptr); + codecContext->thread_count = options.ffmpegThreadCount.value_or(0); streamInfo.codecContext.reset(codecContext); + int retVal = avcodec_parameters_to_context( streamInfo.codecContext.get(), streamInfo.stream->codecpar); + TORCH_CHECK_EQ(retVal, AVSUCCESS); + if (options.device.type() == torch::kCPU) { // No more initialization needed for CPU. } else if (options.device.type() == torch::kCUDA) { @@ -482,16 +506,16 @@ void VideoDecoder::addVideoStreamDecoder( } else { TORCH_CHECK(false, "Invalid device type: " + options.device.str()); } - TORCH_CHECK_EQ(retVal, AVSUCCESS); + retVal = avcodec_open2(streamInfo.codecContext.get(), codec, nullptr); if (retVal < AVSUCCESS) { throw std::invalid_argument(getFFMPEGErrorStringFromErrorCode(retVal)); } + codecContext->time_base = streamInfo.stream->time_base; activeStreamIndices_.insert(streamNumber); updateMetadataWithCodecContext(streamInfo.streamIndex, codecContext); streamInfo.options = options; - int width = options.width.value_or(codecContext->width); // By default, we want to use swscale for color conversion because it is // faster. However, it has width requirements, so we may need to fall back @@ -500,6 +524,7 @@ void VideoDecoder::addVideoStreamDecoder( // swscale's width requirements to be violated. We don't expose the ability to // choose color conversion library publicly; we only use this ability // internally. + int width = options.width.value_or(codecContext->width); auto defaultLibrary = getDefaultColorConversionLibrary(width); streamInfo.colorConversionLibrary = options.colorConversionLibrary.value_or(defaultLibrary); @@ -544,25 +569,32 @@ int VideoDecoder::getKeyFrameIndexForPtsUsingScannedIndex( } void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { - if (scanned_all_streams_) { + if (scannedAllStreams_) { return; } + while (true) { + // Get the next packet. UniqueAVPacket packet(av_packet_alloc()); int ffmpegStatus = av_read_frame(formatContext_.get(), packet.get()); + if (ffmpegStatus == AVERROR_EOF) { break; } + if (ffmpegStatus != AVSUCCESS) { throw std::runtime_error( "Failed to read frame from input file: " + getFFMPEGErrorStringFromErrorCode(ffmpegStatus)); } - int streamIndex = packet->stream_index; if (packet->flags & AV_PKT_FLAG_DISCARD) { continue; } + + // We got a valid packet. Let's figure out what stream it belongs to and + // record its relevant metadata. + int streamIndex = packet->stream_index; auto& streamMetadata = containerMetadata_.streams[streamIndex]; streamMetadata.minPtsFromScan = std::min( streamMetadata.minPtsFromScan.value_or(INT64_MAX), packet->pts); @@ -572,18 +604,24 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { streamMetadata.numFramesFromScan = streamMetadata.numFramesFromScan.value_or(0) + 1; - FrameInfo frameInfo; - frameInfo.pts = packet->pts; - + // Note that we set the other value in this struct, nextPts, only after + // we have scanned all packets and sorted by pts. + FrameInfo frameInfo = {packet->pts}; if (packet->flags & AV_PKT_FLAG_KEY) { streams_[streamIndex].keyFrames.push_back(frameInfo); } streams_[streamIndex].allFrames.push_back(frameInfo); } + + // Set all per-stream metadata that requires knowing the content of all + // packets. for (size_t streamIndex = 0; streamIndex < containerMetadata_.streams.size(); ++streamIndex) { auto& streamMetadata = containerMetadata_.streams[streamIndex]; auto avStream = formatContext_->streams[streamIndex]; + + streamMetadata.numFramesFromScan = streams_[streamIndex].allFrames.size(); + if (streamMetadata.minPtsFromScan.has_value()) { streamMetadata.minPtsSecondsFromScan = *streamMetadata.minPtsFromScan * av_q2d(avStream->time_base); @@ -593,6 +631,8 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { *streamMetadata.maxPtsFromScan * av_q2d(avStream->time_base); } } + + // Reset the seek-cursor back to the beginning. int ffmepgStatus = avformat_seek_file(formatContext_.get(), 0, INT64_MIN, 0, 0, 0); if (ffmepgStatus < 0) { @@ -600,6 +640,8 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { "Could not seek file to pts=0: " + getFFMPEGErrorStringFromErrorCode(ffmepgStatus)); } + + // Sort all frames by their pts. for (auto& [streamIndex, streamInfo] : streams_) { std::sort( streamInfo.keyFrames.begin(), @@ -620,7 +662,8 @@ void VideoDecoder::scanFileAndUpdateMetadataAndIndex() { } } } - scanned_all_streams_ = true; + + scannedAllStreams_ = true; } int VideoDecoder::getKeyFrameIndexForPts( @@ -1019,6 +1062,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( break; } } + setCursorPtsInSeconds(seconds); RawDecodedOutput rawOutput = getDecodedOutputWithFilter( [seconds, this](int frameStreamIndex, AVFrame* frame) { @@ -1038,8 +1082,9 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( } return seconds >= frameStartTime && seconds < frameEndTime; }); + // Convert the frame to tensor. - auto output = convertAVFrameToDecodedOutput(rawOutput); + DecodedOutput output = convertAVFrameToDecodedOutput(rawOutput); output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame); return output; } @@ -1058,20 +1103,21 @@ void VideoDecoder::validateUserProvidedStreamIndex(int streamIndex) { } void VideoDecoder::validateScannedAllStreams(const std::string& msg) { - if (!scanned_all_streams_) { + if (!scannedAllStreams_) { throw std::runtime_error( "Must scan all streams to update metadata before calling " + msg); } } void VideoDecoder::validateFrameIndex( - const StreamInfo& stream, + const StreamMetadata& streamMetadata, int64_t frameIndex) { + int64_t numFrames = getNumFrames(streamMetadata); TORCH_CHECK( - frameIndex >= 0 && frameIndex < static_cast(stream.allFrames.size()), + frameIndex >= 0 && frameIndex < numFrames, "Invalid frame index=" + std::to_string(frameIndex) + - " for streamIndex=" + std::to_string(stream.streamIndex) + - " numFrames=" + std::to_string(stream.allFrames.size())); + " for streamIndex=" + std::to_string(streamMetadata.streamIndex) + + " numFrames=" + std::to_string(numFrames)); } VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( @@ -1082,26 +1128,119 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( return output; } +int64_t VideoDecoder::getPts( + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata, + int64_t frameIndex) { + switch (seekMode_) { + case SeekMode::exact: + return streamInfo.allFrames[frameIndex].pts; + case SeekMode::approximate: + return secondsToClosestPts( + frameIndex / streamMetadata.averageFps.value(), streamInfo.timeBase); + default: + throw std::runtime_error("Unknown SeekMode"); + } +} + +int64_t VideoDecoder::getNumFrames(const StreamMetadata& streamMetadata) { + switch (seekMode_) { + case SeekMode::exact: + return streamMetadata.numFramesFromScan.value(); + case SeekMode::approximate: + return streamMetadata.numFrames.value(); + default: + throw std::runtime_error("Unknown SeekMode"); + } +} + +double VideoDecoder::getMinSeconds(const StreamMetadata& streamMetadata) { + switch (seekMode_) { + case SeekMode::exact: + return streamMetadata.minPtsSecondsFromScan.value(); + case SeekMode::approximate: + return 0; + default: + throw std::runtime_error("Unknown SeekMode"); + } +} + +double VideoDecoder::getMaxSeconds(const StreamMetadata& streamMetadata) { + switch (seekMode_) { + case SeekMode::exact: + return streamMetadata.maxPtsSecondsFromScan.value(); + case SeekMode::approximate: + return streamMetadata.durationSeconds.value(); + default: + throw std::runtime_error("Unknown SeekMode"); + } +} + +int64_t VideoDecoder::secondsToIndexLowerBound( + double seconds, + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata) { + switch (seekMode_) { + case SeekMode::exact: { + auto frame = std::lower_bound( + streamInfo.allFrames.begin(), + streamInfo.allFrames.end(), + seconds, + [&streamInfo](const FrameInfo& info, double start) { + return ptsToSeconds(info.nextPts, streamInfo.timeBase) <= start; + }); + + return frame - streamInfo.allFrames.begin(); + } + case SeekMode::approximate: + return std::floor(seconds * streamMetadata.averageFps.value()); + default: + throw std::runtime_error("Unknown SeekMode"); + } +} + +int64_t VideoDecoder::secondsToIndexUpperBound( + double seconds, + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata) { + switch (seekMode_) { + case SeekMode::exact: { + auto frame = std::upper_bound( + streamInfo.allFrames.begin(), + streamInfo.allFrames.end(), + seconds, + [&streamInfo](double stop, const FrameInfo& info) { + return stop <= ptsToSeconds(info.pts, streamInfo.timeBase); + }); + + return frame - streamInfo.allFrames.begin(); + } + case SeekMode::approximate: + return std::ceil(seconds * streamMetadata.averageFps.value()); + default: + throw std::runtime_error("Unknown SeekMode"); + } +} + VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( int streamIndex, int64_t frameIndex, std::optional preAllocatedOutputTensor) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getFrameAtIndex"); - const auto& stream = streams_[streamIndex]; - validateFrameIndex(stream, frameIndex); + const auto& streamInfo = streams_[streamIndex]; + const auto& streamMetadata = containerMetadata_.streams[streamIndex]; + validateFrameIndex(streamMetadata, frameIndex); - int64_t pts = stream.allFrames[frameIndex].pts; - setCursorPtsInSeconds(ptsToSeconds(pts, stream.timeBase)); - return getNextFrameOutputNoDemuxInternal(preAllocatedOutputTensor); + int64_t pts = getPts(streamInfo, streamMetadata, frameIndex); + setCursorPtsInSeconds(ptsToSeconds(pts, streamInfo.timeBase)); + return getNextFrameNoDemuxInternal(preAllocatedOutputTensor); } VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( int streamIndex, const std::vector& frameIndices) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getFramesAtIndices"); auto indicesAreSorted = std::is_sorted(frameIndices.begin(), frameIndices.end()); @@ -1131,11 +1270,9 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( for (size_t f = 0; f < frameIndices.size(); ++f) { auto indexInOutput = indicesAreSorted ? f : argsort[f]; auto indexInVideo = frameIndices[indexInOutput]; - if (indexInVideo < 0 || - indexInVideo >= static_cast(stream.allFrames.size())) { - throw std::runtime_error( - "Invalid frame index=" + std::to_string(indexInVideo)); - } + + validateFrameIndex(streamMetadata, indexInVideo); + if ((f > 0) && (indexInVideo == previousIndexInVideo)) { // Avoid decoding the same frame twice auto previousIndexInOutput = indicesAreSorted ? f - 1 : argsort[f - 1]; @@ -1160,38 +1297,29 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps( int streamIndex, const std::vector& timestamps) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getFramesPlayedByTimestamps"); + + const auto& streamMetadata = containerMetadata_.streams[streamIndex]; + const auto& stream = streams_[streamIndex]; + + double minSeconds = getMinSeconds(streamMetadata); + double maxSeconds = getMaxSeconds(streamMetadata); // The frame played at timestamp t and the one played at timestamp `t + // eps` are probably the same frame, with the same index. The easiest way to // avoid decoding that unique frame twice is to convert the input timestamps // to indices, and leverage the de-duplication logic of getFramesAtIndices. - // This means this function requires a scan. - // TODO: longer term, we should implement this without requiring a scan - - const auto& streamMetadata = containerMetadata_.streams[streamIndex]; - const auto& stream = streams_[streamIndex]; - double minSeconds = streamMetadata.minPtsSecondsFromScan.value(); - double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value(); std::vector frameIndices(timestamps.size()); for (size_t i = 0; i < timestamps.size(); ++i) { - auto framePts = timestamps[i]; + auto frameSeconds = timestamps[i]; TORCH_CHECK( - framePts >= minSeconds && framePts < maxSeconds, - "frame pts is " + std::to_string(framePts) + "; must be in range [" + - std::to_string(minSeconds) + ", " + std::to_string(maxSeconds) + - ")."); - - auto it = std::lower_bound( - stream.allFrames.begin(), - stream.allFrames.end(), - framePts, - [&stream](const FrameInfo& info, double framePts) { - return ptsToSeconds(info.nextPts, stream.timeBase) <= framePts; - }); - int64_t frameIndex = it - stream.allFrames.begin(); - frameIndices[i] = frameIndex; + frameSeconds >= minSeconds && frameSeconds < maxSeconds, + "frame pts is " + std::to_string(frameSeconds) + + "; must be in range [" + std::to_string(minSeconds) + ", " + + std::to_string(maxSeconds) + ")."); + + frameIndices[i] = + secondsToIndexLowerBound(frameSeconds, stream, streamMetadata); } return getFramesAtIndices(streamIndex, frameIndices); @@ -1203,17 +1331,16 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( int64_t stop, int64_t step) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getFramesInRange"); const auto& streamMetadata = containerMetadata_.streams[streamIndex]; const auto& stream = streams_[streamIndex]; + int64_t numFrames = getNumFrames(streamMetadata); TORCH_CHECK( start >= 0, "Range start, " + std::to_string(start) + " is less than 0."); TORCH_CHECK( - stop <= static_cast(stream.allFrames.size()), + stop <= numFrames, "Range stop, " + std::to_string(stop) + - ", is more than the number of frames, " + - std::to_string(stream.allFrames.size())); + ", is more than the number of frames, " + std::to_string(numFrames)); TORCH_CHECK( step > 0, "Step must be greater than 0; is " + std::to_string(step)); @@ -1237,26 +1364,13 @@ VideoDecoder::getFramesPlayedByTimestampInRange( double startSeconds, double stopSeconds) { validateUserProvidedStreamIndex(streamIndex); - validateScannedAllStreams("getFramesPlayedByTimestampInRange"); const auto& streamMetadata = containerMetadata_.streams[streamIndex]; - double minSeconds = streamMetadata.minPtsSecondsFromScan.value(); - double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value(); TORCH_CHECK( startSeconds <= stopSeconds, "Start seconds (" + std::to_string(startSeconds) + ") must be less than or equal to stop seconds (" + std::to_string(stopSeconds) + "."); - TORCH_CHECK( - startSeconds >= minSeconds && startSeconds < maxSeconds, - "Start seconds is " + std::to_string(startSeconds) + - "; must be in range [" + std::to_string(minSeconds) + ", " + - std::to_string(maxSeconds) + ")."); - TORCH_CHECK( - stopSeconds <= maxSeconds, - "Stop seconds (" + std::to_string(stopSeconds) + - "; must be less than or equal to " + std::to_string(maxSeconds) + - ")."); const auto& stream = streams_[streamIndex]; const auto& options = stream.options; @@ -1284,36 +1398,38 @@ VideoDecoder::getFramesPlayedByTimestampInRange( return output; } - // Note that we look at nextPts for a frame, and not its pts or duration. Our - // abstract player displays frames starting at the pts for that frame until - // the pts for the next frame. There are two consequences: + double minSeconds = getMinSeconds(streamMetadata); + double maxSeconds = getMaxSeconds(streamMetadata); + TORCH_CHECK( + startSeconds >= minSeconds && startSeconds < maxSeconds, + "Start seconds is " + std::to_string(startSeconds) + + "; must be in range [" + std::to_string(minSeconds) + ", " + + std::to_string(maxSeconds) + ")."); + TORCH_CHECK( + stopSeconds <= maxSeconds, + "Stop seconds (" + std::to_string(stopSeconds) + + "; must be less than or equal to " + std::to_string(maxSeconds) + + ")."); + + // Note that we look at nextPts for a frame, and not its pts or duration. + // Our abstract player displays frames starting at the pts for that frame + // until the pts for the next frame. There are two consequences: // // 1. We ignore the duration for a frame. A frame is played until the // next frame replaces it. This model is robust to durations being 0 or // incorrect; our source of truth is the pts for frames. If duration is - // accurate, the nextPts for a frame would be equivalent to pts + duration. - // 2. In order to establish if the start of an interval maps to a particular - // frame, we need to figure out if it is ordered after the frame's pts, but - // before the next frames's pts. - auto startFrame = std::lower_bound( - stream.allFrames.begin(), - stream.allFrames.end(), - startSeconds, - [&stream](const FrameInfo& info, double start) { - return ptsToSeconds(info.nextPts, stream.timeBase) <= start; - }); - - auto stopFrame = std::upper_bound( - stream.allFrames.begin(), - stream.allFrames.end(), - stopSeconds, - [&stream](double stop, const FrameInfo& info) { - return stop <= ptsToSeconds(info.pts, stream.timeBase); - }); - - int64_t startFrameIndex = startFrame - stream.allFrames.begin(); - int64_t stopFrameIndex = stopFrame - stream.allFrames.begin(); + // accurate, the nextPts for a frame would be equivalent to pts + + // duration. + // 2. In order to establish if the start of an interval maps to a + // particular frame, we need to figure out if it is ordered after the + // frame's pts, but before the next frames's pts. + + int64_t startFrameIndex = + secondsToIndexLowerBound(startSeconds, stream, streamMetadata); + int64_t stopFrameIndex = + secondsToIndexUpperBound(stopSeconds, stream, streamMetadata); int64_t numFrames = stopFrameIndex - startFrameIndex; + BatchDecodedOutput output(numFrames, options, streamMetadata); for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { DecodedOutput singleOut = @@ -1336,12 +1452,12 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { } VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux() { - auto output = getNextFrameOutputNoDemuxInternal(); + auto output = getNextFrameNoDemuxInternal(); output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame); return output; } -VideoDecoder::DecodedOutput VideoDecoder::getNextFrameOutputNoDemuxInternal( +VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemuxInternal( std::optional preAllocatedOutputTensor) { auto rawOutput = getNextRawDecodedOutputNoDemux(); return convertAVFrameToDecodedOutput(rawOutput, preAllocatedOutputTensor); @@ -1365,10 +1481,12 @@ double VideoDecoder::getPtsSecondsForFrame( validateUserProvidedStreamIndex(streamIndex); validateScannedAllStreams("getPtsSecondsForFrame"); - const auto& stream = streams_[streamIndex]; - validateFrameIndex(stream, frameIndex); + const auto& streamInfo = streams_[streamIndex]; + const auto& streamMetadata = containerMetadata_.streams[streamIndex]; + validateFrameIndex(streamMetadata, frameIndex); - return ptsToSeconds(stream.allFrames[frameIndex].pts, stream.timeBase); + return ptsToSeconds( + streamInfo.allFrames[frameIndex].pts, streamInfo.timeBase); } void VideoDecoder::createSwsContext( diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 9afce572..ee723695 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -50,19 +50,23 @@ class VideoDecoder { // CONSTRUCTION API // -------------------------------------------------------------------------- + enum class SeekMode { exact, approximate }; + // Creates a VideoDecoder from the video at videoFilePath. - explicit VideoDecoder(const std::string& videoFilePath); + explicit VideoDecoder(const std::string& videoFilePath, SeekMode seekMode); // Creates a VideoDecoder from a given buffer. Note that the buffer is not // owned by the VideoDecoder. - explicit VideoDecoder(const void* buffer, size_t length); + explicit VideoDecoder(const void* buffer, size_t length, SeekMode seekMode); static std::unique_ptr createFromFilePath( - const std::string& videoFilePath); + const std::string& videoFilePath, + SeekMode seekMode = SeekMode::exact); static std::unique_ptr createFromBuffer( const void* buffer, - size_t length); + size_t length, + SeekMode seekMode = SeekMode::exact); // -------------------------------------------------------------------------- // VIDEO METADATA QUERY API @@ -157,7 +161,7 @@ class VideoDecoder { // ---- SINGLE FRAME SEEK AND DECODING API ---- // Places the cursor at the first frame on or after the position in seconds. - // Calling getNextFrameOutputNoDemuxInternal() will return the first frame at + // Calling getNextFrameNoDemuxInternal() will return the first frame at // or after this position. void setCursorPtsInSeconds(double seconds); // This structure ensures we always keep the streamIndex and AVFrame together @@ -342,17 +346,42 @@ class VideoDecoder { void initializeDecoder(); void validateUserProvidedStreamIndex(int streamIndex); void validateScannedAllStreams(const std::string& msg); - void validateFrameIndex(const StreamInfo& stream, int64_t frameIndex); + void validateFrameIndex( + const StreamMetadata& streamMetadata, + int64_t frameIndex); + // Creates and initializes a filter graph for a stream. The filter graph can // do rescaling and color conversion. void createFilterGraph( StreamInfo& streamInfo, int expectedOutputHeight, int expectedOutputWidth); + + int64_t getNumFrames(const StreamMetadata& streamMetadata); + + int64_t getPts( + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata, + int64_t frameIndex); + + double getMinSeconds(const StreamMetadata& streamMetadata); + double getMaxSeconds(const StreamMetadata& streamMetadata); + + int64_t secondsToIndexLowerBound( + double seconds, + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata); + + int64_t secondsToIndexUpperBound( + double seconds, + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata); + void createSwsContext( StreamInfo& streamInfo, const DecodedFrameContext& frameContext, const enum AVColorSpace colorspace); + void maybeSeekToBeforeDesiredPts(); RawDecodedOutput getDecodedOutputWithFilter( std::function); @@ -379,9 +408,10 @@ class VideoDecoder { DecodedOutput& output, std::optional preAllocatedOutputTensor = std::nullopt); - DecodedOutput getNextFrameOutputNoDemuxInternal( + DecodedOutput getNextFrameNoDemuxInternal( std::optional preAllocatedOutputTensor = std::nullopt); + SeekMode seekMode_; ContainerMetadata containerMetadata_; UniqueAVFormatContext formatContext_; std::map streams_; @@ -397,7 +427,7 @@ class VideoDecoder { // Stores the AVIOContext for the input buffer. std::unique_ptr ioBytesContext_; // Whether or not we have already scanned all streams to update the metadata. - bool scanned_all_streams_ = false; + bool scannedAllStreams_ = false; // Tracks that we've already been initialized. bool initialized_ = false; }; diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp index 5d941447..c63a89e7 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.cpp @@ -27,8 +27,9 @@ TORCH_LIBRARY(torchcodec_ns, m) { m.impl_abstract_pystub( "torchcodec.decoders._core.video_decoder_ops", "//pytorch/torchcodec:torchcodec"); - m.def("create_from_file(str filename) -> Tensor"); - m.def("create_from_tensor(Tensor video_tensor) -> Tensor"); + m.def("create_from_file(str filename, str? seek_mode=None) -> Tensor"); + m.def( + "create_from_tensor(Tensor video_tensor, str? seek_mode=None) -> Tensor"); m.def( "_add_video_stream(Tensor(a!) decoder, *, int? width=None, int? height=None, int? num_threads=None, str? dimension_order=None, int? stream_index=None, str? device=None, str? color_conversion_library=None) -> ()"); m.def( @@ -88,31 +89,67 @@ OpsBatchDecodedOutput makeOpsBatchDecodedOutput( VideoDecoder::BatchDecodedOutput& batch) { return std::make_tuple(batch.frames, batch.ptsSeconds, batch.durationSeconds); } + +VideoDecoder::SeekMode seekModeFromString(std::string_view seekMode) { + if (seekMode == "exact") { + return VideoDecoder::SeekMode::exact; + } else if (seekMode == "approximate") { + return VideoDecoder::SeekMode::approximate; + } else { + throw std::runtime_error("Invalid seek mode: " + std::string(seekMode)); + } +} + } // namespace // ============================== // Implementations for the operators // ============================== -at::Tensor create_from_file(std::string_view filename) { +at::Tensor create_from_file( + std::string_view filename, + std::optional seek_mode) { std::string filenameStr(filename); + + VideoDecoder::SeekMode realSeek = VideoDecoder::SeekMode::exact; + if (seek_mode.has_value()) { + realSeek = seekModeFromString(seek_mode.value()); + } + std::unique_ptr uniqueDecoder = - VideoDecoder::createFromFilePath(filenameStr); + VideoDecoder::createFromFilePath(filenameStr, realSeek); + return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); } -at::Tensor create_from_tensor(at::Tensor video_tensor) { +at::Tensor create_from_tensor( + at::Tensor video_tensor, + std::optional seek_mode) { TORCH_CHECK(video_tensor.is_contiguous(), "video_tensor must be contiguous"); void* buffer = video_tensor.mutable_data_ptr(); size_t length = video_tensor.numel(); + + VideoDecoder::SeekMode realSeek = VideoDecoder::SeekMode::exact; + if (seek_mode.has_value()) { + realSeek = seekModeFromString(seek_mode.value()); + } + std::unique_ptr videoDecoder = - VideoDecoder::createFromBuffer(buffer, length); + VideoDecoder::createFromBuffer(buffer, length, realSeek); return wrapDecoderPointerToTensor(std::move(videoDecoder)); } -at::Tensor create_from_buffer(const void* buffer, size_t length) { +at::Tensor create_from_buffer( + const void* buffer, + size_t length, + std::optional seek_mode) { + VideoDecoder::SeekMode realSeek = VideoDecoder::SeekMode::exact; + if (seek_mode.has_value()) { + realSeek = seekModeFromString(seek_mode.value()); + } + std::unique_ptr uniqueDecoder = - VideoDecoder::createFromBuffer(buffer, length); + VideoDecoder::createFromBuffer(buffer, length, realSeek); return wrapDecoderPointerToTensor(std::move(uniqueDecoder)); } diff --git a/src/torchcodec/decoders/_core/VideoDecoderOps.h b/src/torchcodec/decoders/_core/VideoDecoderOps.h index 7717a48b..60635094 100644 --- a/src/torchcodec/decoders/_core/VideoDecoderOps.h +++ b/src/torchcodec/decoders/_core/VideoDecoderOps.h @@ -20,13 +20,20 @@ namespace facebook::torchcodec { // auto decoderTensor = createDecoderOp.call(videoPath); // Create a VideoDecoder from file and wrap the pointer in a tensor. -at::Tensor create_from_file(std::string_view filename); +at::Tensor create_from_file( + std::string_view filename, + std::optional seek_mode = std::nullopt); -at::Tensor create_from_tensor(at::Tensor video_tensor); +at::Tensor create_from_tensor( + at::Tensor video_tensor, + std::optional seek_mode = std::nullopt); // This API is C++ only and will not be exposed via custom ops, use // videodecoder_create_from_bytes in Python -at::Tensor create_from_buffer(const void* buffer, size_t length); +at::Tensor create_from_buffer( + const void* buffer, + size_t length, + std::optional seek_mode = std::nullopt); // Add a new video stream at `stream_index` using the provided options. void add_video_stream( diff --git a/src/torchcodec/decoders/_core/_metadata.py b/src/torchcodec/decoders/_core/_metadata.py index 4d5e9a31..1ec4a87c 100644 --- a/src/torchcodec/decoders/_core/_metadata.py +++ b/src/torchcodec/decoders/_core/_metadata.py @@ -37,12 +37,12 @@ class VideoStreamMetadata: content (the scan doesn't involve decoding). This is more accurate than ``num_frames_from_header``. We recommend using the ``num_frames`` attribute instead. (int or None).""" - begin_stream_seconds: Optional[float] + begin_stream_seconds_from_content: Optional[float] """Beginning of the stream, in seconds (float or None). Conceptually, this corresponds to the first frame's :term:`pts`. It is computed as min(frame.pts) across all frames in the stream. Usually, this is equal to 0.""" - end_stream_seconds: Optional[float] + end_stream_seconds_from_content: Optional[float] """End of the stream, in seconds (float or None). Conceptually, this corresponds to last_frame.pts + last_frame.duration. It is computed as max(frame.pts + frame.duration) across all frames in the @@ -81,9 +81,15 @@ def duration_seconds(self) -> Optional[float]: from the actual frames if a :term:`scan` was performed. Otherwise we fall back to ``duration_seconds_from_header``. """ - if self.end_stream_seconds is None or self.begin_stream_seconds is None: + if ( + self.end_stream_seconds_from_content is None + or self.begin_stream_seconds_from_content is None + ): return self.duration_seconds_from_header - return self.end_stream_seconds - self.begin_stream_seconds + return ( + self.end_stream_seconds_from_content + - self.begin_stream_seconds_from_content + ) @property def average_fps(self) -> Optional[float]: @@ -92,12 +98,39 @@ def average_fps(self) -> Optional[float]: Otherwise we fall back to ``average_fps_from_header``. """ if ( - self.end_stream_seconds is None - or self.begin_stream_seconds is None + self.end_stream_seconds_from_content is None + or self.begin_stream_seconds_from_content is None or self.num_frames is None ): return self.average_fps_from_header - return self.num_frames / (self.end_stream_seconds - self.begin_stream_seconds) + return self.num_frames / ( + self.end_stream_seconds_from_content + - self.begin_stream_seconds_from_content + ) + + @property + def begin_stream_seconds(self) -> float: + """Beginning of the stream, in seconds (float). Conceptually, this + corresponds to the first frame's :term:`pts`. If + ``begin_stream_seconds_from_content`` is not None, then it is returned. + Otherwise, this value is 0. + """ + if self.begin_stream_seconds_from_content is None: + return 0 + else: + return self.begin_stream_seconds_from_content + + @property + def end_stream_seconds(self) -> Optional[float]: + """End of the stream, in seconds (float or None). + Conceptually, this corresponds to last_frame.pts + last_frame.duration. + If ``end_stream_seconds_from_content`` is not None, then that value is + returned. Otherwise, returns ``duration_seconds``. + """ + if self.end_stream_seconds_from_content is None: + return self.duration_seconds + else: + return self.end_stream_seconds_from_content def __repr__(self): # Overridden because properites are not printed by default. @@ -152,8 +185,12 @@ def get_video_metadata(decoder: torch.Tensor) -> VideoMetadata: bit_rate=stream_dict.get("bitRate"), num_frames_from_header=stream_dict.get("numFrames"), num_frames_from_content=stream_dict.get("numFramesFromScan"), - begin_stream_seconds=stream_dict.get("minPtsSecondsFromScan"), - end_stream_seconds=stream_dict.get("maxPtsSecondsFromScan"), + begin_stream_seconds_from_content=stream_dict.get( + "minPtsSecondsFromScan" + ), + end_stream_seconds_from_content=stream_dict.get( + "maxPtsSecondsFromScan" + ), codec=stream_dict.get("codec"), width=stream_dict.get("width"), height=stream_dict.get("height"), @@ -172,4 +209,4 @@ def get_video_metadata(decoder: torch.Tensor) -> VideoMetadata: def get_video_metadata_from_header(filename: Union[str, pathlib.Path]) -> VideoMetadata: - return get_video_metadata(create_from_file(str(filename))) + return get_video_metadata(create_from_file(str(filename), seek_mode="approximate")) diff --git a/src/torchcodec/decoders/_core/video_decoder_ops.py b/src/torchcodec/decoders/_core/video_decoder_ops.py index d3f8e9a6..16e63392 100644 --- a/src/torchcodec/decoders/_core/video_decoder_ops.py +++ b/src/torchcodec/decoders/_core/video_decoder_ops.py @@ -94,25 +94,29 @@ def load_torchcodec_extension(): # ============================= # Functions not related to custom ops, but similar implementation to c++ ops # ============================= -def create_from_bytes(video_bytes: bytes) -> torch.Tensor: +def create_from_bytes( + video_bytes: bytes, seek_mode: Optional[str] = None +) -> torch.Tensor: with warnings.catch_warnings(): # Ignore warning stating that the underlying video_bytes buffer is # non-writable. warnings.filterwarnings("ignore", category=UserWarning) buffer = torch.frombuffer(video_bytes, dtype=torch.uint8) - return create_from_tensor(buffer) + return create_from_tensor(buffer, seek_mode) # ============================== # Abstract impl for the operators. Needed by torch.compile. # ============================== @register_fake("torchcodec_ns::create_from_file") -def create_from_file_abstract(filename: str) -> torch.Tensor: +def create_from_file_abstract(filename: str, seek_mode: Optional[str]) -> torch.Tensor: return torch.empty([], dtype=torch.long) @register_fake("torchcodec_ns::create_from_tensor") -def create_from_tensor_abstract(video_tensor: torch.Tensor) -> torch.Tensor: +def create_from_tensor_abstract( + video_tensor: torch.Tensor, seek_mode: Optional[str] +) -> torch.Tensor: return torch.empty([], dtype=torch.long) diff --git a/src/torchcodec/decoders/_video_decoder.py b/src/torchcodec/decoders/_video_decoder.py index 7930756a..46279135 100644 --- a/src/torchcodec/decoders/_video_decoder.py +++ b/src/torchcodec/decoders/_video_decoder.py @@ -50,6 +50,10 @@ class VideoDecoder: Passing 0 lets FFmpeg decide on the number of threads. Default: 1. device (str or torch.device, optional): The device to use for decoding. Default: "cpu". + seek_mode (str, optional): Determines if frame access will be "exact" or + "approximate". Exact guarantees that requesting frame i will always return frame i, + but doing so requires an initial :term:`scan` of the file. Approximate is faster as it avoids scanning the + file, but less accurate as it uses the file's metadata to calculate where i probably is. Default: "exact". Attributes: @@ -67,15 +71,23 @@ def __init__( dimension_order: Literal["NCHW", "NHWC"] = "NCHW", num_ffmpeg_threads: int = 1, device: Optional[Union[str, device]] = "cpu", + seek_mode: Literal["exact", "approximate"] = "exact", ): + allowed_seek_modes = ("exact", "approximate") + if seek_mode not in allowed_seek_modes: + raise ValueError( + f"Invalid seek mode ({seek_mode}). " + f"Supported values are {', '.join(allowed_seek_modes)}." + ) + if isinstance(source, str): - self._decoder = core.create_from_file(source) + self._decoder = core.create_from_file(source, seek_mode) elif isinstance(source, Path): - self._decoder = core.create_from_file(str(source)) + self._decoder = core.create_from_file(str(source), seek_mode) elif isinstance(source, bytes): - self._decoder = core.create_from_bytes(source) + self._decoder = core.create_from_bytes(source, seek_mode) elif isinstance(source, Tensor): - self._decoder = core.create_from_tensor(source) + self._decoder = core.create_from_tensor(source, seek_mode) else: raise TypeError( f"Unknown source type: {type(source)}. " @@ -92,7 +104,6 @@ def __init__( if num_ffmpeg_threads is None: raise ValueError(f"{num_ffmpeg_threads = } should be an int.") - core.scan_all_streams_to_update_metadata(self._decoder) core.add_video_stream( self._decoder, stream_index=stream_index, @@ -105,11 +116,11 @@ def __init__( self._decoder, stream_index ) - if self.metadata.num_frames_from_content is None: + if self.metadata.num_frames is None: raise ValueError( "The number of frames is unknown. " + _ERROR_REPORTING_INSTRUCTIONS ) - self._num_frames = self.metadata.num_frames_from_content + self._num_frames = self.metadata.num_frames if self.metadata.begin_stream_seconds is None: raise ValueError( diff --git a/src/torchcodec/samplers/_time_based.py b/src/torchcodec/samplers/_time_based.py index 2b531e53..03d11575 100644 --- a/src/torchcodec/samplers/_time_based.py +++ b/src/torchcodec/samplers/_time_based.py @@ -38,12 +38,13 @@ def _validate_params_time_based( "Could not infer average fps from video metadata. " "Try using an index-based sampler instead." ) - if ( - decoder.metadata.end_stream_seconds is None - or decoder.metadata.begin_stream_seconds is None - ): + + # Note that metadata.begin_stream_seconds is a property that will always yield a valid + # value; if it is not present in the actual metadata, the metadata object will return 0. + # Hence, we do not test for it here and only test metadata.end_stream_seconds. + if decoder.metadata.end_stream_seconds is None: raise ValueError( - "Could not infer stream end and start from video metadata. " + "Could not infer stream end from video metadata. " "Try using an index-based sampler instead." ) diff --git a/test/decoders/VideoDecoderTest.cpp b/test/decoders/VideoDecoderTest.cpp index 6b3f32d6..e3b3f1e3 100644 --- a/test/decoders/VideoDecoderTest.cpp +++ b/test/decoders/VideoDecoderTest.cpp @@ -50,9 +50,11 @@ class VideoDecoderTest : public testing::TestWithParam { content_ = outputStringStream.str(); void* buffer = content_.data(); size_t length = outputStringStream.str().length(); - return VideoDecoder::createFromBuffer(buffer, length); + return VideoDecoder::createFromBuffer( + buffer, length, VideoDecoder::SeekMode::approximate); } else { - return VideoDecoder::createFromFilePath(filepath); + return VideoDecoder::createFromFilePath( + filepath, VideoDecoder::SeekMode::approximate); } } std::string content_; diff --git a/test/decoders/test_metadata.py b/test/decoders/test_metadata.py index 1ccceb62..e4caad37 100644 --- a/test/decoders/test_metadata.py +++ b/test/decoders/test_metadata.py @@ -13,17 +13,14 @@ get_ffmpeg_library_versions, get_video_metadata, get_video_metadata_from_header, - scan_all_streams_to_update_metadata, VideoStreamMetadata, ) from ..utils import NASA_VIDEO -def _get_video_metadata(path, with_scan: bool): - decoder = create_from_file(str(path)) - if with_scan: - scan_all_streams_to_update_metadata(decoder) +def _get_video_metadata(path, seek_mode): + decoder = create_from_file(str(path), seek_mode=seek_mode) return get_video_metadata(decoder) @@ -31,13 +28,13 @@ def _get_video_metadata(path, with_scan: bool): "metadata_getter", ( get_video_metadata_from_header, - functools.partial(_get_video_metadata, with_scan=False), - functools.partial(_get_video_metadata, with_scan=True), + functools.partial(_get_video_metadata, seek_mode="approximate"), + functools.partial(_get_video_metadata, seek_mode="exact"), ), ) def test_get_metadata(metadata_getter): with_scan = ( - metadata_getter.keywords["with_scan"] + metadata_getter.keywords["seek_mode"] == "exact" if isinstance(metadata_getter, functools.partial) else False ) @@ -92,8 +89,8 @@ def test_num_frames_fallback( bit_rate=123, num_frames_from_header=num_frames_from_header, num_frames_from_content=num_frames_from_content, - begin_stream_seconds=0, - end_stream_seconds=4, + begin_stream_seconds_from_content=0, + end_stream_seconds_from_content=4, codec="whatever", width=123, height=321, diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index 3cbacd9f..c05b33d9 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -23,7 +23,8 @@ class TestVideoDecoder: @pytest.mark.parametrize("source_kind", ("str", "path", "tensor", "bytes")) - def test_create(self, source_kind): + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_create(self, source_kind, seek_mode): if source_kind == "str": source = str(NASA_VIDEO.path) elif source_kind == "path": @@ -37,14 +38,9 @@ def test_create(self, source_kind): else: raise ValueError("Oops, double check the parametrization of this test!") - decoder = VideoDecoder(source) + decoder = VideoDecoder(source, seek_mode=seek_mode) assert isinstance(decoder.metadata, _core.VideoStreamMetadata) - assert ( - len(decoder) - == decoder._num_frames - == decoder.metadata.num_frames_from_content - == 390 - ) + assert len(decoder) == decoder._num_frames == 390 assert decoder.stream_index == decoder.metadata.stream_index == 3 assert decoder.metadata.duration_seconds == pytest.approx(13.013) assert decoder.metadata.average_fps == pytest.approx(29.970029) @@ -62,11 +58,18 @@ def test_create_fails(self): with pytest.raises(ValueError, match="No valid stream found"): decoder = VideoDecoder(NASA_VIDEO.path, stream_index=1) # noqa + with pytest.raises(ValueError, match="Invalid seek mode"): + decoder = VideoDecoder(NASA_VIDEO.path, seek_mode="blah") # noqa + @pytest.mark.parametrize("num_ffmpeg_threads", (1, 4)) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_getitem_int(self, num_ffmpeg_threads, device): + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_getitem_int(self, num_ffmpeg_threads, device, seek_mode): decoder = VideoDecoder( - NASA_VIDEO.path, num_ffmpeg_threads=num_ffmpeg_threads, device=device + NASA_VIDEO.path, + num_ffmpeg_threads=num_ffmpeg_threads, + device=device, + seek_mode=seek_mode, ) ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0).to(device) @@ -110,8 +113,9 @@ def test_getitem_numpy_int(self): assert_frames_equal(ref_frame180, decoder[numpy.uint32(180)]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_getitem_slice(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_getitem_slice(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) # ensure that the degenerate case of a range of size 1 works @@ -251,8 +255,9 @@ def test_getitem_slice(self, device): assert_frames_equal(sliced, ref) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_getitem_fails(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_getitem_fails(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) with pytest.raises(IndexError, match="out of bounds"): frame = decoder[1000] # noqa @@ -267,8 +272,9 @@ def test_getitem_fails(self, device): frame = decoder[2.3] # noqa @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_iteration(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_iteration(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) ref_frame0 = NASA_VIDEO.get_frame_data_by_index(0).to(device) ref_frame1 = NASA_VIDEO.get_frame_data_by_index(1).to(device) @@ -312,8 +318,9 @@ def test_iteration_slow(self): assert iterations == len(decoder) == 390 @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frame_at(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frame_at(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) ref_frame9 = NASA_VIDEO.get_frame_data_by_index(9).to(device) frame9 = decoder.get_frame_at(9) @@ -355,8 +362,9 @@ def test_get_frame_at_tuple_unpacking(self, device): assert frame.duration_seconds == duration @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frame_at_fails(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frame_at_fails(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) with pytest.raises(IndexError, match="out of bounds"): frame = decoder.get_frame_at(-1) # noqa @@ -365,8 +373,9 @@ def test_get_frame_at_fails(self, device): frame = decoder.get_frame_at(10000) # noqa @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frames_at(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_at(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) frames = decoder.get_frames_at([35, 25]) @@ -404,8 +413,9 @@ def test_get_frames_at(self, device): ) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frames_at_fails(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_at_fails(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) with pytest.raises(RuntimeError, match="Invalid frame index=-1"): decoder.get_frames_at([-1]) @@ -430,8 +440,9 @@ def test_get_frame_at_av1(self, device): assert_frames_equal(decoded_frame10.data, ref_frame10.to(device=device)) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frame_played_at(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frame_played_at(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) ref_frame_played_at_6 = NASA_VIDEO.get_frame_data_by_index(180).to(device) assert_frames_equal( @@ -451,13 +462,16 @@ def test_get_frame_played_at_h265(self): # We don't parametrize with CUDA because the current GPUs on CI do not # support x265: # https://github.com/pytorch/torchcodec/pull/350#issuecomment-2465011730 - decoder = VideoDecoder(H265_VIDEO.path) + # Note that because our internal fix-up depends on the key frame index, it + # only works in exact seeking mode. + decoder = VideoDecoder(H265_VIDEO.path, seek_mode="exact") ref_frame6 = H265_VIDEO.get_frame_data_by_index(5) assert_frames_equal(ref_frame6, decoder.get_frame_played_at(0.5).data) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frame_played_at_fails(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frame_played_at_fails(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) with pytest.raises(IndexError, match="Invalid pts in seconds"): frame = decoder.get_frame_played_at(-1.0) # noqa @@ -466,9 +480,10 @@ def test_get_frame_played_at_fails(self, device): frame = decoder.get_frame_played_at(100.0) # noqa @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frames_played_at(self, device): + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_played_at(self, device, seek_mode): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) # Note: We know the frame at ~0.84s has index 25, the one at 1.16s has # index 35. We use those indices as reference to test against. @@ -482,6 +497,7 @@ def test_get_frames_played_at(self, device): assert_frames_equal( frames.data[i], NASA_VIDEO.get_frame_data_by_index(reference_indices[i]).to(device), + msg=f"index {i}", ) assert frames.pts_seconds.device.type == "cpu" @@ -503,8 +519,9 @@ def test_get_frames_played_at(self, device): ) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frames_played_at_fails(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_played_at_fails(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) with pytest.raises(RuntimeError, match="must be in range"): decoder.get_frames_played_at([-1]) @@ -517,9 +534,13 @@ def test_get_frames_played_at_fails(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) @pytest.mark.parametrize("stream_index", [0, 3, None]) - def test_get_frames_in_range(self, stream_index, device): + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_in_range(self, stream_index, device, seek_mode): decoder = VideoDecoder( - NASA_VIDEO.path, stream_index=stream_index, device=device + NASA_VIDEO.path, + stream_index=stream_index, + device=device, + seek_mode=seek_mode, ) # test degenerate case where we only actually get 1 frame @@ -630,9 +651,13 @@ def test_get_frames_in_range(self, stream_index, device): ), ) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_dimension_order(self, dimension_order, frame_getter, device): + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_dimension_order(self, dimension_order, frame_getter, device, seek_mode): decoder = VideoDecoder( - NASA_VIDEO.path, dimension_order=dimension_order, device=device + NASA_VIDEO.path, + dimension_order=dimension_order, + device=device, + seek_mode=seek_mode, ) frame = frame_getter(decoder) @@ -654,9 +679,13 @@ def test_dimension_order_fails(self): @pytest.mark.parametrize("stream_index", [0, 3, None]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frames_by_pts_in_range(self, stream_index, device): + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_by_pts_in_range(self, stream_index, device, seek_mode): decoder = VideoDecoder( - NASA_VIDEO.path, stream_index=stream_index, device=device + NASA_VIDEO.path, + stream_index=stream_index, + device=device, + seek_mode=seek_mode, ) # Note that we are comparing the results of VideoDecoder's method: @@ -789,8 +818,9 @@ def test_get_frames_by_pts_in_range(self, stream_index, device): assert_frames_equal(all_frames.data, decoder[:]) @pytest.mark.parametrize("device", cpu_and_cuda()) - def test_get_frames_by_pts_in_range_fails(self, device): - decoder = VideoDecoder(NASA_VIDEO.path, device=device) + @pytest.mark.parametrize("seek_mode", ("exact", "approximate")) + def test_get_frames_by_pts_in_range_fails(self, device, seek_mode): + decoder = VideoDecoder(NASA_VIDEO.path, device=device, seek_mode=seek_mode) with pytest.raises(ValueError, match="Invalid start seconds"): frame = decoder.get_frames_played_in_range(100.0, 1.0) # noqa diff --git a/test/decoders/test_video_decoder_ops.py b/test/decoders/test_video_decoder_ops.py index 9baf6a39..9b41126f 100644 --- a/test/decoders/test_video_decoder_ops.py +++ b/test/decoders/test_video_decoder_ops.py @@ -32,7 +32,6 @@ get_frames_in_range, get_json_metadata, get_next_frame, - scan_all_streams_to_update_metadata, seek_to_pts, ) @@ -84,7 +83,6 @@ def test_seek_and_next(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_seek_to_negative_pts(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) - scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder, device=device) frame0, _, _ = get_next_frame(decoder) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) @@ -121,7 +119,6 @@ def test_get_frame_at_pts(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_at_index(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) - scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder, device=device) frame0, _, _ = get_frame_at_index(decoder, stream_index=3, frame_index=0) reference_frame0 = NASA_VIDEO.get_frame_data_by_index(0) @@ -136,7 +133,6 @@ def test_get_frame_at_index(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_with_info_at_index(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) - scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder, device=device) frame6, pts, duration = get_frame_at_index( decoder, stream_index=3, frame_index=180 @@ -151,7 +147,6 @@ def test_get_frame_with_info_at_index(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frames_at_indices(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) - scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder, device=device) frames0and180, *_ = get_frames_at_indices( decoder, stream_index=3, frame_indices=[0, 180] @@ -167,7 +162,6 @@ def test_get_frames_at_indices(self, device): def test_get_frames_at_indices_unsorted_indices(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) _add_video_stream(decoder, device=device) - scan_all_streams_to_update_metadata(decoder) stream_index = 3 frame_indices = [2, 0, 1, 0, 2] @@ -199,7 +193,6 @@ def test_get_frames_at_indices_unsorted_indices(self, device): def test_get_frames_by_pts(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) _add_video_stream(decoder, device=device) - scan_all_streams_to_update_metadata(decoder) stream_index = 3 # Note: 13.01 should give the last video frame for the NASA video @@ -233,7 +226,6 @@ def test_pts_apis_against_index_ref(self, device): # APIs exactly where those frames are supposed to start. We assert that # we get the expected frame. decoder = create_from_file(str(NASA_VIDEO.path)) - scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder, device=device) metadata = get_json_metadata(decoder) @@ -290,7 +282,6 @@ def test_pts_apis_against_index_ref(self, device): @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frames_in_range(self, device): decoder = create_from_file(str(NASA_VIDEO.path)) - scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder, device=device) # ensure that the degenerate case of a range of size 1 works @@ -452,7 +443,6 @@ def test_video_get_json_metadata(self): def test_video_get_json_metadata_with_stream(self): decoder = create_from_file(str(NASA_VIDEO.path)) - scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) @@ -480,7 +470,6 @@ def test_get_ffmpeg_version(self): def test_frame_pts_equality(self): decoder = create_from_file(str(NASA_VIDEO.path)) - scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder) # Note that for all of these tests, we store the return value of @@ -526,7 +515,6 @@ def test_color_conversion_library_with_scaling( self, input_video, width_scaling_factor, height_scaling_factor ): decoder = create_from_file(str(input_video.path)) - scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) @@ -570,7 +558,6 @@ def test_color_conversion_library_with_dimension_order( color_conversion_library=color_conversion_library, dimension_order=dimension_order, ) - scan_all_streams_to_update_metadata(decoder) frame0_ref = NASA_VIDEO.get_frame_data_by_index(0) if dimension_order == "NHWC": @@ -651,7 +638,6 @@ def test_color_conversion_library_with_generated_videos( subprocess.check_call(command) decoder = create_from_file(str(video_path)) - scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder) metadata = get_json_metadata(decoder) metadata_dict = json.loads(metadata) @@ -686,7 +672,6 @@ def test_color_conversion_library_with_generated_videos( @needs_cuda def test_cuda_decoder(self): decoder = create_from_file(str(NASA_VIDEO.path)) - scan_all_streams_to_update_metadata(decoder) add_video_stream(decoder, device="cuda") frame0, pts, duration = get_next_frame(decoder) assert frame0.device.type == "cuda" diff --git a/test/samplers/test_samplers.py b/test/samplers/test_samplers.py index d5c7eb44..7cd0a2da 100644 --- a/test/samplers/test_samplers.py +++ b/test/samplers/test_samplers.py @@ -590,22 +590,15 @@ def restore_metadata(): decoder.metadata = original_metadata with restore_metadata(): - decoder.metadata.begin_stream_seconds = None + decoder.metadata.end_stream_seconds_from_content = None + decoder.metadata.duration_seconds_from_header = None with pytest.raises( - ValueError, match="Could not infer stream end and start from video metadata" + ValueError, match="Could not infer stream end from video metadata" ): sampler(decoder) with restore_metadata(): - decoder.metadata.end_stream_seconds = None - with pytest.raises( - ValueError, match="Could not infer stream end and start from video metadata" - ): - sampler(decoder) - - with restore_metadata(): - decoder.metadata.begin_stream_seconds = None - decoder.metadata.end_stream_seconds = None + decoder.metadata.end_stream_seconds_from_content = None decoder.metadata.average_fps_from_header = None with pytest.raises(ValueError, match="Could not infer average fps"): sampler(decoder)