Skip to content

Commit

Permalink
Lazily init filtergraph so it can respect raw decoded resolution (#432)
Browse files Browse the repository at this point in the history
  • Loading branch information
scotts authored Dec 13, 2024
1 parent 84cef50 commit 943dc6e
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 48 deletions.
86 changes: 43 additions & 43 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ std::vector<std::string> splitStringWithDelimiters(
return result;
}

VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibraryForWidth(
VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibrary(
int width) {
VideoDecoder::ColorConversionLibrary library =
VideoDecoder::ColorConversionLibrary::SWSCALE;
Expand All @@ -121,7 +121,7 @@ VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibraryForWidth(
// or 4D.
// Calling permute() is guaranteed to return a view as per the docs:
// https://pytorch.org/docs/stable/generated/torch.permute.html
torch::Tensor VideoDecoder::MaybePermuteHWC2CHW(
torch::Tensor VideoDecoder::maybePermuteHWC2CHW(
int streamIndex,
torch::Tensor& hwcTensor) {
if (streams_[streamIndex].options.dimensionOrder == "NHWC") {
Expand Down Expand Up @@ -299,31 +299,32 @@ std::unique_ptr<VideoDecoder> VideoDecoder::createFromBuffer(
return decoder;
}

void VideoDecoder::initializeFilterGraphForStream(
int streamIndex,
const VideoStreamDecoderOptions& options) {
FilterState& filterState = streams_[streamIndex].filterState;
void VideoDecoder::initializeFilterGraph(
StreamInfo& streamInfo,
int expectedOutputHeight,
int expectedOutputWidth) {
FilterState& filterState = streamInfo.filterState;
if (filterState.filterGraph) {
return;
}

filterState.filterGraph.reset(avfilter_graph_alloc());
TORCH_CHECK(filterState.filterGraph.get() != nullptr);
if (options.ffmpegThreadCount.has_value()) {
filterState.filterGraph->nb_threads = options.ffmpegThreadCount.value();
if (streamInfo.options.ffmpegThreadCount.has_value()) {
filterState.filterGraph->nb_threads =
streamInfo.options.ffmpegThreadCount.value();
}

const AVFilter* buffersrc = avfilter_get_by_name("buffer");
const AVFilter* buffersink = avfilter_get_by_name("buffersink");
const StreamInfo& activeStream = streams_[streamIndex];
AVCodecContext* codecContext = activeStream.codecContext.get();
AVCodecContext* codecContext = streamInfo.codecContext.get();

std::stringstream filterArgs;
filterArgs << "video_size=" << codecContext->width << "x"
<< codecContext->height;
filterArgs << ":pix_fmt=" << codecContext->pix_fmt;
filterArgs << ":time_base=" << activeStream.stream->time_base.num << "/"
<< activeStream.stream->time_base.den;
filterArgs << ":time_base=" << streamInfo.stream->time_base.num << "/"
<< streamInfo.stream->time_base.den;
filterArgs << ":pixel_aspect=" << codecContext->sample_aspect_ratio.num << "/"
<< codecContext->sample_aspect_ratio.den;

Expand Down Expand Up @@ -378,10 +379,8 @@ void VideoDecoder::initializeFilterGraphForStream(
inputs->pad_idx = 0;
inputs->next = nullptr;

auto frameDims = getHeightAndWidthFromOptionsOrMetadata(
options, containerMetadata_.streams[streamIndex]);
std::stringstream description;
description << "scale=" << frameDims.width << ":" << frameDims.height;
description << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight;
description << ":sws_flags=bilinear";

AVFilterInOut* outputsTmp = outputs.release();
Expand Down Expand Up @@ -469,26 +468,16 @@ void VideoDecoder::addVideoStreamDecoder(
streamInfo.options = options;
int width = options.width.value_or(codecContext->width);

// Use swscale for color conversion by default because it is faster.
VideoDecoder::ColorConversionLibrary defaultColorConversionLibrary =
getDefaultColorConversionLibraryForWidth(width);
// If the user specifies the color conversion library (example in
// benchmarks), we use that instead.
auto colorConversionLibrary =
options.colorConversionLibrary.value_or(defaultColorConversionLibrary);

if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) {
initializeFilterGraphForStream(streamNumber, options);
streamInfo.colorConversionLibrary = ColorConversionLibrary::FILTERGRAPH;
} else if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
streamInfo.colorConversionLibrary = ColorConversionLibrary::SWSCALE;
} else {
throw std::invalid_argument(
"Invalid colorConversionLibrary=" +
std::to_string(static_cast<int>(colorConversionLibrary)) +
". colorConversionLibrary must be either "
"filtergraph or swscale.");
}
// By default, we want to use swscale for color conversion because it is
// faster. However, it has width requirements, so we may need to fall back
// to filtergraph. We also need to respect what was requested from the
// options; we respect the options unconditionally, so it's possible for
// swscale's width requirements to be violated. We don't expose the ability to
// choose color conversion library publicly; we only use this ability
// internally.
auto defaultLibrary = getDefaultColorConversionLibrary(width);
streamInfo.colorConversionLibrary =
options.colorConversionLibrary.value_or(defaultLibrary);
}

void VideoDecoder::updateMetadataWithCodecContext(
Expand Down Expand Up @@ -938,6 +927,17 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
} else if (
streamInfo.colorConversionLibrary ==
ColorConversionLibrary::FILTERGRAPH) {
// Note that is a lazy init; we initialize filtergraph the first time
// we have a raw decoded frame. We do this lazily because up until this
// point, we really don't know what the resolution of the frames are
// without modification. In theory, we should be able to get that from the
// stream metadata, but in practice, we have encountered videos where the
// stream metadata had a different resolution from the actual resolution
// of the raw decoded frames.
if (!streamInfo.filterState.filterGraph) {
initializeFilterGraph(
streamInfo, expectedOutputHeight, expectedOutputWidth);
}
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);

// Similarly to above, if this check fails it means the frame wasn't
Expand All @@ -952,6 +952,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
expectedOutputWidth,
"x3, got ",
shape);

if (preAllocatedOutputTensor.has_value()) {
// We have already validated that preAllocatedOutputTensor and
// outputTensor have the same shape.
Expand All @@ -965,7 +966,6 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
"Invalid color conversion library: " +
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
}

} else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
// TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
// audio decoding.
Expand Down Expand Up @@ -1007,7 +1007,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux(
});
// Convert the frame to tensor.
auto output = convertAVFrameToDecodedOutput(rawOutput);
output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame);
output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame);
return output;
}

Expand Down Expand Up @@ -1045,7 +1045,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex(
int streamIndex,
int64_t frameIndex) {
auto output = getFrameAtIndexInternal(streamIndex, frameIndex);
output.frame = MaybePermuteHWC2CHW(streamIndex, output.frame);
output.frame = maybePermuteHWC2CHW(streamIndex, output.frame);
return output;
}

Expand Down Expand Up @@ -1118,7 +1118,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices(
}
previousIndexInVideo = indexInVideo;
}
output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames);
output.frames = maybePermuteHWC2CHW(streamIndex, output.frames);
return output;
}

Expand Down Expand Up @@ -1193,7 +1193,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange(
output.ptsSeconds[f] = singleOut.ptsSeconds;
output.durationSeconds[f] = singleOut.durationSeconds;
}
output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames);
output.frames = maybePermuteHWC2CHW(streamIndex, output.frames);
return output;
}

Expand Down Expand Up @@ -1246,7 +1246,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
// need this special case below.
if (startSeconds == stopSeconds) {
BatchDecodedOutput output(0, options, streamMetadata);
output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames);
output.frames = maybePermuteHWC2CHW(streamIndex, output.frames);
return output;
}

Expand Down Expand Up @@ -1287,7 +1287,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange(
output.ptsSeconds[f] = singleOut.ptsSeconds;
output.durationSeconds[f] = singleOut.durationSeconds;
}
output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames);
output.frames = maybePermuteHWC2CHW(streamIndex, output.frames);

return output;
}
Expand All @@ -1303,7 +1303,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() {

VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux() {
auto output = getNextFrameOutputNoDemuxInternal();
output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame);
output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame);
return output;
}

Expand Down
11 changes: 6 additions & 5 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ class VideoDecoder {
int streamIndex,
const AudioStreamDecoderOptions& options = AudioStreamDecoderOptions());

torch::Tensor MaybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor);
torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor);

// ---- SINGLE FRAME SEEK AND DECODING API ----
// Places the cursor at the first frame on or after the position in seconds.
Expand Down Expand Up @@ -376,9 +376,10 @@ class VideoDecoder {
void validateFrameIndex(const StreamInfo& stream, int64_t frameIndex);
// Creates and initializes a filter graph for a stream. The filter graph can
// do rescaling and color conversion.
void initializeFilterGraphForStream(
int streamIndex,
const VideoStreamDecoderOptions& options);
void initializeFilterGraph(
StreamInfo& streamInfo,
int expectedOutputHeight,
int expectedOutputWidth);
void maybeSeekToBeforeDesiredPts();
RawDecodedOutput getDecodedOutputWithFilter(
std::function<bool(int, AVFrame*)>);
Expand Down Expand Up @@ -436,7 +437,7 @@ class VideoDecoder {
// We always allocate [N]HWC tensors. The low-level decoding functions all
// assume HWC tensors, since this is what FFmpeg natively handles. It's up to
// the high-level decoding entry-points to permute that back to CHW, by calling
// MaybePermuteHWC2CHW().
// maybePermuteHWC2CHW().
//
// Also, importantly, the way we figure out the the height and width of the
// output frame tensor varies, and depends on the decoding entry-point. In
Expand Down

0 comments on commit 943dc6e

Please sign in to comment.