Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a "division by zero" check in chunked loss handling in kd_losses.py #2239

Merged
merged 6 commits into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions tests/torchtune/modules/loss/test_kd_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ def random():


class TestForwardKLWithChunkedOutputLoss:
def test_forward_kl_loss(self):
def setup_forward_kl_loss(self, set_all_masks_zero: bool = False):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I would do something like "ignore_all_tokens" instead of "set_all_masks_zero" as I think it's clearer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds better, and updated.

# Create a sample input and label
ignore_index = -100
batch_size = 3
Expand All @@ -33,7 +33,10 @@ def test_forward_kl_loss(self):

# add random ignore index to random tokens in the label
random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens))
labels[random_indices < num_tokens // 5] = ignore_index
if set_all_masks_zero:
labels[:] = ignore_index
else:
labels[random_indices < num_tokens // 5] = ignore_index

# chunked FKL
chunked_fkl_loss = ForwardKLWithChunkedOutputLoss(
Expand All @@ -51,6 +54,31 @@ def test_forward_kl_loss(self):
teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1))
labels = labels.reshape(-1)
standard_loss = fkl_loss(logits, teacher_logits, labels)
return chunked_loss, standard_loss

def test_forward_kl_loss(self):

chunked_loss, standard_loss = self.setup_forward_kl_loss(
set_all_masks_zero=False
)

# Assert
assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2)

def test_forward_kl_loss_zero_masks(self):

# set all masks to zero
chunked_loss, standard_loss = self.setup_forward_kl_loss(
set_all_masks_zero=True
)

# Assert
assert_expected(
chunked_loss,
torch.tensor(0.0, device=chunked_loss.device),
rtol=1e-2,
atol=1e-2,
)

# Assert
assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2)
Expand Down
8 changes: 7 additions & 1 deletion torchtune/modules/loss/kd_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ def forward(
mask = (labels != self.ignore_index).int()
if not normalize:
return -torch.sum(x * mask.view(-1), dim=0)
if torch.sum(mask.view(-1), dim=0) == 0:

sum_masks = torch.sum(mask.view(-1), dim=0)
if sum_masks == 0:
return torch.tensor(0.0, device=x.device)
return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0)

Expand Down Expand Up @@ -137,4 +139,8 @@ def forward(
student_chunk, teacher_chunk, label_chunk, normalize=False
)

sum_masks = torch.sum(mask.view(-1), dim=0)
if sum_masks == 0:
return torch.tensor(0.0, device=student_logits[0].device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the discussion on #2094, it seems like if we hit this point, there's probably something wrong with the training data. In that case, should we error out here? Or drop a warning? Or let people shoot themselves in the foot?

cc @lindawangg and @ebsmothers

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will wait for inputs, and will update #2094 after this PR is resolved.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any inputs?
@lindawangg , @ebsmothers

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the bump here. Personally I think this is the right way to handle it. Yes, we could error/raise a warning, but honestly this behavior should be pretty obvious in a loss curve as it will clearly just drop to zero on any such iterations (plus I think zero loss would technically be the "correct" loss value if every token is being ignored). No need to overdo it on handling this edge case here imo


return total_fkl_loss / torch.sum(mask.view(-1), dim=0)
Loading