@@ -54,16 +54,12 @@ def test_liger_dyt_correctness(B, T, hidden_size, init_alpha, dtype, atol, rtol)
5454 gamma = torch .randn (hidden_size , device = device , dtype = dtype )
5555 beta = torch .randn (hidden_size , device = device , dtype = dtype )
5656
57- torch_dyt = (
58- TorchDyT (hidden_size = hidden_size , init_alpha = init_alpha ).to (device ).to (dtype )
59- )
57+ torch_dyt = TorchDyT (hidden_size = hidden_size , init_alpha = init_alpha ).to (device ).to (dtype )
6058 torch_dyt .alpha .data = alpha .clone ()
6159 torch_dyt .gamma .data = gamma .clone ()
6260 torch_dyt .beta .data = beta .clone ()
6361
64- liger_dyt = (
65- LigerDyT (hidden_size = hidden_size , init_alpha = init_alpha ).to (device ).to (dtype )
66- )
62+ liger_dyt = LigerDyT (hidden_size = hidden_size , init_alpha = init_alpha ).to (device ).to (dtype )
6763 liger_dyt .alpha .data = alpha .clone ()
6864 liger_dyt .gamma .data = gamma .clone ()
6965 liger_dyt .beta .data = beta .clone ()
@@ -78,15 +74,9 @@ def test_liger_dyt_correctness(B, T, hidden_size, init_alpha, dtype, atol, rtol)
7874 liger_output .backward (grad_output )
7975
8076 assert_verbose_allclose (x1 .grad , x2 .grad , rtol = rtol , atol = atol )
81- assert_verbose_allclose (
82- torch_dyt .alpha .grad , liger_dyt .alpha .grad , rtol = rtol , atol = atol
83- )
84- assert_verbose_allclose (
85- torch_dyt .gamma .grad , liger_dyt .gamma .grad , rtol = rtol , atol = atol
86- )
87- assert_verbose_allclose (
88- torch_dyt .beta .grad , liger_dyt .beta .grad , rtol = rtol , atol = atol
89- )
77+ assert_verbose_allclose (torch_dyt .alpha .grad , liger_dyt .alpha .grad , rtol = rtol , atol = atol )
78+ assert_verbose_allclose (torch_dyt .gamma .grad , liger_dyt .gamma .grad , rtol = rtol , atol = atol )
79+ assert_verbose_allclose (torch_dyt .beta .grad , liger_dyt .beta .grad , rtol = rtol , atol = atol )
9080
9181
9282@pytest .mark .parametrize (
@@ -108,9 +98,7 @@ def test_liger_dyt_correctness(B, T, hidden_size, init_alpha, dtype, atol, rtol)
10898 torch .bfloat16 ,
10999 1e-8 ,
110100 5e-2 ,
111- marks = pytest .mark .skipif (
112- not supports_bfloat16 (), reason = "bfloat16 not supported on this GPU"
113- ),
101+ marks = pytest .mark .skipif (not supports_bfloat16 (), reason = "bfloat16 not supported on this GPU" ),
114102 ),
115103 ],
116104)
0 commit comments