Skip to content

Commit

Permalink
make sure scale/zp on correct device (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
Sara Adkins authored Apr 16, 2024
1 parent e77ce53 commit 514e4db
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 5 deletions.
7 changes: 4 additions & 3 deletions src/sparsetensors/quantization/lifecycle/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def fake_quantize(
zero_point: torch.Tensor,
args: QuantizationArgs,
) -> torch.Tensor:
max_q = torch.tensor(2**args.num_bits - 1)
max_q = torch.tensor(2**args.num_bits - 1, device=x.device)
Q = torch.zeros_like(x)
Q = quantize(x, scale, zero_point, max_q)
return dequantize(Q, scale, zero_point)
Expand Down Expand Up @@ -112,6 +112,7 @@ def _maybe_calibrate_or_quantize(
}:
return value

device = next(module.parameters()).device
scale = getattr(module, f"{base_name}_scale")
# zero_point = getattr(module, f"{base_name}_zero_point").data
zero_point = getattr(module, f"{base_name}_zero_point")
Expand All @@ -122,7 +123,7 @@ def _maybe_calibrate_or_quantize(
updated_scale, updated_zero_point = observer(value)

# update scale and zero point
scale.data = updated_scale
zero_point.data = updated_zero_point
scale.data = updated_scale.to(device)
zero_point.data = updated_zero_point.to(device)

return fake_quantize(value, scale, zero_point, args)
8 changes: 6 additions & 2 deletions src/sparsetensors/quantization/lifecycle/initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,15 @@ def initialize_module_for_quantization(
def _initialize_scale_zero_point_observer(
module: Module, base_name: str, quantization_args: QuantizationArgs
):
device = next(module.parameters()).device

# initializes empty scale and zero point parameters for the module
init_scale = Parameter(torch.empty(0), requires_grad=False)
init_scale = Parameter(torch.empty(0, device=device), requires_grad=False)
module.register_parameter(f"{base_name}_scale", init_scale)

init_zero_point = Parameter(torch.empty(0, dtype=int), requires_grad=False)
init_zero_point = Parameter(
torch.empty(0, device=device, dtype=int), requires_grad=False
)
module.register_parameter(f"{base_name}_zero_point", init_zero_point)

# initialize observer module and attach as submodule
Expand Down

0 comments on commit 514e4db

Please sign in to comment.