Skip to content

Commit c382d81

Browse files
committed
fix style
1 parent 7316b0d commit c382d81

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

src/liger_kernel/ops/cross_entropy.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from liger_kernel.ops.utils import compare_version
1010
from liger_kernel.ops.utils import element_mul_kernel
1111
from liger_kernel.ops.utils import is_hip
12+
from liger_kernel.utils import infer_device
1213

1314
if compare_version("triton", operator.ge, "3.0.0"):
1415
try:
@@ -258,8 +259,7 @@ def liger_cross_entropy_kernel(
258259
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
259260
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
260261
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
261-
from liger_kernel.utils import infer_device
262-
MAX_FUSED_SIZE = 4096 if infer_device() == 'xpu' else 65536 // 2 # the best size we found by manually tuning
262+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
263263

264264

265265
def cross_entropy_forward(

0 commit comments

Comments
 (0)