Skip to content

Commit d18efa3

Browse files
committed
Make checkstyle
1 parent 2a030bf commit d18efa3

2 files changed

Lines changed: 8 additions & 24 deletions

File tree

benchmark/scripts/benchmark_dyt.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -47,9 +47,7 @@ def fwd():
4747
return torch_compile_dyt(x)
4848

4949
if mode == "forward":
50-
ms_50, ms_20, ms_80 = triton.testing.do_bench(
51-
fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500
52-
)
50+
ms_50, ms_20, ms_80 = triton.testing.do_bench(fwd, quantiles=QUANTILES, grad_to_none=[x], rep=500)
5351
elif mode == "backward":
5452
y = fwd()
5553
ms_50, ms_20, ms_80 = triton.testing.do_bench(
@@ -64,9 +62,7 @@ def full():
6462
y = fwd()
6563
y.backward(dy)
6664

67-
ms_50, ms_20, ms_80 = triton.testing.do_bench(
68-
full, quantiles=QUANTILES, grad_to_none=[x], rep=500
69-
)
65+
ms_50, ms_20, ms_80 = triton.testing.do_bench(full, quantiles=QUANTILES, grad_to_none=[x], rep=500)
7066

7167
return SingleBenchmarkRunOutput(
7268
y_20=ms_20,

test/transformers/test_dyt.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,12 @@ def test_liger_dyt_correctness(B, T, hidden_size, init_alpha, dtype, atol, rtol)
5454
gamma = torch.randn(hidden_size, device=device, dtype=dtype)
5555
beta = torch.randn(hidden_size, device=device, dtype=dtype)
5656

57-
torch_dyt = (
58-
TorchDyT(hidden_size=hidden_size, init_alpha=init_alpha).to(device).to(dtype)
59-
)
57+
torch_dyt = TorchDyT(hidden_size=hidden_size, init_alpha=init_alpha).to(device).to(dtype)
6058
torch_dyt.alpha.data = alpha.clone()
6159
torch_dyt.gamma.data = gamma.clone()
6260
torch_dyt.beta.data = beta.clone()
6361

64-
liger_dyt = (
65-
LigerDyT(hidden_size=hidden_size, init_alpha=init_alpha).to(device).to(dtype)
66-
)
62+
liger_dyt = LigerDyT(hidden_size=hidden_size, init_alpha=init_alpha).to(device).to(dtype)
6763
liger_dyt.alpha.data = alpha.clone()
6864
liger_dyt.gamma.data = gamma.clone()
6965
liger_dyt.beta.data = beta.clone()
@@ -78,15 +74,9 @@ def test_liger_dyt_correctness(B, T, hidden_size, init_alpha, dtype, atol, rtol)
7874
liger_output.backward(grad_output)
7975

8076
assert_verbose_allclose(x1.grad, x2.grad, rtol=rtol, atol=atol)
81-
assert_verbose_allclose(
82-
torch_dyt.alpha.grad, liger_dyt.alpha.grad, rtol=rtol, atol=atol
83-
)
84-
assert_verbose_allclose(
85-
torch_dyt.gamma.grad, liger_dyt.gamma.grad, rtol=rtol, atol=atol
86-
)
87-
assert_verbose_allclose(
88-
torch_dyt.beta.grad, liger_dyt.beta.grad, rtol=rtol, atol=atol
89-
)
77+
assert_verbose_allclose(torch_dyt.alpha.grad, liger_dyt.alpha.grad, rtol=rtol, atol=atol)
78+
assert_verbose_allclose(torch_dyt.gamma.grad, liger_dyt.gamma.grad, rtol=rtol, atol=atol)
79+
assert_verbose_allclose(torch_dyt.beta.grad, liger_dyt.beta.grad, rtol=rtol, atol=atol)
9080

9181

9282
@pytest.mark.parametrize(
@@ -108,9 +98,7 @@ def test_liger_dyt_correctness(B, T, hidden_size, init_alpha, dtype, atol, rtol)
10898
torch.bfloat16,
10999
1e-8,
110100
5e-2,
111-
marks=pytest.mark.skipif(
112-
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
113-
),
101+
marks=pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
114102
),
115103
],
116104
)

0 commit comments

Comments
 (0)