Skip to content

Commit bed5bc9

Browse files
authored
[None][chore] Wrap the swiglu into custom op to avoid redundant device copy. (NVIDIA#7021)
A redundant D2D copy is observed when enabling torch.compile for the Llama model due to the swiglu triton kernel, which brings perf overhead. Use a custom op to wrap the swiglu op to avoid this overhead. Signed-off-by: Yukun He <23156053+hyukn@users.noreply.github.com>
1 parent 82bd187 commit bed5bc9

2 files changed

Lines changed: 53 additions & 34 deletions

File tree

tensorrt_llm/_torch/custom_ops/torch_custom_ops.py

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from functools import lru_cache
2-
from typing import List, Optional, Tuple
2+
from typing import List, Mapping, Optional, Tuple
33

44
import torch
5+
import triton # type: ignore[import]
56

67
import tensorrt_llm.quantization.utils.fp4_utils as fp4_utils
78
import tensorrt_llm.quantization.utils.fp8_utils as fp8_utils
@@ -11,6 +12,7 @@
1112
from ..autotuner import (AutoTuner, ConstraintSpec, DynamicTensorSpec,
1213
OptimizationProfile, TunableRunner, TuningConfig)
1314
from ..modules.multi_stream_utils import do_multi_stream
15+
from ..modules.swiglu import silu_and_mul_kernel
1416
from ..utils import (fp4_scale_infer_shape,
1517
get_last_power_of_2_num_tokens_buckets,
1618
last_positive_power_of_2)
@@ -989,6 +991,50 @@ def _(
989991
return input.new_empty((input.size(0), weight.size(0)), dtype=output_dtype)
990992

991993

994+
@torch.library.custom_op("trtllm::silu_and_mul", mutates_args=())
995+
def silu_and_mul(x: torch.Tensor,
996+
scale: Optional[torch.Tensor] = None,
997+
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
998+
b, n = x.shape
999+
1000+
assert n % 2 == 0
1001+
d = n // 2
1002+
1003+
o_dtype = dtype or x.dtype
1004+
o = torch.empty((b, d), dtype=o_dtype, device=x.device)
1005+
1006+
def grid(meta: Mapping[str, int]) -> tuple[int, int]:
1007+
return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))
1008+
1009+
silu_and_mul_kernel[grid](
1010+
o_ptr=o,
1011+
o_stride=o.stride(0),
1012+
o_scale_ptr=scale,
1013+
x_ptr=x,
1014+
x_stride=x.stride(0),
1015+
d=d,
1016+
BLOCK_SIZE=1024,
1017+
HAS_O_SCALE=scale is not None,
1018+
)
1019+
1020+
return o
1021+
1022+
1023+
@silu_and_mul.register_fake
1024+
def _(
1025+
x: torch.Tensor,
1026+
scale: Optional[torch.Tensor] = None,
1027+
dtype: Optional[torch.dtype] = None,
1028+
) -> torch.Tensor:
1029+
b, n = x.shape
1030+
1031+
assert n % 2 == 0
1032+
d = n // 2
1033+
1034+
o_dtype = dtype or x.dtype
1035+
return x.new_empty((b, d), dtype=o_dtype)
1036+
1037+
9921038
def get_event(event_idx: int):
9931039
from ..utils import get_model_extra_attrs
9941040
extra_attrs = get_model_extra_attrs()

tensorrt_llm/_torch/modules/swiglu.py

Lines changed: 6 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,3 @@
1-
from collections.abc import Mapping
2-
from typing import Optional
3-
41
import torch
52
import triton # type: ignore[import]
63
import triton.language as tl # type: ignore[import]
@@ -51,37 +48,13 @@ def silu_and_mul_kernel(o_ptr, o_stride, o_scale_ptr, x_ptr, x_stride, d,
5148
tl.store(o_row_ptr + offsets, result, mask=mask)
5249

5350

54-
def silu_and_mul(x: torch.Tensor,
55-
scale: Optional[torch.Tensor] = None,
56-
dtype: Optional[torch.dtype] = None) -> torch.Tensor:
57-
b, n = x.shape
58-
59-
assert n % 2 == 0
60-
d = n // 2
61-
62-
o_dtype = dtype or x.dtype
63-
o = torch.empty((b, d), dtype=o_dtype, device=x.device)
64-
65-
def grid(meta: Mapping[str, int]) -> tuple[int, int]:
66-
return (b, triton.cdiv(d, meta["BLOCK_SIZE"]))
67-
68-
silu_and_mul_kernel[grid](
69-
o_ptr=o,
70-
o_stride=o.stride(0),
71-
o_scale_ptr=scale,
72-
x_ptr=x,
73-
x_stride=x.stride(0),
74-
d=d,
75-
BLOCK_SIZE=1024,
76-
HAS_O_SCALE=scale is not None,
77-
)
78-
79-
return o
80-
81-
8251
def swiglu(x, quant_scale: torch.Tensor = None, quant_type=None):
8352
if quant_scale is not None:
8453
assert quant_type is not None
85-
return silu_and_mul(x, scale=quant_scale, dtype=quant_type)
54+
return torch.ops.trtllm.silu_and_mul(
55+
x,
56+
scale=quant_scale,
57+
dtype=quant_type,
58+
)
8659

87-
return silu_and_mul(x)
60+
return torch.ops.trtllm.silu_and_mul(x)

0 commit comments

Comments
 (0)