diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 55e5545c..bc3ec3bb 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -42,11 +42,12 @@ git clone git@github.com:pytorch/torchcodec.git cd torchcodec pip install -e ".[dev]" --no-build-isolation -vv +# Or, for cuda support: ENABLE_CUDA=1 pip install -e ".[dev]" --no-build-isolation -vv ``` ### Running unit tests -To run python tests run: +To run python tests run (please make sure `torchvision` is installed): ```bash pytest test -vvv diff --git a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp index eb605465..7d058130 100644 --- a/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp +++ b/src/torchcodec/decoders/_core/CPUOnlyDevice.cpp @@ -35,4 +35,10 @@ void releaseContextOnCuda( throwUnsupportedDeviceError(device); } +std::optional findCudaCodec( + const torch::Device& device, + const AVCodecID& codecId) { + throwUnsupportedDeviceError(device); +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/CudaDevice.cpp b/src/torchcodec/decoders/_core/CudaDevice.cpp index 5d48d26a..69fef471 100644 --- a/src/torchcodec/decoders/_core/CudaDevice.cpp +++ b/src/torchcodec/decoders/_core/CudaDevice.cpp @@ -256,4 +256,33 @@ void convertAVFrameToDecodedOutputOnCuda( << " took: " << duration.count() << "us" << std::endl; } +// inspired by https://github.com/FFmpeg/FFmpeg/commit/ad67ea9 +// we have to do this because of an FFmpeg bug where hardware decoding is not +// appropriately set, so we just go off and find the matching codec for the CUDA +// device +std::optional findCudaCodec( + const torch::Device& device, + const AVCodecID& codecId) { + throwErrorIfNonCudaDevice(device); + + void* i = NULL; + + AVCodecPtr c; + while (c = av_codec_iterate(&i)) { + const AVCodecHWConfig* config; + + if (c->id != codecId || !av_codec_is_decoder(c)) { + continue; + } + + for (int j = 0; config = avcodec_get_hw_config(c, j); j++) { + if (config->device_type == AV_HWDEVICE_TYPE_CUDA) { + return c; + } + } + } + + return std::nullopt; +} + } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/DeviceInterface.h b/src/torchcodec/decoders/_core/DeviceInterface.h index 42dd63fc..289308cb 100644 --- a/src/torchcodec/decoders/_core/DeviceInterface.h +++ b/src/torchcodec/decoders/_core/DeviceInterface.h @@ -10,6 +10,7 @@ #include #include #include +#include "FFMPEGCommon.h" #include "src/torchcodec/decoders/_core/VideoDecoder.h" extern "C" { @@ -43,4 +44,8 @@ void releaseContextOnCuda( const torch::Device& device, AVCodecContext* codecContext); +std::optional findCudaCodec( + const torch::Device& device, + const AVCodecID& codecId); + } // namespace facebook::torchcodec diff --git a/src/torchcodec/decoders/_core/VideoDecoder.cpp b/src/torchcodec/decoders/_core/VideoDecoder.cpp index 2c955551..737cf478 100644 --- a/src/torchcodec/decoders/_core/VideoDecoder.cpp +++ b/src/torchcodec/decoders/_core/VideoDecoder.cpp @@ -461,6 +461,12 @@ void VideoDecoder::addVideoStreamDecoder( "Stream with index " + std::to_string(streamNumber) + " is not a video stream."); } + + if (options.device.type() == torch::kCUDA) { + codec = findCudaCodec(options.device, streamInfo.stream->codecpar->codec_id) + .value_or(codec); + } + AVCodecContext* codecContext = avcodec_alloc_context3(codec); codecContext->thread_count = options.ffmpegThreadCount.value_or(0); TORCH_CHECK(codecContext != nullptr); diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index 8b3fca3e..dea38bab 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -11,7 +11,14 @@ from torchcodec.decoders import _core, VideoDecoder -from ..utils import assert_frames_equal, cpu_and_cuda, H265_VIDEO, in_fbcode, NASA_VIDEO +from ..utils import ( + assert_frames_equal, + AV1_VIDEO, + cpu_and_cuda, + H265_VIDEO, + in_fbcode, + NASA_VIDEO, +) class TestVideoDecoder: @@ -409,6 +416,16 @@ def test_get_frames_at_fails(self, device): with pytest.raises(RuntimeError, match="Expected a value of type"): decoder.get_frames_at([0.3]) + @pytest.mark.parametrize("device", cpu_and_cuda()) + def test_get_frame_at_av1(self, device): + decoder = VideoDecoder(AV1_VIDEO.path, device=device) + ref_frame10 = AV1_VIDEO.get_frame_data_by_index(10) + ref_frame_info10 = AV1_VIDEO.get_frame_info(10) + decoded_frame10 = decoder.get_frame_at(10) + assert decoded_frame10.duration_seconds == ref_frame_info10.duration_seconds + assert decoded_frame10.pts_seconds == ref_frame_info10.pts_seconds + assert_frames_equal(decoded_frame10.data, ref_frame10.to(device=device)) + @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_played_at(self, device): decoder = VideoDecoder(NASA_VIDEO.path, device=device) diff --git a/test/generate_reference_resources.sh b/test/generate_reference_resources.sh index e10f451c..fba098a7 100755 --- a/test/generate_reference_resources.sh +++ b/test/generate_reference_resources.sh @@ -61,3 +61,19 @@ do python3 "$TORCHCODEC_PATH/test/convert_image_to_tensor.py" "$bmp" rm -f "$bmp" done + +# This video was generated by running the following: +# ffmpeg -f lavfi -i testsrc=duration=5:size=640x360:rate=25,format=yuv420p -c:v libaom-av1 -crf 30 -colorspace bt709 -color_primaries bt709 -color_trc bt709 av1_video.mkv +# Note that this video only has 1 stream, at index 0. +VIDEO_PATH=$RESOURCES_DIR/av1_video.mkv +FRAMES=(10) +for frame in "${FRAMES[@]}"; do + frame_name=$(printf "%06d" "$frame") + ffmpeg -y -i "$VIDEO_PATH" -vf select="eq(n\,$frame)" -vsync vfr -q:v 2 "$VIDEO_PATH.stream0.frame$frame_name.bmp" +done + +for bmp in "$RESOURCES_DIR"/*.bmp +do + python3 "$TORCHCODEC_PATH/test/convert_image_to_tensor.py" "$bmp" + rm -f "$bmp" +done diff --git a/test/resources/av1_video.mkv b/test/resources/av1_video.mkv new file mode 100644 index 00000000..429a1e9a Binary files /dev/null and b/test/resources/av1_video.mkv differ diff --git a/test/resources/av1_video.mkv.stream0.frame000010.pt b/test/resources/av1_video.mkv.stream0.frame000010.pt new file mode 100644 index 00000000..6263e1ca Binary files /dev/null and b/test/resources/av1_video.mkv.stream0.frame000010.pt differ diff --git a/test/utils.py b/test/utils.py index 14e7db0f..1e2a462f 100644 --- a/test/utils.py +++ b/test/utils.py @@ -312,3 +312,18 @@ def get_empty_chw_tensor(self, *, stream_index: int) -> torch.Tensor: }, }, ) + +AV1_VIDEO = TestVideo( + filename="av1_video.mkv", + default_stream_index=0, + # This metadata is extracted manually. + # $ ffprobe -v error -hide_banner -select_streams v:0 -show_frames -of json test/resources/av1_video.mkv > out.json + stream_infos={ + 0: TestVideoStreamInfo(width=640, height=360, num_color_channels=3), + }, + frames={ + 0: { + 10: TestFrameInfo(pts_seconds=0.400000, duration_seconds=0.040000), + }, + }, +)