forked from triton-lang/triton
-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathlanguage_extra.py
More file actions
90 lines (69 loc) · 2.98 KB
/
language_extra.py
File metadata and controls
90 lines (69 loc) · 2.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
################################################################################
#
# Copyright (c) 2025 ByteDance Ltd. and/or its affiliates
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files
# (the "Software"), to deal in the Software without restriction,
# including without limitation the rights to use, copy, modify, merge,
# publish, distribute, sublicense, and/or sell copies of the Software,
# and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
# EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
#
################################################################################
import triton
import triton.language as tl
from triton.language import core
def _str_to_gpu_shfl_mode(mode_str):
# The order of shfl modes is from (llvm-project/mlir/include/mlir/Dialect/GPU/IR/GPUOps.td)
ALL_SHFL_MODES = ["xor", "up", "down", "idx"]
if mode_str not in ALL_SHFL_MODES:
raise RuntimeError(f"unexpected gpu shuffle mode, expecte: {ALL_SHFL_MODES}, but got: {mode_str}")
return ALL_SHFL_MODES.index(mode_str)
@core.extern
def laneid(_semantic=None):
return core.tensor(_semantic.builder.create_laneid(), core.int32)
@core.extern
def tid(axis, _semantic=None):
assert axis <= 2 and axis >= 0
axis_to_xyz = ["x", "y", "z"]
calleeName = f"llvm.amdgcn.workitem.id.{axis_to_xyz[axis]}"
return core.extern_elementwise("", "", [], {
(): (calleeName, core.dtype("int32")),
}, is_pure=True, _semantic=_semantic)
@core.extern
def __syncthreads(_semantic=None):
return core.tensor(_semantic.builder.create_barrier(), core.void)
@core.extern
def __shfl_sync_with_mode_i32(
value,
offset,
mode: core.constexpr = "up",
width: int = 64,
_semantic=None,
):
shfl_mode = _str_to_gpu_shfl_mode(mode.value)
return core.tensor(_semantic.builder.create_warp_shuffle(value, offset, width, shfl_mode), value.dtype)
@triton.jit
def __shfl_sync_i32(value, laneid):
return __shfl_sync_with_mode_i32(value, laneid, "idx", 64)
@triton.jit
def __shfl_up_sync_i32(value, offset):
return __shfl_sync_with_mode_i32(value, offset, "up", 64)
@triton.jit
def __shfl_down_sync_i32(value, offset):
return __shfl_sync_with_mode_i32(value, offset, "down", 64)
@triton.jit
def __shfl_xor_sync_i32(value, offset):
return __shfl_sync_with_mode_i32(value, offset, "xor", 64)