diff --git a/src/sparsetensors/quantization/lifecycle/forward.py b/src/sparsetensors/quantization/lifecycle/forward.py index 3624229a..69a3f250 100644 --- a/src/sparsetensors/quantization/lifecycle/forward.py +++ b/src/sparsetensors/quantization/lifecycle/forward.py @@ -24,6 +24,7 @@ __all__ = ["wrap_module_forward_quantized"] +@torch.no_grad() def quantize( x: torch.Tensor, scale: torch.Tensor, @@ -39,6 +40,7 @@ def quantize( ) +@torch.no_grad() def dequantize( x_q: torch.Tensor, scale: torch.Tensor, @@ -47,6 +49,7 @@ def dequantize( return (x_q - zero_point) * scale +@torch.no_grad() def fake_quantize( x: torch.Tensor, scale: torch.Tensor,