@@ -1349,6 +1349,47 @@ func.func @torch.aten.gather(%arg0: !torch.vtensor<[1,4,3],f32>, %arg1: !torch.v
13491349 return %0 : !torch.vtensor <[1 ,4 ,2 ],f32 >
13501350}
13511351
1352+ // -----
1353+ // CHECK-LABEL: func.func @torch.aten.gather$bool(
1354+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,4,3],i1>,
1355+ // CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[1,4,2],si64>) -> !torch.vtensor<[1,4,2],i1> {
1356+ // CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[1,4,2],si64> -> tensor<1x4x2xi64>
1357+ // CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,4,3],i1> -> tensor<1x4x3xi1>
1358+ // CHECK: %[[VAL_4:.*]] = torch.constant.int -1
1359+ // CHECK: %[[VAL_5:.*]] = torch.constant.bool false
1360+ // CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x4x2xi64>) -> tensor<1x4x2xi32>
1361+ // CHECK: %[[VAL_7:.*]] = tosa.const_shape {values = dense<[1, 4, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4>
1362+ // CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_7]] : (tensor<1x4x2xi32>, !tosa.shape<4>) -> tensor<1x4x2x1xi32>
1363+ // CHECK: %[[VAL_9:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32>
1364+ // CHECK: %[[VAL_10:.*]] = "tosa.const"() <{values = dense<{{\[\[}}{{\[\[}}0], [0]], {{\[\[}}1], [1]], {{\[\[}}2], [2]], {{\[\[}}3], [3]]]]> : tensor<1x4x2x1xi32>}> : () -> tensor<1x4x2x1xi32>
1365+ // CHECK: %[[VAL_11:.*]] = tosa.concat %[[VAL_9]], %[[VAL_10]], %[[VAL_8]] {axis = 3 : i32} : (tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>, tensor<1x4x2x1xi32>) -> tensor<1x4x2x3xi32>
1366+ // CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<[1, 12, 1]> : tensor<3xindex>} : () -> !tosa.shape<3>
1367+ // CHECK: %[[VAL_13:.*]] = tosa.reshape %[[VAL_3]], %[[VAL_12]] : (tensor<1x4x3xi1>, !tosa.shape<3>) -> tensor<1x12x1xi1>
1368+ // CHECK: %[[VAL_14:.*]] = tosa.cast %[[VAL_13]] : (tensor<1x12x1xi1>) -> tensor<1x12x1xi8>
1369+ // CHECK: %[[VAL_15:.*]] = tosa.const_shape {values = dense<[8, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
1370+ // CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_11]], %[[VAL_15]] : (tensor<1x4x2x3xi32>, !tosa.shape<2>) -> tensor<8x3xi32>
1371+ // CHECK: %[[VAL_17:.*]] = "tosa.const"() <{values = dense<[12, 3, 1]> : tensor<3xi32>}> : () -> tensor<3xi32>
1372+ // CHECK: %[[VAL_18:.*]] = tosa.const_shape {values = dense<[1, 3]> : tensor<2xindex>} : () -> !tosa.shape<2>
1373+ // CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_17]], %[[VAL_18]] : (tensor<3xi32>, !tosa.shape<2>) -> tensor<1x3xi32>
1374+ // CHECK: %[[VAL_20:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
1375+ // CHECK: %[[VAL_21:.*]] = tosa.mul %[[VAL_16]], %[[VAL_19]], %[[VAL_20]] : (tensor<8x3xi32>, tensor<1x3xi32>, tensor<1xi8>) -> tensor<8x3xi32>
1376+ // CHECK: %[[VAL_22:.*]] = tosa.reduce_sum %[[VAL_21]] {axis = 1 : i32} : (tensor<8x3xi32>) -> tensor<8x1xi32>
1377+ // CHECK: %[[VAL_23:.*]] = tosa.const_shape {values = dense<[1, 8]> : tensor<2xindex>} : () -> !tosa.shape<2>
1378+ // CHECK: %[[VAL_24:.*]] = tosa.reshape %[[VAL_22]], %[[VAL_23]] : (tensor<8x1xi32>, !tosa.shape<2>) -> tensor<1x8xi32>
1379+ // CHECK: %[[VAL_25:.*]] = tosa.gather %[[VAL_14]], %[[VAL_24]] : (tensor<1x12x1xi8>, tensor<1x8xi32>) -> tensor<1x8x1xi8>
1380+ // CHECK: %[[VAL_26:.*]] = tosa.cast %[[VAL_25]] : (tensor<1x8x1xi8>) -> tensor<1x8x1xi1>
1381+ // CHECK: %[[VAL_27:.*]] = tosa.const_shape {values = dense<[1, 4, 2]> : tensor<3xindex>} : () -> !tosa.shape<3>
1382+ // CHECK: %[[VAL_28:.*]] = tosa.reshape %[[VAL_26]], %[[VAL_27]] : (tensor<1x8x1xi1>, !tosa.shape<3>) -> tensor<1x4x2xi1>
1383+ // CHECK: %[[VAL_29:.*]] = torch_c.from_builtin_tensor %[[VAL_28]] : tensor<1x4x2xi1> -> !torch.vtensor<[1,4,2],i1>
1384+ // CHECK: return %[[VAL_29]] : !torch.vtensor<[1,4,2],i1>
1385+ // CHECK: }
1386+ func.func @torch.aten.gather$bool (%arg0: !torch.vtensor <[1 ,4 ,3 ],i1 >, %arg1: !torch.vtensor <[1 ,4 ,2 ],si64 >) -> !torch.vtensor <[1 ,4 ,2 ],i1 > {
1387+ %int -1 = torch.constant.int -1
1388+ %false = torch.constant.bool false
1389+ %0 = torch.aten.gather %arg0 , %int -1 , %arg1 , %false : !torch.vtensor <[1 ,4 ,3 ],i1 >, !torch.int , !torch.vtensor <[1 ,4 ,2 ],si64 >, !torch.bool -> !torch.vtensor <[1 ,4 ,2 ],i1 >
1390+ return %0 : !torch.vtensor <[1 ,4 ,2 ],i1 >
1391+ }
1392+
13521393// -----
13531394// CHECK-LABEL: func.func @torch.aten.add$int(
13541395// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[2,2],si32>,
@@ -1422,6 +1463,25 @@ func.func @torch.aten.slice.negative_start(%arg0: !torch.vtensor<[4,65,256],f32>
14221463 return %0 : !torch.vtensor <[4 ,16 ,256 ],f32 >
14231464}
14241465
1466+ // -----
1467+ // CHECK-LABEL: func.func @torch.aten.slice.bool_strided(
1468+ // CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,64,1],i1>) -> !torch.vtensor<[1,32,1],i1> {
1469+ // CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[1,64,1],i1> -> tensor<1x64x1xi1>
1470+ // CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x64x1xi1>) -> tensor<1x64x1xi8>
1471+ // CHECK: %[[VAL_3:.*]] = tosa.gather %[[VAL_2]], %{{.*}} : (tensor<1x64x1xi8>, tensor<1x32xi32>) -> tensor<1x32x1xi8>
1472+ // CHECK: %[[VAL_4:.*]] = tosa.cast %[[VAL_3]] : (tensor<1x32x1xi8>) -> tensor<1x32x1xi1>
1473+ // CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<1x32x1xi1> -> !torch.vtensor<[1,32,1],i1>
1474+ // CHECK: return %[[VAL_5]] : !torch.vtensor<[1,32,1],i1>
1475+ // CHECK: }
1476+ func.func @torch.aten.slice.bool_strided (%arg0: !torch.vtensor <[1 ,64 ,1 ],i1 >) -> !torch.vtensor <[1 ,32 ,1 ],i1 > {
1477+ %int1 = torch.constant.int 1
1478+ %int0 = torch.constant.int 0
1479+ %int64 = torch.constant.int 64
1480+ %int2 = torch.constant.int 2
1481+ %0 = torch.aten.slice.Tensor %arg0 , %int1 , %int0 , %int64 , %int2 : !torch.vtensor <[1 ,64 ,1 ],i1 >, !torch.int , !torch.int , !torch.int , !torch.int -> !torch.vtensor <[1 ,32 ,1 ],i1 >
1482+ return %0 : !torch.vtensor <[1 ,32 ,1 ],i1 >
1483+ }
1484+
14251485// -----
14261486// CHECK-LABEL: func.func @torch.aten.clamp.min_none(
14271487// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[1,1,128,128],si64>) -> !torch.vtensor<[1,1,128,128],si64> {
0 commit comments