|
6 | 6 | import glob
|
7 | 7 | import os
|
8 | 8 | import subprocess
|
| 9 | +import sys |
| 10 | +import time |
9 | 11 | from datetime import datetime
|
10 | 12 |
|
11 | 13 | from setuptools import Extension, find_packages, setup
|
@@ -71,6 +73,71 @@ def use_debug_mode():
|
71 | 73 | CUDAExtension,
|
72 | 74 | )
|
73 | 75 |
|
| 76 | +# Constant known variables used throughout this file |
| 77 | +cwd = os.path.abspath(os.path.curdir) |
| 78 | +third_party_path = os.path.join(cwd, "third_party") |
| 79 | + |
| 80 | + |
| 81 | +def get_submodule_folders(): |
| 82 | + git_modules_path = os.path.join(cwd, ".gitmodules") |
| 83 | + default_modules_path = [ |
| 84 | + os.path.join(third_party_path, name) |
| 85 | + for name in [ |
| 86 | + "cutlass", |
| 87 | + ] |
| 88 | + ] |
| 89 | + if not os.path.exists(git_modules_path): |
| 90 | + return default_modules_path |
| 91 | + with open(git_modules_path) as f: |
| 92 | + return [ |
| 93 | + os.path.join(cwd, line.split("=", 1)[1].strip()) |
| 94 | + for line in f |
| 95 | + if line.strip().startswith("path") |
| 96 | + ] |
| 97 | + |
| 98 | + |
| 99 | +def check_submodules(): |
| 100 | + def check_for_files(folder, files): |
| 101 | + if not any(os.path.exists(os.path.join(folder, f)) for f in files): |
| 102 | + print("Could not find any of {} in {}".format(", ".join(files), folder)) |
| 103 | + print("Did you run 'git submodule update --init --recursive'?") |
| 104 | + sys.exit(1) |
| 105 | + |
| 106 | + def not_exists_or_empty(folder): |
| 107 | + return not os.path.exists(folder) or ( |
| 108 | + os.path.isdir(folder) and len(os.listdir(folder)) == 0 |
| 109 | + ) |
| 110 | + |
| 111 | + if bool(os.getenv("USE_SYSTEM_LIBS", False)): |
| 112 | + return |
| 113 | + folders = get_submodule_folders() |
| 114 | + # If none of the submodule folders exists, try to initialize them |
| 115 | + if all(not_exists_or_empty(folder) for folder in folders): |
| 116 | + try: |
| 117 | + print(" --- Trying to initialize submodules") |
| 118 | + start = time.time() |
| 119 | + subprocess.check_call( |
| 120 | + ["git", "submodule", "update", "--init", "--recursive"], cwd=cwd |
| 121 | + ) |
| 122 | + end = time.time() |
| 123 | + print(f" --- Submodule initialization took {end - start:.2f} sec") |
| 124 | + except Exception: |
| 125 | + print(" --- Submodule initalization failed") |
| 126 | + print("Please run:\n\tgit submodule update --init --recursive") |
| 127 | + sys.exit(1) |
| 128 | + for folder in folders: |
| 129 | + check_for_files( |
| 130 | + folder, |
| 131 | + [ |
| 132 | + "CMakeLists.txt", |
| 133 | + "Makefile", |
| 134 | + "setup.py", |
| 135 | + "LICENSE", |
| 136 | + "LICENSE.md", |
| 137 | + "LICENSE.txt", |
| 138 | + ], |
| 139 | + ) |
| 140 | + |
74 | 141 |
|
75 | 142 | # BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
|
76 | 143 | class TorchAOBuildExt(BuildExtension):
|
@@ -172,8 +239,7 @@ def get_extensions():
|
172 | 239 | use_cutlass = False
|
173 | 240 | if use_cuda and not IS_WINDOWS:
|
174 | 241 | use_cutlass = True
|
175 |
| - this_dir = os.path.abspath(os.path.curdir) |
176 |
| - cutlass_dir = os.path.join(this_dir, "third_party", "cutlass") |
| 242 | + cutlass_dir = os.path.join(third_party_path, "cutlass") |
177 | 243 | cutlass_include_dir = os.path.join(cutlass_dir, "include")
|
178 | 244 | if use_cutlass:
|
179 | 245 | extra_compile_args["nvcc"].extend(
|
@@ -218,6 +284,8 @@ def get_extensions():
|
218 | 284 | return ext_modules
|
219 | 285 |
|
220 | 286 |
|
| 287 | +check_submodules() |
| 288 | + |
221 | 289 | setup(
|
222 | 290 | name="torchao",
|
223 | 291 | version=version + version_suffix,
|
|
0 commit comments