diff --git a/src/compressed_tensors/quantization/lifecycle/initialize.py b/src/compressed_tensors/quantization/lifecycle/initialize.py index 2a1efccb..009215e0 100644 --- a/src/compressed_tensors/quantization/lifecycle/initialize.py +++ b/src/compressed_tensors/quantization/lifecycle/initialize.py @@ -173,7 +173,10 @@ def _initialize_scale_zero_point( device = "cpu" if has_offloaded_params(module) else params_device # infer expected scale/zero point shape - expected_shape = 1 # per tensor + if quantization_args.strategy == QuantizationStrategy.TOKEN: + expected_shape = (1, 1) + else: + expected_shape = 1 if base_name == "weight" and weight_shape is not None: if quantization_args.strategy == QuantizationStrategy.CHANNEL: diff --git a/src/compressed_tensors/utils/offload.py b/src/compressed_tensors/utils/offload.py index 5bc00c05..29970bd9 100644 --- a/src/compressed_tensors/utils/offload.py +++ b/src/compressed_tensors/utils/offload.py @@ -179,8 +179,6 @@ def update_offload_data( prefix = module._hf_hook.weights_map.prefix key = f"{prefix}{name}" - breakpoint() - offload_device = ( dataset[key].device if key in dataset