|
1 | | -use burn_ir::{ |
2 | | - BaseOperationIr, BinaryOpIr, BoolOperationIr, CatOpIr, ExpandOpIr, FlipOpIr, HandleContainer, |
3 | | - InitOperationIr, OperationIr, PermuteOpIr, RepeatDimOpIr, SliceAssignOpIr, SliceOpIr, |
4 | | - SwapDimsOpIr, TensorIr, UnaryOpIr, |
5 | | -}; |
| 1 | +use std::cmp::max; |
| 2 | +use burn_ir::{BaseOperationIr, BinaryOpIr, BoolOperationIr, CatOpIr, ExpandOpIr, FlipOpIr, HandleContainer, InitOperationIr, OperationIr, PermuteOpIr, RepeatDimOpIr, SliceAssignOpIr, SliceOpIr, SwapDimsOpIr, TensorIr, UnaryOpIr, UnfoldOpIr}; |
6 | 3 | use burn_tensor::{ |
7 | 4 | Device, Element, Shape, Slice, TensorData, TensorMetadata, |
8 | 5 | ops::{BoolTensor, BoolTensorOps, FloatTensor, IntTensor, binary_ops_shape}, |
@@ -744,4 +741,54 @@ impl<B: FusionBackend> BoolTensorOps<Self> for Fusion<B> { |
744 | 741 |
|
745 | 742 | out |
746 | 743 | } |
| 744 | + |
| 745 | + fn bool_unfold(tensor: BoolTensor<Self>, dim: usize, size: usize, step: usize) -> BoolTensor<Self> { |
| 746 | + #[derive(new, Debug)] |
| 747 | + struct UnfoldOps<B: FusionBackend> { |
| 748 | + desc: UnfoldOpIr, |
| 749 | + _b: PhantomData<B>, |
| 750 | + } |
| 751 | + |
| 752 | + impl<B: FusionBackend> Operation<B::FusionRuntime> for UnfoldOps<B> { |
| 753 | + fn execute(&self, handles: &mut HandleContainer<B::Handle>) { |
| 754 | + let input = handles.get_bool_tensor::<B>(&self.desc.input); |
| 755 | + let output = B::bool_unfold( |
| 756 | + input, |
| 757 | + self.desc.dim, |
| 758 | + self.desc.size, |
| 759 | + self.desc.step); |
| 760 | + |
| 761 | + handles.register_bool_tensor::<B>(&self.desc.out.id, output); |
| 762 | + } |
| 763 | + } |
| 764 | + |
| 765 | + let mut streams = OperationStreams::default(); |
| 766 | + streams.tensor(&tensor); |
| 767 | + |
| 768 | + let mut shape = tensor.shape().dims.clone(); |
| 769 | + let d_shape = shape[dim]; |
| 770 | + let windows = max(0, (d_shape - size).div_ceil(step)); |
| 771 | + shape[dim] = windows; |
| 772 | + shape.insert(dim + 1, size); |
| 773 | + |
| 774 | + let out = tensor |
| 775 | + .client |
| 776 | + .tensor_uninitialized(shape.clone(), tensor.dtype); |
| 777 | + |
| 778 | + let desc = UnfoldOpIr { |
| 779 | + input: tensor.into_ir(), |
| 780 | + out: out.to_ir_out(), |
| 781 | + dim: dim, |
| 782 | + size: size, |
| 783 | + step: step, |
| 784 | + }; |
| 785 | + |
| 786 | + out.client.register( |
| 787 | + streams, |
| 788 | + OperationIr::BaseBool(BaseOperationIr::Unfold(desc.clone())), |
| 789 | + UnfoldOps::<B>::new(desc), |
| 790 | + ); |
| 791 | + |
| 792 | + out |
| 793 | + } |
747 | 794 | } |
0 commit comments