From d64d9fbc6443ddf77b06c0dacc5ec5d6e19fbd7f Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Wed, 13 Mar 2024 15:17:21 -0400 Subject: [PATCH] [Cherry Pick] allow dataset size smaller than calibration samples (#2091) (#2179) * allow dataset size smaller than calibration samples (#2091) * merge issue --- .../transformers/finetune/data/data_helpers.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/sparseml/transformers/finetune/data/data_helpers.py b/src/sparseml/transformers/finetune/data/data_helpers.py index a92d492bb06..d4794e80220 100644 --- a/src/sparseml/transformers/finetune/data/data_helpers.py +++ b/src/sparseml/transformers/finetune/data/data_helpers.py @@ -49,9 +49,17 @@ def format_calibration_data( :param accelerator: optional accelerator for if preparing in FSDP mode :return: list of trimmed calibration data tensors """ - num_calibration_samples = num_calibration_samples or len(tokenized_dataset) + safe_calibration_samples = len(tokenized_dataset) + if num_calibration_samples is not None: + safe_calibration_samples = min(len(tokenized_dataset), num_calibration_samples) + if safe_calibration_samples != num_calibration_samples: + LOGGER.warn( + f"Requested {num_calibration_samples} calibration samples but " + f"the provided dataset only has {safe_calibration_samples}. " + ) + shuffled_calibration = tokenized_dataset.shuffle() - shuffled_calibration = shuffled_calibration.select(range(num_calibration_samples)) + shuffled_calibration = shuffled_calibration.select(range(safe_calibration_samples)) dataloader_params = { "batch_size": 1,