Skip to content

Commit

Permalink
Fix DebertaV2 similarly to deberta.
Browse files Browse the repository at this point in the history
  • Loading branch information
lenglaender committed Jan 7, 2025
1 parent 4e42007 commit c3ab05a
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 12 deletions.
10 changes: 9 additions & 1 deletion src/adapters/models/deberta/modeling_deberta.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
"""PyTorch DeBERTa model."""

import torch
from torch import nn
import torch.utils.checkpoint
from torch import nn

from transformers.models.deberta.modeling_deberta import (
DebertaOutput,
Expand Down Expand Up @@ -96,8 +96,10 @@ def forward(
"""
# >>> START AH Changes <<<
attention_mask = prefix_attention_mask(attention_mask, dim=3, prefix_value=1) # type: ignore
attention_mask = prefix_attention_mask(attention_mask, dim=2, prefix_value=1) # type: ignore
# >>> END AH Changes <<<

if query_states is None:
qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1)
Expand All @@ -110,17 +112,21 @@ def forward(
v = torch.matmul(qkvw[2], hidden_states.t().to(dtype=qkvw[2].dtype))
query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]]

# >>> START AH Changes <<<
query_layer, key_layer, value_layer = match_attn_matrices_for_parallel(query_layer, key_layer, value_layer)
(attention_mask,) = adjust_tensors_for_parallel(query_layer, attention_mask)
# >>> END AH Changes <<<

query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :])

# >>> START AH Changes <<<
orig_key_layer = key_layer # save this for relative attention
key_layer, value_layer, attention_mask = self.prefix_tuning(
key_layer, value_layer, hidden_states, attention_mask, False
)
(query_layer, orig_key_layer) = adjust_tensors_for_parallel(key_layer, query_layer, orig_key_layer)
# >>> END AH Changes <<<

rel_att: int = 0
# Take the dot product between "query" and "key" to get the raw attention scores.
Expand All @@ -131,9 +137,11 @@ def forward(

if self.relative_attention and rel_embeddings is not None and relative_pos is not None:
rel_embeddings = self.pos_dropout(rel_embeddings)
# >>> START AH Changes <<<
rel_att = self.disentangled_att_bias(
query_layer, orig_key_layer, relative_pos, rel_embeddings, scale_factor
)
# >>> END AH Changes <<<

if rel_att is not None:
attention_scores = attention_scores + rel_att
Expand Down
31 changes: 20 additions & 11 deletions src/adapters/models/deberta_v2/modeling_deberta_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

import torch
import torch.utils.checkpoint
from torch import nn

from transformers.models.deberta_v2.modeling_deberta_v2 import (
DebertaV2Output,
DebertaV2SelfOutput,
DisentangledSelfAttention,
XSoftmax,
scaled_size_sqrt,
)

from ...composition import adjust_tensors_for_parallel, match_attn_matrices_for_parallel
Expand Down Expand Up @@ -90,11 +91,15 @@ def forward(
The embedding of relative distances. It's a tensor of shape [\\(2 \\times
\\text{max_relative_positions}\\), *hidden_size*].
"""
# >>> START AH Changes <<<
attention_mask = prefix_attention_mask(attention_mask, dim=3, prefix_value=1) # type: ignore
attention_mask = prefix_attention_mask(attention_mask, dim=2, prefix_value=1) # type: ignore
# >>> END AH Changes <<<

if query_states is None:
query_states = hidden_states

# >>> START AH Changes <<<
query_layer = self.transpose_for_scores_extended(self.query_proj(query_states), self.num_attention_heads)
key_layer = self.transpose_for_scores_extended(self.key_proj(hidden_states), self.num_attention_heads)
value_layer = self.transpose_for_scores_extended(self.value_proj(hidden_states), self.num_attention_heads)
Expand All @@ -112,6 +117,7 @@ def forward(
key_layer = key_layer.contiguous().view(-1, key_layer.size(2), key_layer.size(-1))
value_layer = value_layer.contiguous().view(-1, value_layer.size(2), value_layer.size(-1))
orig_key_layer = orig_key_layer.contiguous().view(-1, orig_key_layer.size(2), orig_key_layer.size(-1))
# >>> END AH Changes <<<

rel_att = None
# Take the dot product between "query" and "key" to get the raw attention scores.
Expand All @@ -120,25 +126,29 @@ def forward(
scale_factor += 1
if "p2c" in self.pos_att_type:
scale_factor += 1
scale = torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2)) / scale.to(dtype=query_layer.dtype)
scale = scaled_size_sqrt(query_layer, scale_factor)
attention_scores = torch.bmm(query_layer, key_layer.transpose(-1, -2) / scale.to(dtype=query_layer.dtype))
if self.relative_attention:
rel_embeddings = self.pos_dropout(rel_embeddings)
# >>> START AH Changes <<<
rel_att = self.disentangled_attention_bias(
query_layer, orig_key_layer, relative_pos, rel_embeddings, scale_factor
)
# >>> END AH Changes <<<

if rel_att is not None:
rel_att_padded = torch.zeros_like(attention_scores)
rel_att_padded[:, :, -rel_att.size(2) :] = rel_att
attention_scores = attention_scores + rel_att_padded
attention_scores = attention_scores + rel_att
attention_scores = attention_scores
attention_scores = attention_scores.view(
-1, self.num_attention_heads, attention_scores.size(-2), attention_scores.size(-1)
)

attention_mask = attention_mask.bool()
attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
# bsz x height x length x dimension
attention_probs = XSoftmax.apply(attention_scores, attention_mask, -1)
attention_probs = nn.functional.softmax(attention_scores, dim=-1)
attention_probs.masked_fill(attention_mask, 0)

attention_probs = self.dropout(attention_probs)
context_layer = torch.bmm(
attention_probs.view(-1, attention_probs.size(-2), attention_probs.size(-1)), value_layer
Expand All @@ -150,7 +160,6 @@ def forward(
)
new_context_layer_shape = context_layer.size()[:-2] + (-1,)
context_layer = context_layer.view(new_context_layer_shape)
if output_attentions:
return (context_layer, attention_probs)
else:
return context_layer
if not output_attentions:
return (context_layer, None)
return (context_layer, attention_probs)

0 comments on commit c3ab05a

Please sign in to comment.