Skip to content

Commit 262b772

Browse files
committed
changing workflow
1 parent 5938a06 commit 262b772

File tree

2 files changed

+4
-18
lines changed

2 files changed

+4
-18
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,7 @@ jobs:
4646

4747
steps:
4848
- uses: actions/checkout@v4
49-
- name: GPU sanity check
50-
run: sudo apt-get update; sudo apt-get install -y cuda-nvrtc-12-8; sudo apt install tree; ls /usr/local/cuda-12.8/lib64
51-
- name: GPU sanity check 1
52-
run: sudo apt install tree; ls /usr/local; ls /usr/local/lib; tree /usr/lib -L 2
49+
- name: Install nvrtc
50+
run: sudo apt-get update; sudo apt-get install -y cuda-nvrtc-12-8
5351
- name: Run CUDA crate tests
5452
run: export LD_LIBRARY_PATH=/usr/local/cuda-12.8/lib64:$LD_LIBRARY_PATH; curl https://sh.rustup.rs -sSf | sh -s -- -y && source "$HOME/.cargo/env"; sudo apt-get update && sudo apt-get install -y protobuf-compiler; rustup update; cargo test -p luminal_cuda --verbose

crates/luminal_cuda/src/kernel/ops.rs

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ use std::sync::Arc;
22

33
use cudarc::{
44
driver::{CudaContext, CudaFunction, CudaSlice, CudaStream},
5-
nvrtc::{compile_ptx_with_opts, CompileOptions},
5+
nvrtc::{compile_ptx, compile_ptx_with_opts, CompileOptions},
66
};
77
use itertools::Itertools;
88
use luminal::{
@@ -115,19 +115,7 @@ extern \"C\" {{
115115
flatten_strides(&self.out_shape, &self.a_stride).to_kernel(),
116116
flatten_strides(&self.out_shape, &self.b_stride).to_kernel()
117117
);
118-
let ptx = compile_ptx_with_opts(
119-
&kernel,
120-
CompileOptions {
121-
// arch: Some("sm_90a"),
122-
// options: vec!["--std=c++17".to_string(), "-default-device".to_string()],
123-
// include_paths: vec![
124-
// "/usr/local/cuda/include".to_string(),
125-
// "/usr/include".to_string(),
126-
// ],
127-
..Default::default()
128-
},
129-
)
130-
.unwrap();
118+
let ptx = compile_ptx(&kernel).unwrap();
131119
let module = ctx.load_module(ptx).unwrap();
132120
let func = module.load_function("add_k").unwrap();
133121
let constants = vars

0 commit comments

Comments
 (0)