From 317c10ca901901c43f767c50aac8362e2c930448 Mon Sep 17 00:00:00 2001 From: Kyle Sayers Date: Thu, 30 Jan 2025 16:36:59 -0500 Subject: [PATCH] [VLM] Update pixtral data collator to reflect latest transformers changes (#1116) ## Purpose ## * In transformers==4.48.0, the Pixtral processor was updated to not add an additional layer of wrapping for `pixel_values` (https://github.com/huggingface/transformers/pull/34801). This is more inline with how other processors handle multimodal inputs * Because previously the data_collator was being used to unwrap this unnecessary wrapping, attempting to quantize pixtral with transformers>=4.48.0 fails ## Changes ## * Update pixtral data collator to match latest transformers version * Add comment for those who want to use transformers<4.48.0 ## Testing ## * Ran pixtral example to completion, @shubhra ran pixtral large --------- Signed-off-by: Kyle Sayers --- examples/multimodal_vision/pixtral_example.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/examples/multimodal_vision/pixtral_example.py b/examples/multimodal_vision/pixtral_example.py index ebb18df12..891819bc6 100644 --- a/examples/multimodal_vision/pixtral_example.py +++ b/examples/multimodal_vision/pixtral_example.py @@ -16,18 +16,20 @@ # Oneshot arguments DATASET_ID = "flickr30k" -DATASET_SPLIT = {"calibration": "test[:512]"} NUM_CALIBRATION_SAMPLES = 512 +DATASET_SPLIT = {"calibration": f"test[:{NUM_CALIBRATION_SAMPLES}]"} MAX_SEQUENCE_LENGTH = 2048 # Define a oneshot data collator for multimodal inputs. +# NOTE: for transformers<4.48.0, please squeeze the first dimension of `pixel_values` +# by appending `[0]` to the end of line 32 def data_collator(batch): assert len(batch) == 1 return { "input_ids": torch.LongTensor(batch[0]["input_ids"]), "attention_mask": torch.tensor(batch[0]["attention_mask"]), - "pixel_values": torch.tensor(batch[0]["pixel_values"])[0], + "pixel_values": torch.tensor(batch[0]["pixel_values"]), }