3131# leading to recompilation of fused kernels. Set it to empty string
3232# to avoid recompilation and assign arch flags explicitly in
3333# extra_cuda_cflags below
34- os .environ ["TORCH_CUDA_ARCH_LIST" ] = ""
34+
35+ # TODO: Remove this
36+ # os.environ["TORCH_CUDA_ARCH_LIST"] = ""
3537
3638
3739def load (neox_args = None ):
3840 print ("\n " + "=" * 80 )
3941 print ("FUSED KERNELS: Starting fused kernel loading process..." )
4042 print ("=" * 80 )
4143 start_time = time .time ()
42-
44+
4345 # Check if cuda 11 is installed for compute capability 8.0
4446 cc_flag = []
4547 if torch .version .hip is None :
4648 print (f"FUSED KERNELS: Detected PyTorch with CUDA support" )
4749 print (f"FUSED KERNELS: CUDA_HOME = { cpp_extension .CUDA_HOME } " )
48-
50+
4951 raw_output , bare_metal_major , bare_metal_minor = _get_cuda_bare_metal_version (
5052 cpp_extension .CUDA_HOME
5153 )
5254 print (f"FUSED KERNELS: Detected CUDA version { bare_metal_major } .{ bare_metal_minor } " )
53-
55+
5456 if int (bare_metal_major ) >= 11 :
5557 cc_flag .append ("-gencode" )
5658 cc_flag .append ("arch=compute_80,code=sm_80" )
5759 print (f"FUSED KERNELS: Added compute capability 8.0 (A100)" )
58-
60+
5961 if int (bare_metal_minor ) >= 1 :
6062 cc_flag .append ("-gencode" )
6163 cc_flag .append ("arch=compute_86,code=sm_86" )
@@ -97,7 +99,7 @@ def _cpp_extention_load_helper(
9799 + extra_cuda_flags
98100 + cc_flag
99101 )
100-
102+
101103 # Check if kernel is already built
102104 kernel_path = buildpath / name
103105 if os .path .exists (kernel_path ) and any (f .endswith ('.so' ) for f in os .listdir (kernel_path ) if os .path .isfile (os .path .join (kernel_path , f ))):
@@ -107,23 +109,23 @@ def _cpp_extention_load_helper(
107109 print (f"FUSED KERNELS: { name } needs to be built" )
108110 print (f"FUSED KERNELS: This will take 30-60 seconds..." )
109111 print (f"FUSED KERNELS: Building with flags: { extra_cuda_cflags } " )
110-
112+
111113 sys .stdout .flush () # Force flush to ensure messages appear
112-
114+
113115 try :
114116 print (f"FUSED KERNELS: Calling cpp_extension.load for { name } ..." )
115117 build_start = time .time ()
116-
118+
117119 # Monkey-patch the ninja build to add progress messages
118120 original_build = cpp_extension ._write_ninja_file_and_build_library
119121 def build_with_progress (* args , ** kwargs ):
120122 print (f"FUSED KERNELS: JIT compiling { name } with ninja..." )
121123 print (f"FUSED KERNELS: This involves compiling CUDA kernels - please be patient..." )
122124 sys .stdout .flush ()
123125 return original_build (* args , ** kwargs )
124-
126+
125127 cpp_extension ._write_ninja_file_and_build_library = build_with_progress
126-
128+
127129 try :
128130 loaded_module = cpp_extension .load (
129131 name = name ,
@@ -139,15 +141,15 @@ def build_with_progress(*args, **kwargs):
139141 finally :
140142 # Restore original function
141143 cpp_extension ._write_ninja_file_and_build_library = original_build
142-
144+
143145 build_time = time .time () - build_start
144146 print (f"FUSED KERNELS: Successfully loaded { name } in { build_time :.2f} seconds" )
145147 return loaded_module
146-
148+
147149 except Exception as e :
148150 print (f"\n FUSED KERNELS ERROR: Failed to build/load { name } " )
149151 print (f"FUSED KERNELS ERROR: { str (e )} " )
150-
152+
151153 # Check for common issues
152154 if "Permission denied" in str (e ) or "cannot create directory" in str (e ):
153155 print (f"FUSED KERNELS ERROR: This might be a file permission issue." )
@@ -160,7 +162,7 @@ def build_with_progress(*args, **kwargs):
160162 elif "nvcc not found" in str (e ) or "CUDA_HOME" in str (e ):
161163 print (f"FUSED KERNELS ERROR: CUDA installation issue detected." )
162164 print (f"FUSED KERNELS ERROR: Make sure CUDA is properly installed and CUDA_HOME is set." )
163-
165+
164166 print (f"FUSED KERNELS ERROR: Full build directory path: { buildpath } " )
165167 raise
166168
@@ -188,7 +190,7 @@ def build_with_progress(*args, **kwargs):
188190
189191 print ("\n FUSED KERNELS: Building/loading 3 fused kernels..." )
190192 print ("-" * 60 )
191-
193+
192194 # Upper triangular softmax.
193195 print ("\n [1/3] Building scaled_upper_triang_masked_softmax_cuda..." )
194196 sources = [
@@ -201,7 +203,7 @@ def build_with_progress(*args, **kwargs):
201203 extra_cuda_flags ,
202204 extra_include_paths ,
203205 )
204-
206+
205207 # Masked softmax.
206208 print ("\n [2/3] Building scaled_masked_softmax_cuda..." )
207209 sources = [
@@ -211,7 +213,7 @@ def build_with_progress(*args, **kwargs):
211213 scaled_masked_softmax_cuda = _cpp_extention_load_helper (
212214 "scaled_masked_softmax_cuda" , sources , extra_cuda_flags , extra_include_paths
213215 )
214-
216+
215217 # fused rope
216218 print ("\n [3/3] Building fused_rotary_positional_embedding..." )
217219 sources = [
@@ -224,7 +226,7 @@ def build_with_progress(*args, **kwargs):
224226 extra_cuda_flags ,
225227 extra_include_paths ,
226228 )
227-
229+
228230 total_time = time .time () - start_time
229231 print ("\n " + "=" * 80 )
230232 print (f"FUSED KERNELS: All kernels loaded successfully!" )
@@ -275,15 +277,15 @@ def load_fused_kernels():
275277 try :
276278 import scaled_upper_triang_masked_softmax_cuda
277279 print ("FUSED KERNELS: ✓ scaled_upper_triang_masked_softmax_cuda imported successfully" )
278-
280+
279281 import scaled_masked_softmax_cuda
280282 print ("FUSED KERNELS: ✓ scaled_masked_softmax_cuda imported successfully" )
281-
283+
282284 import fused_rotary_positional_embedding
283285 print ("FUSED KERNELS: ✓ fused_rotary_positional_embedding imported successfully" )
284-
286+
285287 print ("FUSED KERNELS: All fused kernels are available and ready to use!" )
286-
288+
287289 except (ImportError , ModuleNotFoundError ) as e :
288290 print ("\n " + "!" * 100 )
289291 print ("FUSED KERNELS ERROR: Failed to import fused kernels!" )
0 commit comments