diff --git a/benchmarks/decoders/generate_readme_data.py b/benchmarks/decoders/generate_readme_data.py index 891e5dca..262e21c2 100644 --- a/benchmarks/decoders/generate_readme_data.py +++ b/benchmarks/decoders/generate_readme_data.py @@ -61,7 +61,8 @@ def main() -> None: decoder_dict = {} decoder_dict["torchcodec"] = TorchCodecPublic() - decoder_dict["torchcodec[cuda]"] = TorchCodecPublic(device="cuda") + if torch.cuda.is_available(): + decoder_dict["torchcodec[cuda]"] = TorchCodecPublic(device="cuda") decoder_dict["torchvision[video_reader]"] = TorchVision("video_reader") decoder_dict["torchaudio"] = TorchAudioDecoder() @@ -80,7 +81,7 @@ def main() -> None: batch_parameters=BatchParameters(batch_size=50, num_threads=10), resize_height=256, resize_width=256, - resize_device="cuda", + resize_device="cuda" if torch.cuda.is_available() else "cpu", ), ) data_for_writing = {