Skip to content

Commit 7bee576

Browse files
agron911meta-codesync[bot]
authored andcommitted
[lint-autofix] Fix pre-commit formatting issues (#1287)
Summary: Pull Request resolved: #1287 Automated formatting fix generated by running `pre-commit run --all-files` and `pre-commit run clang-format --all-files` on the GitHub mirror (facebookexperimental/triton). 61 file(s) fixed. Reviewed By: xuzhao9 Differential Revision: D101703819 fbshipit-source-id: 73152b426940741b18b1395afa30bb8ea31f885f
1 parent ab27190 commit 7bee576

61 files changed

Lines changed: 917 additions & 1334 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/DotOpToLLVM/MMAv5.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -667,12 +667,11 @@ LogicalResult convertScaledDot(const LLVMTypeConverter &typeConverter,
667667
mxfpInstKind, twoCTAs);
668668
};
669669

670-
return convertDotImpl(typeConverter, rewriter, loc, op.getA(), op.getB(),
671-
adaptor.getA(), adaptor.getB(), dTensorTy,
672-
adaptor.getUseD(), adaptor.getPred(),
673-
adaptor.getBarriers(), adaptor.getBarrierPreds(),
674-
twoCTAs, tlx::tlxEnablePairedMMA(op), opKindIsMXFP4,
675-
dot);
670+
return convertDotImpl(
671+
typeConverter, rewriter, loc, op.getA(), op.getB(), adaptor.getA(),
672+
adaptor.getB(), dTensorTy, adaptor.getUseD(), adaptor.getPred(),
673+
adaptor.getBarriers(), adaptor.getBarrierPreds(), twoCTAs,
674+
tlx::tlxEnablePairedMMA(op), opKindIsMXFP4, dot);
676675
}
677676

678677
//===----------------------------------------------------------------------===//

third_party/tileir/PerformanceTuningTips.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ The **occupancy** hint accepts an integer N from 1 to 32, indicating that the pr
1414

1515
Unlike the Triton PTX backend, the CUDA Tile IR Backend disables approx and ftz by default. Setting `TILEIR_ENABLE_APPROX=1` and `TILEIR_ENABLE_FTZ=1` can provide performance improvements in certain workloads (with precision degradation within acceptable ranges), such as **`attention`** and its variant kernels.
1616

17-
Note that the TileIR compiler (`tileiras`) shipping in CUDA 13.1 does not automatically optimize `exp.approx -> ex2 + mulf`. For performance and precision parity with the Triton PTX backend, please explicitly rewrite `expOp` to use `ex2 + mulf` instead.
17+
Note that the TileIR compiler (`tileiras`) shipping in CUDA 13.1 does not automatically optimize `exp.approx -> ex2 + mulf`. For performance and precision parity with the Triton PTX backend, please explicitly rewrite `expOp` to use `ex2 + mulf` instead.
1818

1919
#### opt-level
2020

@@ -68,11 +68,11 @@ sudo nvidia-smi -i 0 -pm 1; sudo nvidia-smi -i 0 -pl 1000; sudo nvidia-smi -i 0
6868

6969
![Fused Attention Backward Benchmark](./fused-attention-bwd.png)
7070

71-
### Persistent Matmul (09-persistent-matmul.py)
71+
### Persistent Matmul (09-persistent-matmul.py)
7272

7373
> TFLOPS by Proton
7474
75-
#### NVIDIA PTX backend
75+
#### NVIDIA PTX backend
7676

7777
| Kernel Name | K=512 | K=1024 | K=1536 | K=2048 | K=2560 | K=3072 | K=3584 | K=4096 | K=4608 | K=5120 | K=5632 | K=6144 | K=6656 | K=7168 | K=7680 | K=8192 |
7878
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|

third_party/tileir/backend/code_generator.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535

3636
from triton.backends.tileir.conf import TileIREnvConf
3737

38+
3839
def mangle_fn(name, arg_tys, caller_context):
3940
# doesn't mangle ret type, which must be a function of arg tys
4041
mangled_args = '_'.join([tileir_mangle_ty(ty) for ty in arg_tys])
@@ -46,20 +47,18 @@ def mangle_fn(name, arg_tys, caller_context):
4647
ret += caller_context.mangle()
4748
return ret
4849

50+
4951
def tileir_mangle_ty(ty):
5052
return ty.mangle()
5153

5254

5355
def tileir_mangle_fn(name, arg_tys, constants):
5456
# doesn't mangle ret type, which must be a function of arg tys
5557
mangled_arg_names = "_".join([tileir_mangle_ty(ty) for ty in arg_tys])
56-
mangled_constants = "_".join(
57-
[f"{i}c{repr(constants[i])}" for i in sorted(constants)]
58-
)
58+
mangled_constants = "_".join([f"{i}c{repr(constants[i])}" for i in sorted(constants)])
5959
mangled_constants = mangled_constants.replace(".", "_d_")
6060
mangled_constants = mangled_constants.replace("'", "_sq_")
6161
# [ and ] are not allowed in LLVM identifiers
6262
mangled_constants = mangled_constants.replace('[', '_').replace(']', '_')
6363
ret = f'{name}__{mangled_arg_names}__{mangled_constants}'
6464
return ret
65-

third_party/tileir/backend/compiler.py

Lines changed: 15 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import subprocess
1818
import sys
1919
from pathlib import Path
20+
21+
2022
def format_compute_capability(capability: int) -> str:
2123
"""
2224
Format compute capability for GPU architecture.
@@ -52,14 +54,15 @@ def TemporaryDirectory(suffix=None, prefix=None, dir=None, delete=True):
5254
if delete:
5355
shutil.rmtree(temp_dir)
5456

57+
5558
@dataclass(frozen=True)
5659
class TileIROptions:
5760
########################## tileIR core options ##########################
5861
backend_name: str = 'tileir'
5962
arch: str = None
6063
num_ctas: int = 1
6164
# tileir use num_stages to control the op cost, see <tileir_link>
62-
num_stages: int = 3
65+
num_stages: int = 3
6366
# tileir use opt_level to control the optimization level, see <tileir_link>
6467
opt_level: int = 3
6568
# tileir use occupancy to control the register usage, see <tileir_link>
@@ -103,10 +106,10 @@ def enable_ftz(self):
103106
@property
104107
def enable_approx(self):
105108
return TileIREnvConf.enable_approx()
109+
106110
def __post_init__(self):
107-
assert (
108-
self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0
109-
), "num_warps must be a power of 2"
111+
assert (self.num_warps > 0 and (self.num_warps & (self.num_warps - 1)) == 0), "num_warps must be a power of 2"
112+
110113
def hash(self):
111114
hash_dict = dict(self.__dict__)
112115
# Get all property values from class __dict__
@@ -115,13 +118,7 @@ def hash(self):
115118
hash_dict[name] = getattr(self, name)
116119
# Exclude num_warps from hash since it doesn't affect compilation output.
117120
# This enables kernel cache sharing for configs that only differ in num_warps.
118-
key = "_".join(
119-
[
120-
f"{name}-{val}"
121-
for name, val in sorted(hash_dict.items())
122-
if name != "num_warps"
123-
]
124-
)
121+
key = "_".join([f"{name}-{val}" for name, val in sorted(hash_dict.items()) if name != "num_warps"])
125122
return hashlib.sha256(key.encode("utf-8")).hexdigest()
126123

127124

@@ -130,6 +127,7 @@ def get_tileir_version():
130127

131128

132129
class TileIRBackend(BaseBackend):
130+
133131
def get_module_map(self):
134132
from triton.language.extra.cuda import libdevice
135133

@@ -152,14 +150,7 @@ def __init__(self, target: GPUTarget) -> None:
152150

153151
def parse_options(self, opts) -> Any:
154152
args = {"arch": os.getenv("TRITON_OVERRIDE_ARCH", f"sm{self.target.arch}")}
155-
args.update(
156-
{
157-
k: opts[k]
158-
for k in TileIROptions.__dataclass_fields__.keys()
159-
if k in opts
160-
if opts[k] is not None
161-
}
162-
)
153+
args.update({k: opts[k] for k in TileIROptions.__dataclass_fields__.keys() if k in opts if opts[k] is not None})
163154
capability = int(self._parse_arch(args["arch"]))
164155
if "supported_fp8_dtypes" not in args:
165156
supported_fp8_dtypes = set(TileIROptions.supported_fp8_dtypes)
@@ -288,19 +279,14 @@ def make_tileir(mod, metadata, opt: TileIROptions, capability):
288279
opt.occupancy,
289280
metadata["num_stages"],
290281
)
291-
tileir.passes.add_auto_gen_memtoken(
292-
pm,
293-
opt.enable_autogen_alias_mem_token
294-
)
282+
tileir.passes.add_auto_gen_memtoken(pm, opt.enable_autogen_alias_mem_token)
295283
passes.common.add_inliner(pm)
296284
if opt.enable_fp_fusion:
297285
tileir.passes.add_fma_fusion(pm)
298286
tileir.passes.add_strip_debuginfo(pm)
299287
pm.run(mod, "make_tileir")
300288
if not tileir.only_contain_legal_dialects(mod):
301-
raise RuntimeError(
302-
"Triton ttir to tileir ir failed. Some ttir ops cannot be converted to tileir."
303-
)
289+
raise RuntimeError("Triton ttir to tileir ir failed. Some ttir ops cannot be converted to tileir.")
304290

305291
pattern = r"entry @([a-zA-Z0-9_]*)\("
306292
match = re.findall(pattern, mod.__str__())
@@ -316,15 +302,9 @@ def make_cubin(mod, metadata, opt: TileIROptions, capability):
316302
def add_stages(self, stages, options, language):
317303
assert language == Language.TRITON, "Only TRITON language is supported for now"
318304
capability = int(self._parse_arch(options.arch))
319-
stages["ttir"] = lambda src, metadata: self.make_ttir(
320-
src, metadata, options, capability
321-
)
322-
stages["tileIR"] = lambda src, metadata: self.make_tileir(
323-
src, metadata, options, capability
324-
)
325-
stages["cubin"] = lambda src, metadata: self.make_cubin(
326-
src, metadata, options, capability
327-
)
305+
stages["ttir"] = lambda src, metadata: self.make_ttir(src, metadata, options, capability)
306+
stages["tileIR"] = lambda src, metadata: self.make_tileir(src, metadata, options, capability)
307+
stages["cubin"] = lambda src, metadata: self.make_cubin(src, metadata, options, capability)
328308

329309
@functools.lru_cache()
330310
def hash(self):

third_party/tileir/backend/conf.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66

77
class TileIREnvConf:
8+
89
@staticmethod
910
def enable_approx():
1011
# Enable approximate calculation, trading off numerical precision for performance gains
@@ -35,7 +36,8 @@ def get_tileiras_path():
3536
path = os.path.join(cuda_home, "bin", "tileiras")
3637
if os.path.exists(path):
3738
import subprocess
38-
version_output = subprocess.check_output([path, "--version"], encoding="utf-8", stderr=subprocess.STDOUT)
39+
version_output = subprocess.check_output([path, "--version"], encoding="utf-8",
40+
stderr=subprocess.STDOUT)
3941
if "release 13.1" in version_output:
4042
return path
4143
from shutil import which
@@ -71,6 +73,7 @@ def get_sm_arch():
7173
def enable_tma_offset_assert_check():
7274
return os.getenv("NVT_TMA_OFFSET_CHECK", "0") == "1"
7375

76+
7477
@contextmanager
7578
def set_env_var(var_name, new_value):
7679
# Save the original value of the environment variable

third_party/tileir/backend/driver.c

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,23 +76,21 @@ static PyObject *loadtileIRBinary(PyObject *self, PyObject *args) {
7676
cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun));
7777
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
7878
cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun));
79-
n_spills /= 4; // Convert bytes to number of 32-bit registers.
80-
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(
81-
cuFuncGetAttribute(&static_smem_bytes, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
79+
n_spills /= 4; // Convert bytes to number of 32-bit registers.
80+
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
81+
&static_smem_bytes, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun));
8282
CUDA_CHECK_AND_RETURN_NULL_ALLOW_THREADS(cuFuncGetAttribute(
83-
&n_max_threads, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun));
83+
&n_max_threads, CU_FUNC_ATTRIBUTE_MAX_THREADS_PER_BLOCK, fun));
8484

8585
Py_END_ALLOW_THREADS;
8686

8787
if (PyErr_Occurred()) {
8888
return NULL;
8989
}
90-
return Py_BuildValue(
91-
"(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs, n_spills, n_max_threads
92-
);
90+
return Py_BuildValue("(KKiii)", (uint64_t)mod, (uint64_t)fun, n_regs,
91+
n_spills, n_max_threads);
9392
}
9493

95-
9694
static PyMethodDef ModuleMethods[] = {
9795
{"load_tileir_binary", loadtileIRBinary, METH_VARARGS,
9896
"Load provided tileir into CUDA driver"},
@@ -114,4 +112,3 @@ PyMODINIT_FUNC PyInit_tileir_utils(void) {
114112

115113
return m;
116114
}
117-

third_party/tileir/backend/driver.py

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,7 @@
99
import tempfile
1010
import threading
1111
import torch
12-
from triton.backends.nvidia.driver import (
13-
library_dirs,
14-
include_dirs,
15-
libraries,
16-
ty_to_cpp
17-
)
12+
from triton.backends.nvidia.driver import (library_dirs, include_dirs, libraries, ty_to_cpp)
1813

1914
from triton import knobs
2015
from triton.runtime.build import compile_module_from_src
@@ -24,13 +19,13 @@
2419
from triton.backends.tileir.conf import TileIREnvConf
2520
from triton.tools.tensor_descriptor import TensorDescriptor
2621

27-
2822
# ------------------------
2923
# Utils
3024
# ------------------------
3125

3226

3327
class TileIRUtils(object):
28+
3429
def __new__(cls):
3530
if not hasattr(cls, "instance"):
3631
cls.instance = super(TileIRUtils, cls).__new__(cls)
@@ -40,11 +35,11 @@ def __init__(self):
4035
tile_mod_path = dirname
4136
nvidia_mod_path = os.path.join(os.path.dirname(dirname), "nvidia")
4237
tile_mod = compile_module_from_src(
43-
Path(os.path.join(tile_mod_path, "driver.c")).read_text(), "tileir_utils", library_dirs(), include_dirs, libraries
44-
)
38+
Path(os.path.join(tile_mod_path, "driver.c")).read_text(), "tileir_utils", library_dirs(), include_dirs,
39+
libraries)
4540
nvidia_mod = compile_module_from_src(
46-
Path(os.path.join(nvidia_mod_path, "driver.c")).read_text(), "cuda_utils", library_dirs(), include_dirs, libraries
47-
)
41+
Path(os.path.join(nvidia_mod_path, "driver.c")).read_text(), "cuda_utils", library_dirs(), include_dirs,
42+
libraries)
4843
self.init_nvidia_function(nvidia_mod)
4944
self.init_tileir_function(tile_mod)
5045

@@ -61,7 +56,6 @@ def init_nvidia_function(self, mod):
6156
# Launcher
6257
# ------------------------
6358

64-
6559
dirname = os.path.dirname(__file__)
6660

6761
FLOAT_STORAGE_TYPE = {
@@ -79,12 +73,12 @@ def init_nvidia_function(self, mod):
7973
"fp64": "pack_fp64",
8074
}
8175

82-
8376
_BASE_ARGS_FORMAT = "iiiKKpOOOO"
8477
_BASE_ARGS_FORMAT_LEN = len(_BASE_ARGS_FORMAT)
8578

8679

8780
def make_launcher(constants, signature):
81+
8882
def _flatten_signature(sig, output):
8983
# Flatten tuples
9084
if isinstance(sig, tuple):
@@ -353,7 +347,7 @@ def format_of(ty):
353347
{"; ".join([f"DevicePtrInfo ptr_info{i} = getPointer(_arg{i}, {i}); if (!ptr_info{i}.valid) return NULL;" if ty[0] == "*" else "" for i, ty in signature.items()])};
354348
{newline.join(float_storage_decls)}
355349
Py_BEGIN_ALLOW_THREADS;
356-
350+
357351
_launch(numTilesX, numTilesY, numTilesZ, launch_pdl, (CUstream)_stream, (CUfunction)_function{', ' + ', '.join(internal_args_list) if len(internal_args_list) > 0 else ''});
358352
Py_END_ALLOW_THREADS;
359353
if (PyErr_Occurred()) {{
@@ -399,7 +393,6 @@ def format_of(ty):
399393
return src
400394

401395

402-
403396
# This function unpacks a tensordesc object into its components:
404397
# - data pointer
405398
# - shape dimensions
@@ -418,6 +411,7 @@ def make_tensordesc_arg(arg):
418411

419412

420413
def wrap_handle_tensordesc(launcher):
414+
421415
def inner(*args):
422416
# 9 is the metadata arguments in `args` defined in `make_launcher`
423417
meta_args = args[:9]
@@ -429,6 +423,7 @@ def inner(*args):
429423
else:
430424
final_args.append(arg)
431425
return launcher(*meta_args, *final_args)
426+
432427
return inner
433428

434429

@@ -438,7 +433,7 @@ def __init__(self, src, metadata):
438433
ids = {"ids_of_const_exprs": src.fn.constexprs if hasattr(src, "fn") else tuple()}
439434

440435
constants = src.constants if hasattr(src, "constants") else dict()
441-
arg_idx = lambda x: (src.fn.arg_names.index(x),) if isinstance(x, str) else x
436+
arg_idx = lambda x: (src.fn.arg_names.index(x), ) if isinstance(x, str) else x
442437
constants = {arg_idx(idx): value for idx, value in constants.items()}
443438
signature = {idx: value for idx, value in src.signature.items()}
444439
has_tensordesc = any("tensordesc" in value for value in signature.values())
@@ -473,7 +468,6 @@ def __init__(self, src, metadata):
473468
self.launch = mod.launch
474469
self.launch_pdl = metadata.launch_pdl
475470

476-
477471
def __call__(self, *args, **kwargs):
478472
# TODO: below if branch is for torch 2.8.0a0+5228986c39.nvinternal commit
479473
# where constexpr arguments are not passed to the launch function by inductor
@@ -482,13 +476,11 @@ def __call__(self, *args, **kwargs):
482476
num_launch_args = 9
483477
num_params = len(args) - num_launch_args
484478
if num_params < self.ori_signature_len:
485-
extra_args = [
486-
self.constants[(i,)] for i in range(num_params, self.ori_signature_len)
487-
]
479+
extra_args = [self.constants[(i, )] for i in range(num_params, self.ori_signature_len)]
488480
model_args = args + tuple(extra_args)
489481
else:
490482
model_args = args
491-
model_args = model_args[:5] + (self.launch_pdl,) + model_args[5:]
483+
model_args = model_args[:5] + (self.launch_pdl, ) + model_args[5:]
492484

493485
self.launch(*model_args, **kwargs)
494486

@@ -543,4 +535,5 @@ def get_empty_cache_for_benchmark(self):
543535
def clear_cache(self, cache):
544536
cache.zero_()
545537

538+
546539
GlobalTileIRDriver = TileIRDriver()

0 commit comments

Comments
 (0)