Skip to content

Commit

Permalink
Consolidated logic for timestamp range.
Browse files Browse the repository at this point in the history
  • Loading branch information
scotts committed Dec 21, 2024
1 parent 911a3bc commit 7267b5a
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 121 deletions.
207 changes: 86 additions & 121 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1179,6 +1179,52 @@ double VideoDecoder::getMaxSeconds(const StreamMetadata& streamMetadata) {
}
}

int64_t VideoDecoder::secondsToIndexLowerBound(
double seconds,
const StreamInfo& streamInfo,
const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::exact: {
auto frame = std::lower_bound(
streamInfo.allFrames.begin(),
streamInfo.allFrames.end(),
seconds,
[&streamInfo](const FrameInfo& info, double start) {
return ptsToSeconds(info.nextPts, streamInfo.timeBase) <= start;
});

return frame - streamInfo.allFrames.begin();
}
case SeekMode::approximate:
return std::floor(seconds * streamMetadata.averageFps.value());
default:
throw std::runtime_error("Unknown SeekMode");
}
}

int64_t VideoDecoder::secondsToIndexUpperBound(
double seconds,
const StreamInfo& streamInfo,
const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::exact: {
auto frame = std::upper_bound(
streamInfo.allFrames.begin(),
streamInfo.allFrames.end(),
seconds,
[&streamInfo](double stop, const FrameInfo& info) {
return stop <= ptsToSeconds(info.pts, streamInfo.timeBase);
});

return frame - streamInfo.allFrames.begin();
}
case SeekMode::approximate:
return std::ceil(seconds * streamMetadata.averageFps.value());
default:
throw std::runtime_error("Unknown SeekMode");
}
}

VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal(
int streamIndex,
int64_t frameIndex,
Expand Down Expand Up @@ -1372,129 +1418,48 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
return output;
}

if (seekMode_ == SeekMode::exact) {
double minSeconds = streamMetadata.minPtsSecondsFromScan.value();
double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value();
TORCH_CHECK(
startSeconds >= minSeconds && startSeconds < maxSeconds,
"Start seconds is " + std::to_string(startSeconds) +
"; must be in range [" + std::to_string(minSeconds) + ", " +
std::to_string(maxSeconds) + ").");
TORCH_CHECK(
stopSeconds <= maxSeconds,
"Stop seconds (" + std::to_string(stopSeconds) +
"; must be less than or equal to " + std::to_string(maxSeconds) +
").");

// Note that we look at nextPts for a frame, and not its pts or duration.
// Our abstract player displays frames starting at the pts for that frame
// until the pts for the next frame. There are two consequences:
//
// 1. We ignore the duration for a frame. A frame is played until the
// next frame replaces it. This model is robust to durations being 0 or
// incorrect; our source of truth is the pts for frames. If duration is
// accurate, the nextPts for a frame would be equivalent to pts +
// duration.
// 2. In order to establish if the start of an interval maps to a
// particular frame, we need to figure out if it is ordered after the
// frame's pts, but before the next frames's pts.

auto startFrame = std::lower_bound(
stream.allFrames.begin(),
stream.allFrames.end(),
startSeconds,
[&stream](const FrameInfo& info, double start) {
return ptsToSeconds(info.nextPts, stream.timeBase) <= start;
});

auto stopFrame = std::upper_bound(
stream.allFrames.begin(),
stream.allFrames.end(),
stopSeconds,
[&stream](double stop, const FrameInfo& info) {
return stop <= ptsToSeconds(info.pts, stream.timeBase);
});

int64_t startFrameIndex = startFrame - stream.allFrames.begin();
int64_t stopFrameIndex = stopFrame - stream.allFrames.begin();
int64_t numFrames = stopFrameIndex - startFrameIndex;
BatchDecodedOutput output(numFrames, options, streamMetadata);
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
DecodedOutput singleOut =
getFrameAtIndexInternal(streamIndex, i, output.frames[f]);
output.ptsSeconds[f] = singleOut.ptsSeconds;
output.durationSeconds[f] = singleOut.durationSeconds;
}
output.frames = maybePermuteHWC2CHW(streamIndex, output.frames);

return output;

} else if (seekMode_ == SeekMode::approximate) {
double minSeconds = 0;
double maxSeconds = streamMetadata.durationSeconds.value();
TORCH_CHECK(
startSeconds >= minSeconds && startSeconds < maxSeconds,
"Start seconds is " + std::to_string(startSeconds) +
"; must be in range [" + std::to_string(minSeconds) + ", " +
std::to_string(maxSeconds) + ").");
TORCH_CHECK(
stopSeconds <= maxSeconds,
"Stop seconds (" + std::to_string(stopSeconds) +
"; must be less than or equal to " + std::to_string(maxSeconds) +
").");

// Because we can only discover when to stop by doing the actual decoding,
// we can't pre-allocate the correct dimensions for our BatchDecodedOutput;
// we don't yet know N, the number of frames. So we have to store all of the
// decoded frames in a vector, and construct the final data tensor after.

// TODO: Figure out if there is a better of doing this. That is, we store
// everything in vectors and then call torch::stack and torch::tensor.clone
// after the fact. We can't preallocate the final tensor because we don't
// know how many frames we're going to decode up front.

double minSeconds = getMinSeconds(streamMetadata);
double maxSeconds = getMaxSeconds(streamMetadata);
TORCH_CHECK(
startSeconds >= minSeconds && startSeconds < maxSeconds,
"Start seconds is " + std::to_string(startSeconds) +
"; must be in range [" + std::to_string(minSeconds) + ", " +
std::to_string(maxSeconds) + ").");
TORCH_CHECK(
stopSeconds <= maxSeconds,
"Stop seconds (" + std::to_string(stopSeconds) +
"; must be less than or equal to " + std::to_string(maxSeconds) +
").");

// Note that we look at nextPts for a frame, and not its pts or duration.
// Our abstract player displays frames starting at the pts for that frame
// until the pts for the next frame. There are two consequences:
//
// 1. We ignore the duration for a frame. A frame is played until the
// next frame replaces it. This model is robust to durations being 0 or
// incorrect; our source of truth is the pts for frames. If duration is
// accurate, the nextPts for a frame would be equivalent to pts +
// duration.
// 2. In order to establish if the start of an interval maps to a
// particular frame, we need to figure out if it is ordered after the
// frame's pts, but before the next frames's pts.

int64_t startFrameIndex =
secondsToIndexLowerBound(startSeconds, stream, streamMetadata);
int64_t stopFrameIndex =
secondsToIndexUpperBound(stopSeconds, stream, streamMetadata);
int64_t numFrames = stopFrameIndex - startFrameIndex;

BatchDecodedOutput output(numFrames, options, streamMetadata);
for (int64_t i = startFrameIndex, f = 0; i < stopFrameIndex; ++i, ++f) {
DecodedOutput singleOut =
getFramePlayedAtTimestampNoDemuxInternal(startSeconds);

std::vector<torch::Tensor> frames;
std::vector<double> ptsSeconds;
std::vector<double> durationSeconds;

// Note that we only know we've decoded all frames in the range when we have
// decoded the first frame outside of the range. That is, we have to decode
// one frame past where we want to stop, and conclude from its pts that all
// of the prior frames comprises our range. That means we decode one extra
// frame; we don't return it, but we decode it.
//
// This algorithm works fine except when stopSeconds is the duration of the
// video. In that case, we're going to hit the end-of-file exception.
//
// We could avoid decoding an extra frame, and the end-of-file exception, by
// using the currently decoded frame's duration to know that the next frame
// is outside of our range. This would be more efficient. However, up until
// now we have avoided relying on a frame's duration to determine if a frame
// is played during a time window. So there is a potential TODO here where
// we relax that principle and just do the math to avoid the extra decode.
bool eof = false;
while (singleOut.ptsSeconds < stopSeconds && !eof) {
frames.push_back(singleOut.frame);
ptsSeconds.push_back(singleOut.ptsSeconds);
durationSeconds.push_back(singleOut.durationSeconds);

try {
singleOut = getNextFrameNoDemuxInternal();
} catch (EndOfFileException e) {
eof = true;
}
}

BatchDecodedOutput output(frames, ptsSeconds, durationSeconds);
output.frames = maybePermuteHWC2CHW(streamIndex, output.frames);
return output;

} else {
throw std::runtime_error("Unknown SeekMode");
getFrameAtIndexInternal(streamIndex, i, output.frames[f]);
output.ptsSeconds[f] = singleOut.ptsSeconds;
output.durationSeconds[f] = singleOut.durationSeconds;
}
output.frames = maybePermuteHWC2CHW(streamIndex, output.frames);

return output;
}

VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {
Expand Down
10 changes: 10 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,16 @@ class VideoDecoder {
double getMinSeconds(const StreamMetadata& streamMetadata);
double getMaxSeconds(const StreamMetadata& streamMetadata);

int64_t secondsToIndexLowerBound(
double seconds,
const StreamInfo& streamInfo,
const StreamMetadata& streamMetadata);

int64_t secondsToIndexUpperBound(
double seconds,
const StreamInfo& streamInfo,
const StreamMetadata& streamMetadata);

void createSwsContext(
StreamInfo& streamInfo,
const DecodedFrameContext& frameContext,
Expand Down

0 comments on commit 7267b5a

Please sign in to comment.