diff --git a/python/setup.py b/python/setup.py index 685956d15a..43b168dc10 100644 --- a/python/setup.py +++ b/python/setup.py @@ -50,21 +50,29 @@ class BackendInstaller: @staticmethod def prepare(backend_name: str, backend_src_dir: str = None, is_external: bool = False): + dir_mapping = {"mlu": "cambricon"} + actual_dir_name = dir_mapping.get(backend_name, backend_name) # Initialize submodule if there is one for in-tree backends. if not is_external: root_dir = os.path.join(os.pardir, "third_party") - assert backend_name in os.listdir( - root_dir), f"{backend_name} is requested for install but not present in {root_dir}" + assert actual_dir_name in os.listdir( + root_dir), f"{actual_dir_name} is requested for install but not present in {root_dir}" try: - subprocess.run(["git", "submodule", "update", "--init", f"{backend_name}"], check=True, - stdout=subprocess.DEVNULL, cwd=root_dir) + # flagtree: check if the submodule is defined in .gitmodules + check_result = subprocess.run( + ["git", "config", "-f", ".gitmodules", "--get-regexp", f"submodule.*{actual_dir_name}.path"], + check=False, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL, cwd=os.pardir) + # flagtree: only execute git submodule update for actual git submodules + if check_result.returncode == 0: + subprocess.run(["git", "submodule", "update", "--init", f"{actual_dir_name}"], check=True, + stdout=subprocess.DEVNULL, cwd=root_dir) except subprocess.CalledProcessError: pass except FileNotFoundError: pass - backend_src_dir = os.path.join(root_dir, backend_name) + backend_src_dir = os.path.join(root_dir, actual_dir_name) backend_path = os.path.abspath(os.path.join(backend_src_dir, "backend")) assert os.path.exists(backend_path), f"{backend_path} does not exist!" @@ -80,7 +88,7 @@ def prepare(backend_name: str, backend_src_dir: str = None, is_external: bool = for file in ["compiler.py", "driver.py"]: assert os.path.exists(os.path.join(backend_path, file)), f"${file} does not exist in ${backend_path}" - install_dir = os.path.join(os.path.dirname(__file__), "triton", "backends", backend_name) + install_dir = os.path.join(os.path.dirname(__file__), "triton", "backends", actual_dir_name) package_data = [f"{os.path.relpath(p, backend_path)}/*" for p, _, _, in os.walk(backend_path)] language_package_data = []