Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -607,13 +607,13 @@ def build_extension(self, ext):
if helper.flagtree_backend:
if helper.flagtree_backend in ("aipu", "tsingmicro", "enflame"):
backends = [
*BackendInstaller.copy(helper.default_backends + helper.extend_backends),
*BackendInstaller.copy(helper.configs.default_backends + tuple(helper.configs.extend_backends)),
*BackendInstaller.copy_externals(),
]
else:
backends = [*BackendInstaller.copy(helper.extend_backends), *BackendInstaller.copy_externals()]
backends = [*BackendInstaller.copy(helper.configs.extend_backends), *BackendInstaller.copy_externals()]
else:
backends = [*BackendInstaller.copy(helper.default_backends), *BackendInstaller.copy_externals()]
backends = [*BackendInstaller.copy(helper.configs.default_backends), *BackendInstaller.copy_externals()]


def add_link_to_backends():
Expand Down Expand Up @@ -799,7 +799,7 @@ def get_git_version_suffix():
entry_points=get_entry_points(),
package_data=package_data,
include_package_data=True,
ext_modules=[CMakeExtension("triton", helper.ext_sourcedir)],
ext_modules=[CMakeExtension("triton", helper.configs.ext_sourcedir)],
cmdclass={
"build_ext": CMakeBuild,
"build_py": CMakeBuildPy,
Expand Down
42 changes: 18 additions & 24 deletions python/setup_tools/setup_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,11 @@
import hashlib
from distutils.sysconfig import get_python_lib
from . import utils
from .utils.tools import flagtree_configs as configs

extend_backends = []
default_backends = ["nvidia", "amd"]
plugin_backends = ["cambricon", "ascend", "aipu", "tsingmicro", "enflame"]
ext_sourcedir = "triton/_C/"
flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower()
flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower()
offline_build = os.getenv("FLAGTREE_PLUGIN", "OFF")
device_mapping = {"xpu": "xpu", "mthreads": "musa", "ascend": "ascend"}
activated_module = utils.activate(flagtree_backend)
downloader = utils.tools.DownloadManager()
configs = configs
flagtree_backend = configs.flagtree_backend

set_llvm_env = lambda path: set_env({
'LLVM_INCLUDE_DIRS': Path(path) / "include",
Expand All @@ -27,26 +21,26 @@

def install_extension(*args, **kargs):
try:
activated_module.install_extension(*args, **kargs)
configs.activated_module.install_extension(*args, **kargs)
except Exception:
pass


def get_backend_cmake_args(*args, **kargs):
try:
return activated_module.get_backend_cmake_args(*args, **kargs)
return configs.activated_module.get_backend_cmake_args(*args, **kargs)
except Exception:
return []


def get_device_name():
return device_mapping[flagtree_backend]
return configs.device_alias[flagtree_backend]


def get_extra_packages():
packages = []
try:
packages = activated_module.get_extra_install_packages()
packages = configs.activated_module.get_extra_install_packages()
except Exception:
packages = []
return packages
Expand All @@ -55,7 +49,7 @@ def get_extra_packages():
def get_package_data_tools():
package_data = ["compile.h", "compile.c"]
try:
package_data += activated_module.get_package_data_tools()
package_data += configs.activated_module.get_package_data_tools()
except Exception:
package_data
return package_data
Expand All @@ -81,15 +75,15 @@ def download_flagtree_third_party(name, condition, required=False, hock=None):
submodule = utils.flagtree_submodules[name]
downloader.download(module=submodule, required=required)
if callable(hock):
hock(third_party_base_dir=utils.flagtree_submodule_dir, backend=submodule,
default_backends=default_backends)
configs.default_backends = hock(third_party_base_dir=configs.flagtree_submodule_dir, backend=submodule,
default_backends=configs.default_backends)
else:
print(f"\033[1;33m[Note] Skip downloading {name} since USE_{name.upper()} is set to OFF\033[0m")


def post_install():
try:
activated_module.post_install()
configs.activated_module.post_install()
except Exception:
pass

Expand Down Expand Up @@ -250,14 +244,14 @@ def skip_package_dir(package):
if 'backends' in package or 'profiler' in package:
return True
try:
return activated_module.skip_package_dir(package)
return configs.activated_module.skip_package_dir(package)
except Exception:
return False

@staticmethod
def get_package_dir(packages):
package_dict = {}
if flagtree_backend and flagtree_backend not in plugin_backends:
if flagtree_backend and flagtree_backend not in configs.plugin_backends:
connection = []
backend_triton_path = f"../third_party/{flagtree_backend}/python/"
for package in packages:
Expand All @@ -267,7 +261,7 @@ def get_package_dir(packages):
connection.append(pair)
package_dict.update(connection)
try:
package_dict.update(activated_module.get_package_dir())
package_dict.update(configs.activated_module.get_package_dir())
except Exception:
pass
return package_dict
Expand All @@ -277,8 +271,8 @@ def handle_flagtree_backend():
global ext_sourcedir
if flagtree_backend:
print(f"\033[1;32m[INFO] FlagtreeBackend is {flagtree_backend}\033[0m")
extend_backends.append(flagtree_backend)
if "editable_wheel" in sys.argv and flagtree_backend not in plugin_backends:
configs.extend_backends.append(flagtree_backend)
if "editable_wheel" in sys.argv and flagtree_backend not in configs.plugin_backends:
ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/"


Expand Down Expand Up @@ -337,7 +331,7 @@ def uninstall_triton():
)

cache.store(
file="iluvatarTritonPlugin.so", condition=("iluvatar" == flagtree_backend) and (not flagtree_plugin), url=
file="iluvatarTritonPlugin.so", condition=("iluvatar" == flagtree_backend) and (not configs.flagtree_plugin), url=
"https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/iluvatarTritonPlugin-cpython3.10-glibc2.30-glibcxx3.4.28-cxxabi1.3.12-ubuntu-x86_64_v0.3.0.tar.gz",
copy_dst_path=f"third_party/{flagtree_backend}", md5_digest="015b9af8")

Expand Down Expand Up @@ -376,7 +370,7 @@ def uninstall_triton():
)

cache.store(
file="mthreadsTritonPlugin.so", condition=("mthreads" == flagtree_backend) and (not flagtree_plugin), url=
file="mthreadsTritonPlugin.so", condition=("mthreads" == flagtree_backend) and (not configs.flagtree_plugin), url=
"https://baai-cp-web.ks3-cn-beijing.ksyuncs.com/trans/mthreadsTritonPlugin-cpython3.10-glibc2.35-glibcxx3.4.30-cxxabi1.3.13-ubuntu-x86_64_v0.3.0.tar.gz",
copy_dst_path=f"third_party/{flagtree_backend}", md5_digest="2a9ca0b8")

Expand Down
6 changes: 3 additions & 3 deletions python/setup_tools/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
import importlib.util
import os
from . import tools, default, aipu
from .tools import flagtree_submodule_dir, OfflineBuildManager
from .tools import OfflineBuildManager, flagtree_configs

flagtree_submodules = {
"triton_shared":
tools.Module(name="triton_shared", url="https://github.com/microsoft/triton-shared.git",
commit_id="5842469a16b261e45a2c67fbfc308057622b03ee",
dst_path=os.path.join(flagtree_submodule_dir, "triton_shared")),
dst_path=os.path.join(flagtree_configs.flagtree_submodule_dir, "triton_shared")),
"flir":
tools.Module(name="flir", url="https://github.com/FlagTree/flir.git",
dst_path=os.path.join(flagtree_submodule_dir, "flir")),
dst_path=os.path.join(flagtree_configs.flagtree_submodule_dir, "flir")),
}


Expand Down
6 changes: 5 additions & 1 deletion python/setup_tools/utils/aipu.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
def precompile_hock(*args, **kargs):
default_backends = kargs["default_backends"]
default_backends.append('flir')
default_backends_list = list(default_backends)
default_backends_list.append('flir')
kargs["default_backends"] = tuple(default_backends_list)
default_backends = tuple(default_backends_list)
return default_backends
67 changes: 57 additions & 10 deletions python/setup_tools/utils/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,61 @@
import zipfile
from io import BytesIO
import urllib.request
from dataclasses import dataclass
import json
from build_helpers import get_base_dir
import platform
from typing import Mapping
from types import MappingProxyType
import importlib.util
from dataclasses import dataclass, field

flagtree_root_dir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
flagtree_submodule_dir = os.path.join(flagtree_root_dir, "third_party")
flagtree_backend = os.environ.get("FLAGTREE_BACKEND")
use_cuda_toolkit = ["aipu"]

def _get_flagtree_root() -> str:
return str(Path(__file__).resolve().parents[3])


@dataclass
class FlagtreeConfigs:
default_backends: tuple = ("nvidia", "amd")
plugin_backends: tuple = ("cambricon", "ascend", "aipu", "tsingmicro", "enflame")
use_cuda_toolkit_backends: tuple = ('aipu', )
language_extra_backends: tuple = ('xpu', 'mthreads', "cambricon")
ext_sourcedir: str = "triton/_C/"
flagtree_root_dir: str = field(default_factory=_get_flagtree_root)
flagtree_backend: str = field(default_factory=lambda: os.environ.get("FLAGTREE_BACKEND"))
flagtree_plugin: str = field(default_factory=lambda: os.environ.get("FLAGTREE_PLUGIN"))
extend_backends: list = field(default_factory=list)
activated_module: any = None
flagtree_submodule_dir: str = ''
device_alias_map: Mapping[str, str] = field(default_factory=lambda: MappingProxyType({
"xpu": "xpu",
"mthreads": "musa",
"ascend": "ascend",
"cambricon": "mlu",
}))

def __post_init__(self):
object.__setattr__(
self,
"flagtree_submodule_dir",
os.path.join(self.flagtree_root_dir, "third_party"),
)
object.__setattr__(self, "activated_module", self._activate_device_module(self.flagtree_backend))

def _activate_device_module(self, backend, suffix=".py"):
backend = "default" if not backend else backend
module_path = Path(os.path.dirname(__file__)) / backend
module_path = str(module_path) + suffix
spec = importlib.util.spec_from_file_location("module", module_path)
module = importlib.util.module_from_spec(spec)
try:
spec.loader.exec_module(module)
except Exception:
pass
return module


flagtree_configs = FlagtreeConfigs()


@dataclass
Expand All @@ -40,7 +86,8 @@ def dir_rollback(deep, base_path):


def is_skip_cuda_toolkits():
return flagtree_backend and (flagtree_backend not in use_cuda_toolkit)
return flagtree_configs.flagtree_backend and (flagtree_configs.flagtree_backend
not in flagtree_configs.use_cuda_toolkit_backends)


def remove_triton_in_modules(model):
Expand Down Expand Up @@ -215,7 +262,7 @@ def is_offline_build(self) -> bool:
return os.getenv("TRITON_OFFLINE_BUILD", "OFF") == "ON" or os.getenv("FLAGTREE_OFFLINE_BUILD_DIR")

def copy_to_flagtree_project(self, kargs):
dst_path = os.path.join(flagtree_root_dir,
dst_path = os.path.join(_get_flagtree_root(),
kargs['dst_path']) if 'dst_path' in kargs and kargs['dst_path'] else None
src_path = self.src
if not dst_path:
Expand Down Expand Up @@ -264,7 +311,7 @@ def handle_triton_origin_toolkits(self):
shutil.copytree(src_path, toolkit_cache_path, dirs_exist_ok=True)
else:
raise RuntimeError(
f"\n\n \033[31m[ERROR]:\033[0m The {flagtree_backend} offline build dependency \033[93m{src_path}\033[0m does not exist.\n"
f"\n\n \033[31m[ERROR]:\033[0m The {flagtree_configs.flagtree_backend} offline build dependency \033[93m{src_path}\033[0m does not exist.\n"
)

def validate_offline_build_dir(self, path, required=False):
Expand All @@ -279,7 +326,7 @@ def validate_offline_build_deps(self, path, kargs, required=False):
url = kargs.get('url', None)
if (not path or not os.path.exists(path)) and required:
raise RuntimeError(
f"\n\n \033[31m[ERROR]:\033[0m The {flagtree_backend} offline build dependency \033[93m{path}\033[0m does not exist.\n"
f"\n\n \033[31m[ERROR]:\033[0m The {flagtree_configs.flagtree_backend} offline build dependency \033[93m{path}\033[0m does not exist.\n"
f" And you can download the dependency package from the \n \033[93m{url}\033[0m \n"
f" then extract it to the \033[93m{self.offline_build_dir}\033[0m directory you specified !\033[0m\n\n")

Expand All @@ -300,7 +347,7 @@ def single_build(self, *args, **kargs):
self.copy_to_flagtree_project(kargs)
self.handle_flagtree_hock(kargs)
if is_skip_cuda_toolkits():
print(f"[INFO] Skipping CUDA toolkits for {flagtree_backend} backend in offline build.")
print(f"[INFO] Skipping CUDA toolkits for {flagtree_configs.flagtree_backend} backend in offline build.")
else:
self.handle_triton_origin_toolkits()
return True