Skip to content

Commit dad63bd

Browse files
authored
[TRITON_KERNELS] some more tweaks (#9350)
1 parent 0bff14b commit dad63bd

4 files changed

Lines changed: 28 additions & 18 deletions

File tree

python/triton_kernels/triton_kernels/tensor.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
from .tensor_details import ragged_tensor as ragged_tensor_details
1010
from .tensor_details.layout import BlackwellMXValueLayout, Layout, StridedLayout
1111
from .tensor_details.ragged_tensor import RaggedTensorMetadata
12-
from .tensor_details.dtype import IntegerType, FloatType, DataType, FP4, UINT8, FP8_E4M3FN, FP8_E4M3FNUZ, FP8_E5M2, FP16, BF16, FP32, FP64
12+
from .tensor_details.dtype import IntegerType, FloatType, DataType
13+
from .tensor_details.dtype import FP4, UINT8, FP8_E4M3FN, FP8_E4M3FNUZ, FP8_E5M2, FP16, BF16, FP32, FP64, INT16, INT32, INT64
1314

1415

1516
# storage
@@ -246,6 +247,9 @@ def dtype_to_torch_dtype(dtype: DataType) -> torch.dtype:
246247
FP32: torch.float32,
247248
FP16: torch.float16,
248249
FP64: torch.float64,
250+
INT16: torch.int16,
251+
INT32: torch.int32,
252+
INT64: torch.int64,
249253
}[dtype]
250254

251255

@@ -262,6 +266,9 @@ def torch_dtype_to_dtype(dtype: torch.dtype) -> DataType:
262266
"bfloat16": BF16,
263267
"float32": FP32,
264268
"float64": FP64,
269+
"int16": INT16,
270+
"int32": INT32,
271+
"int64": INT64,
265272
}
266273
if id in vals:
267274
return vals[id]
@@ -270,15 +277,13 @@ def torch_dtype_to_dtype(dtype: torch.dtype) -> DataType:
270277
assert False, f"Unknown dtype: {id}"
271278

272279

273-
def empty(shape: tuple[int], dtype: DataType, device: torch.device, layout=None):
280+
def empty(shape: tuple[int], dtype: DataType, device: torch.device, layout=None,
281+
allow_implicit_conversion: bool = False):
274282
storage_shape = list(shape)
275283
storage_dtype = torch.uint8 if dtype == FP4 else dtype_to_torch_dtype(dtype)
284+
initial_layout = layout if isinstance(layout, StridedLayout) else StridedLayout()
276285
# pack sub-byte datatype along last dimension
277-
if layout is None:
278-
layout = StridedLayout()
279-
# storage shape
280-
assert isinstance(layout, StridedLayout)
281-
order = layout.order(len(storage_shape))
286+
order = initial_layout.order(len(storage_shape))
282287
dim = order[0]
283288
storage_shape[dim] = storage_shape[dim] // (storage_dtype.itemsize * 8 // dtype.bitwidth)
284289
# storage strides
@@ -288,4 +293,8 @@ def empty(shape: tuple[int], dtype: DataType, device: torch.device, layout=None)
288293
strides[d] = running
289294
running *= storage_shape[d]
290295
storage = torch.empty_strided(storage_shape, strides, device=device, dtype=storage_dtype)
291-
return wrap_torch_tensor(storage, dtype=dtype, shape=shape, layout=layout)
296+
ret = wrap_torch_tensor(storage, dtype=dtype, shape=shape, layout=initial_layout)
297+
assert initial_layout == ret.storage.layout or allow_implicit_conversion
298+
if allow_implicit_conversion:
299+
ret = convert_layout(ret, layout)
300+
return ret

python/triton_kernels/triton_kernels/tensor_details/dtype.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,8 @@ def bitwidth(self):
3232
FP16 = FloatType(bitwidth_exponent=5, bitwidth_mantissa=10, is_signed=True)
3333
FP32 = FloatType(bitwidth_exponent=8, bitwidth_mantissa=23, is_signed=True)
3434
FP64 = FloatType(bitwidth_exponent=11, bitwidth_mantissa=52, is_signed=True)
35+
INT16 = IntegerType(16, is_signed=True)
36+
INT32 = IntegerType(32, is_signed=True)
37+
INT64 = IntegerType(64, is_signed=True)
3538

3639
DataType: TypeAlias = IntegerType | FloatType

python/triton_kernels/triton_kernels/topk.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -170,8 +170,8 @@ def topk_torch(
170170
if apply_softmax:
171171
y_vals = torch.softmax(y_vals.float(), dim=-1).to(x.dtype)
172172
if not has_user_provided_indx:
173-
y_indx, sort_indices = torch.sort(y_indx, dim=1)
174-
y_vals = torch.gather(y_vals, 1, sort_indices)
173+
y_vals, sort_indices = torch.sort(y_vals.float(), dim=1, descending=True, stable=True)
174+
y_indx = torch.gather(y_indx, 1, sort_indices)
175175
y_indx[n_rows:, :] = -1
176176
rows = torch.arange(x.shape[0], device=device).unsqueeze(1).expand(-1, y_indx.shape[1]).reshape(-1)
177177
cols = y_indx.reshape(-1) # 64-bit safe for div/mod

python/triton_kernels/triton_kernels/topk_details/_topk_forward.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,14 @@ def streaming_topk(X, stride_xm, n_expts_tot, offs_m, mask_m, N_EXPTS_PAD: tl.co
7171
x = (x.to(x_ultype) << 16) | indx_to_key(offs_x_n, N_EXPTS_PAD)[None, :]
7272
acc = tl.maximum(acc, tl.topk(x, N_EXPTS_ACT, dim=1))
7373

74-
# rotate expert index into upper 16 bits:
75-
# 0000vvvvvvvviiii --> iiii0000vvvvvvvv
76-
acc = (acc << (y_nbits - 16)) | (acc >> 16)
77-
# sort in ascending order of expert (descending order of key)
74+
# sort packed (value_key, index_key) descending:
75+
# this keeps outputs ordered by gate value and uses smaller expert index for ties
7876
acc = tl.sort(acc, dim=1, descending=True)
79-
# iiii0000vvvvvvvv --> 0000iiii:
80-
y_indices_raw = (acc >> (y_nbits - 16)).to(tl.uint32)
77+
# 0000vvvvvvvviiii --> 0000iiii:
78+
y_indices_raw = (acc & 0xFFFF).to(tl.uint32)
8179
y_indices = key_to_indx(y_indices_raw, N_EXPTS_PAD)
82-
# iiii0000vvvvvvvv --> vvvvvvvv:
83-
y_values_raw = acc.to(x_utype)
80+
# 0000vvvvvvvviiii --> vvvvvvvv:
81+
y_values_raw = (acc >> 16).to(x_utype)
8482
y_values = key_to_fpval(y_values_raw).to(x_dtype, bitcast=True)
8583

8684
return y_values, y_indices

0 commit comments

Comments
 (0)