diff --git a/examples/multimodal_vision/pixtral_example.py b/examples/multimodal_vision/pixtral_example.py index ebb18df12..891819bc6 100644 --- a/examples/multimodal_vision/pixtral_example.py +++ b/examples/multimodal_vision/pixtral_example.py @@ -16,18 +16,20 @@ # Oneshot arguments DATASET_ID = "flickr30k" -DATASET_SPLIT = {"calibration": "test[:512]"} NUM_CALIBRATION_SAMPLES = 512 +DATASET_SPLIT = {"calibration": f"test[:{NUM_CALIBRATION_SAMPLES}]"} MAX_SEQUENCE_LENGTH = 2048 # Define a oneshot data collator for multimodal inputs. +# NOTE: for transformers<4.48.0, please squeeze the first dimension of `pixel_values` +# by appending `[0]` to the end of line 32 def data_collator(batch): assert len(batch) == 1 return { "input_ids": torch.LongTensor(batch[0]["input_ids"]), "attention_mask": torch.tensor(batch[0]["attention_mask"]), - "pixel_values": torch.tensor(batch[0]["pixel_values"])[0], + "pixel_values": torch.tensor(batch[0]["pixel_values"]), }