-
Notifications
You must be signed in to change notification settings - Fork 28.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Remove redundant code from T5 encoder mask creation #27216
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for fixing this!
For the code quality checks could you:
- Run
pip install -e .[quality]
to install the required black and ruff versions - Run
make fix-copies
andmake fixup
and then push the applied changes to the branch
Done :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again! One final comment. After that I think we're good to merge :)
if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None: | ||
encoder_seq_length = encoder_hidden_states.shape[1] | ||
encoder_attention_mask = torch.ones( | ||
batch_size, encoder_seq_length, device=inputs_embeds.device, dtype=torch.long |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just one last thing - we need to add dtype=torch.long
on L1044 to make sure the mask is created with the same dtype as before.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @amyeroberts, thanks for catching this and for your comments. Worked on this very quickly yesterday (i.e., sloppy PR 🙃). Went over it again now and added typecasting where appropriate
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Almost there - just one small change. Thanks for your patience and iterating on this!
# initialize past_key_values with `None` if past does not exist | ||
if past_key_values is None: | ||
past_key_values = [None] * len(self.block) | ||
|
||
if attention_mask is None: | ||
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device, dtype=torch.long) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry - I wasn't clear enough on may last comment. For the creation of attention_mask
and encoder_attention_mask
we need both to have the exact same instantiation as before to avoid any unintended changes to current behaviour. So for attention_mask
we don't want to have dtype=torch.long
.
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device, dtype=torch.long) | |
attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No problem at all :) Done!
I agree that it makes sense to preserve the previous behaviour. However, I wonder what should be the recommended dtype. The default that comes out of the tokenizer (with return_tensors="pt"
) is float. Anyway, thanks for your guidance!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The returned dtype from the tokenizer should be torch.int64
i.e. long. If it's float, then I'd be interested to know the checkpoint as that shouldn't be the case!
This means of course that this logic is a bit inconsistent. However, most of the time the models are called with inputs prepared by the tokenizer, so this code path is never touched.
Unfortunately, even though setting dtype=torch.long
is more correct, part of managing the codebase is handling backwards compatibility. I actually don't think this change would cause issues, however it's very hard to know especially for popular models like T5 and tracking numerical issues is a pain 😅
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fully understand and agree that backward compatibility is essential!
Even though feeding input_ids
, labels
, and attention_mask
to the forward method of T5 is usually enough, sometimes it's essential to have more control. For example, to compute the likelihood of continuations, i.e. -logp(continuation | context)
for evaluation purposes (e.g., the OpenLLM leaderboard). That's why I started looking into this code :)
The backstory is that I was trying to understand what the authors of the T-Few paper were doing. To compute logp(continuation | context)
for T5, it is essential to be able to cache the encoder outputs (i.e., the encoded context) and pass the continuations through the model to collect the loglikelihoods. In this case, you need to pass decoder_input_ids
yourself. tl;dr: it felt strange to me how the authors computed the decoder_attention_mask
(see r-three/t-few#32) and to try to understand their code I started to investigate what was done by default in T5 :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. |
* remove redundant code * update * add typecasting * make `attention_mask` float again
What does this PR do?
Removes redundant code in the creation of the encoder attention mask for T5 as discussed in #27211.
@amyeroberts