1010import hashlib
1111from dataclasses import dataclass
1212
13- use_triton_shared = True
14- necessary_third_party = ["triton_shared " ]
13+ use_triton_shared = False
14+ necessary_third_party = ["flir " ]
1515default_backends = ["nvidia" , "amd" ]
1616extend_backends = []
1717ext_sourcedir = "triton/_C/"
@@ -27,6 +27,9 @@ class FlagTreeBackend:
2727
2828
2929flagtree_backend_info = {
30+ "flir" :
31+ FlagTreeBackend (name = "flir" , url = "git@github.com:FlagTree/flir.git" ,
32+ tag = "e72b83ba46a5a9dd6466c7102f93fd600cde909e" ),
3033 "triton_shared" :
3134 FlagTreeBackend (name = "triton_shared" , url = "https://github.com/microsoft/triton-shared.git" ,
3235 tag = "5842469a16b261e45a2c67fbfc308057622b03ee" ),
@@ -274,13 +277,14 @@ def git_clone(lib, lib_path):
274277
275278 print (f"Unable to clone third_party { lib .name } " )
276279 if lib .name in necessary_third_party :
277- use_triton_shared = False
278- print ("\n \t triton_shared is compiled by default, but for "
280+ use_triton_shared = False # TODO
281+ print (f "\n \t { lib . name } is compiled by default, but for "
279282 "some reason we couldn't download triton_shared\n "
280283 "as third_party (most likely for network reasons), "
281284 "so we couldn't compile triton_shared\n " )
282285
283286 third_partys = []
287+ third_partys .append (flagtree_backend_info ["flir" ])
284288 if os .environ .get ("USE_TRITON_SHARED" , "ON" ) == "ON" :
285289 third_partys .append (flagtree_backend_info ["triton_shared" ])
286290 else :
@@ -303,6 +307,7 @@ def handle_flagtree_backend():
303307 extend_backends .append (flagtree_backend )
304308 if "editable_wheel" in sys .argv and flagtree_backend != "aipu" :
305309 ext_sourcedir = os .path .abspath (f"../third_party/{ flagtree_backend } /python/{ ext_sourcedir } " ) + "/"
310+ default_backends .append ("flir" )
306311 if use_triton_shared :
307312 default_backends .append ("triton_shared" )
308313
0 commit comments