From d5576e22aa8513f7426a8b39555683b9213ddc5b Mon Sep 17 00:00:00 2001 From: Insop Song Date: Mon, 13 Jan 2025 11:18:14 -0800 Subject: [PATCH] lint fix --- tests/torchtune/modules/loss/test_kd_losses.py | 13 +++++++++++-- torchtune/modules/loss/__init__.py | 9 ++++++++- 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/tests/torchtune/modules/loss/test_kd_losses.py b/tests/torchtune/modules/loss/test_kd_losses.py index bb52738e0e..ddfdd4012c 100644 --- a/tests/torchtune/modules/loss/test_kd_losses.py +++ b/tests/torchtune/modules/loss/test_kd_losses.py @@ -7,7 +7,14 @@ import pytest import torch from tests.test_utils import assert_expected -from torchtune.modules.loss import ForwardKLLoss, ForwardKLWithChunkedOutputLoss, ReverseKLLoss, ReverseKLWithChunkedOutputLoss, SymmetricKLLoss, SymmetricKLWithChunkedOutputLoss +from torchtune.modules.loss import ( + ForwardKLLoss, + ForwardKLWithChunkedOutputLoss, + ReverseKLLoss, + ReverseKLWithChunkedOutputLoss, + SymmetricKLLoss, + SymmetricKLWithChunkedOutputLoss, +) from torchtune.training.seed import set_seed @@ -115,6 +122,7 @@ def test_forward_kl_loss_expected(self): assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2) assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2) + class TestReverseKLWithChunkedOutputLoss: def test_reverse_kl_loss(self): # Create a sample input and label @@ -214,6 +222,7 @@ def test_reverse_kl_loss_expected(self): assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2) assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2) + class TestSymmetricKLWithChunkedOutputLoss: def test_symmetric_kl_loss(self): # Create a sample input and label @@ -311,4 +320,4 @@ def test_symmetric_kl_loss_expected(self): # assert assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2) - assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2) \ No newline at end of file + assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2) diff --git a/torchtune/modules/loss/__init__.py b/torchtune/modules/loss/__init__.py index 4baf3849eb..194c95ae95 100644 --- a/torchtune/modules/loss/__init__.py +++ b/torchtune/modules/loss/__init__.py @@ -5,7 +5,14 @@ # LICENSE file in the root directory of this source tree. from .ce_chunked_output_loss import CEWithChunkedOutputLoss -from .kd_losses import ForwardKLLoss, ForwardKLWithChunkedOutputLoss, ReverseKLLoss, ReverseKLWithChunkedOutputLoss, SymmetricKLLoss, SymmetricKLWithChunkedOutputLoss +from .kd_losses import ( + ForwardKLLoss, + ForwardKLWithChunkedOutputLoss, + ReverseKLLoss, + ReverseKLWithChunkedOutputLoss, + SymmetricKLLoss, + SymmetricKLWithChunkedOutputLoss, +) __all__ = [ "CEWithChunkedOutputLoss",