Skip to content

Commit bebe030

Browse files
authored
add xpu tuning to CE (#645)
## Summary Tuning on XPU: In cross-entropy, if device is `xpu`, set `MAX_FUSED_SIZE` to `4096` instead of default `65536 // 2`. This gives slightly better performance on xpu. ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: Intel(R) Data Center GPU Max 1550 - [x] run `make test` to ensure correctness - [x] run `make checkstyle` to ensure code style - [x] run `make test-convergence` to ensure convergence
1 parent bf2c67b commit bebe030

1 file changed

Lines changed: 3 additions & 2 deletions

File tree

src/liger_kernel/ops/cross_entropy.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from liger_kernel.ops.utils import compare_version
1010
from liger_kernel.ops.utils import element_mul_kernel
1111
from liger_kernel.ops.utils import is_hip
12+
from liger_kernel.utils import infer_device
1213

1314
if compare_version("triton", operator.ge, "3.0.0"):
1415
try:
@@ -59,7 +60,7 @@ def liger_cross_entropy_kernel(
5960
z_loss_ptr: Pointer to tensor to store the z loss. No operation if RETURN_Z_LOSS is 0.
6061
loss_stride (int): The stride of the loss tensor.
6162
n_cols (int): The number of columns in the input tensor.
62-
n_non_ignore (flaot): The number of non-ignored elements in the batch.
63+
n_non_ignore (float): The number of non-ignored elements in the batch.
6364
sum_non_ignore_weight (float): The sum of non-ignored target's weights in the batch.
6465
weight_sum (float): The sum of weight tensor.
6566
ignore_index (int): The index to ignore in the target.
@@ -258,7 +259,7 @@ def liger_cross_entropy_kernel(
258259
# The hard limit of TRITON_MAX_TENSOR_NUMEL is 1048576 https://github.com/triton-lang/triton/blob/ba42a5c68fd0505f8c42f4202d53be0f8d9a5fe0/python/triton/language/core.py#L19
259260
# However, setting limit as 65536 as in LayerNorm tutorial is faster because of less register spilling
260261
# The optimal maximum block size depends on your hardware, your kernel, and your dtype
261-
MAX_FUSED_SIZE = 65536 // 2 # the best size we found by manually tuning
262+
MAX_FUSED_SIZE = 4096 if infer_device() == "xpu" else 65536 // 2 # the best size we found by manually tuning
262263

263264

264265
def cross_entropy_forward(

0 commit comments

Comments
 (0)