Skip to content

Commit 790ed5f

Browse files
some better imports and fixes
1 parent 24daafd commit 790ed5f

File tree

4 files changed

+25
-3
lines changed

4 files changed

+25
-3
lines changed

crates/burn-core/src/backend.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,4 +29,4 @@ pub use burn_tch as libtorch;
2929
pub use burn_tch::LibTorch;
3030

3131
#[cfg(feature = "sparse")]
32-
pub use burn_sparse as sparse;
32+
pub use burn_sparse::decorator as sparse;

crates/burn-core/src/tensor.rs

+5
Original file line numberDiff line numberDiff line change
@@ -1 +1,6 @@
11
pub use burn_tensor::*;
2+
3+
#[cfg(feature = "sparse")]
4+
pub mod sparse {
5+
pub use burn_sparse::backend::*;
6+
}

crates/burn-sparse/src/backend/api.rs

+18-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,14 @@
11
use crate::backend::{Sparse, SparseBackend};
22
use burn_tensor::{Int, Tensor, TensorPrimitive};
33

4-
pub trait SparseTensor<const D: usize, B>
4+
pub trait ToSparse<const D: usize, B>
5+
where
6+
B: SparseBackend,
7+
{
8+
fn into_sparse(self) -> Tensor<B, D, Sparse>;
9+
}
10+
11+
pub trait SparseTensorApi<const D: usize, B>
512
where
613
B: SparseBackend,
714
{
@@ -10,7 +17,16 @@ where
1017
fn dense(self) -> Tensor<B, D>;
1118
}
1219

13-
impl<const D: usize, B> SparseTensor<D, B> for Tensor<B, D, Sparse>
20+
impl<const D: usize, B> ToSparse<D, B> for Tensor<B, D>
21+
where
22+
B: SparseBackend,
23+
{
24+
fn into_sparse(self) -> Tensor<B, D, Sparse> {
25+
Tensor::new(B::sparse_to_sparse(self.into_primitive().tensor()))
26+
}
27+
}
28+
29+
impl<const D: usize, B> SparseTensorApi<D, B> for Tensor<B, D, Sparse>
1430
where
1531
B: SparseBackend,
1632
{

crates/burn-sparse/src/backend/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,6 @@ mod kind;
44
mod sparse_backend;
55

66
pub use alias::*;
7+
pub use api::*;
78
pub use kind::*;
89
pub use sparse_backend::*;

0 commit comments

Comments
 (0)