Skip to content

Commit 6901ab5

Browse files
authored
Merge branch 'triton_v3.3.x' into workflow_triton_v3.3.x
2 parents a780ea0 + e4eb677 commit 6901ab5

5 files changed

Lines changed: 87 additions & 42 deletions

File tree

python/setup.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -607,13 +607,13 @@ def build_extension(self, ext):
607607
if helper.flagtree_backend:
608608
if helper.flagtree_backend in ("aipu", "tsingmicro", "enflame"):
609609
backends = [
610-
*BackendInstaller.copy(helper.default_backends + helper.extend_backends),
610+
*BackendInstaller.copy(helper.configs.default_backends + tuple(helper.configs.extend_backends)),
611611
*BackendInstaller.copy_externals(),
612612
]
613613
else:
614-
backends = [*BackendInstaller.copy(helper.extend_backends), *BackendInstaller.copy_externals()]
614+
backends = [*BackendInstaller.copy(helper.configs.extend_backends), *BackendInstaller.copy_externals()]
615615
else:
616-
backends = [*BackendInstaller.copy(helper.default_backends), *BackendInstaller.copy_externals()]
616+
backends = [*BackendInstaller.copy(helper.configs.default_backends), *BackendInstaller.copy_externals()]
617617

618618

619619
def add_link_to_backends():
@@ -799,7 +799,7 @@ def get_git_version_suffix():
799799
entry_points=get_entry_points(),
800800
package_data=package_data,
801801
include_package_data=True,
802-
ext_modules=[CMakeExtension("triton", helper.ext_sourcedir)],
802+
ext_modules=[CMakeExtension("triton", helper.configs.ext_sourcedir)],
803803
cmdclass={
804804
"build_ext": CMakeBuild,
805805
"build_py": CMakeBuildPy,

python/setup_tools/setup_helper.py

Lines changed: 18 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,17 +6,11 @@
66
import hashlib
77
from distutils.sysconfig import get_python_lib
88
from . import utils
9+
from .utils.tools import flagtree_configs as configs
910

10-
extend_backends = []
11-
default_backends = ["nvidia", "amd"]
12-
plugin_backends = ["cambricon", "ascend", "aipu", "tsingmicro", "enflame"]
13-
ext_sourcedir = "triton/_C/"
14-
flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower()
15-
flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower()
16-
offline_build = os.getenv("FLAGTREE_PLUGIN", "OFF")
17-
device_mapping = {"xpu": "xpu", "mthreads": "musa", "ascend": "ascend"}
18-
activated_module = utils.activate(flagtree_backend)
1911
downloader = utils.tools.DownloadManager()
12+
configs = configs
13+
flagtree_backend = configs.flagtree_backend
2014

2115
set_llvm_env = lambda path: set_env({
2216
'LLVM_INCLUDE_DIRS': Path(path) / "include",
@@ -27,26 +21,26 @@
2721

2822
def install_extension(*args, **kargs):
2923
try:
30-
activated_module.install_extension(*args, **kargs)
24+
configs.activated_module.install_extension(*args, **kargs)
3125
except Exception:
3226
pass
3327

3428

3529
def get_backend_cmake_args(*args, **kargs):
3630
try:
37-
return activated_module.get_backend_cmake_args(*args, **kargs)
31+
return configs.activated_module.get_backend_cmake_args(*args, **kargs)
3832
except Exception:
3933
return []
4034

4135

4236
def get_device_name():
43-
return device_mapping[flagtree_backend]
37+
return configs.device_alias[flagtree_backend]
4438

4539

4640
def get_extra_packages():
4741
packages = []
4842
try:
49-
packages = activated_module.get_extra_install_packages()
43+
packages = configs.activated_module.get_extra_install_packages()
5044
except Exception:
5145
packages = []
5246
return packages
@@ -55,7 +49,7 @@ def get_extra_packages():
5549
def get_package_data_tools():
5650
package_data = ["compile.h", "compile.c"]
5751
try:
58-
package_data += activated_module.get_package_data_tools()
52+
package_data += configs.activated_module.get_package_data_tools()
5953
except Exception:
6054
package_data
6155
return package_data
@@ -81,15 +75,15 @@ def download_flagtree_third_party(name, condition, required=False, hock=None):
8175
submodule = utils.flagtree_submodules[name]
8276
downloader.download(module=submodule, required=required)
8377
if callable(hock):
84-
hock(third_party_base_dir=utils.flagtree_submodule_dir, backend=submodule,
85-
default_backends=default_backends)
78+
configs.default_backends = hock(third_party_base_dir=configs.flagtree_submodule_dir, backend=submodule,
79+
default_backends=configs.default_backends)
8680
else:
8781
print(f"\033[1;33m[Note] Skip downloading {name} since USE_{name.upper()} is set to OFF\033[0m")
8882

8983

9084
def post_install():
9185
try:
92-
activated_module.post_install()
86+
configs.activated_module.post_install()
9387
except Exception:
9488
pass
9589

@@ -250,14 +244,14 @@ def skip_package_dir(package):
250244
if 'backends' in package or 'profiler' in package:
251245
return True
252246
try:
253-
return activated_module.skip_package_dir(package)
247+
return configs.activated_module.skip_package_dir(package)
254248
except Exception:
255249
return False
256250

257251
@staticmethod
258252
def get_package_dir(packages):
259253
package_dict = {}
260-
if flagtree_backend and flagtree_backend not in plugin_backends:
254+
if flagtree_backend and flagtree_backend not in configs.plugin_backends:
261255
connection = []
262256
backend_triton_path = f"../third_party/{flagtree_backend}/python/"
263257
for package in packages:
@@ -267,7 +261,7 @@ def get_package_dir(packages):
267261
connection.append(pair)
268262
package_dict.update(connection)
269263
try:
270-
package_dict.update(activated_module.get_package_dir())
264+
package_dict.update(configs.activated_module.get_package_dir())
271265
except Exception:
272266
pass
273267
return package_dict
@@ -277,8 +271,8 @@ def handle_flagtree_backend():
277271
global ext_sourcedir
278272
if flagtree_backend:
279273
print(f"\033[1;32m[INFO] FlagtreeBackend is {flagtree_backend}\033[0m")
280-
extend_backends.append(flagtree_backend)
281-
if "editable_wheel" in sys.argv and flagtree_backend not in plugin_backends:
274+
configs.extend_backends.append(flagtree_backend)
275+
if "editable_wheel" in sys.argv and flagtree_backend not in configs.plugin_backends:
282276
ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/"
283277

284278

@@ -337,7 +331,7 @@ def uninstall_triton():
337331
)
338332

339333
cache.store(
340-
file="iluvatarTritonPlugin.so", condition=("iluvatar" == flagtree_backend) and (not flagtree_plugin), url=
334+
file="iluvatarTritonPlugin.so", condition=("iluvatar" == flagtree_backend) and (not configs.flagtree_plugin), url=
341335
"https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/iluvatarTritonPlugin-cpython3.10-glibc2.30-glibcxx3.4.28-cxxabi1.3.12-ubuntu-x86_64_v0.3.0.tar.gz",
342336
copy_dst_path=f"third_party/{flagtree_backend}", md5_digest="015b9af8")
343337

@@ -376,7 +370,7 @@ def uninstall_triton():
376370
)
377371

378372
cache.store(
379-
file="mthreadsTritonPlugin.so", condition=("mthreads" == flagtree_backend) and (not flagtree_plugin), url=
373+
file="mthreadsTritonPlugin.so", condition=("mthreads" == flagtree_backend) and (not configs.flagtree_plugin), url=
380374
"https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/mthreadsTritonPlugin-cpython3.10-glibc2.35-glibcxx3.4.30-cxxabi1.3.13-ubuntu-x86_64_v0.3.0.tar.gz",
381375
copy_dst_path=f"third_party/{flagtree_backend}", md5_digest="2a9ca0b8")
382376

python/setup_tools/utils/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,16 @@
22
import importlib.util
33
import os
44
from . import tools, default, aipu
5-
from .tools import flagtree_submodule_dir, OfflineBuildManager
5+
from .tools import OfflineBuildManager, flagtree_configs
66

77
flagtree_submodules = {
88
"triton_shared":
99
tools.Module(name="triton_shared", url="https://github.com/microsoft/triton-shared.git",
1010
commit_id="5842469a16b261e45a2c67fbfc308057622b03ee",
11-
dst_path=os.path.join(flagtree_submodule_dir, "triton_shared")),
11+
dst_path=os.path.join(flagtree_configs.flagtree_submodule_dir, "triton_shared")),
1212
"flir":
1313
tools.Module(name="flir", url="https://github.com/FlagTree/flir.git",
14-
dst_path=os.path.join(flagtree_submodule_dir, "flir")),
14+
dst_path=os.path.join(flagtree_configs.flagtree_submodule_dir, "flir")),
1515
}
1616

1717

python/setup_tools/utils/aipu.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
11
def precompile_hock(*args, **kargs):
22
default_backends = kargs["default_backends"]
3-
default_backends.append('flir')
3+
default_backends_list = list(default_backends)
4+
default_backends_list.append('flir')
5+
kargs["default_backends"] = tuple(default_backends_list)
6+
default_backends = tuple(default_backends_list)
7+
return default_backends

python/setup_tools/utils/tools.py

Lines changed: 57 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,61 @@
55
import zipfile
66
from io import BytesIO
77
import urllib.request
8-
from dataclasses import dataclass
98
import json
109
from build_helpers import get_base_dir
1110
import platform
11+
from typing import Mapping
12+
from types import MappingProxyType
13+
import importlib.util
14+
from dataclasses import dataclass, field
1215

13-
flagtree_root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
14-
flagtree_submodule_dir = os.path.join(flagtree_root_dir, "third_party")
15-
flagtree_backend = os.environ.get("FLAGTREE_BACKEND")
16-
use_cuda_toolkit = ["aipu"]
16+
17+
def _get_flagtree_root() -> str:
18+
return str(Path(__file__).resolve().parents[3])
19+
20+
21+
@dataclass
22+
class FlagtreeConfigs:
23+
default_backends: tuple = ("nvidia", "amd")
24+
plugin_backends: tuple = ("cambricon", "ascend", "aipu", "tsingmicro", "enflame")
25+
use_cuda_toolkit_backends: tuple = ('aipu', )
26+
language_extra_backends: tuple = ('xpu', 'mthreads', "cambricon")
27+
ext_sourcedir: str = "triton/_C/"
28+
flagtree_root_dir: str = field(default_factory=_get_flagtree_root)
29+
flagtree_backend: str = field(default_factory=lambda: os.environ.get("FLAGTREE_BACKEND"))
30+
flagtree_plugin: str = field(default_factory=lambda: os.environ.get("FLAGTREE_PLUGIN"))
31+
extend_backends: list = field(default_factory=list)
32+
activated_module: any = None
33+
flagtree_submodule_dir: str = ''
34+
device_alias_map: Mapping[str, str] = field(default_factory=lambda: MappingProxyType({
35+
"xpu": "xpu",
36+
"mthreads": "musa",
37+
"ascend": "ascend",
38+
"cambricon": "mlu",
39+
}))
40+
41+
def __post_init__(self):
42+
object.__setattr__(
43+
self,
44+
"flagtree_submodule_dir",
45+
os.path.join(self.flagtree_root_dir, "third_party"),
46+
)
47+
object.__setattr__(self, "activated_module", self._activate_device_module(self.flagtree_backend))
48+
49+
def _activate_device_module(self, backend, suffix=".py"):
50+
backend = "default" if not backend else backend
51+
module_path = Path(os.path.dirname(__file__)) / backend
52+
module_path = str(module_path) + suffix
53+
spec = importlib.util.spec_from_file_location("module", module_path)
54+
module = importlib.util.module_from_spec(spec)
55+
try:
56+
spec.loader.exec_module(module)
57+
except Exception:
58+
pass
59+
return module
60+
61+
62+
flagtree_configs = FlagtreeConfigs()
1763

1864

1965
@dataclass
@@ -40,7 +86,8 @@ def dir_rollback(deep, base_path):
4086

4187

4288
def is_skip_cuda_toolkits():
43-
return flagtree_backend and (flagtree_backend not in use_cuda_toolkit)
89+
return flagtree_configs.flagtree_backend and (flagtree_configs.flagtree_backend
90+
not in flagtree_configs.use_cuda_toolkit_backends)
4491

4592

4693
def remove_triton_in_modules(model):
@@ -215,7 +262,7 @@ def is_offline_build(self) -> bool:
215262
return os.getenv("TRITON_OFFLINE_BUILD", "OFF") == "ON" or os.getenv("FLAGTREE_OFFLINE_BUILD_DIR")
216263

217264
def copy_to_flagtree_project(self, kargs):
218-
dst_path = os.path.join(flagtree_root_dir,
265+
dst_path = os.path.join(_get_flagtree_root(),
219266
kargs['dst_path']) if 'dst_path' in kargs and kargs['dst_path'] else None
220267
src_path = self.src
221268
if not dst_path:
@@ -264,7 +311,7 @@ def handle_triton_origin_toolkits(self):
264311
shutil.copytree(src_path, toolkit_cache_path, dirs_exist_ok=True)
265312
else:
266313
raise RuntimeError(
267-
f"\n\n \033[31m[ERROR]:\033[0m The {flagtree_backend} offline build dependency \033[93m{src_path}\033[0m does not exist.\n"
314+
f"\n\n \033[31m[ERROR]:\033[0m The {flagtree_configs.flagtree_backend} offline build dependency \033[93m{src_path}\033[0m does not exist.\n"
268315
)
269316

270317
def validate_offline_build_dir(self, path, required=False):
@@ -279,7 +326,7 @@ def validate_offline_build_deps(self, path, kargs, required=False):
279326
url = kargs.get('url', None)
280327
if (not path or not os.path.exists(path)) and required:
281328
raise RuntimeError(
282-
f"\n\n \033[31m[ERROR]:\033[0m The {flagtree_backend} offline build dependency \033[93m{path}\033[0m does not exist.\n"
329+
f"\n\n \033[31m[ERROR]:\033[0m The {flagtree_configs.flagtree_backend} offline build dependency \033[93m{path}\033[0m does not exist.\n"
283330
f" And you can download the dependency package from the \n \033[93m{url}\033[0m \n"
284331
f" then extract it to the \033[93m{self.offline_build_dir}\033[0m directory you specified !\033[0m\n\n")
285332

@@ -300,7 +347,7 @@ def single_build(self, *args, **kargs):
300347
self.copy_to_flagtree_project(kargs)
301348
self.handle_flagtree_hock(kargs)
302349
if is_skip_cuda_toolkits():
303-
print(f"[INFO] Skipping CUDA toolkits for {flagtree_backend} backend in offline build.")
350+
print(f"[INFO] Skipping CUDA toolkits for {flagtree_configs.flagtree_backend} backend in offline build.")
304351
else:
305352
self.handle_triton_origin_toolkits()
306353
return True

0 commit comments

Comments
 (0)