Skip to content

Commit ae40993

Browse files
committed
[init refactor] refactor setup
1 parent 97e653a commit ae40993

File tree

84 files changed

+172
-19599
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

84 files changed

+172
-19599
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ third_party/cambricon/
2323
third_party/iluvatar/iluvatarTritonPlugin.so
2424
third_party/triton_shared/
2525
third_party/xpu/backend/xpu3
26+
third_party/ascend
2627

2728
# Proton
2829
python/triton/profiler
@@ -57,6 +58,7 @@ ptxas
5758
third_party/nvidia/backend/include
5859
third_party/nvidia/backend/lib/cupti
5960

61+
6062
# Docs
6163
docs/_build/
6264
docs/python-api/generated/

python/setup.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828

2929
import pybind11
3030

31-
import setup_helper as helper
31+
from setup_tools import setup_helper as helper
3232

3333

3434
@dataclass
@@ -611,31 +611,31 @@ class plugin_install(install):
611611
def run(self):
612612
add_links()
613613
install.run(self)
614-
helper.post_install(self)
614+
helper.post_install()
615615

616616

617617
class plugin_develop(develop):
618618

619619
def run(self):
620620
add_links()
621621
develop.run(self)
622-
helper.post_install(self)
622+
helper.post_install()
623623

624624

625625
class plugin_bdist_wheel(bdist_wheel):
626626

627627
def run(self):
628628
add_links()
629629
bdist_wheel.run(self)
630-
helper.post_install(self)
630+
helper.post_install()
631631

632632

633633
class plugin_egginfo(egg_info):
634634

635635
def run(self):
636636
add_links()
637637
egg_info.run(self)
638-
helper.post_install(self)
638+
helper.post_install()
639639

640640

641641
package_data_tools = helper.get_package_data_tools()

python/setup_tools/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from . import setup_helper
2+
from . import utils
3+
4+
__all__ = ["setup_helper", "utils"]
Lines changed: 66 additions & 101 deletions
Original file line numberDiff line numberDiff line change
@@ -8,35 +8,18 @@
88
import urllib.request
99
from pathlib import Path
1010
import hashlib
11-
from dataclasses import dataclass
1211
from distutils.sysconfig import get_python_lib
12+
from . import utils
1313

14-
use_triton_shared = False
15-
necessary_third_party = ["triton_shared"]
16-
default_backends = ["nvidia", "amd"]
1714
extend_backends = []
15+
default_backends = ["nvidia", "amd"]
1816
plugin_backends = ["cambricon", "ascend"]
1917
ext_sourcedir = "triton/_C/"
2018
flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower()
2119
flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower()
2220
device_mapping = {"xpu": "xpu", "mthreads": "musa", "ascend": "ascend"}
23-
24-
25-
@dataclass
26-
class FlagTreeBackend:
27-
name: str
28-
url: str
29-
tag: str
30-
31-
32-
flagtree_backend_info = {
33-
"triton_shared":
34-
FlagTreeBackend(name="triton_shared", url="https://github.com/microsoft/triton-shared.git",
35-
tag="380b87122c88af131530903a702d5318ec59bb33"),
36-
"cambricon":
37-
FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git",
38-
tag="00f51c2e48a943922f86f03d58e29f514def646d"),
39-
}
21+
flagtree_backends = utils.flagtree_backends
22+
backend_utils = utils.activate(flagtree_backend)
4023

4124
set_llvm_env = lambda path: set_env({
4225
'LLVM_INCLUDE_DIRS': Path(path) / "include",
@@ -51,43 +34,68 @@ def get_device_name():
5134

5235
def get_extra_packages():
5336
packages = []
54-
if flagtree_backend == 'ascend':
55-
packages = [
56-
"triton/triton_patch",
57-
"triton/triton_patch/language",
58-
"triton/triton_patch/compiler",
59-
"triton/triton_patch/runtime",
60-
]
37+
try:
38+
packages = backend_utils.get_extra_install_packages()
39+
except Exception:
40+
packages = []
6141
return packages
6242

6343

6444
def get_package_data_tools():
6545
package_data = ["compile.h", "compile.c"]
66-
if flagtree_backend == 'xpu':
67-
package_data += ["compile_xpu.h", "compile_xpu.c"]
46+
try:
47+
package_data += backend_utils.get_package_data_tools()
48+
except Exception:
49+
package_data
6850
return package_data
6951

7052

71-
def post_install(self):
72-
73-
def get_module(module_path):
74-
import importlib.util
75-
import os
76-
module_path = os.path.abspath(module_path)
77-
spec = importlib.util.spec_from_file_location("module", module_path)
78-
module = importlib.util.module_from_spec(spec)
79-
spec.loader.exec_module(module)
80-
return module
81-
82-
def ascend():
83-
utils = get_module("../third_party/ascend/utils.py")
84-
utils.post_install()
85-
86-
code = f"{flagtree_backend}()"
87-
try:
88-
exec(code, globals(), locals())
89-
except: #noqa: E722
90-
pass
53+
def git_clone(lib, lib_path):
54+
import git
55+
MAX_RETRY = 4
56+
print(f"Clone {lib.name} into {lib_path} ...")
57+
retry_count = MAX_RETRY
58+
while (retry_count):
59+
try:
60+
repo = git.Repo.clone_from(lib.url, lib_path)
61+
if lib.tag is not None:
62+
repo.git.checkout(lib.tag)
63+
sub_triton_path = Path(lib_path) / "triton"
64+
if os.path.exists(sub_triton_path):
65+
shutil.rmtree(sub_triton_path)
66+
print(f"successfully clone {lib.name} into {lib_path} ...")
67+
return True
68+
except Exception:
69+
retry_count -= 1
70+
print(f"\n[{MAX_RETRY - retry_count}] retry to clone {lib.name} to {lib_path}")
71+
return False
72+
73+
74+
def download_flagtree_third_party(name, condition, required=False, hock=None):
75+
if not condition:
76+
return
77+
backend = None
78+
for _backend in flagtree_backends:
79+
if _backend.name is name:
80+
backend = _backend
81+
break
82+
if backend is None:
83+
return backend
84+
third_party_base_dir = Path(os.path.dirname(os.path.dirname(__file__))) / "third_party"
85+
lib_path = Path(third_party_base_dir) / backend.name
86+
if not os.path.exists(lib_path):
87+
succ = git_clone(lib=backend, lib_path=lib_path)
88+
if not succ and required:
89+
raise RuntimeError("Bad network ! ")
90+
else:
91+
print(f'Found third_party {backend.name} at {lib_path}\n')
92+
if callable(hock):
93+
hock(third_party_base_dir=third_party_base_dir, backend=backend)
94+
95+
96+
def post_install():
97+
98+
backend_utils.post_install()
9199

92100

93101
class FlagTreeCache:
@@ -304,54 +312,6 @@ def get_package_dir(packages):
304312
package_dict["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime"
305313
return package_dict
306314

307-
@staticmethod
308-
def download_third_party():
309-
import git
310-
MAX_RETRY = 4
311-
global use_triton_shared, flagtree_backend
312-
third_party_base_dir = Path(os.path.dirname(os.path.dirname(__file__))) / "third_party"
313-
314-
def git_clone(lib, lib_path):
315-
global use_triton_shared
316-
print(f"Clone {lib.name} into {lib_path} ...")
317-
retry_count = MAX_RETRY
318-
while (retry_count):
319-
try:
320-
repo = git.Repo.clone_from(lib.url, lib_path)
321-
repo.git.checkout(lib.tag)
322-
if lib.name in flagtree_backend_info:
323-
sub_triton_path = Path(lib_path) / "triton"
324-
if os.path.exists(sub_triton_path):
325-
shutil.rmtree(sub_triton_path)
326-
print(f"successfully clone {lib.name} into {lib_path} ...")
327-
return
328-
except Exception:
329-
retry_count -= 1
330-
print(f"\n[{MAX_RETRY - retry_count}] retry to clone {lib.name} to {lib_path}")
331-
332-
print(f"Unable to clone third_party {lib.name}")
333-
if lib.name in necessary_third_party:
334-
use_triton_shared = False
335-
print("\n\ttriton_shared is compiled by default, but for "
336-
"some reason we couldn't download triton_shared\n"
337-
"as third_party (most likely for network reasons), "
338-
"so we couldn't compile triton_shared\n")
339-
340-
third_partys = []
341-
if os.environ.get("USE_TRITON_SHARED", "ON") == "ON" and not flagtree_backend:
342-
third_partys.append(flagtree_backend_info["triton_shared"])
343-
else:
344-
use_triton_shared = False
345-
if flagtree_backend in flagtree_backend_info:
346-
third_partys.append(flagtree_backend_info[flagtree_backend])
347-
348-
for lib in third_partys:
349-
lib_path = Path(third_party_base_dir) / lib.name
350-
if not os.path.exists(lib_path):
351-
git_clone(lib=lib, lib_path=lib_path)
352-
else:
353-
print(f'Found third_party {lib.name} at {lib_path}\n')
354-
355315

356316
def handle_flagtree_backend():
357317
global ext_sourcedir
@@ -360,8 +320,6 @@ def handle_flagtree_backend():
360320
extend_backends.append(flagtree_backend)
361321
if "editable_wheel" in sys.argv and flagtree_backend != "ascend":
362322
ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/"
363-
if use_triton_shared and not flagtree_backend:
364-
default_backends.append("triton_shared")
365323

366324

367325
def set_env(env_dict: dict):
@@ -373,8 +331,15 @@ def check_env(env_val):
373331
return os.environ.get(env_val, '') != ''
374332

375333

376-
CommonUtils.download_third_party()
334+
download_flagtree_third_party("triton_shared", condition=(not flagtree_backend))
335+
336+
download_flagtree_third_party("triton_ascend", condition=(flagtree_backend == "ascend"),
337+
hock=utils.ascend.precompile_hock, required=True)
338+
339+
download_flagtree_third_party("cambricon", condition=(flagtree_backend == "cambricon"), required=True)
340+
377341
handle_flagtree_backend()
342+
378343
cache = FlagTreeCache()
379344

380345
# iluvatar
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
from dataclasses import dataclass
2+
from pathlib import Path
3+
import importlib.util
4+
import os
5+
from . import ascend
6+
7+
8+
@dataclass
9+
class FlagTreeBackend:
10+
name: str
11+
url: str
12+
tag: str = None
13+
14+
15+
flagtree_backends = (
16+
FlagTreeBackend(name="triton_shared", url="https://github.com/microsoft/triton-shared.git",
17+
tag="380b87122c88af131530903a702d5318ec59bb33"),
18+
FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git",
19+
tag="00f51c2e48a943922f86f03d58e29f514def646d"),
20+
FlagTreeBackend(
21+
name="triton_ascend",
22+
url="https://gitee.com/ascend/triton-ascend.git",
23+
),
24+
)
25+
26+
27+
def activate(backend, suffix=".py"):
28+
module_path = Path(os.path.dirname(__file__)) / backend
29+
module_path = str(module_path) + suffix
30+
spec = importlib.util.spec_from_file_location("module", module_path)
31+
module = importlib.util.module_from_spec(spec)
32+
spec.loader.exec_module(module)
33+
return module
34+
35+
36+
__all__ = ["ascend"]

0 commit comments

Comments
 (0)