@@ -2,13 +2,15 @@ use alloc::{vec, vec::Vec};
22use burn_tensor:: backend:: Backend ;
33use core:: ops:: Range ;
44
5+ use crate :: { BackendRouter , RunnerChannel , RunnerClient , get_client} ;
56use burn_ir:: {
67 BaseOperationIr , BinaryOpIr , CatOpIr , ClampOpIr , ExpandOpIr , FlipOpIr , FloatOperationIr ,
78 GatherOpIr , InitOperationIr , MaskFillOpIr , MaskWhereOpIr , NumericOperationIr , OperationIr ,
89 PermuteOpIr , RandomOpIr , ReduceDimOpIr , ReduceDimWithIndicesOpIr , RepeatDimOpIr , ScalarIr ,
910 ScalarOpIr , ScatterOpIr , SelectAssignOpIr , SelectOpIr , SliceAssignOpIr , SliceOpIr ,
10- SwapDimsOpIr , UnaryOpIr ,
11+ SwapDimsOpIr , UnaryOpIr , UnfoldOpIr ,
1112} ;
13+ use burn_tensor:: ops:: unfold:: calculate_unfold_windows;
1214use burn_tensor:: ops:: {
1315 BoolTensor , FloatElem , FloatTensor , FloatTensorOps , IntElem , IntTensor , binary_ops_shape,
1416} ;
@@ -17,8 +19,6 @@ use burn_tensor::{
1719 calculate_slice_output_shape,
1820} ;
1921
20- use crate :: { BackendRouter , RunnerChannel , RunnerClient , get_client} ;
21-
2222impl < R : RunnerChannel > FloatTensorOps < Self > for BackendRouter < R > {
2323 fn float_from_data ( data : TensorData , device : & Device < Self > ) -> FloatTensor < Self > {
2424 let client = get_client :: < R > ( device) ;
@@ -1434,4 +1434,33 @@ impl<R: RunnerChannel> FloatTensorOps<Self> for BackendRouter<R> {
14341434
14351435 out
14361436 }
1437+
1438+ fn float_unfold (
1439+ tensor : FloatTensor < Self > ,
1440+ dim : usize ,
1441+ size : usize ,
1442+ step : usize ,
1443+ ) -> FloatTensor < Self > {
1444+ let client = tensor. client . clone ( ) ;
1445+
1446+ let mut shape = tensor. shape ( ) . dims . clone ( ) ;
1447+ let d_shape = shape[ dim] ;
1448+ let windows = calculate_unfold_windows ( d_shape, size, step) ;
1449+ shape[ dim] = windows;
1450+ shape. push ( size) ;
1451+
1452+ let out = client. register_empty_tensor ( shape. clone ( ) , tensor. dtype ) ;
1453+
1454+ let desc = UnfoldOpIr {
1455+ input : tensor. into_ir ( ) ,
1456+ out : out. to_ir_out ( ) ,
1457+ dim,
1458+ size,
1459+ step,
1460+ } ;
1461+
1462+ client. register ( OperationIr :: BaseFloat ( BaseOperationIr :: Unfold ( desc) ) ) ;
1463+
1464+ out
1465+ }
14371466}
0 commit comments