diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 07e25739..c77bdfd8 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -1179,6 +1179,52 @@ double VideoDecoder::getMaxSeconds(const StreamMetadata& streamMetadata) { } } +int64_t VideoDecoder::secondsToIndexLowerBound( + double seconds, + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata) { + switch (seekMode_) { + case SeekMode::exact: { + auto frame = std::lower_bound( + streamInfo.allFrames.begin(), + streamInfo.allFrames.end(), + seconds, + [&streamInfo](const FrameInfo& info, double start) { + return ptsToSeconds(info.nextPts, streamInfo.timeBase) <= start; + }); + + return frame - streamInfo.allFrames.begin(); + } + case SeekMode::approximate: + return std::floor(seconds * streamMetadata.averageFps.value()); + default: + throw std::runtime_error("Unknown SeekMode"); + } +} + +int64_t VideoDecoder::secondsToIndexUpperBound( + double seconds, + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata) { + switch (seekMode_) { + case SeekMode::exact: { + auto frame = std::upper_bound( + streamInfo.allFrames.begin(), + streamInfo.allFrames.end(), + seconds, + [&streamInfo](double stop, const FrameInfo& info) { + return stop <= ptsToSeconds(info.pts, streamInfo.timeBase); + }); + + return frame - streamInfo.allFrames.begin(); + } + case SeekMode::approximate: + return std::ceil(seconds * streamMetadata.averageFps.value()); + default: + throw std::runtime_error("Unknown SeekMode"); + } +} + VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal( int streamIndex, int64_t frameIndex, @@ -1372,129 +1418,48 @@ VideoDecoder::getFramesPlayedByTimestampInRange( return output; } - if (seekMode_ == SeekMode::exact) { - double minSeconds = streamMetadata.minPtsSecondsFromScan.value(); - double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value(); - TORCH_CHECK( - startSeconds >= minSeconds && startSeconds < maxSeconds, - "Start seconds is " + std::to_string(startSeconds) + - "; must be in range [" + std::to_string(minSeconds) + ", " + - std::to_string(maxSeconds) + ")."); - TORCH_CHECK( - stopSeconds <= maxSeconds, - "Stop seconds (" + std::to_string(stopSeconds) + - "; must be less than or equal to " + std::to_string(maxSeconds) + - ")."); - - // Note that we look at nextPts for a frame, and not its pts or duration. - // Our abstract player displays frames starting at the pts for that frame - // until the pts for the next frame. There are two consequences: - // - // 1. We ignore the duration for a frame. A frame is played until the - // next frame replaces it. This model is robust to durations being 0 or - // incorrect; our source of truth is the pts for frames. If duration is - // accurate, the nextPts for a frame would be equivalent to pts + - // duration. - // 2. In order to establish if the start of an interval maps to a - // particular frame, we need to figure out if it is ordered after the - // frame's pts, but before the next frames's pts. - - auto startFrame = std::lower_bound( - stream.allFrames.begin(), - stream.allFrames.end(), - startSeconds, - [&stream](const FrameInfo& info, double start) { - return ptsToSeconds(info.nextPts, stream.timeBase) <= start; - }); - - auto stopFrame = std::upper_bound( - stream.allFrames.begin(), - stream.allFrames.end(), - stopSeconds, - [&stream](double stop, const FrameInfo& info) { - return stop <= ptsToSeconds(info.pts, stream.timeBase); - }); - - int64_t startFrameIndex = startFrame - stream.allFrames.begin(); - int64_t stopFrameIndex = stopFrame - stream.allFrames.begin(); - int64_t numFrames = stopFrameIndex - startFrameIndex; - BatchDecodedOutput output(numFrames, options, streamMetadata); - for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { - DecodedOutput singleOut = - getFrameAtIndexInternal(streamIndex, i, output.frames[f]); - output.ptsSeconds[f] = singleOut.ptsSeconds; - output.durationSeconds[f] = singleOut.durationSeconds; - } - output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); - - return output; - - } else if (seekMode_ == SeekMode::approximate) { - double minSeconds = 0; - double maxSeconds = streamMetadata.durationSeconds.value(); - TORCH_CHECK( - startSeconds >= minSeconds && startSeconds < maxSeconds, - "Start seconds is " + std::to_string(startSeconds) + - "; must be in range [" + std::to_string(minSeconds) + ", " + - std::to_string(maxSeconds) + ")."); - TORCH_CHECK( - stopSeconds <= maxSeconds, - "Stop seconds (" + std::to_string(stopSeconds) + - "; must be less than or equal to " + std::to_string(maxSeconds) + - ")."); - - // Because we can only discover when to stop by doing the actual decoding, - // we can't pre-allocate the correct dimensions for our BatchDecodedOutput; - // we don't yet know N, the number of frames. So we have to store all of the - // decoded frames in a vector, and construct the final data tensor after. - - // TODO: Figure out if there is a better of doing this. That is, we store - // everything in vectors and then call torch::stack and torch::tensor.clone - // after the fact. We can't preallocate the final tensor because we don't - // know how many frames we're going to decode up front. - + double minSeconds = getMinSeconds(streamMetadata); + double maxSeconds = getMaxSeconds(streamMetadata); + TORCH_CHECK( + startSeconds >= minSeconds && startSeconds < maxSeconds, + "Start seconds is " + std::to_string(startSeconds) + + "; must be in range [" + std::to_string(minSeconds) + ", " + + std::to_string(maxSeconds) + ")."); + TORCH_CHECK( + stopSeconds <= maxSeconds, + "Stop seconds (" + std::to_string(stopSeconds) + + "; must be less than or equal to " + std::to_string(maxSeconds) + + ")."); + + // Note that we look at nextPts for a frame, and not its pts or duration. + // Our abstract player displays frames starting at the pts for that frame + // until the pts for the next frame. There are two consequences: + // + // 1. We ignore the duration for a frame. A frame is played until the + // next frame replaces it. This model is robust to durations being 0 or + // incorrect; our source of truth is the pts for frames. If duration is + // accurate, the nextPts for a frame would be equivalent to pts + + // duration. + // 2. In order to establish if the start of an interval maps to a + // particular frame, we need to figure out if it is ordered after the + // frame's pts, but before the next frames's pts. + + int64_t startFrameIndex = + secondsToIndexLowerBound(startSeconds, stream, streamMetadata); + int64_t stopFrameIndex = + secondsToIndexUpperBound(stopSeconds, stream, streamMetadata); + int64_t numFrames = stopFrameIndex - startFrameIndex; + + BatchDecodedOutput output(numFrames, options, streamMetadata); + for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) { DecodedOutput singleOut = - getFramePlayedAtTimestampNoDemuxInternal(startSeconds); - - std::vector frames; - std::vector ptsSeconds; - std::vector durationSeconds; - - // Note that we only know we've decoded all frames in the range when we have - // decoded the first frame outside of the range. That is, we have to decode - // one frame past where we want to stop, and conclude from its pts that all - // of the prior frames comprises our range. That means we decode one extra - // frame; we don't return it, but we decode it. - // - // This algorithm works fine except when stopSeconds is the duration of the - // video. In that case, we're going to hit the end-of-file exception. - // - // We could avoid decoding an extra frame, and the end-of-file exception, by - // using the currently decoded frame's duration to know that the next frame - // is outside of our range. This would be more efficient. However, up until - // now we have avoided relying on a frame's duration to determine if a frame - // is played during a time window. So there is a potential TODO here where - // we relax that principle and just do the math to avoid the extra decode. - bool eof = false; - while (singleOut.ptsSeconds < stopSeconds && !eof) { - frames.push_back(singleOut.frame); - ptsSeconds.push_back(singleOut.ptsSeconds); - durationSeconds.push_back(singleOut.durationSeconds); - - try { - singleOut = getNextFrameNoDemuxInternal(); - } catch (EndOfFileException e) { - eof = true; - } - } - - BatchDecodedOutput output(frames, ptsSeconds, durationSeconds); - output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); - return output; - - } else { - throw std::runtime_error("Unknown SeekMode"); + getFrameAtIndexInternal(streamIndex, i, output.frames[f]); + output.ptsSeconds[f] = singleOut.ptsSeconds; + output.durationSeconds[f] = singleOut.durationSeconds; } + output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); + + return output; } VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 68bb4305..187e199c 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -401,6 +401,16 @@ class VideoDecoder { double getMinSeconds(const StreamMetadata& streamMetadata); double getMaxSeconds(const StreamMetadata& streamMetadata); + int64_t secondsToIndexLowerBound( + double seconds, + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata); + + int64_t secondsToIndexUpperBound( + double seconds, + const StreamInfo& streamInfo, + const StreamMetadata& streamMetadata); + void createSwsContext( StreamInfo& streamInfo, const DecodedFrameContext& frameContext,