88import urllib .request
99from pathlib import Path
1010import hashlib
11- from dataclasses import dataclass
1211from distutils .sysconfig import get_python_lib
12+ from . import utils
1313
14- use_triton_shared = False
15- necessary_third_party = ["triton_shared" ]
16- default_backends = ["nvidia" , "amd" ]
1714extend_backends = []
15+ default_backends = ["nvidia" , "amd" ]
1816plugin_backends = ["cambricon" , "ascend" ]
1917ext_sourcedir = "triton/_C/"
2018flagtree_backend = os .getenv ("FLAGTREE_BACKEND" , "" ).lower ()
2119flagtree_plugin = os .getenv ("FLAGTREE_PLUGIN" , "" ).lower ()
2220device_mapping = {"xpu" : "xpu" , "mthreads" : "musa" , "ascend" : "ascend" }
23-
24-
25- @dataclass
26- class FlagTreeBackend :
27- name : str
28- url : str
29- tag : str
30-
31-
32- flagtree_backend_info = {
33- "triton_shared" :
34- FlagTreeBackend (name = "triton_shared" , url = "https://github.com/microsoft/triton-shared.git" ,
35- tag = "380b87122c88af131530903a702d5318ec59bb33" ),
36- "cambricon" :
37- FlagTreeBackend (name = "cambricon" , url = "https://github.com/Cambricon/triton-linalg.git" ,
38- tag = "00f51c2e48a943922f86f03d58e29f514def646d" ),
39- }
21+ flagtree_backends = utils .flagtree_backends
22+ backend_utils = utils .activate (flagtree_backend )
4023
4124set_llvm_env = lambda path : set_env ({
4225 'LLVM_INCLUDE_DIRS' : Path (path ) / "include" ,
@@ -51,43 +34,68 @@ def get_device_name():
5134
5235def get_extra_packages ():
5336 packages = []
54- if flagtree_backend == 'ascend' :
55- packages = [
56- "triton/triton_patch" ,
57- "triton/triton_patch/language" ,
58- "triton/triton_patch/compiler" ,
59- "triton/triton_patch/runtime" ,
60- ]
37+ try :
38+ packages = backend_utils .get_extra_install_packages ()
39+ except Exception :
40+ packages = []
6141 return packages
6242
6343
6444def get_package_data_tools ():
6545 package_data = ["compile.h" , "compile.c" ]
66- if flagtree_backend == 'xpu' :
67- package_data += ["compile_xpu.h" , "compile_xpu.c" ]
46+ try :
47+ package_data += backend_utils .get_package_data_tools ()
48+ except Exception :
49+ package_data
6850 return package_data
6951
7052
71- def post_install (self ):
72-
73- def get_module (module_path ):
74- import importlib .util
75- import os
76- module_path = os .path .abspath (module_path )
77- spec = importlib .util .spec_from_file_location ("module" , module_path )
78- module = importlib .util .module_from_spec (spec )
79- spec .loader .exec_module (module )
80- return module
81-
82- def ascend ():
83- utils = get_module ("../third_party/ascend/utils.py" )
84- utils .post_install ()
85-
86- code = f"{ flagtree_backend } ()"
87- try :
88- exec (code , globals (), locals ())
89- except : #noqa: E722
90- pass
53+ def git_clone (lib , lib_path ):
54+ import git
55+ MAX_RETRY = 4
56+ print (f"Clone { lib .name } into { lib_path } ..." )
57+ retry_count = MAX_RETRY
58+ while (retry_count ):
59+ try :
60+ repo = git .Repo .clone_from (lib .url , lib_path )
61+ if lib .tag is not None :
62+ repo .git .checkout (lib .tag )
63+ sub_triton_path = Path (lib_path ) / "triton"
64+ if os .path .exists (sub_triton_path ):
65+ shutil .rmtree (sub_triton_path )
66+ print (f"successfully clone { lib .name } into { lib_path } ..." )
67+ return True
68+ except Exception :
69+ retry_count -= 1
70+ print (f"\n [{ MAX_RETRY - retry_count } ] retry to clone { lib .name } to { lib_path } " )
71+ return False
72+
73+
74+ def download_flagtree_third_party (name , condition , required = False , hock = None ):
75+ if not condition :
76+ return
77+ backend = None
78+ for _backend in flagtree_backends :
79+ if _backend .name is name :
80+ backend = _backend
81+ break
82+ if backend is None :
83+ return backend
84+ third_party_base_dir = Path (os .path .dirname (os .path .dirname (__file__ ))) / "third_party"
85+ lib_path = Path (third_party_base_dir ) / backend .name
86+ if not os .path .exists (lib_path ):
87+ succ = git_clone (lib = backend , lib_path = lib_path )
88+ if not succ and required :
89+ raise RuntimeError ("Bad network ! " )
90+ else :
91+ print (f'Found third_party { backend .name } at { lib_path } \n ' )
92+ if callable (hock ):
93+ hock (third_party_base_dir = third_party_base_dir , backend = backend )
94+
95+
96+ def post_install ():
97+
98+ backend_utils .post_install ()
9199
92100
93101class FlagTreeCache :
@@ -304,54 +312,6 @@ def get_package_dir(packages):
304312 package_dict ["triton/triton_patch/runtime" ] = f"{ triton_patch_root_rel_dir } /runtime"
305313 return package_dict
306314
307- @staticmethod
308- def download_third_party ():
309- import git
310- MAX_RETRY = 4
311- global use_triton_shared , flagtree_backend
312- third_party_base_dir = Path (os .path .dirname (os .path .dirname (__file__ ))) / "third_party"
313-
314- def git_clone (lib , lib_path ):
315- global use_triton_shared
316- print (f"Clone { lib .name } into { lib_path } ..." )
317- retry_count = MAX_RETRY
318- while (retry_count ):
319- try :
320- repo = git .Repo .clone_from (lib .url , lib_path )
321- repo .git .checkout (lib .tag )
322- if lib .name in flagtree_backend_info :
323- sub_triton_path = Path (lib_path ) / "triton"
324- if os .path .exists (sub_triton_path ):
325- shutil .rmtree (sub_triton_path )
326- print (f"successfully clone { lib .name } into { lib_path } ..." )
327- return
328- except Exception :
329- retry_count -= 1
330- print (f"\n [{ MAX_RETRY - retry_count } ] retry to clone { lib .name } to { lib_path } " )
331-
332- print (f"Unable to clone third_party { lib .name } " )
333- if lib .name in necessary_third_party :
334- use_triton_shared = False
335- print ("\n \t triton_shared is compiled by default, but for "
336- "some reason we couldn't download triton_shared\n "
337- "as third_party (most likely for network reasons), "
338- "so we couldn't compile triton_shared\n " )
339-
340- third_partys = []
341- if os .environ .get ("USE_TRITON_SHARED" , "ON" ) == "ON" and not flagtree_backend :
342- third_partys .append (flagtree_backend_info ["triton_shared" ])
343- else :
344- use_triton_shared = False
345- if flagtree_backend in flagtree_backend_info :
346- third_partys .append (flagtree_backend_info [flagtree_backend ])
347-
348- for lib in third_partys :
349- lib_path = Path (third_party_base_dir ) / lib .name
350- if not os .path .exists (lib_path ):
351- git_clone (lib = lib , lib_path = lib_path )
352- else :
353- print (f'Found third_party { lib .name } at { lib_path } \n ' )
354-
355315
356316def handle_flagtree_backend ():
357317 global ext_sourcedir
@@ -360,8 +320,6 @@ def handle_flagtree_backend():
360320 extend_backends .append (flagtree_backend )
361321 if "editable_wheel" in sys .argv and flagtree_backend != "ascend" :
362322 ext_sourcedir = os .path .abspath (f"../third_party/{ flagtree_backend } /python/{ ext_sourcedir } " ) + "/"
363- if use_triton_shared and not flagtree_backend :
364- default_backends .append ("triton_shared" )
365323
366324
367325def set_env (env_dict : dict ):
@@ -373,8 +331,15 @@ def check_env(env_val):
373331 return os .environ .get (env_val , '' ) != ''
374332
375333
376- CommonUtils .download_third_party ()
334+ download_flagtree_third_party ("triton_shared" , condition = (not flagtree_backend ))
335+
336+ download_flagtree_third_party ("triton_ascend" , condition = (flagtree_backend == "ascend" ),
337+ hock = utils .ascend .precompile_hock , required = True )
338+
339+ download_flagtree_third_party ("cambricon" , condition = (flagtree_backend == "cambricon" ), required = True )
340+
377341handle_flagtree_backend ()
342+
378343cache = FlagTreeCache ()
379344
380345# iluvatar
0 commit comments