From e01c7e08df8097125924f2fc045f3910f12b2483 Mon Sep 17 00:00:00 2001 From: Insop Song Date: Tue, 7 Jan 2025 22:39:43 -0800 Subject: [PATCH 1/4] add a "division by zero" check in chunked loss handling in kd_losses.py #2225 --- torchtune/modules/loss/kd_losses.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/torchtune/modules/loss/kd_losses.py b/torchtune/modules/loss/kd_losses.py index 53875dc813..74d751a308 100644 --- a/torchtune/modules/loss/kd_losses.py +++ b/torchtune/modules/loss/kd_losses.py @@ -137,4 +137,7 @@ def forward( student_chunk, teacher_chunk, label_chunk, normalize=False ) + if torch.sum(mask.view(-1), dim=0) == 0: + return torch.tensor(0.0, device=student_logits[0].device) + return total_fkl_loss / torch.sum(mask.view(-1), dim=0) From 383ec2e694818bc9426bc6809022d613e422e3d4 Mon Sep 17 00:00:00 2001 From: Insop Song Date: Wed, 8 Jan 2025 19:49:31 -0800 Subject: [PATCH 2/4] Add test for added code - test kd_loss when all mask is zero --- .../torchtune/modules/loss/test_kd_losses.py | 32 +++++++++++++++++-- 1 file changed, 30 insertions(+), 2 deletions(-) diff --git a/tests/torchtune/modules/loss/test_kd_losses.py b/tests/torchtune/modules/loss/test_kd_losses.py index 6903f696e8..330a8564dc 100644 --- a/tests/torchtune/modules/loss/test_kd_losses.py +++ b/tests/torchtune/modules/loss/test_kd_losses.py @@ -17,7 +17,7 @@ def random(): class TestForwardKLWithChunkedOutputLoss: - def test_forward_kl_loss(self): + def setup_forward_kl_loss(self, set_all_masks_zero: bool = False): # Create a sample input and label ignore_index = -100 batch_size = 3 @@ -33,7 +33,10 @@ def test_forward_kl_loss(self): # add random ignore index to random tokens in the label random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens)) - labels[random_indices < num_tokens // 5] = ignore_index + if set_all_masks_zero: + labels[:] = ignore_index + else: + labels[random_indices < num_tokens // 5] = ignore_index # chunked FKL chunked_fkl_loss = ForwardKLWithChunkedOutputLoss( @@ -51,6 +54,31 @@ def test_forward_kl_loss(self): teacher_logits = teacher_logits.reshape(-1, teacher_logits.size(-1)) labels = labels.reshape(-1) standard_loss = fkl_loss(logits, teacher_logits, labels) + return chunked_loss, standard_loss + + def test_forward_kl_loss(self): + + chunked_loss, standard_loss = self.setup_forward_kl_loss( + set_all_masks_zero=False + ) + + # Assert + assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2) + + def test_forward_kl_loss_zero_masks(self): + + # set all masks to zero + chunked_loss, standard_loss = self.setup_forward_kl_loss( + set_all_masks_zero=True + ) + + # Assert + assert_expected( + chunked_loss, + torch.tensor(0.0, device=chunked_loss.device), + rtol=1e-2, + atol=1e-2, + ) # Assert assert_expected(chunked_loss, standard_loss, rtol=1e-2, atol=1e-2) From 3760196e6a422a9c531b9763cf441a9d93ee4f3c Mon Sep 17 00:00:00 2001 From: Insop Song Date: Thu, 9 Jan 2025 16:37:12 -0800 Subject: [PATCH 3/4] Pull out to a variable, review feedback --- torchtune/modules/loss/kd_losses.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/torchtune/modules/loss/kd_losses.py b/torchtune/modules/loss/kd_losses.py index 74d751a308..820fa9156b 100644 --- a/torchtune/modules/loss/kd_losses.py +++ b/torchtune/modules/loss/kd_losses.py @@ -54,7 +54,9 @@ def forward( mask = (labels != self.ignore_index).int() if not normalize: return -torch.sum(x * mask.view(-1), dim=0) - if torch.sum(mask.view(-1), dim=0) == 0: + + sum_masks = torch.sum(mask.view(-1), dim=0) + if sum_masks == 0: return torch.tensor(0.0, device=x.device) return -torch.sum(x * mask.view(-1), dim=0) / torch.sum(mask.view(-1), dim=0) @@ -137,7 +139,8 @@ def forward( student_chunk, teacher_chunk, label_chunk, normalize=False ) - if torch.sum(mask.view(-1), dim=0) == 0: + sum_masks = torch.sum(mask.view(-1), dim=0) + if sum_masks == 0: return torch.tensor(0.0, device=student_logits[0].device) return total_fkl_loss / torch.sum(mask.view(-1), dim=0) From dc634282e9b81d2eb8682ccf5add8741f23b96ad Mon Sep 17 00:00:00 2001 From: Insop Song Date: Fri, 10 Jan 2025 17:02:18 -0800 Subject: [PATCH 4/4] review feedback --- tests/torchtune/modules/loss/test_kd_losses.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/tests/torchtune/modules/loss/test_kd_losses.py b/tests/torchtune/modules/loss/test_kd_losses.py index 330a8564dc..d33409fb44 100644 --- a/tests/torchtune/modules/loss/test_kd_losses.py +++ b/tests/torchtune/modules/loss/test_kd_losses.py @@ -17,7 +17,7 @@ def random(): class TestForwardKLWithChunkedOutputLoss: - def setup_forward_kl_loss(self, set_all_masks_zero: bool = False): + def setup_forward_kl_loss(self, ignore_all_tokens: bool = False): # Create a sample input and label ignore_index = -100 batch_size = 3 @@ -33,7 +33,7 @@ def setup_forward_kl_loss(self, set_all_masks_zero: bool = False): # add random ignore index to random tokens in the label random_indices = torch.randint(0, num_tokens, (batch_size, num_tokens)) - if set_all_masks_zero: + if ignore_all_tokens: labels[:] = ignore_index else: labels[random_indices < num_tokens // 5] = ignore_index @@ -59,7 +59,7 @@ def setup_forward_kl_loss(self, set_all_masks_zero: bool = False): def test_forward_kl_loss(self): chunked_loss, standard_loss = self.setup_forward_kl_loss( - set_all_masks_zero=False + ignore_all_tokens=False ) # Assert @@ -68,9 +68,7 @@ def test_forward_kl_loss(self): def test_forward_kl_loss_zero_masks(self): # set all masks to zero - chunked_loss, standard_loss = self.setup_forward_kl_loss( - set_all_masks_zero=True - ) + chunked_loss, standard_loss = self.setup_forward_kl_loss(ignore_all_tokens=True) # Assert assert_expected(