From b1719e77a8b20022fd92bf24ac58b4dfac6d4e69 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Tue, 21 Jan 2025 17:11:00 +0000 Subject: [PATCH] Remove `streamType` field from `DecodedOutput` (#457) --- .../decoders/_core/VideoDecoder.cpp | 144 +++++++++--------- src/torchcodec/decoders/_core/VideoDecoder.h | 2 - 2 files changed, 68 insertions(+), 78 deletions(-) diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 6a74f51d..f1d53308 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -871,7 +871,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput( AVFrame* frame = rawOutput.frame.get(); output.streamIndex = streamIndex; auto& streamInfo = streams_[streamIndex]; - output.streamType = streams_[streamIndex].stream->codecpar->codec_type; + TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO); output.pts = frame->pts; output.ptsSeconds = ptsToSeconds(frame->pts, formatContext_->streams[streamIndex]->time_base); @@ -932,86 +932,78 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU( } torch::Tensor outputTensor; - if (output.streamType == AVMEDIA_TYPE_VIDEO) { - // We need to compare the current frame context with our previous frame - // context. If they are different, then we need to re-create our colorspace - // conversion objects. We create our colorspace conversion objects late so - // that we don't have to depend on the unreliable metadata in the header. - // And we sometimes re-create them because it's possible for frame - // resolution to change mid-stream. Finally, we want to reuse the colorspace - // conversion objects as much as possible for performance reasons. - enum AVPixelFormat frameFormat = - static_cast(frame->format); - auto frameContext = DecodedFrameContext{ - frame->width, - frame->height, - frameFormat, - expectedOutputWidth, - expectedOutputHeight}; + // We need to compare the current frame context with our previous frame + // context. If they are different, then we need to re-create our colorspace + // conversion objects. We create our colorspace conversion objects late so + // that we don't have to depend on the unreliable metadata in the header. + // And we sometimes re-create them because it's possible for frame + // resolution to change mid-stream. Finally, we want to reuse the colorspace + // conversion objects as much as possible for performance reasons. + enum AVPixelFormat frameFormat = + static_cast(frame->format); + auto frameContext = DecodedFrameContext{ + frame->width, + frame->height, + frameFormat, + expectedOutputWidth, + expectedOutputHeight}; - if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { - outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor( - expectedOutputHeight, expectedOutputWidth, torch::kCPU)); + if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) { + outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor( + expectedOutputHeight, expectedOutputWidth, torch::kCPU)); - if (!streamInfo.swsContext || - streamInfo.prevFrameContext != frameContext) { - createSwsContext(streamInfo, frameContext, frame->colorspace); - streamInfo.prevFrameContext = frameContext; - } - int resultHeight = - convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor); - // If this check failed, it would mean that the frame wasn't reshaped to - // the expected height. - // TODO: Can we do the same check for width? - TORCH_CHECK( - resultHeight == expectedOutputHeight, - "resultHeight != expectedOutputHeight: ", - resultHeight, - " != ", - expectedOutputHeight); + if (!streamInfo.swsContext || streamInfo.prevFrameContext != frameContext) { + createSwsContext(streamInfo, frameContext, frame->colorspace); + streamInfo.prevFrameContext = frameContext; + } + int resultHeight = + convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor); + // If this check failed, it would mean that the frame wasn't reshaped to + // the expected height. + // TODO: Can we do the same check for width? + TORCH_CHECK( + resultHeight == expectedOutputHeight, + "resultHeight != expectedOutputHeight: ", + resultHeight, + " != ", + expectedOutputHeight); + + output.frame = outputTensor; + } else if ( + streamInfo.colorConversionLibrary == + ColorConversionLibrary::FILTERGRAPH) { + if (!streamInfo.filterState.filterGraph || + streamInfo.prevFrameContext != frameContext) { + createFilterGraph(streamInfo, expectedOutputHeight, expectedOutputWidth); + streamInfo.prevFrameContext = frameContext; + } + outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame); - output.frame = outputTensor; - } else if ( - streamInfo.colorConversionLibrary == - ColorConversionLibrary::FILTERGRAPH) { - if (!streamInfo.filterState.filterGraph || - streamInfo.prevFrameContext != frameContext) { - createFilterGraph( - streamInfo, expectedOutputHeight, expectedOutputWidth); - streamInfo.prevFrameContext = frameContext; - } - outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame); - - // Similarly to above, if this check fails it means the frame wasn't - // reshaped to its expected dimensions by filtergraph. - auto shape = outputTensor.sizes(); - TORCH_CHECK( - (shape.size() == 3) && (shape[0] == expectedOutputHeight) && - (shape[1] == expectedOutputWidth) && (shape[2] == 3), - "Expected output tensor of shape ", - expectedOutputHeight, - "x", - expectedOutputWidth, - "x3, got ", - shape); - - if (preAllocatedOutputTensor.has_value()) { - // We have already validated that preAllocatedOutputTensor and - // outputTensor have the same shape. - preAllocatedOutputTensor.value().copy_(outputTensor); - output.frame = preAllocatedOutputTensor.value(); - } else { - output.frame = outputTensor; - } + // Similarly to above, if this check fails it means the frame wasn't + // reshaped to its expected dimensions by filtergraph. + auto shape = outputTensor.sizes(); + TORCH_CHECK( + (shape.size() == 3) && (shape[0] == expectedOutputHeight) && + (shape[1] == expectedOutputWidth) && (shape[2] == 3), + "Expected output tensor of shape ", + expectedOutputHeight, + "x", + expectedOutputWidth, + "x3, got ", + shape); + + if (preAllocatedOutputTensor.has_value()) { + // We have already validated that preAllocatedOutputTensor and + // outputTensor have the same shape. + preAllocatedOutputTensor.value().copy_(outputTensor); + output.frame = preAllocatedOutputTensor.value(); } else { - throw std::runtime_error( - "Invalid color conversion library: " + - std::to_string(static_cast(streamInfo.colorConversionLibrary))); + output.frame = outputTensor; } - } else if (output.streamType == AVMEDIA_TYPE_AUDIO) { - // TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement - // audio decoding. - throw std::runtime_error("Audio is not supported yet."); + } else { + throw std::runtime_error( + "Invalid color conversion library: " + + std::to_string(static_cast(streamInfo.colorConversionLibrary))); } } diff --git a/src/torchcodec/decoders/_core/VideoDecoder.h b/src/torchcodec/decoders/_core/VideoDecoder.h index c6083a2f..bed2ed05 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.h +++ b/src/torchcodec/decoders/_core/VideoDecoder.h @@ -171,8 +171,6 @@ class VideoDecoder { struct DecodedOutput { // The actual decoded output as a Tensor. torch::Tensor frame; - // Could be AVMEDIA_TYPE_VIDEO or AVMEDIA_TYPE_AUDIO. - AVMediaType streamType; // The stream index of the decoded frame. Used to distinguish // between streams that are of the same type. int streamIndex;