Skip to content

Commit

Permalink
fix: Solve CUDA AV1 decoding (#448)
Browse files Browse the repository at this point in the history
Co-authored-by: Scott Schneider <scott.a.s@gmail.com>
  • Loading branch information
hugo-ijw and scotts authored Jan 13, 2025
1 parent 81de40e commit d8dde5c
Show file tree
Hide file tree
Showing 11 changed files with 98 additions and 3 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/linux_cuda_wheel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
build-command: "BUILD_AGAINST_ALL_FFMPEG_FROM_S3=1 ENABLE_CUDA=1 python -m build --wheel -vvv --no-isolation"

install-and-test:
runs-on: linux.4xlarge.nvidia.gpu
runs-on: linux.g5.4xlarge.nvidia.gpu
strategy:
fail-fast: false
matrix:
Expand Down
3 changes: 2 additions & 1 deletion CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,12 @@ git clone git@github.com:pytorch/torchcodec.git
cd torchcodec

pip install -e ".[dev]" --no-build-isolation -vv
# Or, for cuda support: ENABLE_CUDA=1 pip install -e ".[dev]" --no-build-isolation -vv
```

### Running unit tests

To run python tests run:
To run python tests run (please make sure `torchvision` is installed):

```bash
pytest test -vvv
Expand Down
6 changes: 6 additions & 0 deletions src/torchcodec/decoders/_core/CPUOnlyDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,10 @@ void releaseContextOnCuda(
throwUnsupportedDeviceError(device);
}

std::optional<AVCodecPtr> findCudaCodec(
const torch::Device& device,
const AVCodecID& codecId) {
throwUnsupportedDeviceError(device);
}

} // namespace facebook::torchcodec
29 changes: 29 additions & 0 deletions src/torchcodec/decoders/_core/CudaDevice.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,4 +256,33 @@ void convertAVFrameToDecodedOutputOnCuda(
<< " took: " << duration.count() << "us" << std::endl;
}

// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9
// we have to do this because of an FFmpeg bug where hardware decoding is not
// appropriately set, so we just go off and find the matching codec for the CUDA
// device
std::optional<AVCodecPtr> findCudaCodec(
const torch::Device& device,
const AVCodecID& codecId) {
throwErrorIfNonCudaDevice(device);

void* i = NULL;

AVCodecPtr c;
while (c = av_codec_iterate(&i)) {
const AVCodecHWConfig* config;

if (c->id != codecId || !av_codec_is_decoder(c)) {
continue;
}

for (int j = 0; config = avcodec_get_hw_config(c, j); j++) {
if (config->device_type == AV_HWDEVICE_TYPE_CUDA) {
return c;
}
}
}

return std::nullopt;
}

} // namespace facebook::torchcodec
5 changes: 5 additions & 0 deletions src/torchcodec/decoders/_core/DeviceInterface.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#include <memory>
#include <stdexcept>
#include <string>
#include "FFMPEGCommon.h"
#include "src/torchcodec/decoders/_core/VideoDecoder.h"

extern "C" {
Expand Down Expand Up @@ -43,4 +44,8 @@ void releaseContextOnCuda(
const torch::Device& device,
AVCodecContext* codecContext);

std::optional<AVCodecPtr> findCudaCodec(
const torch::Device& device,
const AVCodecID& codecId);

} // namespace facebook::torchcodec
6 changes: 6 additions & 0 deletions src/torchcodec/decoders/_core/VideoDecoder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -461,6 +461,12 @@ void VideoDecoder::addVideoStreamDecoder(
"Stream with index " + std::to_string(streamNumber) +
" is not a video stream.");
}

if (options.device.type() == torch::kCUDA) {
codec = findCudaCodec(options.device, streamInfo.stream->codecpar->codec_id)
.value_or(codec);
}

AVCodecContext* codecContext = avcodec_alloc_context3(codec);
codecContext->thread_count = options.ffmpegThreadCount.value_or(0);
TORCH_CHECK(codecContext != nullptr);
Expand Down
19 changes: 18 additions & 1 deletion test/decoders/test_video_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,14 @@

from torchcodec.decoders import _core, VideoDecoder

from ..utils import assert_frames_equal, cpu_and_cuda, H265_VIDEO, in_fbcode, NASA_VIDEO
from ..utils import (
assert_frames_equal,
AV1_VIDEO,
cpu_and_cuda,
H265_VIDEO,
in_fbcode,
NASA_VIDEO,
)


class TestVideoDecoder:
Expand Down Expand Up @@ -409,6 +416,16 @@ def test_get_frames_at_fails(self, device):
with pytest.raises(RuntimeError, match="Expected a value of type"):
decoder.get_frames_at([0.3])

@pytest.mark.parametrize("device", cpu_and_cuda())
def test_get_frame_at_av1(self, device):
decoder = VideoDecoder(AV1_VIDEO.path, device=device)
ref_frame10 = AV1_VIDEO.get_frame_data_by_index(10)
ref_frame_info10 = AV1_VIDEO.get_frame_info(10)
decoded_frame10 = decoder.get_frame_at(10)
assert decoded_frame10.duration_seconds == ref_frame_info10.duration_seconds
assert decoded_frame10.pts_seconds == ref_frame_info10.pts_seconds
assert_frames_equal(decoded_frame10.data, ref_frame10.to(device=device))

@pytest.mark.parametrize("device", cpu_and_cuda())
def test_get_frame_played_at(self, device):
decoder = VideoDecoder(NASA_VIDEO.path, device=device)
Expand Down
16 changes: 16 additions & 0 deletions test/generate_reference_resources.sh
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,19 @@ do
python3 "$TORCHCODEC_PATH/test/convert_image_to_tensor.py" "$bmp"
rm -f "$bmp"
done

# This video was generated by running the following:
# ffmpeg -f lavfi -i testsrc=duration=5:size=640x360:rate=25,format=yuv420p -c:v libaom-av1 -crf 30 -colorspace bt709 -color_primaries bt709 -color_trc bt709 av1_video.mkv
# Note that this video only has 1 stream, at index 0.
VIDEO_PATH=$RESOURCES_DIR/av1_video.mkv
FRAMES=(10)
for frame in "${FRAMES[@]}"; do
frame_name=$(printf "%06d" "$frame")
ffmpeg -y -i "$VIDEO_PATH" -vf select="eq(n\,$frame)" -vsync vfr -q:v 2 "$VIDEO_PATH.stream0.frame$frame_name.bmp"
done

for bmp in "$RESOURCES_DIR"/*.bmp
do
python3 "$TORCHCODEC_PATH/test/convert_image_to_tensor.py" "$bmp"
rm -f "$bmp"
done
Binary file added test/resources/av1_video.mkv
Binary file not shown.
Binary file not shown.
15 changes: 15 additions & 0 deletions test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,3 +312,18 @@ def get_empty_chw_tensor(self, *, stream_index: int) -> torch.Tensor:
},
},
)

AV1_VIDEO = TestVideo(
filename="av1_video.mkv",
default_stream_index=0,
# This metadata is extracted manually.
# $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of json test/resources/av1_video.mkv > out.json
stream_infos={
0: TestVideoStreamInfo(width=640, height=360, num_color_channels=3),
},
frames={
0: {
10: TestFrameInfo(pts_seconds=0.400000, duration_seconds=0.040000),
},
},
)

0 comments on commit d8dde5c

Please sign in to comment.