Skip to content

Commit 3e73353

Browse files
committed
[WIP] towards pytorch.unfold()
1 parent c339df5 commit 3e73353

File tree

23 files changed

+492
-14
lines changed

23 files changed

+492
-14
lines changed

crates/burn-autodiff/src/ops/bool_tensor.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,4 +107,8 @@ impl<B: Backend, C: CheckpointStrategy> BoolTensorOps<Self> for Autodiff<B, C> {
107107
fn bool_repeat_dim(tensor: BoolTensor<B>, dim: usize, times: usize) -> BoolTensor<B> {
108108
B::bool_repeat_dim(tensor, dim, times)
109109
}
110+
111+
fn bool_unfold(tensor: BoolTensor<Self>, dim: usize, size: usize, step: usize) -> BoolTensor<Self> {
112+
B::bool_unfold(tensor, dim, size, step)
113+
}
110114
}

crates/burn-autodiff/src/ops/int_tensor.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,4 +377,8 @@ impl<B: Backend, C: CheckpointStrategy> IntTensorOps<Self> for Autodiff<B, C> {
377377
fn int_cast(tensor: IntTensor<Self>, dtype: IntDType) -> IntTensor<Self> {
378378
B::int_cast(tensor, dtype)
379379
}
380+
381+
fn int_unfold(tensor: IntTensor<Self>, dim: usize, size: usize, step: usize) -> IntTensor<Self> {
382+
B::int_unfold(tensor, dim, size, step)
383+
}
380384
}

crates/burn-autodiff/src/ops/tensor.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2592,6 +2592,10 @@ impl<B: Backend, C: CheckpointStrategy> FloatTensorOps<Self> for Autodiff<B, C>
25922592

25932593
// TODO: Implement float_prod and float_sum
25942594
// https://github.com/tracel-ai/burn/issues/1458
2595+
2596+
fn float_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
2597+
AutodiffTensor::new(B::float_unfold(tensor.primitive, dim, size, step))
2598+
}
25952599
}
25962600

25972601
#[derive(Debug, Clone)]

crates/burn-cubecl/src/ops/base.rs

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
1+
use std::cmp::max;
12
use crate::{CubeRuntime, element::CubeElement, kernel, tensor::CubeTensor};
23
use burn_common::tensor::{ReshapeAction, reshape_action};
3-
use burn_tensor::{
4-
Shape, TensorData,
5-
quantization::{QTensorPrimitive, QuantLevel},
6-
};
4+
use burn_tensor::{Shape, TensorData, quantization::{QTensorPrimitive, QuantLevel}};
75
use cubecl::{server::CopyDescriptor, tensor_vectorization_factor};
86

97
pub(crate) fn from_data<R: CubeRuntime>(data: TensorData, device: &R::Device) -> CubeTensor<R> {
@@ -213,3 +211,46 @@ pub(crate) fn max_line_size_many<R: CubeRuntime>(tensors: &[&CubeTensor<R>], dim
213211

214212
vec.unwrap_or(0)
215213
}
214+
215+
/// Unfold windows along a dimension.
216+
///
217+
/// Returns a view of the tensor with all complete windows of size `size` in dimension `dim`;
218+
/// where windows are advanced by `step` at each index.
219+
///
220+
/// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
221+
///
222+
/// # Arguments
223+
///
224+
/// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
225+
/// * `dim` - the dimension to unfold.
226+
/// * `size` - the size of each unfolded window.
227+
/// * `stride` - the step between each window.
228+
///
229+
/// # Returns
230+
///
231+
/// A tensor view with shape ``[pre=..., windows, size, post=...]``.
232+
pub fn unfold<R: CubeRuntime>(
233+
tensor: CubeTensor<R>,
234+
dim: usize,
235+
size: usize,
236+
step: usize,
237+
) -> CubeTensor<R> {
238+
let d_shape = tensor.shape.dims[dim];
239+
let d_stride = tensor.strides[dim];
240+
241+
let windows = max(0, (d_shape - size).div_ceil(step));
242+
243+
let mut shape = tensor.shape.clone();
244+
shape.dims[dim] = windows;
245+
shape.dims.insert(dim + 1, size);
246+
247+
let mut strides = tensor.strides.clone();
248+
strides[dim] = step * d_stride;
249+
strides.insert(dim + 1, d_stride);
250+
251+
CubeTensor {
252+
shape,
253+
strides,
254+
..tensor
255+
}
256+
}

crates/burn-cubecl/src/ops/bool_ops.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ use burn_tensor::ops::{BoolTensor, BoolTensorOps, Device, FloatTensor, IntTensor
77
use burn_tensor::{Shape, TensorData};
88
use std::ops::Range;
99

10-
use super::{expand, numeric, permute};
10+
use super::{expand, numeric, permute, unfold};
1111

1212
impl<R, F, I, BT> BoolTensorOps<Self> for CubeBackend<R, F, I, BT>
1313
where
@@ -141,4 +141,8 @@ where
141141
fn bool_flip(tensor: BoolTensor<Self>, axes: &[usize]) -> BoolTensor<Self> {
142142
kernel::flip::<R, BT, BT>(tensor, axes)
143143
}
144+
145+
fn bool_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
146+
unfold(tensor, dim, size, step)
147+
}
144148
}

crates/burn-cubecl/src/ops/float_ops.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
use super::{expand, numeric, permute};
1+
use super::{expand, numeric, permute, unfold};
22
use crate::kernel::prng::{random_bernoulli, random_normal, random_uniform};
33
use crate::kernel::unary_basic::BasicFloatUnaryKind;
44
use crate::kernel::{
@@ -702,4 +702,8 @@ where
702702
_ => unimplemented!("Unsupported floating point type cast"),
703703
}
704704
}
705+
706+
fn float_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
707+
unfold(tensor, dim, size, step)
708+
}
705709
}

crates/burn-cubecl/src/ops/int_ops.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use self::unary_basic_int::BasicIntUnaryKind;
22

3-
use super::{expand, numeric, permute};
3+
use super::{expand, numeric, permute, unfold};
44
use crate::{
55
CubeBackend, CubeRuntime, FloatElement, IntElement,
66
kernel::{
@@ -684,4 +684,8 @@ where
684684
}
685685
)
686686
}
687+
688+
fn int_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
689+
unfold(tensor, dim, size, step)
690+
}
687691
}

crates/burn-fusion/src/ops/boolean.rs

Lines changed: 52 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,5 @@
1-
use burn_ir::{
2-
BaseOperationIr, BinaryOpIr, BoolOperationIr, CatOpIr, ExpandOpIr, FlipOpIr, HandleContainer,
3-
InitOperationIr, OperationIr, PermuteOpIr, RepeatDimOpIr, SliceAssignOpIr, SliceOpIr,
4-
SwapDimsOpIr, TensorIr, UnaryOpIr,
5-
};
1+
use std::cmp::max;
2+
use burn_ir::{BaseOperationIr, BinaryOpIr, BoolOperationIr, CatOpIr, ExpandOpIr, FlipOpIr, HandleContainer, InitOperationIr, OperationIr, PermuteOpIr, RepeatDimOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, TensorIr, UnaryOpIr, UnfoldOpIr};
63
use burn_tensor::{
74
Device, Element, Shape, Slice, TensorData, TensorMetadata,
85
ops::{BoolTensor, BoolTensorOps, FloatTensor, IntTensor, binary_ops_shape},
@@ -744,4 +741,54 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
744741

745742
out
746743
}
744+
745+
fn bool_unfold(tensor: BoolTensor<Self>, dim: usize, size: usize, step: usize) -> BoolTensor<Self> {
746+
#[derive(new, Debug)]
747+
struct UnfoldOps<B: FusionBackend> {
748+
desc: UnfoldOpIr,
749+
_b: PhantomData<B>,
750+
}
751+
752+
impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
753+
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
754+
let input = handles.get_bool_tensor::<B>(&self.desc.input);
755+
let output = B::bool_unfold(
756+
input,
757+
self.desc.dim,
758+
self.desc.size,
759+
self.desc.step);
760+
761+
handles.register_bool_tensor::<B>(&self.desc.out.id, output);
762+
}
763+
}
764+
765+
let mut streams = OperationStreams::default();
766+
streams.tensor(&tensor);
767+
768+
let mut shape = tensor.shape().dims.clone();
769+
let d_shape = shape[dim];
770+
let windows = max(0, (d_shape - size).div_ceil(step));
771+
shape[dim] = windows;
772+
shape.insert(dim + 1, size);
773+
774+
let out = tensor
775+
.client
776+
.tensor_uninitialized(shape.clone(), tensor.dtype);
777+
778+
let desc = UnfoldOpIr {
779+
input: tensor.into_ir(),
780+
out: out.to_ir_out(),
781+
dim: dim,
782+
size: size,
783+
step: step,
784+
};
785+
786+
out.client.register(
787+
streams,
788+
OperationIr::BaseBool(BaseOperationIr::Unfold(desc.clone())),
789+
UnfoldOps::<B>::new(desc),
790+
);
791+
792+
out
793+
}
747794
}

crates/burn-fusion/src/ops/float.rs

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ use burn_tensor::{
1313
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor, binary_ops_shape},
1414
};
1515
use std::{marker::PhantomData, ops::Range};
16-
16+
use std::cmp::max;
1717
use super::NoOp;
1818

1919
impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
@@ -2259,4 +2259,54 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
22592259

22602260
out
22612261
}
2262+
2263+
fn float_unfold(tensor: FloatTensor<Self>, dim: usize, size: usize, step: usize) -> FloatTensor<Self> {
2264+
#[derive(new, Debug)]
2265+
struct UnfoldOps<B: FusionBackend> {
2266+
desc: UnfoldOpIr,
2267+
_b: PhantomData<B>,
2268+
}
2269+
2270+
impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
2271+
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2272+
let input = handles.get_float_tensor::<B>(&self.desc.input);
2273+
let output = B::float_unfold(
2274+
input,
2275+
self.desc.dim,
2276+
self.desc.size,
2277+
self.desc.step);
2278+
2279+
handles.register_float_tensor::<B>(&self.desc.out.id, output);
2280+
}
2281+
}
2282+
2283+
let mut streams = OperationStreams::default();
2284+
streams.tensor(&tensor);
2285+
2286+
let mut shape = tensor.shape().dims.clone();
2287+
let d_shape = shape[dim];
2288+
let windows = max(0, (d_shape - size).div_ceil(step));
2289+
shape[dim] = windows;
2290+
shape.insert(dim + 1, size);
2291+
2292+
let out = tensor
2293+
.client
2294+
.tensor_uninitialized(shape.clone(), tensor.dtype);
2295+
2296+
let desc = UnfoldOpIr {
2297+
input: tensor.into_ir(),
2298+
out: out.to_ir_out(),
2299+
dim: dim,
2300+
size: size,
2301+
step: step,
2302+
};
2303+
2304+
out.client.register(
2305+
streams,
2306+
OperationIr::BaseFloat(BaseOperationIr::Unfold(desc.clone())),
2307+
UnfoldOps::<B>::new(desc),
2308+
);
2309+
2310+
out
2311+
}
22622312
}

crates/burn-fusion/src/ops/int.rs

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use burn_tensor::{
1111
ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps, binary_ops_shape},
1212
};
1313
use core::ops::Range;
14+
use std::cmp::max;
1415
use std::marker::PhantomData;
1516

1617
use super::NoOp;
@@ -2171,4 +2172,54 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
21712172

21722173
out
21732174
}
2175+
2176+
fn int_unfold(tensor: IntTensor<Self>, dim: usize, size: usize, step: usize) -> IntTensor<Self> {
2177+
#[derive(new, Debug)]
2178+
struct UnfoldOps<B: FusionBackend> {
2179+
desc: UnfoldOpIr,
2180+
_b: PhantomData<B>,
2181+
}
2182+
2183+
impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> {
2184+
fn execute(&self, handles: &mut HandleContainer<B::Handle>) {
2185+
let input = handles.get_int_tensor::<B>(&self.desc.input);
2186+
let output = B::int_unfold(
2187+
input,
2188+
self.desc.dim,
2189+
self.desc.size,
2190+
self.desc.step);
2191+
2192+
handles.register_int_tensor::<B>(&self.desc.out.id, output);
2193+
}
2194+
}
2195+
2196+
let mut streams = OperationStreams::default();
2197+
streams.tensor(&tensor);
2198+
2199+
let mut shape = tensor.shape().dims.clone();
2200+
let d_shape = shape[dim];
2201+
let windows = max(0, (d_shape - size).div_ceil(step));
2202+
shape[dim] = windows;
2203+
shape.insert(dim + 1, size);
2204+
2205+
let out = tensor
2206+
.client
2207+
.tensor_uninitialized(shape.clone(), tensor.dtype);
2208+
2209+
let desc = UnfoldOpIr {
2210+
input: tensor.into_ir(),
2211+
out: out.to_ir_out(),
2212+
dim: dim,
2213+
size: size,
2214+
step: step,
2215+
};
2216+
2217+
out.client.register(
2218+
streams,
2219+
OperationIr::BaseInt(BaseOperationIr::Unfold(desc.clone())),
2220+
UnfoldOps::<B>::new(desc),
2221+
);
2222+
2223+
out
2224+
}
21742225
}

0 commit comments

Comments
 (0)