diff --git a/auto_fp8/modeling.py b/auto_fp8/modeling.py index eb4d2ba..0e4e8cc 100644 --- a/auto_fp8/modeling.py +++ b/auto_fp8/modeling.py @@ -79,6 +79,7 @@ def quantize(self, dataset: Optional[Dataset] = None): model=self.model, dataset=dataset, recipe=recipe, + num_calibration_samples=dataset.shape[0], ) def save_quantized(self, save_directory: str):