Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ascend-build-and-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,4 @@ jobs:
shell: bash
run: |
source /usr/local/Ascend/ascend-toolkit/set_env.sh
python3.9 third_party/ascend/python/tutorials/01-vector-add.py
python3.9 third_party/ascend/examples/tutorials/01-vector-add.py
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ third_party/cambricon/
third_party/iluvatar/iluvatarTritonPlugin.so
third_party/triton_shared/
third_party/xpu/backend/xpu3
third_party/ascend

# Proton
python/triton/profiler
Expand Down Expand Up @@ -57,6 +58,7 @@ ptxas
third_party/nvidia/backend/include
third_party/nvidia/backend/lib/cupti


Comment thread
zhzhcookie marked this conversation as resolved.
Outdated
# Docs
docs/_build/
docs/python-api/generated/
Expand Down
10 changes: 5 additions & 5 deletions python/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

import pybind11

import setup_helper as helper
from setup_tools import setup_helper as helper


@dataclass
Expand Down Expand Up @@ -611,31 +611,31 @@ class plugin_install(install):
def run(self):
add_links()
install.run(self)
helper.post_install(self)
helper.post_install()


class plugin_develop(develop):

def run(self):
add_links()
develop.run(self)
helper.post_install(self)
helper.post_install()


class plugin_bdist_wheel(bdist_wheel):

def run(self):
add_links()
bdist_wheel.run(self)
helper.post_install(self)
helper.post_install()


class plugin_egginfo(egg_info):

def run(self):
add_links()
egg_info.run(self)
helper.post_install(self)
helper.post_install()


package_data_tools = helper.get_package_data_tools()
Expand Down
4 changes: 4 additions & 0 deletions python/setup_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from . import setup_helper
from . import utils

__all__ = ["setup_helper", "utils"]
193 changes: 84 additions & 109 deletions python/setup_helper.py → python/setup_tools/setup_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,35 +8,19 @@
import urllib.request
from pathlib import Path
import hashlib
from dataclasses import dataclass
from distutils.sysconfig import get_python_lib
from . import utils

use_triton_shared = False
necessary_third_party = ["triton_shared"]
default_backends = ["nvidia", "amd"]
extend_backends = []
default_backends = ["nvidia", "amd"]
plugin_backends = ["cambricon", "ascend"]
ext_sourcedir = "triton/_C/"
flagtree_backend = os.getenv("FLAGTREE_BACKEND", "").lower()
flagtree_plugin = os.getenv("FLAGTREE_PLUGIN", "").lower()
offline_build = os.getenv("FLAGTREE_PLUGIN", "OFF")
device_mapping = {"xpu": "xpu", "mthreads": "musa", "ascend": "ascend"}


@dataclass
class FlagTreeBackend:
name: str
url: str
tag: str


flagtree_backend_info = {
"triton_shared":
FlagTreeBackend(name="triton_shared", url="https://github.com/microsoft/triton-shared.git",
tag="380b87122c88af131530903a702d5318ec59bb33"),
"cambricon":
FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git",
tag="00f51c2e48a943922f86f03d58e29f514def646d"),
}
flagtree_backends = utils.flagtree_backends
backend_utils = utils.activate(flagtree_backend)

set_llvm_env = lambda path: set_env({
'LLVM_INCLUDE_DIRS': Path(path) / "include",
Expand All @@ -51,42 +35,78 @@ def get_device_name():

def get_extra_packages():
packages = []
if flagtree_backend == 'ascend':
packages = [
"triton/triton_patch",
"triton/triton_patch/language",
"triton/triton_patch/compiler",
"triton/triton_patch/runtime",
]
try:
packages = backend_utils.get_extra_install_packages()
except Exception:
packages = []
return packages


def get_package_data_tools():
package_data = ["compile.h", "compile.c"]
if flagtree_backend == 'xpu':
package_data += ["compile_xpu.h", "compile_xpu.c"]
try:
package_data += backend_utils.get_package_data_tools()
except Exception:
package_data
return package_data


def post_install(self):

def get_module(module_path):
import importlib.util
import os
module_path = os.path.abspath(module_path)
spec = importlib.util.spec_from_file_location("module", module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module

def ascend():
utils = get_module("../third_party/ascend/utils.py")
utils.post_install()

code = f"{flagtree_backend}()"
def git_clone(lib, lib_path):
import git
MAX_RETRY = 4
print(f"Clone {lib.name} into {lib_path} ...")
retry_count = MAX_RETRY
while (retry_count):
try:
repo = git.Repo.clone_from(lib.url, lib_path)
if lib.tag is not None:
repo.git.checkout(lib.tag)
sub_triton_path = Path(lib_path) / "triton"
if os.path.exists(sub_triton_path):
shutil.rmtree(sub_triton_path)
print(f"successfully clone {lib.name} into {lib_path} ...")
return True
except Exception:
retry_count -= 1
print(f"\n[{MAX_RETRY - retry_count}] retry to clone {lib.name} to {lib_path}")
return False


def dir_rollback(deep, base_path):
while (deep):
base_path = os.path.dirname(base_path)
deep -= 1
return Path(base_path)


def download_flagtree_third_party(name, condition, required=False, hock=None):
if not condition:
return
backend = None
for _backend in flagtree_backends:
if _backend.name in name:
backend = _backend
break
if backend is None:
return backend
base_dir = dir_rollback(3, __file__) / "third_party"
prelib_path = Path(base_dir) / name
lib_path = Path(base_dir) / _backend.name

if not os.path.exists(prelib_path) and not os.path.exists(lib_path):
succ = git_clone(lib=backend, lib_path=prelib_path)
if not succ and required:
raise RuntimeError("Bad network ! ")
if callable(hock):
hock(third_party_base_dir=base_dir, backend=backend)
else:
print(f'Found third_party {backend.name} at {lib_path}\n')


def post_install():
try:
exec(code, globals(), locals())
except: #noqa: E722
backend_utils.post_install()
except Exception:
pass


Expand Down Expand Up @@ -279,10 +299,10 @@ def unlink():
def skip_package_dir(package):
if 'backends' in package or 'profiler' in package:
return True
if flagtree_backend in ['cambricon']:
if package not in ['triton', 'triton/_C']:
return True
return False
try:
return backend_utils.skip_package_dir(package)
except Exception:
return False

@staticmethod
def get_package_dir(packages):
Expand All @@ -296,62 +316,12 @@ def get_package_dir(packages):
pair = (package, f"{backend_triton_path}{package}")
connection.append(pair)
package_dict.update(connection)
if flagtree_backend == "ascend":
triton_patch_root_rel_dir = "../third_party/ascend/triton_patch/python/triton_patch"
package_dict["triton/triton_patch"] = f"{triton_patch_root_rel_dir}"
package_dict["triton/triton_patch/language"] = f"{triton_patch_root_rel_dir}/language"
package_dict["triton/triton_patch/compiler"] = f"{triton_patch_root_rel_dir}/compiler"
package_dict["triton/triton_patch/runtime"] = f"{triton_patch_root_rel_dir}/runtime"
try:
package_dict.update(backend_utils.get_package_dir())
except Exception:
pass
return package_dict

@staticmethod
def download_third_party():
import git
MAX_RETRY = 4
global use_triton_shared, flagtree_backend
third_party_base_dir = Path(os.path.dirname(os.path.dirname(__file__))) / "third_party"

def git_clone(lib, lib_path):
global use_triton_shared
print(f"Clone {lib.name} into {lib_path} ...")
retry_count = MAX_RETRY
while (retry_count):
try:
repo = git.Repo.clone_from(lib.url, lib_path)
repo.git.checkout(lib.tag)
if lib.name in flagtree_backend_info:
sub_triton_path = Path(lib_path) / "triton"
if os.path.exists(sub_triton_path):
shutil.rmtree(sub_triton_path)
print(f"successfully clone {lib.name} into {lib_path} ...")
return
except Exception:
retry_count -= 1
print(f"\n[{MAX_RETRY - retry_count}] retry to clone {lib.name} to {lib_path}")

print(f"Unable to clone third_party {lib.name}")
if lib.name in necessary_third_party:
use_triton_shared = False
print("\n\ttriton_shared is compiled by default, but for "
"some reason we couldn't download triton_shared\n"
"as third_party (most likely for network reasons), "
"so we couldn't compile triton_shared\n")

third_partys = []
if os.environ.get("USE_TRITON_SHARED", "ON") == "ON" and not flagtree_backend:
third_partys.append(flagtree_backend_info["triton_shared"])
else:
use_triton_shared = False
if flagtree_backend in flagtree_backend_info:
third_partys.append(flagtree_backend_info[flagtree_backend])

for lib in third_partys:
lib_path = Path(third_party_base_dir) / lib.name
if not os.path.exists(lib_path):
git_clone(lib=lib, lib_path=lib_path)
else:
print(f'Found third_party {lib.name} at {lib_path}\n')


def handle_flagtree_backend():
global ext_sourcedir
Expand All @@ -360,8 +330,6 @@ def handle_flagtree_backend():
extend_backends.append(flagtree_backend)
if "editable_wheel" in sys.argv and flagtree_backend != "ascend":
ext_sourcedir = os.path.abspath(f"../third_party/{flagtree_backend}/python/{ext_sourcedir}") + "/"
if use_triton_shared and not flagtree_backend:
default_backends.append("triton_shared")


def set_env(env_dict: dict):
Expand All @@ -373,8 +341,15 @@ def check_env(env_val):
return os.environ.get(env_val, '') != ''


CommonUtils.download_third_party()
download_flagtree_third_party("triton_shared", condition=(not flagtree_backend))

download_flagtree_third_party("triton_ascend", condition=(flagtree_backend == "ascend"),
hock=utils.ascend.precompile_hock, required=True)

download_flagtree_third_party("cambricon", condition=(flagtree_backend == "cambricon"), required=True)

handle_flagtree_backend()

cache = FlagTreeCache()

# iluvatar
Expand Down
38 changes: 38 additions & 0 deletions python/setup_tools/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from dataclasses import dataclass
from pathlib import Path
import importlib.util
import os
from . import ascend


@dataclass
class FlagTreeBackend:
name: str
url: str
tag: str = None


flagtree_backends = (
FlagTreeBackend(name="triton_shared", url="https://github.com/microsoft/triton-shared.git",
tag="380b87122c88af131530903a702d5318ec59bb33"),
FlagTreeBackend(name="cambricon", url="https://github.com/Cambricon/triton-linalg.git",
tag="00f51c2e48a943922f86f03d58e29f514def646d"),
FlagTreeBackend(
name="ascend",
url="https://gitee.com/ascend/triton-ascend.git",
),
)


def activate(backend, suffix=".py"):
if not backend:
return
module_path = Path(os.path.dirname(__file__)) / backend
module_path = str(module_path) + suffix
spec = importlib.util.spec_from_file_location("module", module_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module


__all__ = ["ascend"]
Loading
Loading