From 4d44da6fbb7e0158936e744faaf2304ccae888d9 Mon Sep 17 00:00:00 2001 From: Benjamin Fineran Date: Mon, 22 Aug 2022 17:28:12 -0400 Subject: [PATCH] set embeddings weight QAT params to correct device in DP mode (#1013) (#1014) --- .../pytorch/sparsification/quantization/helpers.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/sparseml/pytorch/sparsification/quantization/helpers.py b/src/sparseml/pytorch/sparsification/quantization/helpers.py index 80309eb2126..288c51fedba 100644 --- a/src/sparseml/pytorch/sparsification/quantization/helpers.py +++ b/src/sparseml/pytorch/sparsification/quantization/helpers.py @@ -796,9 +796,15 @@ def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConf embedding.weight_fake_quant = qconfig.weight() def _qat_forward(self, input: torch.Tensor) -> torch.Tensor: + weight = self.weight_fake_quant(self.weight) + if weight.device != input.device: + # torch DataParallel may not pick up overwritten bound method + # send weight to correct device + weight = weight.to(input.device) + return torch.nn.functional.embedding( input, - self.weight_fake_quant(self.weight), + weight, self.padding_idx, self.max_norm, self.norm_type, @@ -808,6 +814,7 @@ def _qat_forward(self, input: torch.Tensor) -> torch.Tensor: # bind qat forward to embedding qat_forward_bound = _qat_forward.__get__(embedding, embedding.__class__) + embedding.to(embedding.weight.device) # set weight_fake_quant to correct device setattr(embedding, "forward", qat_forward_bound)