Skip to content

Commit 7b6de79

Browse files
committed
Remove logic where CUDa archs env var is overeritten
1 parent 9703227 commit 7b6de79

File tree

1 file changed

+25
-23
lines changed

1 file changed

+25
-23
lines changed

megatron/fused_kernels/__init__.py

Lines changed: 25 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -31,31 +31,33 @@
3131
# leading to recompilation of fused kernels. Set it to empty string
3232
# to avoid recompilation and assign arch flags explicitly in
3333
# extra_cuda_cflags below
34-
os.environ["TORCH_CUDA_ARCH_LIST"] = ""
34+
35+
# TODO: Remove this
36+
# os.environ["TORCH_CUDA_ARCH_LIST"] = ""
3537

3638

3739
def load(neox_args=None):
3840
print("\n" + "="*80)
3941
print("FUSED KERNELS: Starting fused kernel loading process...")
4042
print("="*80)
4143
start_time = time.time()
42-
44+
4345
# Check if cuda 11 is installed for compute capability 8.0
4446
cc_flag = []
4547
if torch.version.hip is None:
4648
print(f"FUSED KERNELS: Detected PyTorch with CUDA support")
4749
print(f"FUSED KERNELS: CUDA_HOME = {cpp_extension.CUDA_HOME}")
48-
50+
4951
raw_output, bare_metal_major, bare_metal_minor = _get_cuda_bare_metal_version(
5052
cpp_extension.CUDA_HOME
5153
)
5254
print(f"FUSED KERNELS: Detected CUDA version {bare_metal_major}.{bare_metal_minor}")
53-
55+
5456
if int(bare_metal_major) >= 11:
5557
cc_flag.append("-gencode")
5658
cc_flag.append("arch=compute_80,code=sm_80")
5759
print(f"FUSED KERNELS: Added compute capability 8.0 (A100)")
58-
60+
5961
if int(bare_metal_minor) >= 1:
6062
cc_flag.append("-gencode")
6163
cc_flag.append("arch=compute_86,code=sm_86")
@@ -97,7 +99,7 @@ def _cpp_extention_load_helper(
9799
+ extra_cuda_flags
98100
+ cc_flag
99101
)
100-
102+
101103
# Check if kernel is already built
102104
kernel_path = buildpath / name
103105
if os.path.exists(kernel_path) and any(f.endswith('.so') for f in os.listdir(kernel_path) if os.path.isfile(os.path.join(kernel_path, f))):
@@ -107,23 +109,23 @@ def _cpp_extention_load_helper(
107109
print(f"FUSED KERNELS: {name} needs to be built")
108110
print(f"FUSED KERNELS: This will take 30-60 seconds...")
109111
print(f"FUSED KERNELS: Building with flags: {extra_cuda_cflags}")
110-
112+
111113
sys.stdout.flush() # Force flush to ensure messages appear
112-
114+
113115
try:
114116
print(f"FUSED KERNELS: Calling cpp_extension.load for {name}...")
115117
build_start = time.time()
116-
118+
117119
# Monkey-patch the ninja build to add progress messages
118120
original_build = cpp_extension._write_ninja_file_and_build_library
119121
def build_with_progress(*args, **kwargs):
120122
print(f"FUSED KERNELS: JIT compiling {name} with ninja...")
121123
print(f"FUSED KERNELS: This involves compiling CUDA kernels - please be patient...")
122124
sys.stdout.flush()
123125
return original_build(*args, **kwargs)
124-
126+
125127
cpp_extension._write_ninja_file_and_build_library = build_with_progress
126-
128+
127129
try:
128130
loaded_module = cpp_extension.load(
129131
name=name,
@@ -139,15 +141,15 @@ def build_with_progress(*args, **kwargs):
139141
finally:
140142
# Restore original function
141143
cpp_extension._write_ninja_file_and_build_library = original_build
142-
144+
143145
build_time = time.time() - build_start
144146
print(f"FUSED KERNELS: Successfully loaded {name} in {build_time:.2f} seconds")
145147
return loaded_module
146-
148+
147149
except Exception as e:
148150
print(f"\nFUSED KERNELS ERROR: Failed to build/load {name}")
149151
print(f"FUSED KERNELS ERROR: {str(e)}")
150-
152+
151153
# Check for common issues
152154
if "Permission denied" in str(e) or "cannot create directory" in str(e):
153155
print(f"FUSED KERNELS ERROR: This might be a file permission issue.")
@@ -160,7 +162,7 @@ def build_with_progress(*args, **kwargs):
160162
elif "nvcc not found" in str(e) or "CUDA_HOME" in str(e):
161163
print(f"FUSED KERNELS ERROR: CUDA installation issue detected.")
162164
print(f"FUSED KERNELS ERROR: Make sure CUDA is properly installed and CUDA_HOME is set.")
163-
165+
164166
print(f"FUSED KERNELS ERROR: Full build directory path: {buildpath}")
165167
raise
166168

@@ -188,7 +190,7 @@ def build_with_progress(*args, **kwargs):
188190

189191
print("\nFUSED KERNELS: Building/loading 3 fused kernels...")
190192
print("-"*60)
191-
193+
192194
# Upper triangular softmax.
193195
print("\n[1/3] Building scaled_upper_triang_masked_softmax_cuda...")
194196
sources = [
@@ -201,7 +203,7 @@ def build_with_progress(*args, **kwargs):
201203
extra_cuda_flags,
202204
extra_include_paths,
203205
)
204-
206+
205207
# Masked softmax.
206208
print("\n[2/3] Building scaled_masked_softmax_cuda...")
207209
sources = [
@@ -211,7 +213,7 @@ def build_with_progress(*args, **kwargs):
211213
scaled_masked_softmax_cuda = _cpp_extention_load_helper(
212214
"scaled_masked_softmax_cuda", sources, extra_cuda_flags, extra_include_paths
213215
)
214-
216+
215217
# fused rope
216218
print("\n[3/3] Building fused_rotary_positional_embedding...")
217219
sources = [
@@ -224,7 +226,7 @@ def build_with_progress(*args, **kwargs):
224226
extra_cuda_flags,
225227
extra_include_paths,
226228
)
227-
229+
228230
total_time = time.time() - start_time
229231
print("\n" + "="*80)
230232
print(f"FUSED KERNELS: All kernels loaded successfully!")
@@ -275,15 +277,15 @@ def load_fused_kernels():
275277
try:
276278
import scaled_upper_triang_masked_softmax_cuda
277279
print("FUSED KERNELS: ✓ scaled_upper_triang_masked_softmax_cuda imported successfully")
278-
280+
279281
import scaled_masked_softmax_cuda
280282
print("FUSED KERNELS: ✓ scaled_masked_softmax_cuda imported successfully")
281-
283+
282284
import fused_rotary_positional_embedding
283285
print("FUSED KERNELS: ✓ fused_rotary_positional_embedding imported successfully")
284-
286+
285287
print("FUSED KERNELS: All fused kernels are available and ready to use!")
286-
288+
287289
except (ImportError, ModuleNotFoundError) as e:
288290
print("\n" + "!"*100)
289291
print("FUSED KERNELS ERROR: Failed to import fused kernels!")

0 commit comments

Comments
 (0)