Skip to content

Commit f3eb3cd

Browse files
committed
polish_code
1 parent a1ba666 commit f3eb3cd

6 files changed

Lines changed: 61 additions & 64 deletions

File tree

python/setup_tools/setup_helper.py

Lines changed: 46 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,9 @@
2020

2121

2222
def install_extension(*args, **kargs):
23-
try:
24-
configs.activated_module.install_extension(*args, **kargs)
25-
except Exception:
26-
pass
23+
backend_spec_install_extension_fn = get_hook_instance("install_extension")
24+
if backend_spec_install_extension_fn:
25+
backend_spec_install_extension_fn(*args, **kargs)
2726

2827

2928
def get_backend_cmake_args(*args, **kargs):
@@ -62,30 +61,41 @@ def dir_rollback(deep, base_path):
6261
return Path(base_path)
6362

6463

64+
def get_hook_instance(hook_name):
65+
if not hook_name or not configs.activated_module:
66+
return None
67+
hook_instance = getattr(configs.activated_module, hook_name, None)
68+
return hook_instance if callable(hook_instance) else None
69+
70+
71+
def print_info(message):
72+
print(f"\033[1;32m[INFO] {message}\033[0m")
73+
74+
6575
def enable_flagtree_third_party(name):
6676
if name in ["triton_shared"]:
6777
return os.environ.get(f"USE_{name.upper()}", 'OFF') == 'ON'
6878
else:
6979
return os.environ.get(f"USE_{name.upper()}", 'ON') == 'ON'
7080

7181

72-
def download_flagtree_third_party(name, condition, required=False, hock=None):
82+
def download_flagtree_third_party(name, condition, required=False, hook=None):
7383
if condition:
7484
if enable_flagtree_third_party(name):
7585
submodule = utils.flagtree_submodules[name]
7686
downloader.download(module=submodule, required=required)
77-
if callable(hock):
78-
configs.default_backends = hock(third_party_base_dir=configs.flagtree_submodule_dir, backend=submodule,
79-
default_backends=configs.default_backends)
87+
hook_func = get_hook_instance(hook)
88+
if hook_func:
89+
configs.default_backends = hook_func(third_party_base_dir=configs.flagtree_submodule_dir,
90+
backend=submodule, default_backends=configs.default_backends)
8091
else:
81-
print(f"\033[1;33m[Note] Skip downloading {name} since USE_{name.upper()} is set to OFF\033[0m")
92+
print_info(f"Skip downloading {name} since USE_{name.upper()} is set to OFF")
8293

8394

8495
def post_install():
85-
try:
86-
configs.activated_module.post_install()
87-
except Exception:
88-
pass
96+
backend_spec_post_install_fn = get_hook_instance("post_install")
97+
if backend_spec_post_install_fn:
98+
backend_spec_post_install_fn()
8999

90100

91101
class FlagTreeCache:
@@ -169,9 +179,9 @@ def reverse_copy(self, src_path, cache_file_path, md5_digest):
169179
return False
170180

171181
def store(self, file=None, condition=None, url=None, copy_src_path=None, copy_dst_path=None, files=None,
172-
md5_digest=None, pre_hock=None, post_hock=None):
182+
md5_digest=None, pre_hook=None, post_hook=None):
173183

174-
if not condition or (pre_hock and pre_hock()):
184+
if not condition or (pre_hook and pre_hook()):
175185
return
176186
is_url = False if url is None else True
177187
path = self.sub_dirs[flagtree_backend] if flagtree_backend else self.dir_path
@@ -209,7 +219,7 @@ def store(self, file=None, condition=None, url=None, copy_src_path=None, copy_ds
209219
shutil.copytree(src_path, dst_path, dirs_exist_ok=True)
210220
else:
211221
shutil.copy(src_path, dst_path)
212-
post_hock(self.cache_files[file]) if post_hock else False
222+
post_hook(self.cache_files[file]) if post_hook else False
213223

214224
def get(self, file_name) -> Path:
215225
return self.cache_files[file_name]
@@ -312,13 +322,9 @@ def uninstall_triton():
312322
print('[INFO] FlagTree Offline Build: No offline build for triton origin toolkits')
313323
offline_build = False
314324

315-
download_flagtree_third_party("triton_shared", hock=utils.default.precompile_hock, condition=(not flagtree_backend))
316-
317-
download_flagtree_third_party("flir", condition=(flagtree_backend == "aipu"), hock=utils.aipu.precompile_hock,
318-
required=True)
325+
download_flagtree_third_party("triton_shared", hook="precompile_hook", condition=(not flagtree_backend))
319326

320-
download_flagtree_third_party("flir", condition=(flagtree_backend == "tsingmicro"),
321-
hock=utils.tsingmicro.precompile_hock, required=True)
327+
download_flagtree_third_party("flir", condition=(flagtree_backend in configs.use_filr), required=True)
322328

323329
handle_flagtree_backend()
324330

@@ -329,8 +335,8 @@ def uninstall_triton():
329335
file="iluvatar-llvm18-x86_64",
330336
condition=("iluvatar" == flagtree_backend),
331337
url="https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/iluvatar-llvm18-x86_64_v0.3.0.tar.gz",
332-
pre_hock=lambda: check_env('LLVM_SYSPATH'),
333-
post_hock=set_llvm_env,
338+
pre_hook=lambda: check_env('LLVM_SYSPATH'),
339+
post_hook=set_llvm_env,
334340
)
335341

336342
cache.store(
@@ -343,8 +349,8 @@ def uninstall_triton():
343349
file="XTDK-llvm18-ubuntu2004_x86_64",
344350
condition=("xpu" == flagtree_backend),
345351
url="https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/XTDK-llvm19-ubuntu2004_x86_64_v0.3.0.tar.gz",
346-
pre_hock=lambda: check_env('LLVM_SYSPATH'),
347-
post_hock=set_llvm_env,
352+
pre_hook=lambda: check_env('LLVM_SYSPATH'),
353+
post_hook=set_llvm_env,
348354
)
349355

350356
cache.store(file="xre-Linux-x86_64", condition=("xpu" == flagtree_backend),
@@ -368,8 +374,8 @@ def uninstall_triton():
368374
file="mthreads-llvm19-glibc2.34-glibcxx3.4.30-x64",
369375
condition=("mthreads" == flagtree_backend),
370376
url="https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/mthreads-llvm19-glibc2.34-glibcxx3.4.30-x64_v0.1.0.tar.gz",
371-
pre_hock=lambda: check_env('LLVM_SYSPATH'),
372-
post_hock=set_llvm_env,
377+
pre_hook=lambda: check_env('LLVM_SYSPATH'),
378+
post_hook=set_llvm_env,
373379
)
374380

375381
cache.store(
@@ -382,26 +388,26 @@ def uninstall_triton():
382388
file="llvm-b5cc222d-ubuntu-arm64",
383389
condition=("ascend" == flagtree_backend),
384390
url="https://oaitriton.blob.core.windows.net/public/llvm-builds/llvm-b5cc222d-ubuntu-arm64.tar.gz",
385-
pre_hock=lambda: check_env('LLVM_SYSPATH'),
386-
post_hock=set_llvm_env,
391+
pre_hook=lambda: check_env('LLVM_SYSPATH'),
392+
post_hook=set_llvm_env,
387393
)
388394

389395
# aipu
390396
cache.store(
391397
file="llvm-a66376b0-ubuntu-x64-clang16-lld16",
392398
condition=("aipu" == flagtree_backend),
393399
url="https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/llvm-a66376b0-ubuntu-x64-clang16-lld16_v0.4.0.tar.gz",
394-
pre_hock=lambda: check_env('LLVM_SYSPATH'),
395-
post_hock=set_llvm_env,
400+
pre_hook=lambda: check_env('LLVM_SYSPATH'),
401+
post_hook=set_llvm_env,
396402
)
397403

398404
# enflame
399405
cache.store(
400406
file="llvm-d752c5b-gcc9-x64",
401407
condition=("enflame" == flagtree_backend),
402408
url="https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/enflame-llvm21-d752c5b-gcc9-x64_v0.3.0.tar.gz",
403-
pre_hock=lambda: check_env('KURAMA_LLVM_DIR_GCU300'),
404-
post_hock=lambda path: set_env({
409+
pre_hook=lambda: check_env('KURAMA_LLVM_DIR_GCU300'),
410+
post_hook=lambda path: set_env({
405411
'KURAMA_LLVM_DIR_GCU300': path,
406412
'LLVM_INCLUDE_DIRS': Path(path) / "include",
407413
'LLVM_LIBRARY_DIR': Path(path) / "lib",
@@ -415,16 +421,16 @@ def uninstall_triton():
415421
condition=("tsingmicro" == flagtree_backend),
416422
url=
417423
"https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/tsingmicro-llvm21-glibc2.30-glibcxx3.4.28-python3.10-x64_v0.4.0.tar.gz",
418-
pre_hock=lambda: check_env('LLVM_SYSPATH'),
419-
post_hock=set_llvm_env,
424+
pre_hook=lambda: check_env('LLVM_SYSPATH'),
425+
post_hook=set_llvm_env,
420426
)
421427

422428
cache.store(
423429
file="tx8_deps",
424430
condition=("tsingmicro" == flagtree_backend),
425431
url="https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/tx8_depends_dev_20251218_164108_v0.4.0.tar.gz",
426-
pre_hock=lambda: check_env('TX8_DEPS_ROOT'),
427-
post_hock=lambda path: set_env({
432+
pre_hook=lambda: check_env('TX8_DEPS_ROOT'),
433+
post_hook=lambda path: set_env({
428434
'TX8_DEPS_ROOT': path,
429435
}),
430436
)
@@ -435,6 +441,6 @@ def uninstall_triton():
435441
condition=("hcu" == flagtree_backend),
436442
url=
437443
"https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/hcu-llvm20-df0864e-glibc2.35-glibcxx3.4.30-ubuntu-x86_64_v0.3.0.tar.gz",
438-
pre_hock=lambda: check_env('LLVM_SYSPATH'),
439-
post_hock=set_llvm_env,
444+
pre_hook=lambda: check_env('LLVM_SYSPATH'),
445+
post_hook=set_llvm_env,
440446
)

python/setup_tools/utils/aipu.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +0,0 @@
1-
def precompile_hock(*args, **kargs):
2-
default_backends = kargs["default_backends"]
3-
default_backends_list = list(default_backends)
4-
default_backends_list.append('flir')
5-
kargs["default_backends"] = tuple(default_backends_list)
6-
default_backends = tuple(default_backends_list)
7-
return default_backends

python/setup_tools/utils/ascend.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def get_extra_install_packages():
190190
]
191191

192192

193-
def precompile_hock(*args, **kargs):
193+
def precompile_hook(*args, **kargs):
194194
third_party_base_dir = Path(kargs['third_party_base_dir'])
195195
ascend_path = Path(third_party_base_dir) / "ascend"
196196
patch_path = Path(ascend_path) / "triton_patch"
Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,7 @@
1-
def precompile_hock(*args, **kargs):
2-
default_backends = kargs['default_backends']
3-
default_backends.append('triton_shared')
1+
def precompile_hook(*args, **kargs):
2+
default_backends = kargs["default_backends"]
3+
default_backends_list = list(default_backends)
4+
default_backends_list.append('triton_shared')
5+
kargs["default_backends"] = tuple(default_backends_list)
6+
default_backends = tuple(default_backends_list)
7+
return default_backends

python/setup_tools/utils/tools.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ def _get_flagtree_root() -> str:
2020

2121
@dataclass
2222
class FlagtreeConfigs:
23+
use_filr: tuple = ("aipu", "tsingmicro")
2324
default_backends: tuple = ("nvidia", "amd")
2425
plugin_backends: tuple = ("cambricon", "ascend", "aipu", "tsingmicro", "enflame")
2526
use_cuda_toolkit_backends: tuple = ('aipu', )
@@ -45,6 +46,8 @@ def __post_init__(self):
4546
os.path.join(self.flagtree_root_dir, "third_party"),
4647
)
4748
object.__setattr__(self, "activated_module", self._activate_device_module(self.flagtree_backend))
49+
if self.flagtree_backend in self.use_filr:
50+
self.default_backends = self.default_backends + ("flir", )
4851

4952
def _activate_device_module(self, backend, suffix=".py"):
5053
backend = "default" if not backend else backend
@@ -274,9 +277,9 @@ def copy_to_flagtree_project(self, kargs):
274277
else:
275278
shutil.copy(src_path, dst_path)
276279

277-
def handle_flagtree_hock(self, kargs):
278-
if 'post_hock' in kargs and kargs['post_hock']:
279-
kargs['post_hock'](self.src)
280+
def handle_flagtree_hook(self, kargs):
281+
if 'post_hook' in kargs and kargs['post_hook']:
282+
kargs['post_hook'](self.src)
280283

281284
def handle_triton_origin_toolkits(self):
282285

@@ -345,7 +348,7 @@ def single_build(self, *args, **kargs):
345348
self.validate_offline_build(self.src, required, kargs=kargs)
346349
print(f"[INFO] Building in offline mode using directory: {self.src}")
347350
self.copy_to_flagtree_project(kargs)
348-
self.handle_flagtree_hock(kargs)
351+
self.handle_flagtree_hook(kargs)
349352
if is_skip_cuda_toolkits():
350353
print(f"[INFO] Skipping CUDA toolkits for {flagtree_configs.flagtree_backend} backend in offline build.")
351354
else:

python/setup_tools/utils/tsingmicro.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,6 @@
11
import os
22

33

4-
def precompile_hock(*args, **kargs):
5-
default_backends = kargs["default_backends"]
6-
default_backends_list = list(default_backends)
7-
default_backends_list.append('flir')
8-
kargs["default_backends"] = tuple(default_backends_list)
9-
default_backends = tuple(default_backends_list)
10-
return default_backends
11-
12-
134
def get_backend_cmake_args(*args, **kargs):
145
build_ext = kargs['build_ext']
156
src_ext_path = build_ext.get_ext_fullpath("triton")

0 commit comments

Comments
 (0)