From faa93c9ab1d7a8ceeb42e2c053d52c70dba8e13e Mon Sep 17 00:00:00 2001 From: Sara Adkins Date: Tue, 16 Apr 2024 18:43:13 +0000 Subject: [PATCH] make sure scale/zp on correct device --- src/sparsetensors/quantization/lifecycle/forward.py | 7 ++++--- src/sparsetensors/quantization/lifecycle/initialize.py | 8 ++++++-- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/src/sparsetensors/quantization/lifecycle/forward.py b/src/sparsetensors/quantization/lifecycle/forward.py index e917c022..6416a10b 100644 --- a/src/sparsetensors/quantization/lifecycle/forward.py +++ b/src/sparsetensors/quantization/lifecycle/forward.py @@ -56,7 +56,7 @@ def fake_quantize( zero_point: torch.Tensor, args: QuantizationArgs, ) -> torch.Tensor: - max_q = torch.tensor(2**args.num_bits - 1) + max_q = torch.tensor(2**args.num_bits - 1, device=x.device) Q = torch.zeros_like(x) Q = quantize(x, scale, zero_point, max_q) return dequantize(Q, scale, zero_point) @@ -112,6 +112,7 @@ def _maybe_calibrate_or_quantize( }: return value + device = next(module.parameters()).device scale = getattr(module, f"{base_name}_scale") # zero_point = getattr(module, f"{base_name}_zero_point").data zero_point = getattr(module, f"{base_name}_zero_point") @@ -122,7 +123,7 @@ def _maybe_calibrate_or_quantize( updated_scale, updated_zero_point = observer(value) # update scale and zero point - scale.data = updated_scale - zero_point.data = updated_zero_point + scale.data = updated_scale.to(device) + zero_point.data = updated_zero_point.to(device) return fake_quantize(value, scale, zero_point, args) diff --git a/src/sparsetensors/quantization/lifecycle/initialize.py b/src/sparsetensors/quantization/lifecycle/initialize.py index a87dbc3d..bfe3d268 100644 --- a/src/sparsetensors/quantization/lifecycle/initialize.py +++ b/src/sparsetensors/quantization/lifecycle/initialize.py @@ -59,11 +59,15 @@ def initialize_module_for_quantization(module: Module, scheme: QuantizationSchem def _initialize_scale_zero_point_observer( module: Module, base_name: str, quantization_args: QuantizationArgs ): + device = next(module.parameters()).device + # initializes empty scale and zero point parameters for the module - init_scale = Parameter(torch.empty(0), requires_grad=False) + init_scale = Parameter(torch.empty(0, device=device), requires_grad=False) module.register_parameter(f"{base_name}_scale", init_scale) - init_zero_point = Parameter(torch.empty(0, dtype=int), requires_grad=False) + init_zero_point = Parameter( + torch.empty(0, device=device, dtype=int), requires_grad=False + ) module.register_parameter(f"{base_name}_zero_point", init_zero_point) # initialize observer module and attach as submodule