Skip to content

Commit

Permalink
lint fix
Browse files Browse the repository at this point in the history
  • Loading branch information
insop committed Jan 13, 2025
1 parent a4c818c commit d5576e2
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 3 deletions.
13 changes: 11 additions & 2 deletions tests/torchtune/modules/loss/test_kd_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@
import pytest
import torch
from tests.test_utils import assert_expected
from torchtune.modules.loss import ForwardKLLoss, ForwardKLWithChunkedOutputLoss, ReverseKLLoss, ReverseKLWithChunkedOutputLoss, SymmetricKLLoss, SymmetricKLWithChunkedOutputLoss
from torchtune.modules.loss import (
ForwardKLLoss,
ForwardKLWithChunkedOutputLoss,
ReverseKLLoss,
ReverseKLWithChunkedOutputLoss,
SymmetricKLLoss,
SymmetricKLWithChunkedOutputLoss,
)
from torchtune.training.seed import set_seed


Expand Down Expand Up @@ -115,6 +122,7 @@ def test_forward_kl_loss_expected(self):
assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2)
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2)


class TestReverseKLWithChunkedOutputLoss:
def test_reverse_kl_loss(self):
# Create a sample input and label
Expand Down Expand Up @@ -214,6 +222,7 @@ def test_reverse_kl_loss_expected(self):
assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2)
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2)


class TestSymmetricKLWithChunkedOutputLoss:
def test_symmetric_kl_loss(self):
# Create a sample input and label
Expand Down Expand Up @@ -311,4 +320,4 @@ def test_symmetric_kl_loss_expected(self):

# assert
assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2)
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2)
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2)
9 changes: 8 additions & 1 deletion torchtune/modules/loss/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
# LICENSE file in the root directory of this source tree.

from .ce_chunked_output_loss import CEWithChunkedOutputLoss
from .kd_losses import ForwardKLLoss, ForwardKLWithChunkedOutputLoss, ReverseKLLoss, ReverseKLWithChunkedOutputLoss, SymmetricKLLoss, SymmetricKLWithChunkedOutputLoss
from .kd_losses import (
ForwardKLLoss,
ForwardKLWithChunkedOutputLoss,
ReverseKLLoss,
ReverseKLWithChunkedOutputLoss,
SymmetricKLLoss,
SymmetricKLWithChunkedOutputLoss,
)

__all__ = [
"CEWithChunkedOutputLoss",
Expand Down

0 comments on commit d5576e2

Please sign in to comment.