Skip to content

Commit 4854fd4

Browse files
authored
add xpu tuning to FLJSD (#647)
## Summary Tuning on XPU: In fused linear JSD, if device is xpu, set MAX_FUSED_SIZE to 4096 instead of default 65536 // 2. This gives slightly better performance on xpu. Very similar to #645 ## Testing Done - Hardware Type: Intel(R) Data Center GPU Max 1550 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent bebe030 commit 4854fd4

1 file changed

Lines changed: 2 additions & 1 deletion

File tree

src/liger_kernel/ops/fused_linear_jsd.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,12 @@
88
from liger_kernel.ops.utils import amp_custom_fwd
99
from liger_kernel.ops.utils import element_mul_kernel
1010
from liger_kernel.ops.utils import is_hip
11+
from liger_kernel.utils import infer_device
1112

1213
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
1314
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
1415
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
15-
MAX_FUSED_SIZE = 65536 // 2
16+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2
1617

1718

1819
def fused_linear_jsd_forward(

0 commit comments

Comments
 (0)