@@ -1242,6 +1242,143 @@ def opset_13(cls, graph, node, **kw):
12421242 outputs = node .output ('Out' ))
12431243
12441244
1245+ @op_mapper ('unfold' )
1246+ class Unfold ():
1247+ support_opset_version_range = (11 , 15 )
1248+
1249+ @classmethod
1250+ def opset_11 (cls , graph , node , ** kw ):
1251+
1252+ strides = node .attr ('strides' )
1253+ stride_h = strides [0 ]
1254+ stride_w = strides [1 ]
1255+
1256+ paddings = node .attr ('paddings' )
1257+ padding_h_1 = paddings [0 ]
1258+ padding_w_1 = paddings [1 ]
1259+ padding_h_2 = paddings [2 ]
1260+ padding_w_2 = paddings [3 ]
1261+
1262+ dilations = node .attr ('dilations' )
1263+ dilation_h = dilations [0 ]
1264+ dilation_w = dilations [1 ]
1265+
1266+ kernel_sizes = node .attr ('kernel_sizes' )
1267+ kernel_h = kernel_sizes [0 ]
1268+ kernel_w = kernel_sizes [1 ]
1269+
1270+ input_w = mapper_helper .shape_helper (graph , node .input ('X' , 0 ), 3 )
1271+ blocks_row_indices_node = cls ._get_im2col_indices_along_dim (
1272+ graph , node , 2 , kernel_h , dilation_h , padding_h_1 , padding_h_2 ,
1273+ stride_h )
1274+ blocks_col_indices_node = cls ._get_im2col_indices_along_dim (
1275+ graph , node , 3 , kernel_w , dilation_w , padding_w_1 , padding_w_2 ,
1276+ stride_w )
1277+
1278+ output_shape = cls ._get_im2col_output_shape (graph , node , kernel_h ,
1279+ kernel_w )
1280+ padded_input = cls ._get_im2col_padded_input (
1281+ graph , node , padding_h_1 , padding_h_2 , padding_w_1 , padding_w_2 )
1282+
1283+ output = graph .make_node (
1284+ 'Gather' , inputs = [padded_input , blocks_row_indices_node ], axis = 2 )
1285+
1286+ output = graph .make_node (
1287+ 'Gather' , inputs = [output , blocks_col_indices_node ], axis = 4 )
1288+ output = graph .make_node (
1289+ 'Transpose' , inputs = [output ], perm = [0 , 1 , 2 , 4 , 3 , 5 ])
1290+
1291+ graph .make_node (
1292+ 'Reshape' , inputs = [output , output_shape ], outputs = node .output ('Y' ))
1293+
1294+ @classmethod
1295+ def _get_im2col_indices_along_dim (cls , graph , node , index , kernel_size_d ,
1296+ dilation_d , padding_d_1 , padding_d_2 ,
1297+ stride_d ):
1298+ input_shape = node .input_shape ('X' , 0 )
1299+ if input_shape [index ] == - 1 :
1300+ input_d_node = mapper_helper .shape_helper (graph ,
1301+ node .input ('X' , 0 ), index )
1302+
1303+ padding_d_node = graph .make_node (
1304+ 'Constant' ,
1305+ dtype = dtypes .ONNX .INT64 ,
1306+ value = [padding_d_1 + padding_d_2 ])
1307+ blocks_d_node = graph .make_node (
1308+ 'Add' , inputs = [input_d_node , padding_d_node ])
1309+
1310+ dilation_kernel_size_node = graph .make_node (
1311+ 'Constant' ,
1312+ dtype = dtypes .ONNX .INT64 ,
1313+ value = [dilation_d * (kernel_size_d - 1 )])
1314+ blocks_d_node = graph .make_node (
1315+ 'Sub' , inputs = [blocks_d_node , dilation_kernel_size_node ])
1316+
1317+ zero_node = graph .make_node (
1318+ 'Constant' , dtype = dtypes .ONNX .INT64 , value = [0 ])
1319+ stride_node = graph .make_node (
1320+ 'Constant' , dtype = dtypes .ONNX .INT64 , value = [stride_d ])
1321+ blocks_d_indices_node = graph .make_node (
1322+ 'Range' , inputs = [zero_node , blocks_d_node , stride_node ])
1323+ else :
1324+ end = input_shape [
1325+ index ] + padding_d_1 + padding_d_2 - dilation_d * (kernel_size_d
1326+ - 1 )
1327+ stride = stride_d
1328+ blocks_d_indices = np .arange (0 , end , stride )
1329+ blocks_d_indices_node = graph .make_node (
1330+ 'Constant' ,
1331+ dtype = dtypes .ONNX .INT64 ,
1332+ value = blocks_d_indices .flatten ().tolist ())
1333+
1334+ kernel_grid = np .arange (0 , kernel_size_d * dilation_d , dilation_d )
1335+ kernel_grid_node = graph .make_node (
1336+ 'Constant' ,
1337+ dtype = dtypes .ONNX .INT64 ,
1338+ value = kernel_grid .flatten ().tolist ())
1339+
1340+ shape_node = graph .make_node (
1341+ 'Constant' , dtype = dtypes .ONNX .INT64 , value = [- 1 , 1 ])
1342+ kernel_mask_node = graph .make_node (
1343+ 'Reshape' , inputs = [kernel_grid_node , shape_node ])
1344+
1345+ block_mask_node = graph .make_node (
1346+ 'Add' , inputs = [blocks_d_indices_node , kernel_mask_node ])
1347+ return block_mask_node
1348+
1349+ @classmethod
1350+ def _get_im2col_output_shape (cls , graph , node , kernel_h , kernel_w ):
1351+ batch_dim = mapper_helper .shape_helper (graph , node .input ('X' , 0 ), 0 )
1352+ channel_dim = mapper_helper .shape_helper (graph , node .input ('X' , 0 ), 1 )
1353+
1354+ constant_node = graph .make_node (
1355+ 'Constant' , dtype = dtypes .ONNX .INT64 , value = [kernel_h * kernel_w ])
1356+ channel_unfolded = graph .make_node (
1357+ 'Mul' , inputs = [channel_dim , constant_node ])
1358+
1359+ concat_const_node = graph .make_node (
1360+ 'Constant' , dtype = dtypes .ONNX .INT64 , value = [- 1 ])
1361+ result_node = graph .make_node (
1362+ 'Concat' ,
1363+ inputs = [batch_dim , channel_unfolded , concat_const_node ],
1364+ axis = 0 )
1365+
1366+ return result_node
1367+
1368+ @classmethod
1369+ def _get_im2col_padded_input (cls , graph , node , padding_h_1 , padding_h_2 ,
1370+ padding_w_1 , padding_w_2 ):
1371+ pad_const_node = graph .make_node (
1372+ 'Constant' ,
1373+ dtype = dtypes .ONNX .INT64 ,
1374+ value = [
1375+ 0 , 0 , padding_h_1 , padding_w_1 , 0 , 0 , padding_h_2 , padding_w_2
1376+ ])
1377+ result_node = graph .make_node (
1378+ 'Pad' , inputs = [node .input ('X' , 0 ), pad_const_node ])
1379+ return result_node
1380+
1381+
12451382@op_mapper ('softmax_with_cross_entropy' )
12461383class SoftmaxCrossEntropyLoss ():
12471384 support_opset_version_range = (12 , 15 )
0 commit comments