@@ -2,20 +2,20 @@ use alloc::{vec, vec::Vec};
22use burn_tensor:: backend:: Backend ;
33use core:: ops:: Range ;
44
5+ use crate :: { BackendRouter , RunnerChannel , RunnerClient , get_client} ;
56use burn_ir:: {
67 BaseOperationIr , BinaryOpIr , CatOpIr , ClampOpIr , ExpandOpIr , FlipOpIr , FloatOperationIr ,
78 GatherOpIr , InitOperationIr , MaskFillOpIr , MaskWhereOpIr , NumericOperationIr , OperationIr ,
89 PermuteOpIr , RandomOpIr , ReduceDimOpIr , ReduceDimWithIndicesOpIr , RepeatDimOpIr , ScalarIr ,
910 ScalarOpIr , ScatterOpIr , SelectAssignOpIr , SelectOpIr , SliceAssignOpIr , SliceOpIr ,
10- SwapDimsOpIr , UnaryOpIr ,
11+ SwapDimsOpIr , UnaryOpIr , UnfoldOpIr ,
1112} ;
13+ use burn_tensor:: ops:: unfold:: calculate_unfold_windows;
1214use burn_tensor:: ops:: {
1315 BoolTensor , FloatElem , FloatTensor , FloatTensorOps , IntElem , IntTensor , binary_ops_shape,
1416} ;
1517use burn_tensor:: { Device , Distribution , Element , FloatDType , Shape , TensorData , TensorMetadata } ;
1618
17- use crate :: { BackendRouter , RunnerChannel , RunnerClient , get_client} ;
18-
1919impl < R : RunnerChannel > FloatTensorOps < Self > for BackendRouter < R > {
2020 fn float_from_data ( data : TensorData , device : & Device < Self > ) -> FloatTensor < Self > {
2121 let client = get_client :: < R > ( device) ;
@@ -1436,4 +1436,33 @@ impl<R: RunnerChannel> FloatTensorOps<Self> for BackendRouter<R> {
14361436
14371437 out
14381438 }
1439+
1440+ fn float_unfold (
1441+ tensor : FloatTensor < Self > ,
1442+ dim : usize ,
1443+ size : usize ,
1444+ step : usize ,
1445+ ) -> FloatTensor < Self > {
1446+ let client = tensor. client . clone ( ) ;
1447+
1448+ let mut shape = tensor. shape ( ) . dims . clone ( ) ;
1449+ let d_shape = shape[ dim] ;
1450+ let windows = calculate_unfold_windows ( d_shape, size, step) ;
1451+ shape[ dim] = windows;
1452+ shape. push ( size) ;
1453+
1454+ let out = client. register_empty_tensor ( shape. clone ( ) , tensor. dtype ) ;
1455+
1456+ let desc = UnfoldOpIr {
1457+ input : tensor. into_ir ( ) ,
1458+ out : out. to_ir_out ( ) ,
1459+ dim,
1460+ size,
1461+ step,
1462+ } ;
1463+
1464+ client. register ( OperationIr :: BaseFloat ( BaseOperationIr :: Unfold ( desc) ) ) ;
1465+
1466+ out
1467+ }
14391468}
0 commit comments