-
Notifications
You must be signed in to change notification settings - Fork 496
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add a "division by zero" check in chunked loss handling in kd_losses.py #2239
Changes from 4 commits
e01c7e0
383ec2e
3760196
6f0262d
dc63428
f5befaf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -54,7 +54,9 @@ def forward( | |
mask = (labels != self.ignore_index).int() | ||
if not normalize: | ||
return -torch.sum(x * mask.view(-1), dim=0) | ||
if torch.sum(mask.view(-1), dim=0) == 0: | ||
|
||
sum_masks = torch.sum(mask.view(-1), dim=0) | ||
if sum_masks == 0: | ||
return torch.tensor(0.0, device=x.device) | ||
return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0) | ||
|
||
|
@@ -137,4 +139,8 @@ def forward( | |
student_chunk, teacher_chunk, label_chunk, normalize=False | ||
) | ||
|
||
sum_masks = torch.sum(mask.view(-1), dim=0) | ||
if sum_masks == 0: | ||
return torch.tensor(0.0, device=student_logits[0].device) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In the discussion on #2094, it seems like if we hit this point, there's probably something wrong with the training data. In that case, should we error out here? Or drop a warning? Or let people shoot themselves in the foot? cc @lindawangg and @ebsmothers There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will wait for inputs, and will update #2094 after this PR is resolved. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any inputs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for the bump here. Personally I think this is the right way to handle it. Yes, we could error/raise a warning, but honestly this behavior should be pretty obvious in a loss curve as it will clearly just drop to zero on any such iterations (plus I think zero loss would technically be the "correct" loss value if every token is being ignored). No need to overdo it on handling this edge case here imo |
||
|
||
return total_fkl_loss / torch.sum(mask.view(-1), dim=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: I would do something like "ignore_all_tokens" instead of "set_all_masks_zero" as I think it's clearer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds better, and updated.