-
Notifications
You must be signed in to change notification settings - Fork 2.5k
Expand file tree
/
Copy pathcute_dsl_utils.py
More file actions
206 lines (167 loc) · 7.66 KB
/
cute_dsl_utils.py
File metadata and controls
206 lines (167 loc) · 7.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
# Copyright (c) 2025, Tri Dao.
import os
import pathlib
from typing import Tuple
from functools import partial, lru_cache
from dataclasses import dataclass, fields
import torch
try:
from triton.tools.disasm import extract
except ImportError:
extract = None
import cutlass
import cutlass.cute as cute
from cutlass.base_dsl.typing import JitArgument
from cutlass.cutlass_dsl import NumericMeta
from cutlass.cute.runtime import from_dlpack
StaticTypes = (cutlass.Constexpr, NumericMeta, int, bool, str, float, type(None))
load_cubin_module_data_og = cutlass.base_dsl.runtime.cuda.load_cubin_module_data
cute_compile_og = cute.compile
torch2cute_dtype_map = {
torch.float16: cutlass.Float16,
torch.bfloat16: cutlass.BFloat16,
torch.float32: cutlass.Float32,
torch.float8_e4m3fn: cutlass.Float8E4M3FN,
torch.float8_e5m2: cutlass.Float8E5M2,
}
@lru_cache
def get_max_active_clusters(cluster_size):
return cutlass.utils.HardwareInfo().get_max_active_clusters(cluster_size=cluster_size)
@lru_cache
def get_device_capacity(device: torch.device = None) -> Tuple[int, int]:
return torch.cuda.get_device_capability(device)
@dataclass
class ParamsBase:
def __extract_mlir_values__(self):
all_fields = [getattr(self, field.name) for field in fields(self)]
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
values, self._values_pos = [], []
for obj in non_constexpr_fields:
obj_values = cutlass.extract_mlir_values(obj)
values += obj_values
self._values_pos.append(len(obj_values))
return values
def __new_from_mlir_values__(self, values):
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
non_constexpr_fields = {
n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
}
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
values = values[n_items:]
return self.__class__(**non_constexpr_fields, **constexpr_fields)
@dataclass
class ArgumentsBase(JitArgument):
def __c_pointers__(self):
all_fields = [getattr(self, field.name) for field in fields(self)]
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
c_ptrs = []
for obj in non_constexpr_fields:
if hasattr(obj, "__c_pointers__"):
c_ptrs.extend(obj.__c_pointers__())
return c_ptrs
def __get_mlir_types__(self):
all_fields = [getattr(self, field.name) for field in fields(self)]
non_constexpr_fields = [f for f in all_fields if not isinstance(f, StaticTypes)]
types, self._values_pos = [], []
for obj in non_constexpr_fields:
if hasattr(obj, "__get_mlir_types__"):
obj_types = obj.__get_mlir_types__()
types.extend(obj_types)
self._values_pos.append(len(obj_types))
else:
self._values_pos.append(0)
return types
def __new_from_mlir_values__(self, values):
all_fields = {field.name: getattr(self, field.name) for field in fields(self)}
constexpr_fields = {n: f for n, f in all_fields.items() if isinstance(f, StaticTypes)}
non_constexpr_fields = {
n: f for n, f in all_fields.items() if not isinstance(f, StaticTypes)
}
for (name, field), n_items in zip(non_constexpr_fields.items(), self._values_pos):
non_constexpr_fields[name] = cutlass.new_from_mlir_values(field, values[:n_items])
values = values[n_items:]
return self.__class__(**non_constexpr_fields, **constexpr_fields)
def load_cubin_module_data_patched(cubin_data, filepath):
pathlib.Path(filepath).write_bytes(cubin_data)
return load_cubin_module_data_og(cubin_data)
def cute_compile_patched(*args, **kwargs):
"""A patched version of cute.compile that dump the SASS to a file if CUTE_CUBIN_PATH is set."""
cubin_path = os.getenv("CUTE_CUBIN_PATH", None)
if cubin_path is not None:
cutlass.base_dsl.runtime.cuda.load_cubin_module_data = partial(
load_cubin_module_data_patched, filepath=cubin_path
)
output = cute_compile_og(*args, **kwargs)
if cubin_path is not None:
cutlass.base_dsl.runtime.cuda.load_cubin_module_data = load_cubin_module_data_og
if extract is not None:
sass = extract(cubin_path, None)
pathlib.Path(cubin_path).with_suffix(".annotated.sass").write_text(sass)
return output
def assume_strides_aligned(t):
"""Assume all strides except the last are divisible by 128 bits.
Python int strides (e.g., stride=0 from GQA expand) are kept as-is
since they're static and don't need alignment assumptions.
"""
divby = 128 // t.element_type.width
strides = tuple(s if isinstance(s, int) else cute.assume(s, divby=divby) for s in t.stride[:-1])
return (*strides, t.stride[-1])
def assume_tensor_aligned(t):
"""Rebuild a tensor with 128-bit aligned stride assumptions. Passes through None."""
if t is None:
return None
return cute.make_tensor(t.iterator, cute.make_layout(t.shape, stride=assume_strides_aligned(t)))
def to_cute_tensor(t, assumed_align=16, leading_dim=-1, fully_dynamic=False, enable_tvm_ffi=True):
"""Convert torch tensor to cute tensor for TVM FFI. leading_dim=-1 defaults to t.ndim-1."""
# NOTE: torch 2.9.1 doesn't support fp8 via DLPack but 2.11.0 nightly does
# currently export raw bytes as uint8 and tell cutlass correct type
# can directly export as fp8 when torch supports it
if t.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
tensor = from_dlpack(
t.view(torch.uint8).detach(),
assumed_align=assumed_align,
enable_tvm_ffi=enable_tvm_ffi,
)
tensor.element_type = (
cutlass.Float8E4M3FN if t.dtype == torch.float8_e4m3fn else cutlass.Float8E5M2
)
else:
tensor = from_dlpack(t.detach(), assumed_align=assumed_align, enable_tvm_ffi=enable_tvm_ffi)
if fully_dynamic:
return tensor.mark_layout_dynamic()
if leading_dim == -1:
leading_dim = t.ndim - 1
return tensor.mark_layout_dynamic(leading_dim=leading_dim)
def to_cute_aux_tensor(t, enable_tvm_ffi=True):
"""Convert torch tensor to cute tensor for TVM FFI, tailored to FlexAttention aux tensors.
This allows the user to specify alignment and leading dimension for aux tensors used in
custom score_mod callables.
"""
assumed_align: int = getattr(t, "__assumed_align__", None)
leading_dim: int = getattr(t, "__leading_dim__", None)
fully_dynamic: bool = leading_dim is None
return to_cute_tensor(
t,
assumed_align=assumed_align,
leading_dim=leading_dim,
fully_dynamic=fully_dynamic,
enable_tvm_ffi=enable_tvm_ffi,
)
def get_aux_tensor_metadata(aux_tensors):
return tuple(
(
getattr(t, "__assumed_align__", 0),
getattr(t, "__leading_dim__", -1),
hasattr(t, "__leading_dim__"),
)
for t in aux_tensors
)
def get_broadcast_dims(tensor: torch.Tensor) -> Tuple[bool, ...]:
"""Return tuple of bools indicating which dims have stride=0 (broadcast).
This is useful for compile keys since CuTe's mark_layout_dynamic() keeps
stride=0 as static, meaning kernels compiled with different broadcast
patterns are not interchangeable.
"""
return tuple(s == 0 for s in tensor.stride())