Skip to content

Commit b59efe0

Browse files
committed
Optimize unfold4d implementation for zero-padding and unit-dilation cases. Update imports.
1 parent eba688f commit b59efe0

File tree

1 file changed

+29
-7
lines changed
  • crates/burn-tensor/src/tensor/ops/modules

1 file changed

+29
-7
lines changed

crates/burn-tensor/src/tensor/ops/modules/base.rs

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
1-
use alloc::vec;
2-
use core::num::NonZeroUsize;
3-
4-
use super::{conv, pool, unfold::unfold4d_using_conv2d};
1+
use super::{conv, pool};
2+
use crate::ops::unfold::unfold4d_using_conv2d;
53
use crate::{
64
Shape, TensorMetadata,
75
backend::Backend,
86
ops::{FloatTensor, IntTensor},
97
};
8+
use alloc::vec;
9+
use core::num::NonZeroUsize;
1010

1111
/// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d).
1212
#[derive(new)]
@@ -579,14 +579,36 @@ pub trait ModuleOps<B: Backend> {
579579
///
580580
/// # Shapes
581581
///
582-
/// x: `[batch_size, channels_in, height, width]`,
583-
/// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`,
582+
/// * x: ``[batch_size, channels_in, height, width]``,
583+
/// * returns: ``[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]``,
584584
fn unfold4d(
585585
x: FloatTensor<B>,
586586
kernel_size: [usize; 2],
587587
options: UnfoldOptions,
588588
) -> FloatTensor<B> {
589-
unfold4d_using_conv2d::<B>(x, kernel_size, options)
589+
if options.padding == [0, 0] && options.dilation == [1, 1] {
590+
let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]);
591+
let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]);
592+
593+
// batch, channels, h_blocks, w_blocks, h_kern, w_kern
594+
595+
let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]);
596+
let shape = &blocks.shape().dims;
597+
598+
// batch, channels, h_kern, w_kern, h_blocks, w_blocks
599+
600+
B::float_reshape(
601+
blocks,
602+
[
603+
shape[0],
604+
shape[1] * shape[2] * shape[3],
605+
shape[4] * shape[5],
606+
]
607+
.into(),
608+
)
609+
} else {
610+
unfold4d_using_conv2d::<B>(x, kernel_size, options)
611+
}
590612
}
591613

592614
/// One dimensional avg pooling.

0 commit comments

Comments
 (0)