Skip to content

Commit

Permalink
Removed double quantization of output of context layer. (#45)
Browse files Browse the repository at this point in the history
  • Loading branch information
anmarques authored May 17, 2022
1 parent 86c51a9 commit 5afbd46
Showing 1 changed file with 19 additions and 3 deletions.
22 changes: 19 additions & 3 deletions src/transformers/models/distilbert/modeling_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _create_sinusoidal_embeddings(n_pos, dim, out):
out.detach_()


class QATMatMul(nn.Module):
class QATAttentionScores(nn.Module):
def __init__(self):
super().__init__()

Expand All @@ -106,6 +106,22 @@ def __init__(self):
def forward(self, a: torch.Tensor, b: torch.Tensor):
return torch.matmul(a, b)

class QATContextLayer(nn.Module):
def __init__(self):
super().__init__()

# behaves like normal torch.matmul unless a SparseML QuantizationModifier
# is initialized
self.wrap_qat = True
self.qat_wrapper_kwargs = {
"num_inputs": 2,
"num_outputs": 0,
"input_qconfigs": ["asymmetric", "symmetric"],
}

def forward(self, a: torch.Tensor, b: torch.Tensor):
return torch.matmul(a, b)


class Embeddings(nn.Module):
def __init__(self, config):
Expand Down Expand Up @@ -171,8 +187,8 @@ def __init__(self, config):

# non-parameterized matmuls will behave as normal torch.matmul ops unless
# Quantization-Aware-Training is invoked
self.attention_scores_matmul = QATMatMul()
self.context_layer_matmul = QATMatMul()
self.attention_scores_matmul = QATAttentionScores()
self.context_layer_matmul = QATContextLayer()

def prune_heads(self, heads):
attention_head_size = self.dim // self.n_heads
Expand Down

0 comments on commit 5afbd46

Please sign in to comment.