Skip to content

Commit d5576e2

Browse files
committed
lint fix
1 parent a4c818c commit d5576e2

File tree

2 files changed

+19
-3
lines changed

2 files changed

+19
-3
lines changed

tests/torchtune/modules/loss/test_kd_losses.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,14 @@
77
import pytest
88
import torch
99
from tests.test_utils import assert_expected
10-
from torchtune.modules.loss import ForwardKLLoss, ForwardKLWithChunkedOutputLoss, ReverseKLLoss, ReverseKLWithChunkedOutputLoss, SymmetricKLLoss, SymmetricKLWithChunkedOutputLoss
10+
from torchtune.modules.loss import (
11+
ForwardKLLoss,
12+
ForwardKLWithChunkedOutputLoss,
13+
ReverseKLLoss,
14+
ReverseKLWithChunkedOutputLoss,
15+
SymmetricKLLoss,
16+
SymmetricKLWithChunkedOutputLoss,
17+
)
1118
from torchtune.training.seed import set_seed
1219

1320

@@ -115,6 +122,7 @@ def test_forward_kl_loss_expected(self):
115122
assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2)
116123
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2)
117124

125+
118126
class TestReverseKLWithChunkedOutputLoss:
119127
def test_reverse_kl_loss(self):
120128
# Create a sample input and label
@@ -214,6 +222,7 @@ def test_reverse_kl_loss_expected(self):
214222
assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2)
215223
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2)
216224

225+
217226
class TestSymmetricKLWithChunkedOutputLoss:
218227
def test_symmetric_kl_loss(self):
219228
# Create a sample input and label
@@ -311,4 +320,4 @@ def test_symmetric_kl_loss_expected(self):
311320

312321
# assert
313322
assert_expected(chunked_loss, expected_loss, rtol=1e-2, atol=1e-2)
314-
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2)
323+
assert_expected(standard_loss, expected_loss, rtol=1e-2, atol=1e-2)

torchtune/modules/loss/__init__.py

+8-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,14 @@
55
# LICENSE file in the root directory of this source tree.
66

77
from .ce_chunked_output_loss import CEWithChunkedOutputLoss
8-
from .kd_losses import ForwardKLLoss, ForwardKLWithChunkedOutputLoss, ReverseKLLoss, ReverseKLWithChunkedOutputLoss, SymmetricKLLoss, SymmetricKLWithChunkedOutputLoss
8+
from .kd_losses import (
9+
ForwardKLLoss,
10+
ForwardKLWithChunkedOutputLoss,
11+
ReverseKLLoss,
12+
ReverseKLWithChunkedOutputLoss,
13+
SymmetricKLLoss,
14+
SymmetricKLWithChunkedOutputLoss,
15+
)
916

1017
__all__ = [
1118
"CEWithChunkedOutputLoss",

0 commit comments

Comments
 (0)