From 35008383df2b096c3838b28133c64464e44abadd Mon Sep 17 00:00:00 2001 From: Scott Schneider Date: Wed, 15 Jan 2025 14:02:06 -0800 Subject: [PATCH] Reference must be on device --- test/decoders/test_video_decoder.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index 8b15c1a3..4632e5ff 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -448,7 +448,9 @@ def test_get_frame_played_at_h265(self, device): # Non-regression test for https://github.com/pytorch/torchcodec/issues/179 decoder = VideoDecoder(H265_VIDEO.path, device=device) ref_frame6 = H265_VIDEO.get_frame_data_by_index(5) - assert_frames_equal(ref_frame6, decoder.get_frame_played_at(0.5).data) + assert_frames_equal( + ref_frame6.to(device=device), decoder.get_frame_played_at(0.5).data + ) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_played_at_fails(self, device):