From 0ac6f60dd72d7638363a63fbd0f14326d52651cb Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Mon, 15 Apr 2024 11:50:56 -0400 Subject: [PATCH] decorate fake quant with torch.no_grad (#8) gradients shouldn't be computed for Q/DQ in QAT --- src/sparsetensors/quantization/lifecycle/forward.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/sparsetensors/quantization/lifecycle/forward.py b/src/sparsetensors/quantization/lifecycle/forward.py index 3624229a..69a3f250 100644 --- a/src/sparsetensors/quantization/lifecycle/forward.py +++ b/src/sparsetensors/quantization/lifecycle/forward.py @@ -24,6 +24,7 @@ __all__ = ["wrap_module_forward_quantized"] +@torch.no_grad() def quantize( x: torch.Tensor, scale: torch.Tensor, @@ -39,6 +40,7 @@ def quantize( ) +@torch.no_grad() def dequantize( x_q: torch.Tensor, scale: torch.Tensor, @@ -47,6 +49,7 @@ def dequantize( return (x_q - zero_point) * scale +@torch.no_grad() def fake_quantize( x: torch.Tensor, scale: torch.Tensor,