From 943dc6efaf14140958bdf73eba76ddd5b221ca61 Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Fri, 13 Dec 2024 11:02:15 -0500 Subject: [PATCH] Lazily init filtergraph so it can respect raw decoded resolution (#432) --- .../decoders/_core/VideoDecoder.cpp | 86 +++++++++---------- src/torchcodec/decoders/_core/VideoDecoder.h | 11 +-- 2 files changed, 49 insertions(+), 48 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index ac3b5a6b..ae951b91 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -101,7 +101,7 @@ std::vector splitStringWithDelimiters( return result; } -VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibraryForWidth( +VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibrary( int width) { VideoDecoder::ColorConversionLibrary library = VideoDecoder::ColorConversionLibrary::SWSCALE; @@ -121,7 +121,7 @@ VideoDecoder::ColorConversionLibrary getDefaultColorConversionLibraryForWidth( // or 4D. // Calling permute() is guaranteed to return a view as per the docs: // https://pytorch.org/docs/stable/generated/torch.permute.html -torch::Tensor VideoDecoder::MaybePermuteHWC2CHW( +torch::Tensor VideoDecoder::maybePermuteHWC2CHW( int streamIndex, torch::Tensor& hwcTensor) { if (streams_[streamIndex].options.dimensionOrder == "NHWC") { @@ -299,31 +299,32 @@ std::unique_ptr VideoDecoder::createFromBuffer( return decoder; } -void VideoDecoder::initializeFilterGraphForStream( - int streamIndex, - const VideoStreamDecoderOptions& options) { - FilterState& filterState = streams_[streamIndex].filterState; +void VideoDecoder::initializeFilterGraph( + StreamInfo& streamInfo, + int expectedOutputHeight, + int expectedOutputWidth) { + FilterState& filterState = streamInfo.filterState; if (filterState.filterGraph) { return; } filterState.filterGraph.reset(avfilter_graph_alloc()); TORCH_CHECK(filterState.filterGraph.get() != nullptr); - if (options.ffmpegThreadCount.has_value()) { - filterState.filterGraph->nb_threads = options.ffmpegThreadCount.value(); + if (streamInfo.options.ffmpegThreadCount.has_value()) { + filterState.filterGraph->nb_threads = + streamInfo.options.ffmpegThreadCount.value(); } const AVFilter* buffersrc = avfilter_get_by_name("buffer"); const AVFilter* buffersink = avfilter_get_by_name("buffersink"); - const StreamInfo& activeStream = streams_[streamIndex]; - AVCodecContext* codecContext = activeStream.codecContext.get(); + AVCodecContext* codecContext = streamInfo.codecContext.get(); std::stringstream filterArgs; filterArgs << "video_size=" << codecContext->width << "x" << codecContext->height; filterArgs << ":pix_fmt=" << codecContext->pix_fmt; - filterArgs << ":time_base=" << activeStream.stream->time_base.num << "/" - << activeStream.stream->time_base.den; + filterArgs << ":time_base=" << streamInfo.stream->time_base.num << "/" + << streamInfo.stream->time_base.den; filterArgs << ":pixel_aspect=" << codecContext->sample_aspect_ratio.num << "/" << codecContext->sample_aspect_ratio.den; @@ -378,10 +379,8 @@ void VideoDecoder::initializeFilterGraphForStream( inputs->pad_idx = 0; inputs->next = nullptr; - auto frameDims = getHeightAndWidthFromOptionsOrMetadata( - options, containerMetadata_.streams[streamIndex]); std::stringstream description; - description << "scale=" << frameDims.width << ":" << frameDims.height; + description << "scale=" << expectedOutputWidth << ":" << expectedOutputHeight; description << ":sws_flags=bilinear"; AVFilterInOut* outputsTmp = outputs.release(); @@ -469,26 +468,16 @@ void VideoDecoder::addVideoStreamDecoder( streamInfo.options = options; int width = options.width.value_or(codecContext->width); - // Use swscale for color conversion by default because it is faster. - VideoDecoder::ColorConversionLibrary defaultColorConversionLibrary = - getDefaultColorConversionLibraryForWidth(width); - // If the user specifies the color conversion library (example in - // benchmarks), we use that instead. - auto colorConversionLibrary = - options.colorConversionLibrary.value_or(defaultColorConversionLibrary); - - if (colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { - initializeFilterGraphForStream(streamNumber, options); - streamInfo.colorConversionLibrary = ColorConversionLibrary::FILTERGRAPH; - } else if (colorConversionLibrary == ColorConversionLibrary::SWSCALE) { - streamInfo.colorConversionLibrary = ColorConversionLibrary::SWSCALE; - } else { - throw std::invalid_argument( - "Invalid colorConversionLibrary=" + - std::to_string(static_cast(colorConversionLibrary)) + - ". colorConversionLibrary must be either " - "filtergraph or swscale."); - } + // By default, we want to use swscale for color conversion because it is + // faster. However, it has width requirements, so we may need to fall back + // to filtergraph. We also need to respect what was requested from the + // options; we respect the options unconditionally, so it's possible for + // swscale's width requirements to be violated. We don't expose the ability to + // choose color conversion library publicly; we only use this ability + // internally. + auto defaultLibrary = getDefaultColorConversionLibrary(width); + streamInfo.colorConversionLibrary = + options.colorConversionLibrary.value_or(defaultLibrary); } void VideoDecoder::updateMetadataWithCodecContext( @@ -938,6 +927,17 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( } else if ( streamInfo.colorConversionLibrary == ColorConversionLibrary::FILTERGRAPH) { + // Note that is a lazy init; we initialize filtergraph the first time + // we have a raw decoded frame. We do this lazily because up until this + // point, we really don't know what the resolution of the frames are + // without modification. In theory, we should be able to get that from the + // stream metadata, but in practice, we have encountered videos where the + // stream metadata had a different resolution from the actual resolution + // of the raw decoded frames. + if (!streamInfo.filterState.filterGraph) { + initializeFilterGraph( + streamInfo, expectedOutputHeight, expectedOutputWidth); + } outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame); // Similarly to above, if this check fails it means the frame wasn't @@ -952,6 +952,7 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( expectedOutputWidth, "x3, got ", shape); + if (preAllocatedOutputTensor.has_value()) { // We have already validated that preAllocatedOutputTensor and // outputTensor have the same shape. @@ -965,7 +966,6 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( "Invalid color conversion library: " + std::to_string(static_cast(streamInfo.colorConversionLibrary))); } - } else if (output.streamType == AVMEDIA_TYPE_AUDIO) { // TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement // audio decoding. @@ -1007,7 +1007,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFramePlayedAtTimestampNoDemux( }); // Convert the frame to tensor. auto output = convertAVFrameToDecodedOutput(rawOutput); - output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame); + output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame); return output; } @@ -1045,7 +1045,7 @@ VideoDecoder::DecodedOutput VideoDecoder::getFrameAtIndex( int streamIndex, int64_t frameIndex) { auto output = getFrameAtIndexInternal(streamIndex, frameIndex); - output.frame = MaybePermuteHWC2CHW(streamIndex, output.frame); + output.frame = maybePermuteHWC2CHW(streamIndex, output.frame); return output; } @@ -1118,7 +1118,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesAtIndices( } previousIndexInVideo = indexInVideo; } - output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames); + output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); return output; } @@ -1193,7 +1193,7 @@ VideoDecoder::BatchDecodedOutput VideoDecoder::getFramesInRange( output.ptsSeconds[f] = singleOut.ptsSeconds; output.durationSeconds[f] = singleOut.durationSeconds; } - output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames); + output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); return output; } @@ -1246,7 +1246,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( // need this special case below. if (startSeconds == stopSeconds) { BatchDecodedOutput output(0, options, streamMetadata); - output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames); + output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); return output; } @@ -1287,7 +1287,7 @@ VideoDecoder::getFramesPlayedByTimestampInRange( output.ptsSeconds[f] = singleOut.ptsSeconds; output.durationSeconds[f] = singleOut.durationSeconds; } - output.frames = MaybePermuteHWC2CHW(streamIndex, output.frames); + output.frames = maybePermuteHWC2CHW(streamIndex, output.frames); return output; } @@ -1303,7 +1303,7 @@ VideoDecoder::RawDecodedOutput VideoDecoder::getNextRawDecodedOutputNoDemux() { VideoDecoder::DecodedOutput VideoDecoder::getNextFrameNoDemux() { auto output = getNextFrameOutputNoDemuxInternal(); - output.frame = MaybePermuteHWC2CHW(output.streamIndex, output.frame); + output.frame = maybePermuteHWC2CHW(output.streamIndex, output.frame); return output; } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index 8ae7cc17..1da745b3 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -157,7 +157,7 @@ class VideoDecoder { int streamIndex, const AudioStreamDecoderOptions& options = AudioStreamDecoderOptions()); - torch::Tensor MaybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor); + torch::Tensor maybePermuteHWC2CHW(int streamIndex, torch::Tensor& hwcTensor); // ---- SINGLE FRAME SEEK AND DECODING API ---- // Places the cursor at the first frame on or after the position in seconds. @@ -376,9 +376,10 @@ class VideoDecoder { void validateFrameIndex(const StreamInfo& stream, int64_t frameIndex); // Creates and initializes a filter graph for a stream. The filter graph can // do rescaling and color conversion. - void initializeFilterGraphForStream( - int streamIndex, - const VideoStreamDecoderOptions& options); + void initializeFilterGraph( + StreamInfo& streamInfo, + int expectedOutputHeight, + int expectedOutputWidth); void maybeSeekToBeforeDesiredPts(); RawDecodedOutput getDecodedOutputWithFilter( std::function); @@ -436,7 +437,7 @@ class VideoDecoder { // We always allocate [N]HWC tensors. The low-level decoding functions all // assume HWC tensors, since this is what FFmpeg natively handles. It's up to // the high-level decoding entry-points to permute that back to CHW, by calling -// MaybePermuteHWC2CHW(). +// maybePermuteHWC2CHW(). // // Also, importantly, the way we figure out the the height and width of the // output frame tensor varies, and depends on the decoding entry-point. In