Skip to content

Commit 6ce3b8b

Browse files
committed
Expand unfold impl.
- switched to pytorch's return shape. - added burn-router - exposed unfold calculation module. - ndarray and candle both need either upstream support or work-arounds. candle has a PR in-flight (from me): huggingface/candle#3091
1 parent 730567e commit 6ce3b8b

File tree

16 files changed

+251
-52
lines changed

16 files changed

+251
-52
lines changed

crates/burn-candle/src/ops/base.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
1+
use std::cmp::max;
12
use std::marker::PhantomData;
23

34
use burn_tensor::{Element, Shape, TensorData, TensorMetadata, backend::Backend};
4-
use candle_core::WithDType;
5+
use candle_core::{Layout, WithDType};
56
use half::{bf16, f16};
67

78
use crate::{
@@ -133,6 +134,10 @@ pub fn expand(tensor: CandleTensor, shape: Shape) -> CandleTensor {
133134
CandleTensor::new(tensor.tensor.broadcast_as(shape.dims).unwrap())
134135
}
135136

137+
pub fn unfold(tensor: CandleTensor, dim: usize, size: usize, step: usize) -> CandleTensor {
138+
todo!()
139+
}
140+
136141
pub fn sign(tensor: CandleTensor) -> CandleTensor {
137142
CandleTensor::new(tensor.tensor.sign().unwrap())
138143
}

crates/burn-cubecl/src/ops/base.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
use crate::{CubeRuntime, element::CubeElement, kernel, tensor::CubeTensor};
22
use burn_common::tensor::{ReshapeAction, reshape_action};
3+
use burn_tensor::ops::unfold::calculate_unfold_windows;
34
use burn_tensor::{
45
Shape, TensorData,
56
quantization::{QTensorPrimitive, QuantLevel},
67
};
78
use cubecl::{server::CopyDescriptor, tensor_vectorization_factor};
8-
use std::cmp::max;
99

1010
pub(crate) fn from_data<R: CubeRuntime>(data: TensorData, device: &R::Device) -> CubeTensor<R> {
1111
let shape: Shape = (&data.shape).into();
@@ -222,6 +222,10 @@ pub(crate) fn max_line_size_many<R: CubeRuntime>(tensors: &[&CubeTensor<R>], dim
222222
///
223223
/// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
224224
///
225+
/// The new view will have the unfolded dimension replaced by two dimensions;
226+
/// one in the position of the original dimension, with size equal to the number of windows,
227+
/// and one appended to the right-most position, with size equal to `size`.
228+
///
225229
/// # Arguments
226230
///
227231
/// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
@@ -231,7 +235,7 @@ pub(crate) fn max_line_size_many<R: CubeRuntime>(tensors: &[&CubeTensor<R>], dim
231235
///
232236
/// # Returns
233237
///
234-
/// A tensor view with shape ``[pre=..., windows, size, post=...]``.
238+
/// A tensor view with the shape ``[pre=..., windows, post=..., size]``.
235239
pub fn unfold<R: CubeRuntime>(
236240
tensor: CubeTensor<R>,
237241
dim: usize,
@@ -241,15 +245,15 @@ pub fn unfold<R: CubeRuntime>(
241245
let d_shape = tensor.shape.dims[dim];
242246
let d_stride = tensor.strides[dim];
243247

244-
let windows = max(0, (d_shape - size).div_ceil(step));
248+
let windows = calculate_unfold_windows(d_shape, size, step);
245249

246250
let mut shape = tensor.shape.clone();
247251
shape.dims[dim] = windows;
248-
shape.dims.insert(dim + 1, size);
252+
shape.dims.push(size);
249253

250254
let mut strides = tensor.strides.clone();
251255
strides[dim] = step * d_stride;
252-
strides.insert(dim + 1, d_stride);
256+
strides.push(d_stride);
253257

254258
CubeTensor {
255259
shape,

crates/burn-fusion/src/ops/boolean.rs

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,21 @@
1+
use crate::{
2+
Fusion, FusionBackend,
3+
client::FusionClient,
4+
get_client,
5+
stream::{OperationStreams, StreamId, execution::Operation},
6+
};
17
use burn_ir::{
28
BaseOperationIr, BinaryOpIr, BoolOperationIr, CatOpIr, ExpandOpIr, FlipOpIr, HandleContainer,
39
InitOperationIr, OperationIr, PermuteOpIr, RepeatDimOpIr, SliceAssignOpIr, SliceOpIr,
410
SwapDimsOpIr, TensorIr, UnaryOpIr, UnfoldOpIr,
511
};
12+
use burn_tensor::ops::unfold::calculate_unfold_windows;
613
use burn_tensor::{
714
Device, Element, Shape, TensorData, TensorMetadata,
815
ops::{BoolTensor, BoolTensorOps, FloatTensor, IntTensor, binary_ops_shape},
916
};
10-
use std::cmp::max;
1117
use std::marker::PhantomData;
1218

13-
use crate::{
14-
Fusion, FusionBackend,
15-
client::FusionClient,
16-
get_client,
17-
stream::{OperationStreams, StreamId, execution::Operation},
18-
};
19-
2019
use super::NoOp;
2120

2221
impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
@@ -777,9 +776,10 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> {
777776

778777
let mut shape = tensor.shape().dims.clone();
779778
let d_shape = shape[dim];
780-
let windows = max(0, (d_shape - size).div_ceil(step));
779+
let windows = calculate_unfold_windows(d_shape, size, step);
780+
781781
shape[dim] = windows;
782-
shape.insert(dim + 1, size);
782+
shape.push(size);
783783

784784
let out = tensor
785785
.client

crates/burn-fusion/src/ops/float.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,11 @@ use crate::{
99
unary_float_ops,
1010
};
1111
use burn_ir::*;
12+
use burn_tensor::ops::unfold::calculate_unfold_windows;
1213
use burn_tensor::{
1314
Device, Distribution, Element, FloatDType, Shape, TensorData, TensorMetadata,
1415
ops::{BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntTensor, binary_ops_shape},
1516
};
16-
use std::cmp::max;
1717
use std::{marker::PhantomData, ops::Range};
1818

1919
impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
@@ -2291,9 +2291,11 @@ impl<B: FusionBackend> FloatTensorOps<Self> for Fusion<B> {
22912291

22922292
let mut shape = tensor.shape().dims.clone();
22932293
let d_shape = shape[dim];
2294-
let windows = max(0, (d_shape - size).div_ceil(step));
2294+
2295+
let windows = calculate_unfold_windows(d_shape, size, step);
2296+
22952297
shape[dim] = windows;
2296-
shape.insert(dim + 1, size);
2298+
shape.push(size);
22972299

22982300
let out = tensor
22992301
.client

crates/burn-fusion/src/ops/int.rs

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use super::NoOp;
12
use crate::{
23
Fusion, FusionBackend, binary_int_cmp_ops, binary_int_ops,
34
client::FusionClient,
@@ -6,16 +7,14 @@ use crate::{
67
unary_int_ops,
78
};
89
use burn_ir::*;
10+
use burn_tensor::ops::unfold::calculate_unfold_windows;
911
use burn_tensor::{
1012
Device, Distribution, Element, IntDType, Shape, TensorData, TensorMetadata,
1113
ops::{BoolTensor, FloatTensor, IntElem, IntTensor, IntTensorOps, binary_ops_shape},
1214
};
1315
use core::ops::Range;
14-
use std::cmp::max;
1516
use std::marker::PhantomData;
1617

17-
use super::NoOp;
18-
1918
impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
2019
fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
2120
#[derive(new, Debug)]
@@ -2204,9 +2203,11 @@ impl<B: FusionBackend> IntTensorOps<Self> for Fusion<B> {
22042203

22052204
let mut shape = tensor.shape().dims.clone();
22062205
let d_shape = shape[dim];
2207-
let windows = max(0, (d_shape - size).div_ceil(step));
2206+
2207+
let windows = calculate_unfold_windows(d_shape, size, step);
2208+
22082209
shape[dim] = windows;
2209-
shape.insert(dim + 1, size);
2210+
shape.push(size);
22102211

22112212
let out = tensor
22122213
.client

crates/burn-ndarray/src/ops/base.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ where
177177
///
178178
/// # Returns
179179
///
180-
/// A tensor view with shape ``[pre=..., windows, size, post=...]``.
180+
/// A tensor view with shape ``[pre=..., windows, post=..., size]``.
181181
#[allow(unused)]
182182
pub(crate) fn unfold(
183183
tensor: SharedArray<E>,

crates/burn-router/src/ops/op_bool.rs

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,15 @@
11
use alloc::vec::Vec;
22

3+
use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
34
use burn_ir::{
45
BaseOperationIr, BinaryOpIr, BoolOperationIr, CatOpIr, ExpandOpIr, FlipOpIr, InitOperationIr,
56
OperationIr, PermuteOpIr, RepeatDimOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, UnaryOpIr,
7+
UnfoldOpIr,
68
};
9+
use burn_tensor::ops::unfold::calculate_unfold_windows;
710
use burn_tensor::ops::{BoolTensor, BoolTensorOps, FloatElem, FloatTensor, IntElem, IntTensor};
811
use burn_tensor::{Device, Element, Shape, TensorData, TensorMetadata};
912

10-
use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
11-
1213
impl<R: RunnerChannel> BoolTensorOps<Self> for BackendRouter<R> {
1314
fn bool_empty(shape: Shape, device: &Device<Self>) -> BoolTensor<Self> {
1415
// Get the runtime client on which to register the operation for execution.
@@ -323,4 +324,33 @@ impl<R: RunnerChannel> BoolTensorOps<Self> for BackendRouter<R> {
323324

324325
out
325326
}
327+
328+
fn bool_unfold(
329+
tensor: BoolTensor<Self>,
330+
dim: usize,
331+
size: usize,
332+
step: usize,
333+
) -> BoolTensor<Self> {
334+
let client = tensor.client.clone();
335+
336+
let mut shape = tensor.shape().dims.clone();
337+
let d_shape = shape[dim];
338+
let windows = calculate_unfold_windows(d_shape, size, step);
339+
shape[dim] = windows;
340+
shape.push(size);
341+
342+
let out = client.register_empty_tensor(shape.clone(), tensor.dtype);
343+
344+
let desc = UnfoldOpIr {
345+
input: tensor.into_ir(),
346+
out: out.to_ir_out(),
347+
dim,
348+
size,
349+
step,
350+
};
351+
352+
client.register(OperationIr::BaseBool(BaseOperationIr::Unfold(desc)));
353+
354+
out
355+
}
326356
}

crates/burn-router/src/ops/op_float.rs

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,20 @@ use alloc::{vec, vec::Vec};
22
use burn_tensor::backend::Backend;
33
use core::ops::Range;
44

5+
use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
56
use burn_ir::{
67
BaseOperationIr, BinaryOpIr, CatOpIr, ClampOpIr, ExpandOpIr, FlipOpIr, FloatOperationIr,
78
GatherOpIr, InitOperationIr, MaskFillOpIr, MaskWhereOpIr, NumericOperationIr, OperationIr,
89
PermuteOpIr, RandomOpIr, ReduceDimOpIr, ReduceDimWithIndicesOpIr, RepeatDimOpIr, ScalarIr,
910
ScalarOpIr, ScatterOpIr, SelectAssignOpIr, SelectOpIr, SliceAssignOpIr, SliceOpIr,
10-
SwapDimsOpIr, UnaryOpIr,
11+
SwapDimsOpIr, UnaryOpIr, UnfoldOpIr,
1112
};
13+
use burn_tensor::ops::unfold::calculate_unfold_windows;
1214
use burn_tensor::ops::{
1315
BoolTensor, FloatElem, FloatTensor, FloatTensorOps, IntElem, IntTensor, binary_ops_shape,
1416
};
1517
use burn_tensor::{Device, Distribution, Element, FloatDType, Shape, TensorData, TensorMetadata};
1618

17-
use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
18-
1919
impl<R: RunnerChannel> FloatTensorOps<Self> for BackendRouter<R> {
2020
fn float_from_data(data: TensorData, device: &Device<Self>) -> FloatTensor<Self> {
2121
let client = get_client::<R>(device);
@@ -1436,4 +1436,33 @@ impl<R: RunnerChannel> FloatTensorOps<Self> for BackendRouter<R> {
14361436

14371437
out
14381438
}
1439+
1440+
fn float_unfold(
1441+
tensor: FloatTensor<Self>,
1442+
dim: usize,
1443+
size: usize,
1444+
step: usize,
1445+
) -> FloatTensor<Self> {
1446+
let client = tensor.client.clone();
1447+
1448+
let mut shape = tensor.shape().dims.clone();
1449+
let d_shape = shape[dim];
1450+
let windows = calculate_unfold_windows(d_shape, size, step);
1451+
shape[dim] = windows;
1452+
shape.push(size);
1453+
1454+
let out = client.register_empty_tensor(shape.clone(), tensor.dtype);
1455+
1456+
let desc = UnfoldOpIr {
1457+
input: tensor.into_ir(),
1458+
out: out.to_ir_out(),
1459+
dim,
1460+
size,
1461+
step,
1462+
};
1463+
1464+
client.register(OperationIr::BaseFloat(BaseOperationIr::Unfold(desc)));
1465+
1466+
out
1467+
}
14391468
}

crates/burn-router/src/ops/op_int.rs

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,20 @@ use alloc::{vec, vec::Vec};
22
use burn_tensor::backend::Backend;
33
use core::ops::Range;
44

5+
use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
56
use burn_ir::{
67
BaseOperationIr, BinaryOpIr, CatOpIr, ClampOpIr, ExpandOpIr, FlipOpIr, GatherOpIr,
78
InitOperationIr, IntOperationIr, MaskFillOpIr, MaskWhereOpIr, NumericOperationIr, OperationIr,
89
PermuteOpIr, RandomOpIr, ReduceDimOpIr, ReduceDimWithIndicesOpIr, RepeatDimOpIr, ScalarIr,
910
ScalarOpIr, ScatterOpIr, SelectAssignOpIr, SelectOpIr, SliceAssignOpIr, SliceOpIr,
10-
SwapDimsOpIr, UnaryOpIr,
11+
SwapDimsOpIr, UnaryOpIr, UnfoldOpIr,
1112
};
13+
use burn_tensor::ops::unfold::calculate_unfold_windows;
1214
use burn_tensor::ops::{
1315
BoolTensor, FloatElem, FloatTensor, IntElem, IntTensor, IntTensorOps, binary_ops_shape,
1416
};
1517
use burn_tensor::{Device, Distribution, Element, IntDType, Shape, TensorData, TensorMetadata};
1618

17-
use crate::{BackendRouter, RunnerChannel, RunnerClient, get_client};
18-
1919
impl<R: RunnerChannel> IntTensorOps<Self> for BackendRouter<R> {
2020
fn int_empty(shape: Shape, device: &Device<Self>, dtype: IntDType) -> IntTensor<Self> {
2121
// Get the runtime client on which to register the operation for execution.
@@ -1416,4 +1416,33 @@ impl<R: RunnerChannel> IntTensorOps<Self> for BackendRouter<R> {
14161416

14171417
out
14181418
}
1419+
1420+
fn int_unfold(
1421+
tensor: IntTensor<Self>,
1422+
dim: usize,
1423+
size: usize,
1424+
step: usize,
1425+
) -> IntTensor<Self> {
1426+
let client = tensor.client.clone();
1427+
1428+
let mut shape = tensor.shape().dims.clone();
1429+
let d_shape = shape[dim];
1430+
let windows = calculate_unfold_windows(d_shape, size, step);
1431+
shape[dim] = windows;
1432+
shape.push(size);
1433+
1434+
let out = client.register_empty_tensor(shape.clone(), tensor.dtype);
1435+
1436+
let desc = UnfoldOpIr {
1437+
input: tensor.into_ir(),
1438+
out: out.to_ir_out(),
1439+
dim,
1440+
size,
1441+
step,
1442+
};
1443+
1444+
client.register(OperationIr::BaseInt(BaseOperationIr::Unfold(desc)));
1445+
1446+
out
1447+
}
14191448
}

crates/burn-router/src/runner.rs

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,6 +186,12 @@ impl<B: BackendIr> RunnerClient for Runner<B> {
186186
let output = B::float_expand(tensor, desc.shape.clone().into());
187187
handles.register_float_tensor::<B>(&desc.out.id, output);
188188
}
189+
BaseOperationIr::Unfold(desc) => {
190+
let tensor = handles.get_float_tensor::<B>(&desc.input);
191+
192+
let output = B::float_unfold(tensor, desc.dim, desc.size, desc.step);
193+
handles.register_float_tensor::<B>(&desc.out.id, output);
194+
}
189195
BaseOperationIr::Slice(desc) => {
190196
let tensor = handles.get_float_tensor::<B>(&desc.tensor);
191197

@@ -261,6 +267,12 @@ impl<B: BackendIr> RunnerClient for Runner<B> {
261267
let output = B::int_expand(tensor, desc.shape.clone().into());
262268
handles.register_int_tensor::<B>(&desc.out.id, output);
263269
}
270+
BaseOperationIr::Unfold(desc) => {
271+
let tensor = handles.get_int_tensor::<B>(&desc.input);
272+
273+
let output = B::int_unfold(tensor, desc.dim, desc.size, desc.step);
274+
handles.register_int_tensor::<B>(&desc.out.id, output);
275+
}
264276
BaseOperationIr::Slice(desc) => {
265277
let tensor = handles.get_int_tensor::<B>(&desc.tensor);
266278

@@ -332,6 +344,12 @@ impl<B: BackendIr> RunnerClient for Runner<B> {
332344
let output = B::bool_expand(tensor, desc.shape.clone().into());
333345
handles.register_bool_tensor::<B>(&desc.out.id, output);
334346
}
347+
BaseOperationIr::Unfold(desc) => {
348+
let tensor = handles.get_bool_tensor::<B>(&desc.input);
349+
350+
let output = B::bool_unfold(tensor, desc.dim, desc.size, desc.step);
351+
handles.register_bool_tensor::<B>(&desc.out.id, output);
352+
}
335353
BaseOperationIr::Slice(desc) => {
336354
let tensor = handles.get_bool_tensor::<B>(&desc.tensor);
337355

0 commit comments

Comments
 (0)