|
1 | | -use alloc::vec; |
2 | | -use core::num::NonZeroUsize; |
3 | | - |
4 | | -use super::{conv, pool, unfold::unfold4d_using_conv2d}; |
| 1 | +use super::{conv, pool}; |
| 2 | +use crate::ops::unfold::unfold4d_using_conv2d; |
5 | 3 | use crate::{ |
6 | 4 | Shape, TensorMetadata, |
7 | 5 | backend::Backend, |
8 | 6 | ops::{FloatTensor, IntTensor}, |
9 | 7 | }; |
| 8 | +use alloc::vec; |
| 9 | +use core::num::NonZeroUsize; |
10 | 10 |
|
11 | 11 | /// Gradient computed during the backward pass for each tensor used by [conv2d](ModuleOps::conv2d). |
12 | 12 | #[derive(new)] |
@@ -579,14 +579,36 @@ pub trait ModuleOps<B: Backend> { |
579 | 579 | /// |
580 | 580 | /// # Shapes |
581 | 581 | /// |
582 | | - /// x: `[batch_size, channels_in, height, width]`, |
583 | | - /// returns: `[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]`, |
| 582 | + /// * x: ``[batch_size, channels_in, height, width]``, |
| 583 | + /// * returns: ``[batch_size, channels_in * kernel_size_1 * kernel_size_2, number of blocks]``, |
584 | 584 | fn unfold4d( |
585 | 585 | x: FloatTensor<B>, |
586 | 586 | kernel_size: [usize; 2], |
587 | 587 | options: UnfoldOptions, |
588 | 588 | ) -> FloatTensor<B> { |
589 | | - unfold4d_using_conv2d::<B>(x, kernel_size, options) |
| 589 | + if options.padding == [0, 0] && options.dilation == [1, 1] { |
| 590 | + let blocks = B::float_unfold(x, 2, kernel_size[0], options.stride[0]); |
| 591 | + let blocks = B::float_unfold(blocks, 3, kernel_size[1], options.stride[1]); |
| 592 | + |
| 593 | + // batch, channels, h_blocks, w_blocks, h_kern, w_kern |
| 594 | + |
| 595 | + let blocks = B::float_permute(blocks, &[0, 1, 4, 5, 2, 3]); |
| 596 | + let shape = &blocks.shape().dims; |
| 597 | + |
| 598 | + // batch, channels, h_kern, w_kern, h_blocks, w_blocks |
| 599 | + |
| 600 | + B::float_reshape( |
| 601 | + blocks, |
| 602 | + [ |
| 603 | + shape[0], |
| 604 | + shape[1] * shape[2] * shape[3], |
| 605 | + shape[4] * shape[5], |
| 606 | + ] |
| 607 | + .into(), |
| 608 | + ) |
| 609 | + } else { |
| 610 | + unfold4d_using_conv2d::<B>(x, kernel_size, options) |
| 611 | + } |
590 | 612 | } |
591 | 613 |
|
592 | 614 | /// One dimensional avg pooling. |
|
0 commit comments