Skip to content

Commit 7b3caa6

Browse files
Verify that submodules are checked out (#1536)
1 parent 12a58cf commit 7b3caa6

File tree

1 file changed

+70
-2
lines changed

1 file changed

+70
-2
lines changed

setup.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import glob
77
import os
88
import subprocess
9+
import sys
10+
import time
911
from datetime import datetime
1012

1113
from setuptools import Extension, find_packages, setup
@@ -71,6 +73,71 @@ def use_debug_mode():
7173
CUDAExtension,
7274
)
7375

76+
# Constant known variables used throughout this file
77+
cwd = os.path.abspath(os.path.curdir)
78+
third_party_path = os.path.join(cwd, "third_party")
79+
80+
81+
def get_submodule_folders():
82+
git_modules_path = os.path.join(cwd, ".gitmodules")
83+
default_modules_path = [
84+
os.path.join(third_party_path, name)
85+
for name in [
86+
"cutlass",
87+
]
88+
]
89+
if not os.path.exists(git_modules_path):
90+
return default_modules_path
91+
with open(git_modules_path) as f:
92+
return [
93+
os.path.join(cwd, line.split("=", 1)[1].strip())
94+
for line in f
95+
if line.strip().startswith("path")
96+
]
97+
98+
99+
def check_submodules():
100+
def check_for_files(folder, files):
101+
if not any(os.path.exists(os.path.join(folder, f)) for f in files):
102+
print("Could not find any of {} in {}".format(", ".join(files), folder))
103+
print("Did you run 'git submodule update --init --recursive'?")
104+
sys.exit(1)
105+
106+
def not_exists_or_empty(folder):
107+
return not os.path.exists(folder) or (
108+
os.path.isdir(folder) and len(os.listdir(folder)) == 0
109+
)
110+
111+
if bool(os.getenv("USE_SYSTEM_LIBS", False)):
112+
return
113+
folders = get_submodule_folders()
114+
# If none of the submodule folders exists, try to initialize them
115+
if all(not_exists_or_empty(folder) for folder in folders):
116+
try:
117+
print(" --- Trying to initialize submodules")
118+
start = time.time()
119+
subprocess.check_call(
120+
["git", "submodule", "update", "--init", "--recursive"], cwd=cwd
121+
)
122+
end = time.time()
123+
print(f" --- Submodule initialization took {end - start:.2f} sec")
124+
except Exception:
125+
print(" --- Submodule initalization failed")
126+
print("Please run:\n\tgit submodule update --init --recursive")
127+
sys.exit(1)
128+
for folder in folders:
129+
check_for_files(
130+
folder,
131+
[
132+
"CMakeLists.txt",
133+
"Makefile",
134+
"setup.py",
135+
"LICENSE",
136+
"LICENSE.md",
137+
"LICENSE.txt",
138+
],
139+
)
140+
74141

75142
# BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
76143
class TorchAOBuildExt(BuildExtension):
@@ -172,8 +239,7 @@ def get_extensions():
172239
use_cutlass = False
173240
if use_cuda and not IS_WINDOWS:
174241
use_cutlass = True
175-
this_dir = os.path.abspath(os.path.curdir)
176-
cutlass_dir = os.path.join(this_dir, "third_party", "cutlass")
242+
cutlass_dir = os.path.join(third_party_path, "cutlass")
177243
cutlass_include_dir = os.path.join(cutlass_dir, "include")
178244
if use_cutlass:
179245
extra_compile_args["nvcc"].extend(
@@ -218,6 +284,8 @@ def get_extensions():
218284
return ext_modules
219285

220286

287+
check_submodules()
288+
221289
setup(
222290
name="torchao",
223291
version=version + version_suffix,

0 commit comments

Comments
 (0)