Skip to content

Commit 24daafd

Browse files
Fixed errors from moving sparse
1 parent 60ee8e6 commit 24daafd

File tree

3 files changed

+115
-22
lines changed

3 files changed

+115
-22
lines changed

crates/burn-sparse/src/backend/kind.rs

+94
Original file line numberDiff line numberDiff line change
@@ -58,4 +58,98 @@ impl<B: SparseBackend> BasicOps<B> for Sparse {
5858
) -> Self::Primitive<D1> {
5959
B::sparse_slice(tensor, ranges)
6060
}
61+
62+
fn reshape<const D1: usize, const D2: usize>(
63+
tensor: Self::Primitive<D1>,
64+
shape: Shape<D2>,
65+
) -> Self::Primitive<D2> {
66+
todo!()
67+
}
68+
69+
fn transpose<const D: usize>(tensor: Self::Primitive<D>) -> Self::Primitive<D> {
70+
todo!()
71+
}
72+
73+
fn swap_dims<const D: usize>(
74+
tensor: Self::Primitive<D>,
75+
dim1: usize,
76+
dim2: usize,
77+
) -> Self::Primitive<D> {
78+
todo!()
79+
}
80+
81+
fn permute<const D: usize>(tensor: Self::Primitive<D>, axes: [usize; D]) -> Self::Primitive<D> {
82+
todo!()
83+
}
84+
85+
fn flip<const D: usize>(tensor: Self::Primitive<D>, axes: &[usize]) -> Self::Primitive<D> {
86+
todo!()
87+
}
88+
89+
fn slice_assign<const D1: usize, const D2: usize>(
90+
tensor: Self::Primitive<D1>,
91+
ranges: [Range<usize>; D2],
92+
value: Self::Primitive<D1>,
93+
) -> Self::Primitive<D1> {
94+
todo!()
95+
}
96+
97+
fn repeat<const D: usize>(
98+
tensor: Self::Primitive<D>,
99+
dim: usize,
100+
times: usize,
101+
) -> Self::Primitive<D> {
102+
todo!()
103+
}
104+
105+
fn cat<const D: usize>(vectors: Vec<Self::Primitive<D>>, dim: usize) -> Self::Primitive<D> {
106+
todo!()
107+
}
108+
109+
fn equal<const D: usize>(
110+
lhs: Self::Primitive<D>,
111+
rhs: Self::Primitive<D>,
112+
) -> burn_tensor::Tensor<B, D, burn_tensor::Bool> {
113+
todo!()
114+
}
115+
116+
fn not_equal<const D: usize>(
117+
lhs: Self::Primitive<D>,
118+
rhs: Self::Primitive<D>,
119+
) -> burn_tensor::Tensor<B, D, burn_tensor::Bool> {
120+
todo!()
121+
}
122+
123+
fn any<const D: usize>(
124+
tensor: Self::Primitive<D>,
125+
) -> burn_tensor::Tensor<B, 1, burn_tensor::Bool> {
126+
todo!()
127+
}
128+
129+
fn any_dim<const D: usize>(
130+
tensor: Self::Primitive<D>,
131+
dim: usize,
132+
) -> burn_tensor::Tensor<B, D, burn_tensor::Bool> {
133+
todo!()
134+
}
135+
136+
fn all<const D: usize>(
137+
tensor: Self::Primitive<D>,
138+
) -> burn_tensor::Tensor<B, 1, burn_tensor::Bool> {
139+
todo!()
140+
}
141+
142+
fn all_dim<const D: usize>(
143+
tensor: Self::Primitive<D>,
144+
dim: usize,
145+
) -> burn_tensor::Tensor<B, D, burn_tensor::Bool> {
146+
todo!()
147+
}
148+
149+
fn expand<const D1: usize, const D2: usize>(
150+
tensor: Self::Primitive<D1>,
151+
shape: Shape<D2>,
152+
) -> Self::Primitive<D2> {
153+
todo!()
154+
}
61155
}

crates/burn-sparse/src/decorator/sparse_coo.rs

+9-8
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1+
use crate::backend::SparseBackend;
2+
use crate::backend::SparseTensor;
13
use crate::decorator::SparseCOO;
24
use crate::decorator::SparseDecorator;
35
use burn_tensor::{
4-
backend::Backend, ops::SparseTensor, sparse_backend::SparseBackend, ElementConversion, Float,
5-
Int, Shape, Tensor, TensorData, TensorPrimitive,
6+
backend::Backend, ElementConversion, Float, Int, Shape, Tensor, TensorData, TensorPrimitive,
67
};
78

89
#[derive(Clone, Debug)]
@@ -174,9 +175,9 @@ where
174175
}
175176

176177
fn sparse_to_device<const D: usize>(
177-
tensor: burn_tensor::ops::SparseTensor<Self, D>,
178+
tensor: SparseTensor<Self, D>,
178179
device: &burn_tensor::Device<Self>,
179-
) -> burn_tensor::ops::SparseTensor<Self, D> {
180+
) -> SparseTensor<Self, D> {
180181
SparseCOOTensor {
181182
coordinates: tensor.coordinates.to_device(device),
182183
values: tensor.values.to_device(device),
@@ -193,7 +194,7 @@ where
193194
fn sparse_empty<const D: usize>(
194195
shape: burn_tensor::Shape<D>,
195196
device: &burn_tensor::Device<B>,
196-
) -> burn_tensor::ops::SparseTensor<Self, D> {
197+
) -> SparseTensor<Self, D> {
197198
SparseCOOTensor {
198199
coordinates: Tensor::from_primitive(B::int_empty(
199200
burn_tensor::Shape::new([0, 0]),
@@ -210,7 +211,7 @@ where
210211
fn sparse_slice<const D1: usize, const D2: usize>(
211212
tensor: Self::SparseTensorPrimitive<D1>,
212213
indices: [std::ops::Range<usize>; D2],
213-
) -> burn_tensor::ops::SparseTensor<Self, D1> {
214+
) -> SparseTensor<Self, D1> {
214215
let SparseCOOTensor {
215216
coordinates,
216217
values,
@@ -259,13 +260,13 @@ where
259260
fn sparse_from_data<const D: usize>(
260261
data: TensorData,
261262
device: &burn_tensor::Device<Self>,
262-
) -> burn_tensor::ops::SparseTensor<Self, D> {
263+
) -> SparseTensor<Self, D> {
263264
let dense = B::float_from_data(data, &device);
264265
Self::sparse_to_sparse(dense)
265266
}
266267

267268
fn sparse_into_data<const D: usize>(
268-
tensor: burn_tensor::ops::SparseTensor<Self, D>,
269+
tensor: SparseTensor<Self, D>,
269270
) -> impl std::future::Future<Output = TensorData> + Send {
270271
// TODO this could be way better
271272
B::float_into_data(Self::sparse_to_dense(tensor))

crates/burn-sparse/src/decorator/sparse_csr.rs

+12-14
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
use crate::backend::SparseBackend;
2+
use crate::backend::SparseTensor;
13
use crate::decorator::SparseCSR;
24
use crate::decorator::SparseDecorator;
3-
use burn_tensor::{backend::Backend, sparse_backend::SparseBackend};
5+
use burn_tensor::backend::Backend;
46
use core::marker::PhantomData;
57

68
#[derive(Debug, Default, Clone)]
@@ -17,7 +19,7 @@ where
1719
fn sparse_empty<const D: usize>(
1820
shape: burn_tensor::Shape<D>,
1921
device: &burn_tensor::Device<Self>,
20-
) -> burn_tensor::ops::SparseTensor<Self, D> {
22+
) -> SparseTensor<Self, D> {
2123
todo!()
2224
}
2325

@@ -48,41 +50,37 @@ where
4850
}
4951

5052
fn sparse_slice<const D1: usize, const D2: usize>(
51-
tensor: burn_tensor::ops::SparseTensor<Self, D1>,
53+
tensor: SparseTensor<Self, D1>,
5254
indices: [std::ops::Range<usize>; D2],
53-
) -> burn_tensor::ops::SparseTensor<Self, D1> {
55+
) -> SparseTensor<Self, D1> {
5456
todo!()
5557
}
5658

57-
fn sparse_device<const D: usize>(
58-
tensor: &burn_tensor::ops::SparseTensor<Self, D>,
59-
) -> burn_tensor::Device<Self> {
59+
fn sparse_device<const D: usize>(tensor: &SparseTensor<Self, D>) -> burn_tensor::Device<Self> {
6060
todo!()
6161
}
6262

6363
fn sparse_to_device<const D: usize>(
64-
tensor: burn_tensor::ops::SparseTensor<Self, D>,
64+
tensor: SparseTensor<Self, D>,
6565
device: &burn_tensor::Device<Self>,
66-
) -> burn_tensor::ops::SparseTensor<Self, D> {
66+
) -> SparseTensor<Self, D> {
6767
todo!()
6868
}
6969

70-
fn sparse_shape<const D: usize>(
71-
tensor: &burn_tensor::ops::SparseTensor<Self, D>,
72-
) -> burn_tensor::Shape<D> {
70+
fn sparse_shape<const D: usize>(tensor: &SparseTensor<Self, D>) -> burn_tensor::Shape<D> {
7371
todo!()
7472
}
7573

7674
fn sparse_into_data<const D: usize>(
77-
tensor: burn_tensor::ops::SparseTensor<Self, D>,
75+
tensor: SparseTensor<Self, D>,
7876
) -> impl std::future::Future<Output = burn_tensor::TensorData> + Send {
7977
async { todo!() }
8078
}
8179

8280
fn sparse_from_data<const D: usize>(
8381
data: burn_tensor::TensorData,
8482
device: &burn_tensor::Device<Self>,
85-
) -> burn_tensor::ops::SparseTensor<Self, D> {
83+
) -> SparseTensor<Self, D> {
8684
todo!()
8785
}
8886
}

0 commit comments

Comments
 (0)