Skip to content

Commit 60ee8e6

Browse files
sparse
1 parent 69be99b commit 60ee8e6

File tree

18 files changed

+1900
-0
lines changed

18 files changed

+1900
-0
lines changed

Cargo.lock

Lines changed: 19 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

crates/burn-core/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ vision = ["burn-dataset?/vision", "burn-common/network"]
6868
# Backend
6969
autodiff = ["burn-autodiff"]
7070
fusion = ["burn-wgpu?/fusion"]
71+
sparse = ["burn-sparse"]
7172

7273
## Backend features
7374
metal = ["burn-candle?/metal"]
@@ -111,6 +112,7 @@ burn-wgpu = { path = "../burn-wgpu", version = "0.14.0", optional = true, defaul
111112
burn-autodiff = { path = "../burn-autodiff", version = "0.14.0", optional = true }
112113
burn-tch = { path = "../burn-tch", version = "0.14.0", optional = true }
113114
burn-candle = { path = "../burn-candle", version = "0.14.0", optional = true }
115+
burn-sparse = { path = "../burn-sparse", version = "0.14.0", optional = true }
114116

115117
derive-new = { workspace = true }
116118
log = { workspace = true, optional = true }

crates/burn-core/src/backend.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,6 @@ pub use burn_tch as libtorch;
2727

2828
#[cfg(feature = "tch")]
2929
pub use burn_tch::LibTorch;
30+
31+
#[cfg(feature = "sparse")]
32+
pub use burn_sparse as sparse;

crates/burn-sparse/Cargo.toml

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
[package]
2+
authors = []
3+
categories = ["science", "no-std", "embedded", "wasm"]
4+
description = "Sparse tensor crate that offers a default sparse backend wrapper around burn backends."
5+
edition.workspace = true
6+
keywords = ["deep-learning", "machine-learning", "tensor", "sparse"]
7+
license.workspace = true
8+
name = "burn-sparse"
9+
readme.workspace = true
10+
repository = "https://github.com/tracel-ai/burn/tree/main/burn-sparse"
11+
version.workspace = true
12+
13+
[features]
14+
default = ["std"]
15+
doc = ["default"]
16+
experimental-named-tensor = []
17+
std = ["rand/std", "half/std", "num-traits/std"]
18+
wasm-sync = []
19+
20+
[dependencies]
21+
burn-common = { path = "../burn-common", version = "0.14.0", default-features = false }
22+
burn-tensor = { path = "../burn-tensor", version = "0.14.0" }
23+
24+
proc-macro2 = { workspace = true }
25+
quote = { workspace = true }
26+
syn = { workspace = true }
27+
derive-new = { workspace = true }
28+
half = { workspace = true }
29+
num-traits = { workspace = true }
30+
rand = { workspace = true }
31+
rand_distr = { workspace = true } # use instead of statrs because it supports no_std
32+
33+
# The same implementation of HashMap in std but with no_std support (only needs alloc crate)
34+
hashbrown = { workspace = true } # no_std compatible
35+
36+
# Serialization
37+
serde = { workspace = true }
38+
39+
[dev-dependencies]
40+
rand = { workspace = true, features = ["std", "std_rng"] } # Default enables std
41+
42+
[package.metadata.docs.rs]
43+
features = ["doc"]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
use crate::backend::SparseBackend;
2+
3+
/// Sparse tensor primitive type used by the backend.
4+
pub type SparseTensor<B, const D: usize> = <B as SparseBackend>::SparseTensorPrimitive<D>;

crates/burn-sparse/src/backend/api.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
use crate::backend::{Sparse, SparseBackend};
2+
use burn_tensor::{Int, Tensor, TensorPrimitive};
3+
4+
pub trait SparseTensor<const D: usize, B>
5+
where
6+
B: SparseBackend,
7+
{
8+
fn dense_int(self) -> Tensor<B, D, Int>;
9+
fn spmm(self, rhs: Tensor<B, D>) -> Tensor<B, D>;
10+
fn dense(self) -> Tensor<B, D>;
11+
}
12+
13+
impl<const D: usize, B> SparseTensor<D, B> for Tensor<B, D, Sparse>
14+
where
15+
B: SparseBackend,
16+
{
17+
fn dense(self) -> Tensor<B, D> {
18+
Tensor::new(TensorPrimitive::Float(B::sparse_to_dense(
19+
self.into_primitive(),
20+
)))
21+
}
22+
23+
fn dense_int(self) -> Tensor<B, D, Int> {
24+
self.dense().int()
25+
}
26+
27+
fn spmm(self, rhs: Tensor<B, D>) -> Tensor<B, D> {
28+
Tensor::new(TensorPrimitive::Float(B::sparse_spmm(
29+
self.into_primitive(),
30+
rhs.into_primitive().tensor(),
31+
)))
32+
}
33+
}
Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
use std::{future::Future, ops::Range};
2+
3+
use crate::backend::SparseBackend;
4+
use burn_tensor::{backend::Backend, BasicOps, Shape, TensorData, TensorKind};
5+
6+
/// A type-level representation of the kind of a sparse (float) tensor.
7+
#[derive(Clone, Debug)]
8+
pub struct Sparse;
9+
10+
impl<B: SparseBackend> TensorKind<B> for Sparse {
11+
type Primitive<const D: usize> = B::SparseTensorPrimitive<D>;
12+
fn name() -> &'static str {
13+
"Sparse"
14+
}
15+
}
16+
17+
impl<B: SparseBackend> BasicOps<B> for Sparse {
18+
type Elem = B::FloatElem;
19+
20+
fn into_data_async<const D: usize>(
21+
tensor: Self::Primitive<D>,
22+
) -> impl Future<Output = TensorData> + Send {
23+
B::sparse_into_data(tensor)
24+
}
25+
26+
fn device<const D: usize>(tensor: &Self::Primitive<D>) -> <B as Backend>::Device {
27+
B::sparse_device(tensor)
28+
}
29+
30+
fn to_device<const D: usize>(
31+
tensor: Self::Primitive<D>,
32+
device: &<B as Backend>::Device,
33+
) -> Self::Primitive<D> {
34+
B::sparse_to_device(tensor, device)
35+
}
36+
37+
fn from_data<const D: usize>(
38+
data: TensorData,
39+
device: &<B as Backend>::Device,
40+
) -> Self::Primitive<D> {
41+
B::sparse_from_data(data, device)
42+
}
43+
44+
fn shape<const D: usize>(tensor: &Self::Primitive<D>) -> Shape<D> {
45+
B::sparse_shape(tensor)
46+
}
47+
48+
fn empty<const D: usize>(
49+
shape: Shape<D>,
50+
device: &<B as Backend>::Device,
51+
) -> Self::Primitive<D> {
52+
B::sparse_empty(shape, device)
53+
}
54+
55+
fn slice<const D1: usize, const D2: usize>(
56+
tensor: Self::Primitive<D1>,
57+
ranges: [Range<usize>; D2],
58+
) -> Self::Primitive<D1> {
59+
B::sparse_slice(tensor, ranges)
60+
}
61+
}

crates/burn-sparse/src/backend/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
mod alias;
2+
mod api;
3+
mod kind;
4+
mod sparse_backend;
5+
6+
pub use alias::*;
7+
pub use kind::*;
8+
pub use sparse_backend::*;
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
use crate::backend::SparseTensor;
2+
use burn_tensor::{backend::Backend, Device, Shape, TensorData};
3+
use core::{future::Future, ops::Range};
4+
5+
pub trait SparseBackend: Backend {
6+
type SparseTensorPrimitive<const D: usize>: Clone + Send + 'static + core::fmt::Debug;
7+
8+
fn sparse_empty<const D: usize>(
9+
shape: Shape<D>,
10+
device: &Device<Self>,
11+
) -> SparseTensor<Self, D>;
12+
13+
fn sparse_to_sparse<const D: usize>(
14+
dense: Self::FloatTensorPrimitive<D>,
15+
) -> Self::SparseTensorPrimitive<D>;
16+
17+
fn sparse_to_dense<const D: usize>(
18+
sparse: Self::SparseTensorPrimitive<D>,
19+
) -> Self::FloatTensorPrimitive<D>;
20+
21+
fn sparse_spmm<const D: usize>(
22+
lhs: Self::SparseTensorPrimitive<D>,
23+
rhs: Self::FloatTensorPrimitive<D>,
24+
) -> Self::FloatTensorPrimitive<D>;
25+
26+
fn sparse_sddmm<const D: usize>(
27+
lhs: Self::SparseTensorPrimitive<D>,
28+
rhs: Self::FloatTensorPrimitive<D>,
29+
) -> Self::SparseTensorPrimitive<D>;
30+
31+
/// Gets the element at the given indices.
32+
///
33+
/// # Arguments
34+
///
35+
/// * `tensor` - The tensor.
36+
/// * `indices` - The indices.
37+
///
38+
/// # Returns
39+
///
40+
/// The elements at the given indices.
41+
fn sparse_slice<const D1: usize, const D2: usize>(
42+
tensor: SparseTensor<Self, D1>,
43+
indices: [Range<usize>; D2],
44+
) -> SparseTensor<Self, D1>;
45+
46+
/// Gets the device of the tensor.
47+
///
48+
/// # Arguments
49+
///
50+
/// * `tensor` - The tensor.
51+
///
52+
/// # Returns
53+
///
54+
/// The device of the tensor.
55+
fn sparse_device<const D: usize>(tensor: &SparseTensor<Self, D>) -> Device<Self>;
56+
57+
/// Moves the tensor to the given device.
58+
///
59+
/// # Arguments
60+
///
61+
/// * `tensor` - The tensor.
62+
/// * `device` - The device to move the tensor to.
63+
///
64+
/// # Returns
65+
///
66+
/// The tensor on the given device.
67+
fn sparse_to_device<const D: usize>(
68+
tensor: SparseTensor<Self, D>,
69+
device: &Device<Self>,
70+
) -> SparseTensor<Self, D>;
71+
72+
/// Gets the shape of the tensor.
73+
///
74+
/// # Arguments
75+
///
76+
/// * `tensor` - The tensor.
77+
///
78+
/// # Returns
79+
///
80+
/// The shape of the tensor.
81+
fn sparse_shape<const D: usize>(tensor: &SparseTensor<Self, D>) -> Shape<D>;
82+
83+
/// Converts the tensor to a data structure.
84+
///
85+
/// # Arguments
86+
///
87+
/// * `tensor` - The tensor.
88+
///
89+
/// # Returns
90+
///
91+
/// The data structure with the tensor's data.
92+
fn sparse_into_data<const D: usize>(
93+
tensor: SparseTensor<Self, D>,
94+
) -> impl Future<Output = TensorData> + Send;
95+
96+
/// Creates a tensor from the data structure.
97+
///
98+
/// # Arguments
99+
///
100+
/// * `data` - The data structure.
101+
/// * `device` - The device to create the tensor on.
102+
///
103+
/// # Returns
104+
///
105+
/// The tensor with the data.
106+
fn sparse_from_data<const D: usize>(
107+
data: TensorData,
108+
device: &Device<Self>,
109+
) -> SparseTensor<Self, D>;
110+
}
Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
use crate::decorator::FullPrecisionBridge;
2+
use crate::decorator::SparseRepresentation;
3+
use burn_tensor::backend::Backend;
4+
use core::marker::PhantomData;
5+
use derive_new::new;
6+
7+
/// Tensor backend that extends existing backends with sparse tensor support.
8+
/// This backend abstracts over all backends, and so lacks the performance of a direct implementation.
9+
/// Backends implementing SparseDecorator should be used directly where possible.
10+
#[derive(new, Clone, Copy, Default, Debug)]
11+
pub struct SparseDecorator<B: Backend, R: SparseRepresentation> {
12+
_p: PhantomData<B>,
13+
_r: PhantomData<R>,
14+
}
15+
16+
impl<B: Backend, R: SparseRepresentation> Backend for SparseDecorator<B, R> {
17+
type Device = B::Device;
18+
19+
type FullPrecisionBridge = FullPrecisionBridge<B::FullPrecisionBridge>;
20+
21+
type FloatTensorPrimitive<const D: usize> = B::FloatTensorPrimitive<D>;
22+
23+
type FloatElem = B::FloatElem;
24+
25+
type IntTensorPrimitive<const D: usize> = B::IntTensorPrimitive<D>;
26+
27+
type IntElem = B::IntElem;
28+
29+
type BoolTensorPrimitive<const D: usize> = B::BoolTensorPrimitive<D>;
30+
31+
type QuantizedTensorPrimitive<const D: usize> = B::QuantizedTensorPrimitive<D>;
32+
33+
fn name() -> String {
34+
format!("SparseDecorator<{}>", B::name())
35+
}
36+
37+
fn seed(seed: u64) {
38+
B::seed(seed)
39+
}
40+
}
41+
42+
impl<B: Backend, R: SparseRepresentation> SparseDecorator<B, R> {}
Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
mod backend;
2+
mod ops;
3+
mod precision_bridge;
4+
mod representation;
5+
mod sparse_coo;
6+
mod sparse_csr;
7+
8+
pub use backend::*;
9+
pub use precision_bridge::*;
10+
pub use representation::*;

0 commit comments

Comments
 (0)