diff --git a/test/decoders/test_video_decoder.py b/test/decoders/test_video_decoder.py index 8b15c1a3..4632e5ff 100644 --- a/test/decoders/test_video_decoder.py +++ b/test/decoders/test_video_decoder.py @@ -448,7 +448,9 @@ def test_get_frame_played_at_h265(self, device): # Non-regression test for https://github.com/pytorch/torchcodec/issues/179 decoder = VideoDecoder(H265_VIDEO.path, device=device) ref_frame6 = H265_VIDEO.get_frame_data_by_index(5) - assert_frames_equal(ref_frame6, decoder.get_frame_played_at(0.5).data) + assert_frames_equal( + ref_frame6.to(device=device), decoder.get_frame_played_at(0.5).data + ) @pytest.mark.parametrize("device", cpu_and_cuda()) def test_get_frame_played_at_fails(self, device):