Skip to content

Commit

Permalink
Consolidated logic for timestamp batch. Big perf win.
Browse files Browse the repository at this point in the history
  • Loading branch information
scotts committed Dec 20, 2024
1 parent 802b881 commit 911a3bc
Show file tree
Hide file tree
Showing 3 changed files with 50 additions and 50 deletions.
2 changes: 1 addition & 1 deletion benchmarks/decoders/benchmark_decoders_library.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def decode_and_resize(self, video_file, pts_list, height, width, device):


class TorchCodecPublicNonBatch(AbstractDecoder):
def __init__(self, num_ffmpeg_threads=None, device="cpu", seek_mode="exact"):
def __init__(self, num_ffmpeg_threads=None, device="cpu", seek_mode="approximate"):
self._num_ffmpeg_threads = num_ffmpeg_threads
self._device = device
self._seek_mode = seek_mode
Expand Down
95 changes: 46 additions & 49 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1157,6 +1157,28 @@ int64_t VideoDecoder::getFramesSize(
}
}

double VideoDecoder::getMinSeconds(const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::exact:
return streamMetadata.minPtsSecondsFromScan.value();
case SeekMode::approximate:
return 0;
default:
throw std::runtime_error("Unknown SeekMode");
}
}

double VideoDecoder::getMaxSeconds(const StreamMetadata& streamMetadata) {
switch (seekMode_) {
case SeekMode::exact:
return streamMetadata.maxPtsSecondsFromScan.value();
case SeekMode::approximate:
return streamMetadata.durationSeconds.value();
default:
throw std::runtime_error("Unknown SeekMode");
}
}

VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndexInternal(
int streamIndex,
int64_t frameIndex,
Expand Down Expand Up @@ -1238,68 +1260,43 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesPlayedByTimestamps(
const auto& streamMetadata = containerMetadata_.streams[streamIndex];
const auto& stream = streams_[streamIndex];

if (seekMode_ == SeekMode::exact) {
double minSeconds = streamMetadata.minPtsSecondsFromScan.value();
double maxSeconds = streamMetadata.maxPtsSecondsFromScan.value();
double minSeconds = getMinSeconds(streamMetadata);
double maxSeconds = getMaxSeconds(streamMetadata);

// The frame played at timestamp t and the one played at timestamp `t +
// eps` are probably the same frame, with the same index. The easiest way to
// avoid decoding that unique frame twice is to convert the input timestamps
// to indices, and leverage the de-duplication logic of getFramesAtIndices.
// The frame played at timestamp t and the one played at timestamp `t +
// eps` are probably the same frame, with the same index. The easiest way to
// avoid decoding that unique frame twice is to convert the input timestamps
// to indices, and leverage the de-duplication logic of getFramesAtIndices.

std::vector<int64_t> frameIndices(timestamps.size());
for (auto i = 0; i < timestamps.size(); ++i) {
auto frameSeconds = timestamps[i];
TORCH_CHECK(
frameSeconds >= minSeconds && frameSeconds < maxSeconds,
"frame pts is " + std::to_string(frameSeconds) +
"; must be in range [" + std::to_string(minSeconds) + ", " +
std::to_string(maxSeconds) + ").");
std::vector<int64_t> frameIndices(timestamps.size());
for (auto i = 0; i < timestamps.size(); ++i) {
auto frameSeconds = timestamps[i];
TORCH_CHECK(
frameSeconds >= minSeconds && frameSeconds < maxSeconds,
"frame pts is " + std::to_string(frameSeconds) +
"; must be in range [" + std::to_string(minSeconds) + ", " +
std::to_string(maxSeconds) + ").");

int64_t frameIndex = -1;
if (seekMode_ == SeekMode::exact) {
auto it = std::lower_bound(
stream.allFrames.begin(),
stream.allFrames.end(),
frameSeconds,
[&stream](const FrameInfo& info, double frameSeconds) {
return ptsToSeconds(info.nextPts, stream.timeBase) <= frameSeconds;
});
int64_t frameIndex = it - stream.allFrames.begin();
frameIndices[i] = frameIndex;
}

return getFramesAtIndices(streamIndex, frameIndices);

} else if (seekMode_ == SeekMode::approximate) {
double minSeconds = 0;
double maxSeconds = streamMetadata.durationSeconds.value();

// TODO: Figure out if we can be smarter than just iterating over the
// timestamps one-by-one.

const auto& streamMetadata = containerMetadata_.streams[streamIndex];
const auto& options = stream.options;
BatchDecodedOutput output(timestamps.size(), options, streamMetadata);

for (auto i = 0; i < timestamps.size(); ++i) {
auto frameSeconds = timestamps[i];
TORCH_CHECK(
frameSeconds >= minSeconds && frameSeconds < maxSeconds,
"frame pts is " + std::to_string(frameSeconds) +
"; must be in range [" + std::to_string(minSeconds) + ", " +
std::to_string(maxSeconds) + ").");

DecodedOutput singleOut = getFramePlayedAtTimestampNoDemuxInternal(
frameSeconds, output.frames[i]);
output.ptsSeconds[i] = singleOut.ptsSeconds;
output.durationSeconds[i] = singleOut.durationSeconds;
frameIndex = it - stream.allFrames.begin();
} else if (seekMode_ == SeekMode::approximate) {
frameIndex = std::floor(frameSeconds * streamMetadata.averageFps.value());
} else {
throw std::runtime_error("Unknown SeekMode");
}

output.frames = maybePermuteHWC2CHW(streamIndex, output.frames);
return output;

} else {
throw std::runtime_error("Unknown SeekMode");
frameIndices[i] = frameIndex;
}

return getFramesAtIndices(streamIndex, frameIndices);
}

VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
Expand Down
3 changes: 3 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,9 @@ class VideoDecoder {
const StreamMetadata& streamMetadata,
int64_t frameIndex);

double getMinSeconds(const StreamMetadata& streamMetadata);
double getMaxSeconds(const StreamMetadata& streamMetadata);

void createSwsContext(
StreamInfo& streamInfo,
const DecodedFrameContext& frameContext,
Expand Down

0 comments on commit 911a3bc

Please sign in to comment.