Skip to content

Commit

Permalink
set embeddings weight QAT params to correct device in DP mode (#1013) (
Browse files Browse the repository at this point in the history
  • Loading branch information
bfineran authored Aug 22, 2022
1 parent ea432e9 commit 4d44da6
Showing 1 changed file with 8 additions and 1 deletion.
9 changes: 8 additions & 1 deletion src/sparseml/pytorch/sparsification/quantization/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -796,9 +796,15 @@ def _prepare_qat_embedding(embedding: Module, qconfig: "torch.quantization.QConf
embedding.weight_fake_quant = qconfig.weight()

def _qat_forward(self, input: torch.Tensor) -> torch.Tensor:
weight = self.weight_fake_quant(self.weight)
if weight.device != input.device:
# torch DataParallel may not pick up overwritten bound method
# send weight to correct device
weight = weight.to(input.device)

return torch.nn.functional.embedding(
input,
self.weight_fake_quant(self.weight),
weight,
self.padding_idx,
self.max_norm,
self.norm_type,
Expand All @@ -808,6 +814,7 @@ def _qat_forward(self, input: torch.Tensor) -> torch.Tensor:

# bind qat forward to embedding
qat_forward_bound = _qat_forward.__get__(embedding, embedding.__class__)
embedding.to(embedding.weight.device) # set weight_fake_quant to correct device
setattr(embedding, "forward", qat_forward_bound)


Expand Down

0 comments on commit 4d44da6

Please sign in to comment.