Skip to content

Commit

Permalink
Remove streamType field from DecodedOutput (#457)
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug authored Jan 21, 2025
1 parent 4f3e491 commit b1719e7
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 78 deletions.
144 changes: 68 additions & 76 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -871,7 +871,7 @@ VideoDecoder::DecodedOutput VideoDecoder::convertAVFrameToDecodedOutput(
AVFrame* frame = rawOutput.frame.get();
output.streamIndex = streamIndex;
auto& streamInfo = streams_[streamIndex];
output.streamType = streams_[streamIndex].stream->codecpar->codec_type;
TORCH_CHECK(streamInfo.stream->codecpar->codec_type == AVMEDIA_TYPE_VIDEO);
output.pts = frame->pts;
output.ptsSeconds =
ptsToSeconds(frame->pts, formatContext_->streams[streamIndex]->time_base);
Expand Down Expand Up @@ -932,86 +932,78 @@ void VideoDecoder::convertAVFrameToDecodedOutputOnCPU(
}

torch::Tensor outputTensor;
if (output.streamType == AVMEDIA_TYPE_VIDEO) {
// We need to compare the current frame context with our previous frame
// context. If they are different, then we need to re-create our colorspace
// conversion objects. We create our colorspace conversion objects late so
// that we don't have to depend on the unreliable metadata in the header.
// And we sometimes re-create them because it's possible for frame
// resolution to change mid-stream. Finally, we want to reuse the colorspace
// conversion objects as much as possible for performance reasons.
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(frame->format);
auto frameContext = DecodedFrameContext{
frame->width,
frame->height,
frameFormat,
expectedOutputWidth,
expectedOutputHeight};
// We need to compare the current frame context with our previous frame
// context. If they are different, then we need to re-create our colorspace
// conversion objects. We create our colorspace conversion objects late so
// that we don't have to depend on the unreliable metadata in the header.
// And we sometimes re-create them because it's possible for frame
// resolution to change mid-stream. Finally, we want to reuse the colorspace
// conversion objects as much as possible for performance reasons.
enum AVPixelFormat frameFormat =
static_cast<enum AVPixelFormat>(frame->format);
auto frameContext = DecodedFrameContext{
frame->width,
frame->height,
frameFormat,
expectedOutputWidth,
expectedOutputHeight};

if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
expectedOutputHeight, expectedOutputWidth, torch::kCPU));
if (streamInfo.colorConversionLibrary == ColorConversionLibrary::SWSCALE) {
outputTensor = preAllocatedOutputTensor.value_or(allocateEmptyHWCTensor(
expectedOutputHeight, expectedOutputWidth, torch::kCPU));

if (!streamInfo.swsContext ||
streamInfo.prevFrameContext != frameContext) {
createSwsContext(streamInfo, frameContext, frame->colorspace);
streamInfo.prevFrameContext = frameContext;
}
int resultHeight =
convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor);
// If this check failed, it would mean that the frame wasn't reshaped to
// the expected height.
// TODO: Can we do the same check for width?
TORCH_CHECK(
resultHeight == expectedOutputHeight,
"resultHeight != expectedOutputHeight: ",
resultHeight,
" != ",
expectedOutputHeight);
if (!streamInfo.swsContext || streamInfo.prevFrameContext != frameContext) {
createSwsContext(streamInfo, frameContext, frame->colorspace);
streamInfo.prevFrameContext = frameContext;
}
int resultHeight =
convertFrameToTensorUsingSwsScale(streamIndex, frame, outputTensor);
// If this check failed, it would mean that the frame wasn't reshaped to
// the expected height.
// TODO: Can we do the same check for width?
TORCH_CHECK(
resultHeight == expectedOutputHeight,
"resultHeight != expectedOutputHeight: ",
resultHeight,
" != ",
expectedOutputHeight);

output.frame = outputTensor;
} else if (
streamInfo.colorConversionLibrary ==
ColorConversionLibrary::FILTERGRAPH) {
if (!streamInfo.filterState.filterGraph ||
streamInfo.prevFrameContext != frameContext) {
createFilterGraph(streamInfo, expectedOutputHeight, expectedOutputWidth);
streamInfo.prevFrameContext = frameContext;
}
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);

output.frame = outputTensor;
} else if (
streamInfo.colorConversionLibrary ==
ColorConversionLibrary::FILTERGRAPH) {
if (!streamInfo.filterState.filterGraph ||
streamInfo.prevFrameContext != frameContext) {
createFilterGraph(
streamInfo, expectedOutputHeight, expectedOutputWidth);
streamInfo.prevFrameContext = frameContext;
}
outputTensor = convertFrameToTensorUsingFilterGraph(streamIndex, frame);

// Similarly to above, if this check fails it means the frame wasn't
// reshaped to its expected dimensions by filtergraph.
auto shape = outputTensor.sizes();
TORCH_CHECK(
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
"Expected output tensor of shape ",
expectedOutputHeight,
"x",
expectedOutputWidth,
"x3, got ",
shape);

if (preAllocatedOutputTensor.has_value()) {
// We have already validated that preAllocatedOutputTensor and
// outputTensor have the same shape.
preAllocatedOutputTensor.value().copy_(outputTensor);
output.frame = preAllocatedOutputTensor.value();
} else {
output.frame = outputTensor;
}
// Similarly to above, if this check fails it means the frame wasn't
// reshaped to its expected dimensions by filtergraph.
auto shape = outputTensor.sizes();
TORCH_CHECK(
(shape.size() == 3) && (shape[0] == expectedOutputHeight) &&
(shape[1] == expectedOutputWidth) && (shape[2] == 3),
"Expected output tensor of shape ",
expectedOutputHeight,
"x",
expectedOutputWidth,
"x3, got ",
shape);

if (preAllocatedOutputTensor.has_value()) {
// We have already validated that preAllocatedOutputTensor and
// outputTensor have the same shape.
preAllocatedOutputTensor.value().copy_(outputTensor);
output.frame = preAllocatedOutputTensor.value();
} else {
throw std::runtime_error(
"Invalid color conversion library: " +
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
output.frame = outputTensor;
}
} else if (output.streamType == AVMEDIA_TYPE_AUDIO) {
// TODO: https://github.com/pytorch-labs/torchcodec/issues/85 implement
// audio decoding.
throw std::runtime_error("Audio is not supported yet.");
} else {
throw std::runtime_error(
"Invalid color conversion library: " +
std::to_string(static_cast<int>(streamInfo.colorConversionLibrary)));
}
}

Expand Down
2 changes: 0 additions & 2 deletions src/torchcodec/decoders/_core/VideoDecoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,6 @@ class VideoDecoder {
struct DecodedOutput {
// The actual decoded output as a Tensor.
torch::Tensor frame;
// Could be AVMEDIA_TYPE_VIDEO or AVMEDIA_TYPE_AUDIO.
AVMediaType streamType;
// The stream index of the decoded frame. Used to distinguish
// between streams that are of the same type.
int streamIndex;
Expand Down

0 comments on commit b1719e7

Please sign in to comment.