Skip to content

Commit 2f3fe55

Browse files
committed
fixed compile args
1 parent af7fa41 commit 2f3fe55

File tree

1 file changed

+4
-40
lines changed
  • crates/luminal_cuda/src/kernel

1 file changed

+4
-40
lines changed

crates/luminal_cuda/src/kernel/ops.rs

Lines changed: 4 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::sync::Arc;
22

33
use cudarc::{
44
driver::{CudaContext, CudaFunction, CudaSlice, CudaStream},
5-
nvrtc::{compile_ptx, compile_ptx_with_opts, CompileOptions},
5+
nvrtc::{compile_ptx, CompileOptions},
66
};
77
use itertools::Itertools;
88
use luminal::{
@@ -235,19 +235,7 @@ extern \"C\" {{
235235
flatten_mul_strides(&self.out_shape, &self.a_stride).to_kernel(),
236236
flatten_mul_strides(&self.out_shape, &self.b_stride).to_kernel()
237237
);
238-
let ptx = compile_ptx_with_opts(
239-
&kernel,
240-
CompileOptions {
241-
arch: Some("sm_90a"),
242-
options: vec!["--std=c++17".to_string(), "-default-device".to_string()],
243-
include_paths: vec![
244-
"/usr/local/cuda/include".to_string(),
245-
"/usr/include".to_string(),
246-
],
247-
..Default::default()
248-
},
249-
)
250-
.unwrap();
238+
let ptx = compile_ptx(&kernel).unwrap();
251239
let module = ctx.load_module(ptx).unwrap();
252240
let func = module.load_function("mul_k").unwrap();
253241
let constants = vars
@@ -372,19 +360,7 @@ extern \"C\" {{
372360
flatten_mul_strides(&self.out_shape, &self.index_stride).to_kernel(),
373361
flatten_mul_strides(&self.out_shape, &self.data_stride).to_kernel()
374362
);
375-
let ptx = compile_ptx_with_opts(
376-
&kernel,
377-
CompileOptions {
378-
arch: Some("sm_90a"),
379-
options: vec!["--std=c++17".to_string(), "-default-device".to_string()],
380-
include_paths: vec![
381-
"/usr/local/cuda/include".to_string(),
382-
"/usr/include".to_string(),
383-
],
384-
..Default::default()
385-
},
386-
)
387-
.unwrap();
363+
let ptx = compile_ptx(&kernel).unwrap();
388364
let module = ctx.load_module(ptx).unwrap();
389365
let func = module.load_function("gather").unwrap();
390366
let constants = vars
@@ -482,19 +458,7 @@ extern \"C\" {{
482458
.join("\n"),
483459
self.expr.to_kernel(),
484460
);
485-
let ptx = compile_ptx_with_opts(
486-
&kernel,
487-
CompileOptions {
488-
arch: Some("sm_90a"),
489-
options: vec!["--std=c++17".to_string(), "-default-device".to_string()],
490-
include_paths: vec![
491-
"/usr/local/cuda/include".to_string(),
492-
"/usr/include".to_string(),
493-
],
494-
..Default::default()
495-
},
496-
)
497-
.unwrap();
461+
let ptx = compile_ptx(&kernel).unwrap();
498462
let module = ctx.load_module(ptx).unwrap();
499463
let func = module.load_function("iota_k").unwrap();
500464
let constants = vars

0 commit comments

Comments
 (0)