Skip to content

Commit

Permalink
[VLM] Update pixtral data collator to reflect latest transformers cha…
Browse files Browse the repository at this point in the history
…nges (#1116)

## Purpose ##
* In transformers==4.48.0, the Pixtral processor was updated to not add
an additional layer of wrapping for `pixel_values`
(huggingface/transformers#34801). This is more
inline with how other processors handle multimodal inputs
* Because previously the data_collator was being used to unwrap this
unnecessary wrapping, attempting to quantize pixtral with
transformers>=4.48.0 fails

## Changes ##
* Update pixtral data collator to match latest transformers version
* Add comment for those who want to use transformers<4.48.0

## Testing ##
* Ran pixtral example to completion, @shubhra ran pixtral large

---------

Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
  • Loading branch information
kylesayrs authored Jan 30, 2025
1 parent 999d660 commit 317c10c
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions examples/multimodal_vision/pixtral_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,20 @@

# Oneshot arguments
DATASET_ID = "flickr30k"
DATASET_SPLIT = {"calibration": "test[:512]"}
NUM_CALIBRATION_SAMPLES = 512
DATASET_SPLIT = {"calibration": f"test[:{NUM_CALIBRATION_SAMPLES}]"}
MAX_SEQUENCE_LENGTH = 2048


# Define a oneshot data collator for multimodal inputs.
# NOTE: for transformers<4.48.0, please squeeze the first dimension of `pixel_values`
# by appending `[0]` to the end of line 32
def data_collator(batch):
assert len(batch) == 1
return {
"input_ids": torch.LongTensor(batch[0]["input_ids"]),
"attention_mask": torch.tensor(batch[0]["attention_mask"]),
"pixel_values": torch.tensor(batch[0]["pixel_values"])[0],
"pixel_values": torch.tensor(batch[0]["pixel_values"]),
}


Expand Down

0 comments on commit 317c10c

Please sign in to comment.