diff --git a/tests/torchtune/modules/loss/test_kd_losses.py b/tests/torchtune/modules/loss/test_kd_losses.py index 6903f696e8..d33409fb44 100644 --- a/tests/torchtune/modules/loss/test_kd_losses.py +++ b/tests/torchtune/modules/loss/test_kd_losses.py @@ -17,7 +17,7 @@ def random(): class TestForwardKLWithChunkedOutputLoss: - def test_forward_kl_loss(self): + def setup_forward_kl_loss(self, ignore_all_tokens: bool = False): # Create a sample input and label ignore_index = -100 batch_size = 3 @@ -33,7 +33,10 @@ def test_forward_kl_loss(self): # add random ignore index to random tokens in the label random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens)) - labels[random_indices < num_tokens // 5] = ignore_index + if ignore_all_tokens: + labels[:] = ignore_index + else: + labels[random_indices < num_tokens // 5] = ignore_index # chunked FKL chunked_fkl_loss = ForwardKLWithChunkedOutputLoss( @@ -51,6 +54,29 @@ def test_forward_kl_loss(self): teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1)) labels = labels.reshape(-1) standard_loss = fkl_loss(logits, teacher_logits, labels) + return chunked_loss, standard_loss + + def test_forward_kl_loss(self): + + chunked_loss, standard_loss = self.setup_forward_kl_loss( + ignore_all_tokens=False + ) + + # Assert + assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2) + + def test_forward_kl_loss_zero_masks(self): + + # set all masks to zero + chunked_loss, standard_loss = self.setup_forward_kl_loss(ignore_all_tokens=True) + + # Assert + assert_expected( + chunked_loss, + torch.tensor(0.0, device=chunked_loss.device), + rtol=1e-2, + atol=1e-2, + ) # Assert assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2) diff --git a/torchtune/modules/loss/kd_losses.py b/torchtune/modules/loss/kd_losses.py index 53875dc813..820fa9156b 100644 --- a/torchtune/modules/loss/kd_losses.py +++ b/torchtune/modules/loss/kd_losses.py @@ -54,7 +54,9 @@ def forward( mask = (labels != self.ignore_index).int() if not normalize: return -torch.sum(x * mask.view(-1), dim=0) - if torch.sum(mask.view(-1), dim=0) == 0: + + sum_masks = torch.sum(mask.view(-1), dim=0) + if sum_masks == 0: return torch.tensor(0.0, device=x.device) return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0) @@ -137,4 +139,8 @@ def forward( student_chunk, teacher_chunk, label_chunk, normalize=False ) + sum_masks = torch.sum(mask.view(-1), dim=0) + if sum_masks == 0: + return torch.tensor(0.0, device=student_logits[0].device) + return total_fkl_loss / torch.sum(mask.view(-1), dim=0)