Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .cargo/config.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
[build]
rustflags = ["-C", "target-cpu=native"]

[target.x86_64-unknown-linux-gnu]
rustflags = ["-C", "target-cpu=native"]
27 changes: 27 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: CI

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
build:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Install Rust
uses: actions-rs/toolchain@v1
with:
toolchain: stable
profile: minimal
- name: Cache cargo registry
uses: actions/cache@v4
with:
path: |
~/.cargo/registry
~/.cargo/git
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
- name: Cargo build (workspace)
run: cargo build --locked --workspace
66 changes: 45 additions & 21 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,38 +16,62 @@ jobs:
timeout-minutes: 20

steps:
- uses: actions/checkout@v4
- name: Run tests
run: rustup update; cargo test --workspace --verbose
- uses: actions/checkout@v4
- name: Cache cargo registry
uses: actions/cache@v3
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
- name: Run tests
run: rustup update; cargo test --workspace --verbose

clippy:
name: Clippy
runs-on: ubuntu-latest
timeout-minutes: 20

steps:
- uses: actions/checkout@v4
- name: Run clippy
run: rustup update; cargo clippy --all-targets -- -D warnings
- uses: actions/checkout@v4
- name: Cache cargo registry
uses: actions/cache@v3
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
- name: Run clippy
run: rustup update; cargo clippy --all-targets -- -D warnings

fmt:
name: Fmt
runs-on: ubuntu-latest
timeout-minutes: 20

steps:
- uses: actions/checkout@v4
- name: Format
run: cargo fmt --all --check

# macos_test:
# name: MacOS Tests
# runs-on: macos-13
# timeout-minutes: 20

# steps:
# - uses: actions/checkout@v4
# - name: Build
# run: cargo build --verbose
# - name: Run tests
# run: cargo test --verbose -- --test-threads 1
- uses: actions/checkout@v4
- name: Run format check
run: cargo fmt --all --check

macos_test:
name: MacOS Tests
runs-on: macos-13
timeout-minutes: 20

steps:
- uses: actions/checkout@v4
- name: Cache cargo registry
uses: actions/cache@v3
with:
path: |
~/.cargo/registry
~/.cargo/git
target
key: ${{ runner.os }}-cargo-${{ hashFiles('**/Cargo.lock') }}
- name: Build
run: cargo build --verbose
- name: Run tests
run: cargo test --verbose -- --test-threads 1
31 changes: 18 additions & 13 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,36 @@ description = "Deep learning at the speed of light."
license = "MIT OR Apache-2.0"

[dependencies]
as-any = "0.3.1"
colored = "2.0.4"
dyn-clone = "1.0.12"
egg = "0.9.5"
generational-box = "0.5.6"
half = "*"
itertools = "0.11.0"
num-traits = "0.2.16"
petgraph = "0.6.4"
rand = "0.9.2"
urlencoding = "2.1.2"
webbrowser = "1.0.0"
dyn-clone = "1.0.12"
half = "*"
tinyvec = { version = "1.6.0", features = ["serde"] }
term_size = "0.3.2"
colored = "2.0.4"
regex = "1.9.5"
rustc-hash = "2.1.1"
uuid = { version = "1.7.0", features = ["v4"] }
as-any = "0.3.1"
egg = "0.9.5"
symbolic_expressions = "5.0.3"
serde = { version = "1.0.202", features = ["derive"] }
thread_local = "1.1.8"
generational-box = "0.5.6"
serde_json = "1.0.140"
symbolic_expressions = "5.0.3"
term_size = "0.3.2"
thread_local = "1.1.8"
tinyvec = { version = "1.6.0", features = ["serde"] }
urlencoding = "2.1.2"
uuid = { version = "1.7.0", features = ["v4"] }
webbrowser = "1.0.0"

[dev-dependencies]
dfdx = { version = "0.13", features = ["f16"] }

[profile.release]
codegen-units = 1
lto = "fat"
opt-level = 3

[workspace]
members = [
"examples/*",
Expand Down
4 changes: 2 additions & 2 deletions crates/luminal_2/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ metal = ["dep:objc2", "dep:objc2-metal", "dep:objc2-foundation"]
[dependencies]
luminal = { path = "../../" }
luminal_cuda = { path = "../luminal_cuda", optional = true }
cudarc = { version = "0.16.6", features = [
cudarc = { git = "https://github.com/rust-cuda/cudarc.git", branch = "main", features = [
"f16",
"cuda-12080",
"cuda-12200",
], optional = true }
#metal-rs = { version = "0.28.0", package = "metal", optional=true }
objc2 = { version = "0.6.2", optional = true }
Expand Down
4 changes: 2 additions & 2 deletions crates/luminal_cuda/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ license = "MIT OR Apache-2.0"

[dependencies]
luminal = { path = "../.." }
cudarc = { version = "0.16.6", features = [
cudarc = { git = "https://github.com/rust-cuda/cudarc.git", branch = "main", features = [
"f16",
"cuda-12080",
"cuda-12200",
] }
itertools = "0.12.1"
rustc-hash = "2.1.1"
Expand Down
98 changes: 97 additions & 1 deletion crates/luminal_cuda/src/matmul.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,102 @@ impl<T: CudaFloat> Operator for Matmul<T> {
}
}

#[derive(Clone)]
pub struct CublasGemm<T>(Arc<CudaBlas>, Arc<CudaContext>, PhantomData<T>);
crate::debug_type!(CublasGemm);

impl<T: CudaFloat> Operator for CublasGemm<T> {
fn process(&mut self, inp: Vec<(InputTensor, ShapeTracker)>) -> Vec<Tensor> {
// Delegate to the same fast cuBLAS path used by Matmul.
let stream = self.1.default_stream();
let (a_shape, b_shape) = (inp[0].1.dims(), inp[1].1.dims());
let (batch_size, m, k, n) = (
a_shape
.iter()
.take(a_shape.len() - 2)
.map(|i| i.to_usize().unwrap())
.product::<usize>() as i32,
a_shape[a_shape.len() - 2].to_usize().unwrap() as i32,
a_shape[a_shape.len() - 1].to_usize().unwrap() as i32,
b_shape[b_shape.len() - 1].to_usize().unwrap() as i32,
);
let a = get_buffer_from_tensor::<T>(&inp[0].0);
let b = get_buffer_from_tensor::<T>(&inp[1].0);
let out = self
.1
.default_stream()
.alloc_zeros::<T>((m * n * batch_size) as usize)
.unwrap();
let (a_row_major, b_row_major) = (
inp[0].1.indexes[inp[0].1.len() - 1] > inp[0].1.indexes[inp[0].1.len() - 2],
inp[1].1.indexes[inp[1].1.len() - 1] > inp[1].1.indexes[inp[1].1.len() - 2],
);
let (transa, transb) = match (a_row_major, b_row_major) {
(true, true) => (CUBLAS_OP_N, CUBLAS_OP_N),
(false, false) => (CUBLAS_OP_T, CUBLAS_OP_T),
(false, true) => (CUBLAS_OP_N, CUBLAS_OP_T),
(true, false) => (CUBLAS_OP_T, CUBLAS_OP_N),
};

let a_dims = inp[0].1.fake.iter().filter(|f| !**f).count();
let b_dims = inp[1].1.fake.iter().filter(|f| !**f).count();
let (a_ptr, _a) = a.device_ptr(&stream);
let (b_ptr, _b) = b.device_ptr(&stream);
let (out_ptr, _out) = out.device_ptr(&stream);
if T::is_f32() {
unsafe {
cudarc::cublas::result::sgemm_strided_batched(
*self.0.handle(),
transa,
transb,
n,
m,
k,
&1.0_f32 as *const f32,
b_ptr as *const f32,
if b_row_major { n } else { k },
if b_dims == 2 { 0 } else { (n * k) as i64 },
a_ptr as *const f32,
if a_row_major { k } else { m },
if a_dims == 2 { 0 } else { (m * k) as i64 },
&0.0_f32 as *const f32,
out_ptr as *mut f32,
n,
(m * n) as i64,
batch_size,
)
.unwrap();
}
} else {
unsafe {
cudarc::cublas::result::hgemm_strided_batched(
*self.0.handle(),
transa,
transb,
n,
m,
k,
&f16::from_f32(1.0) as *const f16,
b_ptr as *const f16,
if b_row_major { n } else { k },
if b_dims == 2 { 0 } else { (n * k) as i64 },
a_ptr as *const f16,
if a_row_major { k } else { m },
if a_dims == 2 { 0 } else { (m * k) as i64 },
&f16::from_f32(0.0) as *const f16,
out_ptr as *mut f16,
n,
(m * n) as i64,
batch_size,
)
.unwrap();
}
}
drop(_out);
vec![Tensor::new(CudaData(out))]
}
}

#[derive(Default)]
pub struct MatMulCompiler<T>(PhantomData<T>);

Expand Down Expand Up @@ -224,7 +320,7 @@ where
dims.swap(src2_shape.len() - 2, src2_shape.len() - 1);
src2_shape.permute(&dims);
let new_op = graph
.add_op(Matmul::<T>(
.add_op(CublasGemm::<T>(
Arc::new(CudaBlas::new(dev.default_stream()).unwrap()),
dev.clone(),
Default::default(),
Expand Down
4 changes: 2 additions & 2 deletions demos/matmul/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,9 @@ metal = [
itertools = "0.14.0"
luminal = { path = "../.." }
luminal_cuda = { path = "../../crates/luminal_cuda", optional = true }
cudarc = { version = "0.16.6", features = [
cudarc = { git = "https://github.com/rust-cuda/cudarc.git", branch = "main", features = [
"f16",
"cuda-12080",
"cuda-12200",
], optional = true }
luminal_nn = { path = "../../crates/luminal_nn" }
luminal_2 = { path = "../../crates/luminal_2" }
Expand Down
Loading
Loading