@@ -2,7 +2,7 @@ use std::sync::Arc;
22
33use cudarc:: {
44 driver:: { CudaContext , CudaFunction , CudaSlice , CudaStream } ,
5- nvrtc:: { compile_ptx, compile_ptx_with_opts , CompileOptions } ,
5+ nvrtc:: { compile_ptx, CompileOptions } ,
66} ;
77use itertools:: Itertools ;
88use luminal:: {
@@ -235,19 +235,7 @@ extern \"C\" {{
235235 flatten_mul_strides( & self . out_shape, & self . a_stride) . to_kernel( ) ,
236236 flatten_mul_strides( & self . out_shape, & self . b_stride) . to_kernel( )
237237 ) ;
238- let ptx = compile_ptx_with_opts (
239- & kernel,
240- CompileOptions {
241- arch : Some ( "sm_90a" ) ,
242- options : vec ! [ "--std=c++17" . to_string( ) , "-default-device" . to_string( ) ] ,
243- include_paths : vec ! [
244- "/usr/local/cuda/include" . to_string( ) ,
245- "/usr/include" . to_string( ) ,
246- ] ,
247- ..Default :: default ( )
248- } ,
249- )
250- . unwrap ( ) ;
238+ let ptx = compile_ptx ( & kernel) . unwrap ( ) ;
251239 let module = ctx. load_module ( ptx) . unwrap ( ) ;
252240 let func = module. load_function ( "mul_k" ) . unwrap ( ) ;
253241 let constants = vars
@@ -372,19 +360,7 @@ extern \"C\" {{
372360 flatten_mul_strides( & self . out_shape, & self . index_stride) . to_kernel( ) ,
373361 flatten_mul_strides( & self . out_shape, & self . data_stride) . to_kernel( )
374362 ) ;
375- let ptx = compile_ptx_with_opts (
376- & kernel,
377- CompileOptions {
378- arch : Some ( "sm_90a" ) ,
379- options : vec ! [ "--std=c++17" . to_string( ) , "-default-device" . to_string( ) ] ,
380- include_paths : vec ! [
381- "/usr/local/cuda/include" . to_string( ) ,
382- "/usr/include" . to_string( ) ,
383- ] ,
384- ..Default :: default ( )
385- } ,
386- )
387- . unwrap ( ) ;
363+ let ptx = compile_ptx ( & kernel) . unwrap ( ) ;
388364 let module = ctx. load_module ( ptx) . unwrap ( ) ;
389365 let func = module. load_function ( "gather" ) . unwrap ( ) ;
390366 let constants = vars
@@ -482,19 +458,7 @@ extern \"C\" {{
482458 . join( "\n " ) ,
483459 self . expr. to_kernel( ) ,
484460 ) ;
485- let ptx = compile_ptx_with_opts (
486- & kernel,
487- CompileOptions {
488- arch : Some ( "sm_90a" ) ,
489- options : vec ! [ "--std=c++17" . to_string( ) , "-default-device" . to_string( ) ] ,
490- include_paths : vec ! [
491- "/usr/local/cuda/include" . to_string( ) ,
492- "/usr/include" . to_string( ) ,
493- ] ,
494- ..Default :: default ( )
495- } ,
496- )
497- . unwrap ( ) ;
461+ let ptx = compile_ptx ( & kernel) . unwrap ( ) ;
498462 let module = ctx. load_module ( ptx) . unwrap ( ) ;
499463 let func = module. load_function ( "iota_k" ) . unwrap ( ) ;
500464 let constants = vars
0 commit comments