7
7
import pytest
8
8
import torch
9
9
from tests .test_utils import assert_expected
10
- from torchtune .modules .loss import ForwardKLLoss , ForwardKLWithChunkedOutputLoss , ReverseKLLoss , ReverseKLWithChunkedOutputLoss , SymmetricKLLoss , SymmetricKLWithChunkedOutputLoss
10
+ from torchtune .modules .loss import (
11
+ ForwardKLLoss ,
12
+ ForwardKLWithChunkedOutputLoss ,
13
+ ReverseKLLoss ,
14
+ ReverseKLWithChunkedOutputLoss ,
15
+ SymmetricKLLoss ,
16
+ SymmetricKLWithChunkedOutputLoss ,
17
+ )
11
18
from torchtune .training .seed import set_seed
12
19
13
20
@@ -115,6 +122,7 @@ def test_forward_kl_loss_expected(self):
115
122
assert_expected (chunked_loss , expected_loss , rtol = 1e-2 , atol = 1e-2 )
116
123
assert_expected (standard_loss , expected_loss , rtol = 1e-2 , atol = 1e-2 )
117
124
125
+
118
126
class TestReverseKLWithChunkedOutputLoss :
119
127
def test_reverse_kl_loss (self ):
120
128
# Create a sample input and label
@@ -214,6 +222,7 @@ def test_reverse_kl_loss_expected(self):
214
222
assert_expected (chunked_loss , expected_loss , rtol = 1e-2 , atol = 1e-2 )
215
223
assert_expected (standard_loss , expected_loss , rtol = 1e-2 , atol = 1e-2 )
216
224
225
+
217
226
class TestSymmetricKLWithChunkedOutputLoss :
218
227
def test_symmetric_kl_loss (self ):
219
228
# Create a sample input and label
@@ -311,4 +320,4 @@ def test_symmetric_kl_loss_expected(self):
311
320
312
321
# assert
313
322
assert_expected (chunked_loss , expected_loss , rtol = 1e-2 , atol = 1e-2 )
314
- assert_expected (standard_loss , expected_loss , rtol = 1e-2 , atol = 1e-2 )
323
+ assert_expected (standard_loss , expected_loss , rtol = 1e-2 , atol = 1e-2 )
0 commit comments