Skip to content

Commit

Permalink
Verify that submodules are checked out (#1536)
Browse files Browse the repository at this point in the history
  • Loading branch information
alexsamardzic authored Jan 13, 2025
1 parent 12a58cf commit 7b3caa6
Showing 1 changed file with 70 additions and 2 deletions.
72 changes: 70 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import glob
import os
import subprocess
import sys
import time
from datetime import datetime

from setuptools import Extension, find_packages, setup
Expand Down Expand Up @@ -71,6 +73,71 @@ def use_debug_mode():
CUDAExtension,
)

# Constant known variables used throughout this file
cwd = os.path.abspath(os.path.curdir)
third_party_path = os.path.join(cwd, "third_party")


def get_submodule_folders():
git_modules_path = os.path.join(cwd, ".gitmodules")
default_modules_path = [
os.path.join(third_party_path, name)
for name in [
"cutlass",
]
]
if not os.path.exists(git_modules_path):
return default_modules_path
with open(git_modules_path) as f:
return [
os.path.join(cwd, line.split("=", 1)[1].strip())
for line in f
if line.strip().startswith("path")
]


def check_submodules():
def check_for_files(folder, files):
if not any(os.path.exists(os.path.join(folder, f)) for f in files):
print("Could not find any of {} in {}".format(", ".join(files), folder))
print("Did you run 'git submodule update --init --recursive'?")
sys.exit(1)

def not_exists_or_empty(folder):
return not os.path.exists(folder) or (
os.path.isdir(folder) and len(os.listdir(folder)) == 0
)

if bool(os.getenv("USE_SYSTEM_LIBS", False)):
return
folders = get_submodule_folders()
# If none of the submodule folders exists, try to initialize them
if all(not_exists_or_empty(folder) for folder in folders):
try:
print(" --- Trying to initialize submodules")
start = time.time()
subprocess.check_call(
["git", "submodule", "update", "--init", "--recursive"], cwd=cwd
)
end = time.time()
print(f" --- Submodule initialization took {end - start:.2f} sec")
except Exception:
print(" --- Submodule initalization failed")
print("Please run:\n\tgit submodule update --init --recursive")
sys.exit(1)
for folder in folders:
check_for_files(
folder,
[
"CMakeLists.txt",
"Makefile",
"setup.py",
"LICENSE",
"LICENSE.md",
"LICENSE.txt",
],
)


# BuildExtension is a subclass of from setuptools.command.build_ext.build_ext
class TorchAOBuildExt(BuildExtension):
Expand Down Expand Up @@ -172,8 +239,7 @@ def get_extensions():
use_cutlass = False
if use_cuda and not IS_WINDOWS:
use_cutlass = True
this_dir = os.path.abspath(os.path.curdir)
cutlass_dir = os.path.join(this_dir, "third_party", "cutlass")
cutlass_dir = os.path.join(third_party_path, "cutlass")
cutlass_include_dir = os.path.join(cutlass_dir, "include")
if use_cutlass:
extra_compile_args["nvcc"].extend(
Expand Down Expand Up @@ -218,6 +284,8 @@ def get_extensions():
return ext_modules


check_submodules()

setup(
name="torchao",
version=version + version_suffix,
Expand Down

0 comments on commit 7b3caa6

Please sign in to comment.