diff --git a/src/transformers/models/distilbert/modeling_distilbert.py b/src/transformers/models/distilbert/modeling_distilbert.py index 248dbfcbbbd70d..bf32341d4268c0 100755 --- a/src/transformers/models/distilbert/modeling_distilbert.py +++ b/src/transformers/models/distilbert/modeling_distilbert.py @@ -91,7 +91,7 @@ def _create_sinusoidal_embeddings(n_pos, dim, out): out.detach_() -class QATMatMul(nn.Module): +class QATAttentionScores(nn.Module): def __init__(self): super().__init__() @@ -106,6 +106,22 @@ def __init__(self): def forward(self, a: torch.Tensor, b: torch.Tensor): return torch.matmul(a, b) +class QATContextLayer(nn.Module): + def __init__(self): + super().__init__() + + # behaves like normal torch.matmul unless a SparseML QuantizationModifier + # is initialized + self.wrap_qat = True + self.qat_wrapper_kwargs = { + "num_inputs": 2, + "num_outputs": 0, + "input_qconfigs": ["asymmetric", "symmetric"], + } + + def forward(self, a: torch.Tensor, b: torch.Tensor): + return torch.matmul(a, b) + class Embeddings(nn.Module): def __init__(self, config): @@ -171,8 +187,8 @@ def __init__(self, config): # non-parameterized matmuls will behave as normal torch.matmul ops unless # Quantization-Aware-Training is invoked - self.attention_scores_matmul = QATMatMul() - self.context_layer_matmul = QATMatMul() + self.attention_scores_matmul = QATAttentionScores() + self.context_layer_matmul = QATContextLayer() def prune_heads(self, heads): attention_head_size = self.dim // self.n_heads