Skip to content

Commit f780e0d

Browse files
committed
add attr
1 parent ae40993 commit f780e0d

File tree

4 files changed

+36
-20
lines changed

4 files changed

+36
-20
lines changed

python/setup_tools/setup_helper.py

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
ext_sourcedir = "triton/_C/"
1818
flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower()
1919
flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower()
20+
offline_build = os.getenv("FLAGTREE_PLUGIN", "OFF")
2021
device_mapping = {"xpu": "xpu", "mthreads": "musa", "ascend": "ascend"}
2122
flagtree_backends = utils.flagtree_backends
2223
backend_utils = utils.activate(flagtree_backend)
@@ -76,26 +77,29 @@ def download_flagtree_third_party(name, condition, required=False, hock=None):
7677
return
7778
backend = None
7879
for _backend in flagtree_backends:
79-
if _backend.name is name:
80+
if _backend.name in name:
8081
backend = _backend
8182
break
8283
if backend is None:
8384
return backend
84-
third_party_base_dir = Path(os.path.dirname(os.path.dirname(__file__))) / "third_party"
85-
lib_path = Path(third_party_base_dir) / backend.name
86-
if not os.path.exists(lib_path):
87-
succ = git_clone(lib=backend, lib_path=lib_path)
85+
third_party_base_dir = Path(os.path.dirname(os.path.dirname(os.path.dirname(__file__)))) / "third_party"
86+
prelib_path = Path(third_party_base_dir) / name
87+
lib_path = Path(third_party_base_dir) / _backend.name
88+
if not os.path.exists(prelib_path) and not os.path.exists(lib_path):
89+
succ = git_clone(lib=backend, lib_path=prelib_path)
8890
if not succ and required:
8991
raise RuntimeError("Bad network ! ")
92+
if callable(hock):
93+
hock(third_party_base_dir=third_party_base_dir, backend=backend)
9094
else:
9195
print(f'Found third_party {backend.name} at {lib_path}\n')
92-
if callable(hock):
93-
hock(third_party_base_dir=third_party_base_dir, backend=backend)
9496

9597

9698
def post_install():
97-
98-
backend_utils.post_install()
99+
try:
100+
backend_utils.post_install()
101+
except Exception:
102+
pass
99103

100104

101105
class FlagTreeCache:
@@ -287,10 +291,10 @@ def unlink():
287291
def skip_package_dir(package):
288292
if 'backends' in package or 'profiler' in package:
289293
return True
290-
if flagtree_backend in ['cambricon']:
291-
if package not in ['triton', 'triton/_C']:
292-
return True
293-
return False
294+
try:
295+
return backend_utils.skip_package_dir(package)
296+
except Exception:
297+
return False
294298

295299
@staticmethod
296300
def get_package_dir(packages):
@@ -304,12 +308,10 @@ def get_package_dir(packages):
304308
pair = (package, f"{backend_triton_path}{package}")
305309
connection.append(pair)
306310
package_dict.update(connection)
307-
if flagtree_backend == "ascend":
308-
triton_patch_root_rel_dir = "../third_party/ascend/triton_patch/python/triton_patch"
309-
package_dict["triton/triton_patch"] = f"{triton_patch_root_rel_dir}"
310-
package_dict["triton/triton_patch/language"] = f"{triton_patch_root_rel_dir}/language"
311-
package_dict["triton/triton_patch/compiler"] = f"{triton_patch_root_rel_dir}/compiler"
312-
package_dict["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime"
311+
try:
312+
package_dict.update(backend_utils.get_package_dir())
313+
except Exception:
314+
pass
313315
return package_dict
314316

315317

python/setup_tools/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class FlagTreeBackend:
1818
FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git",
1919
tag="00f51c2e48a943922f86f03d58e29f514def646d"),
2020
FlagTreeBackend(
21-
name="triton_ascend",
21+
name="ascend",
2222
url="https://gitee.com/ascend/triton-ascend.git",
2323
),
2424
)

python/setup_tools/utils/ascend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,16 @@
33
from pathlib import Path
44

55

6+
def get_package_dir():
7+
package_dict = {}
8+
triton_patch_root_rel_dir = "../third_party/ascend/triton_patch/python/triton_patch"
9+
package_dict["triton/triton_patch"] = f"{triton_patch_root_rel_dir}"
10+
package_dict["triton/triton_patch/language"] = f"{triton_patch_root_rel_dir}/language"
11+
package_dict["triton/triton_patch/compiler"] = f"{triton_patch_root_rel_dir}/compiler"
12+
package_dict["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime"
13+
return package_dict
14+
15+
616
def insert_at_file_start(filepath, import_lines):
717
import tempfile
818
try:
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
def skip_package_dir(package):
2+
if package not in ['triton', 'triton/_C']:
3+
return True
4+
return False

0 commit comments

Comments
 (0)