Skip to content

Commit b538f24

Browse files
committed
Add cumsum tensor op
1 parent ec8e45a commit b538f24

File tree

33 files changed

+395
-3
lines changed

33 files changed

+395
-3
lines changed

burn-book/src/building-blocks/tensor.md

+1
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,7 @@ Those operations are available for numeric tensor kinds: `Float` and `Int`.
196196
| `tensor.clamp_max(max)` | `torch.clamp(tensor, max=max)` |
197197
| `tensor.clamp_min(min)` | `torch.clamp(tensor, min=min)` |
198198
| `tensor.contains_nan()` | N/A |
199+
| `tensor.cumsum(dim)` | `tensor.cumsum(dim)` |
199200
| `tensor.div(other)` or `tensor / other` | `tensor / other` |
200201
| `tensor.div_scalar(scalar)` or `tensor / scalar` | `tensor / scalar` |
201202
| `tensor.equal_elem(other)` | `tensor.eq(other)` |

crates/burn-autodiff/src/ops/int_tensor.rs

+4
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,10 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
127127
B::int_sum(tensor)
128128
}
129129

130+
fn int_cumsum(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
131+
B::int_cumsum(tensor, dim)
132+
}
133+
130134
fn int_sum_dim(tensor: IntTensor<B>, dim: usize) -> IntTensor<B> {
131135
B::int_sum_dim(tensor, dim)
132136
}

crates/burn-autodiff/src/ops/tensor.rs

+32
Original file line numberDiff line numberDiff line change
@@ -1488,6 +1488,38 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
14881488
}
14891489
}
14901490

1491+
fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1492+
#[derive(Debug)]
1493+
struct CumSum;
1494+
1495+
impl<B: Backend> Backward<B, 1> for CumSum {
1496+
type State = usize;
1497+
1498+
fn backward(
1499+
self,
1500+
ops: Ops<Self::State, 1>,
1501+
grads: &mut Gradients,
1502+
_checkpointer: &mut Checkpointer,
1503+
) {
1504+
let dim = ops.state;
1505+
1506+
unary::<B, _>(ops.parents, ops.node, grads, |grad| {
1507+
let cumsum = B::float_cumsum(grad.clone(), dim);
1508+
B::float_flip(cumsum.clone(), &[dim])
1509+
});
1510+
}
1511+
}
1512+
1513+
match CumSum
1514+
.prepare::<C>([tensor.node])
1515+
.compute_bound()
1516+
.stateful()
1517+
{
1518+
OpsKind::Tracked(prep) => prep.finish(dim, B::float_cumsum(tensor.primitive, dim)),
1519+
OpsKind::UnTracked(prep) => prep.finish(B::float_cumsum(tensor.primitive, dim)),
1520+
}
1521+
}
1522+
14911523
fn float_mean_dim(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
14921524
#[derive(Debug)]
14931525
struct MeanDim;
+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
#[burn_tensor_testgen::testgen(ad_cumsum)]
2+
mod tests {
3+
use super::*;
4+
use burn_tensor::{loss, Tensor, TensorData};
5+
6+
#[test]
7+
fn should_diff_cumsum() {
8+
let device = Default::default();
9+
let tensor_0 =
10+
TestAutodiffTensor::<2>::from_data([[0.0, 1.0, 2.0], [3.0, 4.0, 5.0]], &device)
11+
.require_grad();
12+
13+
let dim = 1;
14+
let tensor_1 = tensor_0.clone().cumsum(dim);
15+
16+
let grads = tensor_1.backward();
17+
18+
let grad_0 = tensor_0.grad(&grads).unwrap();
19+
let grad_0_expected = TensorData::from([[3., 2., 1.], [3., 2., 1.]]);
20+
grad_0.into_data().assert_approx_eq(&grad_0_expected, 2);
21+
}
22+
}

crates/burn-autodiff/src/tests/mod.rs

+2
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ mod conv_transpose2d;
2222
mod conv_transpose3d;
2323
mod cos;
2424
mod cross_entropy;
25+
mod cumsum;
2526
mod deform_conv2d;
2627
mod div;
2728
mod erf;
@@ -188,5 +189,6 @@ macro_rules! testgen_with_float_param {
188189
burn_autodiff::testgen_ad_expand!();
189190
burn_autodiff::testgen_ad_sort!();
190191
burn_autodiff::testgen_ad_repeat_dim!();
192+
burn_autodiff::testgen_ad_cumsum!();
191193
};
192194
}

crates/burn-candle/src/lib.rs

+2
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,7 @@ mod tests {
9595
burn_tensor::testgen_round!();
9696
burn_tensor::testgen_floor!();
9797
burn_tensor::testgen_ceil!();
98+
burn_tensor::testgen_cumsum!();
9899

99100
// TODO: https://github.com/tracel-ai/burn/issues/1237
100101
//
@@ -175,4 +176,5 @@ mod tests {
175176
burn_autodiff::testgen_ad_round!();
176177
burn_autodiff::testgen_ad_floor!();
177178
burn_autodiff::testgen_ad_ceil!();
179+
burn_autodiff::testgen_ad_cumsum!();
178180
}

crates/burn-candle/src/ops/base.rs

+28
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,31 @@ pub fn mask_where_broadcasted(
145145

146146
CandleTensor::new(mask.tensor.where_cond(&value.tensor, &tensor).unwrap())
147147
}
148+
149+
// Taken from: https://github.com/mokeyish/candle-ext/blob/main/src/cumsum.rs
150+
fn cumsum_ext<D: candle_core::shape::Dim>(
151+
input: &candle_core::Tensor,
152+
dim: D,
153+
) -> candle_core::Result<candle_core::Tensor> {
154+
let dim = dim.to_index(input.shape(), "cumsum")?;
155+
let dim_size = input.dim(dim)?;
156+
157+
let mut tensors = Vec::with_capacity(dim_size);
158+
159+
let mut a = input.clone();
160+
for i in 0..dim_size {
161+
if i > 0 {
162+
a = a.narrow(dim, 1, dim_size - i)?;
163+
let b = input.narrow(dim, 0, dim_size - i)?;
164+
a = (a + b)?;
165+
}
166+
tensors.push(a.narrow(dim, 0, 1)?);
167+
}
168+
let cumsum = candle_core::Tensor::cat(&tensors, dim)?;
169+
Ok(cumsum)
170+
}
171+
172+
/// Cumulative sum (used for int tensors since the default candle implementation uses matmul).
173+
pub fn cumsum(tensor: CandleTensor, dim: usize) -> CandleTensor {
174+
CandleTensor::new(cumsum_ext(&tensor.tensor, dim).unwrap())
175+
}

crates/burn-candle/src/ops/int_tensor.rs

+4
Original file line numberDiff line numberDiff line change
@@ -372,4 +372,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> IntTensorOps<Self> for Candle<F
372372
fn int_sign(tensor: IntTensor<Self>) -> IntTensor<Self> {
373373
sign(tensor)
374374
}
375+
376+
fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
377+
super::base::cumsum(tensor, dim)
378+
}
375379
}

crates/burn-candle/src/ops/tensor.rs

+4
Original file line numberDiff line numberDiff line change
@@ -481,4 +481,8 @@ impl<F: FloatCandleElement, I: IntCandleElement> FloatTensorOps<Self> for Candle
481481
CandleTensor::new(tensor.tensor.to_dtype(dtype).unwrap())
482482
}
483483
}
484+
485+
fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
486+
CandleTensor::new(tensor.tensor.cumsum(dim).unwrap())
487+
}
484488
}

crates/burn-fusion/src/ops/float.rs

+27
Original file line numberDiff line numberDiff line change
@@ -2263,4 +2263,31 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
22632263

22642264
out
22652265
}
2266+
2267+
fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
2268+
scalar_float_ops!(CumsumOps, B::float_cumsum, usize, noconvert);
2269+
2270+
let stream = tensor.stream;
2271+
let dtype = tensor.dtype;
2272+
let shape = tensor.shape.clone();
2273+
let out = tensor
2274+
.client
2275+
.tensor_uninitialized(shape, B::FloatElem::dtype());
2276+
2277+
let desc = ScalarOperationDescription {
2278+
lhs: tensor.into_description(),
2279+
rhs: dim,
2280+
out: out.to_description_out(),
2281+
};
2282+
out.client.register(
2283+
vec![stream],
2284+
OperationDescription::NumericFloat(
2285+
dtype,
2286+
NumericOperationDescription::CumSum(desc.clone()),
2287+
),
2288+
CumsumOps::<B>::new(desc),
2289+
);
2290+
2291+
out
2292+
}
22662293
}

crates/burn-fusion/src/ops/int.rs

+27
Original file line numberDiff line numberDiff line change
@@ -1819,4 +1819,31 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
18191819

18201820
out
18211821
}
1822+
1823+
fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1824+
scalar_int_ops!(CumsumOps, B::int_cumsum, usize, noconvert);
1825+
1826+
let stream = tensor.stream;
1827+
let dtype = tensor.dtype;
1828+
let shape = tensor.shape.clone();
1829+
let out = tensor
1830+
.client
1831+
.tensor_uninitialized(shape, B::FloatElem::dtype());
1832+
1833+
let desc = ScalarOperationDescription {
1834+
lhs: tensor.into_description(),
1835+
rhs: dim,
1836+
out: out.to_description_out(),
1837+
};
1838+
out.client.register(
1839+
vec![stream],
1840+
OperationDescription::NumericInt(
1841+
dtype,
1842+
NumericOperationDescription::CumSum(desc.clone()),
1843+
),
1844+
CumsumOps::<B>::new(desc),
1845+
);
1846+
1847+
out
1848+
}
18221849
}

crates/burn-fusion/src/stream/context.rs

+7
Original file line numberDiff line numberDiff line change
@@ -961,6 +961,13 @@ impl<E: Element> RelativeOpsScalar<E> for NumericOperationDescription<E> {
961961
out: desc.out.to_relative(converter),
962962
})
963963
}
964+
NumericOperationDescription::CumSum(desc) => {
965+
NumericOperationDescription::CumSum(ScalarOperationDescription {
966+
lhs: desc.lhs.to_relative(converter),
967+
rhs: desc.rhs,
968+
out: desc.out.to_relative(converter),
969+
})
970+
}
964971
}
965972
}
966973
}

crates/burn-jit/src/ops/float_ops.rs

+4
Original file line numberDiff line numberDiff line change
@@ -665,4 +665,8 @@ where
665665
_ => unimplemented!("Unsupported floating point type cast"),
666666
}
667667
}
668+
669+
fn float_cumsum(_tensor: FloatTensor<Self>, _dim: usize) -> FloatTensor<Self> {
670+
todo!()
671+
}
668672
}

crates/burn-jit/src/ops/int_ops.rs

+4
Original file line numberDiff line numberDiff line change
@@ -283,4 +283,8 @@ where
283283
fn int_flip(tensor: IntTensor<Self>, axes: &[usize]) -> IntTensor<Self> {
284284
kernel::flip::<R, I, BT>(tensor, axes)
285285
}
286+
287+
fn int_cumsum(_tensor: IntTensor<Self>, _dim: usize) -> IntTensor<Self> {
288+
todo!()
289+
}
286290
}

crates/burn-ndarray/src/ops/base.rs

+10
Original file line numberDiff line numberDiff line change
@@ -262,6 +262,16 @@ where
262262
NdArrayTensor::from_data(data)
263263
}
264264

265+
pub fn cumsum(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<E> {
266+
let mut array = tensor.array.into_owned();
267+
array.accumulate_axis_inplace(Axis(dim), |&prev, curr| {
268+
*curr += prev;
269+
});
270+
let array = array.into_shared();
271+
272+
NdArrayTensor { array }
273+
}
274+
265275
pub fn mean_dim(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<E> {
266276
let ndims = tensor.shape().num_dims();
267277
match ndims {

crates/burn-ndarray/src/ops/int_tensor.rs

+4
Original file line numberDiff line numberDiff line change
@@ -351,4 +351,8 @@ impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> IntTensorOps
351351
fn int_expand(tensor: NdArrayTensor<I>, shape: Shape) -> NdArrayTensor<I> {
352352
NdArrayOps::expand(tensor, shape)
353353
}
354+
355+
fn int_cumsum(tensor: NdArrayTensor<I>, dim: usize) -> NdArrayTensor<I> {
356+
NdArrayMathOps::cumsum(tensor, dim)
357+
}
354358
}

crates/burn-ndarray/src/ops/tensor.rs

+4
Original file line numberDiff line numberDiff line change
@@ -575,4 +575,8 @@ impl<E: FloatNdArrayElement, I: IntNdArrayElement, Q: QuantElement> FloatTensorO
575575
_ => panic!("Invalid cast types"),
576576
}
577577
}
578+
579+
fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
580+
execute_with_float_dtype!(tensor, |tensor| NdArrayMathOps::cumsum(tensor, dim))
581+
}
578582
}

crates/burn-router/src/ops/op_float.rs

+19
Original file line numberDiff line numberDiff line change
@@ -1491,4 +1491,23 @@ impl<R: RunnerChannel> FloatTensorOps<Self> for BackendRouter<R> {
14911491

14921492
out
14931493
}
1494+
1495+
fn float_cumsum(tensor: FloatTensor<Self>, dim: usize) -> FloatTensor<Self> {
1496+
let client = tensor.client.clone();
1497+
let dtype = tensor.dtype;
1498+
let out = client.register_empty_tensor(tensor.shape.clone(), dtype);
1499+
1500+
let desc = ScalarOperationDescription {
1501+
lhs: tensor.into_description(),
1502+
rhs: dim,
1503+
out: out.to_description_out(),
1504+
};
1505+
1506+
client.register(OperationDescription::NumericFloat(
1507+
dtype,
1508+
NumericOperationDescription::CumSum(desc),
1509+
));
1510+
1511+
out
1512+
}
14941513
}

crates/burn-router/src/ops/op_int.rs

+19
Original file line numberDiff line numberDiff line change
@@ -1173,4 +1173,23 @@ impl<R: RunnerChannel> IntTensorOps<Self> for BackendRouter<R> {
11731173

11741174
out
11751175
}
1176+
1177+
fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
1178+
let client = tensor.client.clone();
1179+
let dtype = tensor.dtype;
1180+
let out = client.register_empty_tensor(tensor.shape.clone(), dtype);
1181+
1182+
let desc = ScalarOperationDescription {
1183+
lhs: tensor.into_description(),
1184+
rhs: dim,
1185+
out: out.to_description_out(),
1186+
};
1187+
1188+
client.register(OperationDescription::NumericInt(
1189+
dtype,
1190+
NumericOperationDescription::CumSum(desc),
1191+
));
1192+
1193+
out
1194+
}
11761195
}

crates/burn-router/src/runner.rs

+6
Original file line numberDiff line numberDiff line change
@@ -573,6 +573,9 @@ impl<B: ReprBackend> RunnerClient for Runner<B> {
573573
NumericOperationDescription::Powf(desc) => {
574574
binary_float_ops!(handles, desc, B::float_powf)
575575
}
576+
NumericOperationDescription::CumSum(desc) => {
577+
scalar_float_dim_ops!(handles, desc, B::float_cumsum)
578+
}
576579
},
577580
OperationDescription::NumericInt(_dtype, op) => match op {
578581
NumericOperationDescription::Add(desc) => {
@@ -764,6 +767,9 @@ impl<B: ReprBackend> RunnerClient for Runner<B> {
764767
let output = B::int_powf(lhs, rhs);
765768
handles.register_int_tensor::<B>(&desc.out.id, output);
766769
}
770+
NumericOperationDescription::CumSum(desc) => {
771+
scalar_int_dim_ops!(handles, desc, B::int_cumsum)
772+
}
767773
},
768774
OperationDescription::Bool(op) => match op {
769775
BoolOperationDescription::IntoFloat(desc) => {

crates/burn-tch/src/ops/base.rs

+7
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,13 @@ impl TchOps {
299299
TchTensor::new(tensor)
300300
}
301301

302+
pub fn cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
303+
TchTensor::from_existing(
304+
tensor.tensor.cumsum(dim as i64, tensor.tensor.kind()),
305+
tensor.storage,
306+
)
307+
}
308+
302309
pub fn prod_dim(tensor: TchTensor, dim: usize) -> TchTensor {
303310
TchTensor::from_existing(
304311
tensor

crates/burn-tch/src/ops/int_tensor.rs

+4
Original file line numberDiff line numberDiff line change
@@ -416,4 +416,8 @@ impl<E: TchElement, Q: QuantElement> IntTensorOps<Self> for LibTorch<E, Q> {
416416
fn int_argsort(tensor: IntTensor<Self>, dim: usize, descending: bool) -> IntTensor<Self> {
417417
TchOps::argsort(tensor, dim, descending)
418418
}
419+
420+
fn int_cumsum(tensor: IntTensor<Self>, dim: usize) -> IntTensor<Self> {
421+
TchOps::cumsum(tensor, dim)
422+
}
419423
}

crates/burn-tch/src/ops/tensor.rs

+4
Original file line numberDiff line numberDiff line change
@@ -479,4 +479,8 @@ impl<E: TchElement, Q: QuantElement> FloatTensorOps<Self> for LibTorch<E, Q> {
479479
TchTensor::new(tensor.tensor.to_kind(kind))
480480
}
481481
}
482+
483+
fn float_cumsum(tensor: TchTensor, dim: usize) -> TchTensor {
484+
TchOps::cumsum(tensor, dim)
485+
}
482486
}

0 commit comments

Comments
 (0)