1- // RUN: torch-mlir-opt < %s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s
1+ // RUN: torch-mlir-opt %s -convert-torch-to-tosa -split-input-file -verify-diagnostics | FileCheck %s --check-prefix=CHECK
22
33// CHECK-LABEL: func.func @torch.aten.tanh$basic(
44// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
@@ -13,6 +13,80 @@ func.func @torch.aten.tanh$basic(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vte
1313
1414// -----
1515
16+ // CHECK-LABEL: func.func @conv2d_io_insert_reshape(
17+ // CHECK: %[[SHAPE:.*]] = tosa.const_shape
18+ // CHECK: %[[INPUT_ZP:.*]] = "tosa.const"
19+ // CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"
20+ // CHECK: %[[R0:.*]] = tosa.reshape %arg0, %[[SHAPE]]
21+ // CHECK: %[[R1:.*]] = tosa.reshape %arg1, %[[SHAPE]]
22+ // CHECK: %[[CONV:.*]] = tosa.conv2d %[[R0]], %[[R1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]]
23+ func.func @conv2d_io_insert_reshape (%arg0: tensor <256 xf32 >, %arg1: tensor <256 xf32 >, %arg2: tensor <16 xf32 >) -> tensor <1 x1 x1 x16 xf32 > {
24+ %shape = " tosa.const_shape" () {values = dense <[1 , 1 , 16 , 16 ]> : tensor <4 xindex >} : () -> !tosa.shape <4 >
25+ %input_zp = " tosa.const" () {values = dense <0.0 > : tensor <1 xf32 >} : () -> tensor <1 xf32 >
26+ %weight_zp = " tosa.const" () {values = dense <0.0 > : tensor <1 xf32 >} : () -> tensor <1 xf32 >
27+ %r0 = " tosa.reshape" (%arg0 , %shape ) : (tensor <256 xf32 >, !tosa.shape <4 >) -> tensor <1 x1 x16 x16 xf32 >
28+ %r1 = " tosa.reshape" (%arg1 , %shape ) : (tensor <256 xf32 >, !tosa.shape <4 >) -> tensor <1 x1 x16 x16 xf32 >
29+ %conv = " tosa.conv2d" (%r0 , %r1 , %arg2 , %input_zp , %weight_zp ) {pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, dilation = array<i64 : 1 , 1 >, acc_type = f32 } : (tensor <1 x1 x16 x16 xf32 >, tensor <1 x1 x16 x16 xf32 >, tensor <16 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >) -> tensor <1 x1 x1 x16 xf32 >
30+ return %conv : tensor <1 x1 x1 x16 xf32 >
31+ }
32+
33+ // CHECK-LABEL: func.func @depthwise_conv2d_io_insert_reshape(
34+ // CHECK: %[[SHAPE:.*]] = tosa.const_shape
35+ // CHECK: %[[WSHAPE:.*]] = tosa.const_shape
36+ // CHECK: %[[INPUT_ZP:.*]] = "tosa.const"
37+ // CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"
38+ // CHECK: %[[R0:.*]] = tosa.reshape %arg0, %[[SHAPE]]
39+ // CHECK: %[[R1:.*]] = tosa.reshape %arg1, %[[WSHAPE]]
40+ // CHECK: %[[CONV:.*]] = tosa.depthwise_conv2d %[[R0]], %[[R1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]]
41+ func.func @depthwise_conv2d_io_insert_reshape (%arg0: tensor <9 xf32 >, %arg1: tensor <9 xf32 >, %arg2: tensor <1 xf32 >) -> tensor <1 x1 x1 x1 xf32 > {
42+ %shape = " tosa.const_shape" () {values = dense <[1 , 3 , 3 , 1 ]> : tensor <4 xindex >} : () -> !tosa.shape <4 >
43+ %wshape = " tosa.const_shape" () {values = dense <[3 , 3 , 1 , 1 ]> : tensor <4 xindex >} : () -> !tosa.shape <4 >
44+ %input_zp = " tosa.const" () {values = dense <0.0 > : tensor <1 xf32 >} : () -> tensor <1 xf32 >
45+ %weight_zp = " tosa.const" () {values = dense <0.0 > : tensor <1 xf32 >} : () -> tensor <1 xf32 >
46+ %r0 = " tosa.reshape" (%arg0 , %shape ) : (tensor <9 xf32 >, !tosa.shape <4 >) -> tensor <1 x3 x3 x1 xf32 >
47+ %r1 = " tosa.reshape" (%arg1 , %wshape ) : (tensor <9 xf32 >, !tosa.shape <4 >) -> tensor <3 x3 x1 x1 xf32 >
48+ %conv = " tosa.depthwise_conv2d" (%r0 , %r1 , %arg2 , %input_zp , %weight_zp ) {pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, dilation = array<i64 : 1 , 1 >, acc_type = f32 } : (tensor <1 x3 x3 x1 xf32 >, tensor <3 x3 x1 x1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >) -> tensor <1 x1 x1 x1 xf32 >
49+ return %conv : tensor <1 x1 x1 x1 xf32 >
50+ }
51+
52+ // CHECK-LABEL: func.func @transpose_conv2d_io_insert_reshape(
53+ // CHECK: %[[SHAPE:.*]] = tosa.const_shape
54+ // CHECK: %[[WSHAPE:.*]] = tosa.const_shape
55+ // CHECK: %[[INPUT_ZP:.*]] = "tosa.const"
56+ // CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"
57+ // CHECK: %[[R0:.*]] = tosa.reshape %arg0, %[[SHAPE]]
58+ // CHECK: %[[R1:.*]] = tosa.reshape %arg1, %[[WSHAPE]]
59+ // CHECK: %[[CONV:.*]] = tosa.transpose_conv2d %[[R0]], %[[R1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]]
60+ func.func @transpose_conv2d_io_insert_reshape (%arg0: tensor <9 xf32 >, %arg1: tensor <9 xf32 >, %arg2: tensor <1 xf32 >) -> tensor <1 x5 x5 x1 xf32 > {
61+ %shape = " tosa.const_shape" () {values = dense <[1 , 3 , 3 , 1 ]> : tensor <4 xindex >} : () -> !tosa.shape <4 >
62+ %wshape = " tosa.const_shape" () {values = dense <[1 , 3 , 3 , 1 ]> : tensor <4 xindex >} : () -> !tosa.shape <4 >
63+ %input_zp = " tosa.const" () {values = dense <0.0 > : tensor <1 xf32 >} : () -> tensor <1 xf32 >
64+ %weight_zp = " tosa.const" () {values = dense <0.0 > : tensor <1 xf32 >} : () -> tensor <1 xf32 >
65+ %r0 = " tosa.reshape" (%arg0 , %shape ) : (tensor <9 xf32 >, !tosa.shape <4 >) -> tensor <1 x3 x3 x1 xf32 >
66+ %r1 = " tosa.reshape" (%arg1 , %wshape ) : (tensor <9 xf32 >, !tosa.shape <4 >) -> tensor <1 x3 x3 x1 xf32 >
67+ %conv = " tosa.transpose_conv2d" (%r0 , %r1 , %arg2 , %input_zp , %weight_zp ) {out_pad = array<i64 : 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 >, acc_type = f32 , dilation = array<i64 : 1 , 1 >, pad = array<i64 : 0 , 0 , 0 , 0 >} : (tensor <1 x3 x3 x1 xf32 >, tensor <1 x3 x3 x1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >) -> tensor <1 x5 x5 x1 xf32 >
68+ return %conv : tensor <1 x5 x5 x1 xf32 >
69+ }
70+
71+ // CHECK-LABEL: func.func @conv3d_io_insert_reshape(
72+ // CHECK: %[[SHAPE:.*]] = tosa.const_shape
73+ // CHECK: %[[WSHAPE:.*]] = tosa.const_shape
74+ // CHECK: %[[INPUT_ZP:.*]] = "tosa.const"
75+ // CHECK: %[[WEIGHT_ZP:.*]] = "tosa.const"
76+ // CHECK: %[[R0:.*]] = tosa.reshape %arg0, %[[SHAPE]]
77+ // CHECK: %[[R1:.*]] = tosa.reshape %arg1, %[[WSHAPE]]
78+ // CHECK: %[[CONV:.*]] = tosa.conv3d %[[R0]], %[[R1]], %arg2, %[[INPUT_ZP]], %[[WEIGHT_ZP]]
79+ func.func @conv3d_io_insert_reshape (%arg0: tensor <64 xf32 >, %arg1: tensor <1 xf32 >, %arg2: tensor <1 xf32 >) -> tensor <1 x1 x4 x4 x4 xf32 > {
80+ %shape = " tosa.const_shape" () {values = dense <[1 , 1 , 4 , 4 , 4 ]> : tensor <5 xindex >} : () -> !tosa.shape <5 >
81+ %wshape = " tosa.const_shape" () {values = dense <[1 , 1 , 1 , 1 , 1 ]> : tensor <5 xindex >} : () -> !tosa.shape <5 >
82+ %input_zp = " tosa.const" () {values = dense <0.0 > : tensor <1 xf32 >} : () -> tensor <1 xf32 >
83+ %weight_zp = " tosa.const" () {values = dense <0.0 > : tensor <1 xf32 >} : () -> tensor <1 xf32 >
84+ %r0 = " tosa.reshape" (%arg0 , %shape ) : (tensor <64 xf32 >, !tosa.shape <5 >) -> tensor <1 x1 x4 x4 x4 xf32 >
85+ %r1 = " tosa.reshape" (%arg1 , %wshape ) : (tensor <1 xf32 >, !tosa.shape <5 >) -> tensor <1 x1 x1 x1 x1 xf32 >
86+ %conv = " tosa.conv3d" (%r0 , %r1 , %arg2 , %input_zp , %weight_zp ) {pad = array<i64 : 0 , 0 , 0 , 0 , 0 , 0 >, stride = array<i64 : 1 , 1 , 1 >, dilation = array<i64 : 1 , 1 , 1 >, acc_type = f32 } : (tensor <1 x1 x4 x4 x4 xf32 >, tensor <1 x1 x1 x1 x1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >, tensor <1 xf32 >) -> tensor <1 x1 x4 x4 x4 xf32 >
87+ return %conv : tensor <1 x1 x4 x4 x4 xf32 >
88+ }
89+
1690// CHECK-LABEL: func.func @torch.aten.sigmoid$basic(
1791// CHECK-SAME: %[[ARG:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
1892// CHECK: %[[ARG_BUILTIN:.*]] = torch_c.to_builtin_tensor %[[ARG]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
@@ -2417,8 +2491,7 @@ func.func @torch.aten.avg_pool2d.divisor_override_unsupported_value(%arg0: !torc
24172491 %0 = torch.prim.ListConstruct %int3 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
24182492 %1 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
24192493 %2 = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
2420- // expected-error @+1 {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
2421- %3 = torch.aten.avg_pool2d %arg0 , %0 , %1 , %2 , %false , %count_include_pad , %divisor_override : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool , !torch.int -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
2494+ %3 = torch.aten.avg_pool2d %arg0 , %0 , %1 , %2 , %false , %count_include_pad , %divisor_override : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.list <int >, !torch.bool , !torch.bool , !torch.int -> !torch.vtensor <[1 ,192 ,35 ,35 ],f32 > // expected-error {{failed to legalize operation 'torch.aten.avg_pool2d' that was explicitly marked illegal}}
24222495 return %3 : !torch.vtensor <[1 ,192 ,35 ,35 ],f32 >
24232496}
24242497
@@ -2664,8 +2737,7 @@ func.func @torch.aten.index.Tensor_hacked_twin(%arg0: !torch.vtensor<[2,4,2],si6
26642737
26652738func.func @torch.aten.index.Tensor_hacked_twin.dynamic_size (%arg0: !torch.vtensor <[?,4 ],f32 >, %arg1: !torch.vtensor <[?,1 ],si64 >, %arg2: !torch.vtensor <[1 ,4 ],si64 >) -> !torch.vtensor <[?,4 ],f32 > attributes {torch.assume_strict_symbolic_shapes } {
26662739 %0 = torch.prim.ListConstruct %arg1 , %arg2 : (!torch.vtensor <[?,1 ],si64 >, !torch.vtensor <[1 ,4 ],si64 >) -> !torch.list <vtensor >
2667- // expected-error @+1 {{failed to legalize operation 'torch.aten.index.Tensor_hacked_twin' that was explicitly marked illegal}}
2668- %1 = torch.aten.index.Tensor_hacked_twin %arg0 , %0 : !torch.vtensor <[?,4 ],f32 >, !torch.list <vtensor > -> !torch.vtensor <[?,4 ],f32 >
2740+ %1 = torch.aten.index.Tensor_hacked_twin %arg0 , %0 : !torch.vtensor <[?,4 ],f32 >, !torch.list <vtensor > -> !torch.vtensor <[?,4 ],f32 > // expected-error {{failed to legalize operation 'torch.aten.index.Tensor_hacked_twin' that was explicitly marked illegal}}
26692741 return %1 : !torch.vtensor <[?,4 ],f32 >
26702742}
26712743
@@ -4552,8 +4624,7 @@ func.func @torch.aten.empty.memory_format() -> !torch.vtensor<[1,0,256],f32>{
45524624 %none = torch.constant.none
45534625 %cpu = torch.constant.device " cpu"
45544626 %false = torch.constant.bool false
4555- // expected-error @below {{failed to legalize operation 'torch.aten.empty.memory_format' that was explicitly marked illegal}}
4556- %out = torch.aten.empty.memory_format %2452 , %none , %none , %cpu , %false , %none : !torch.list <int >, !torch.none , !torch.none , !torch.Device , !torch.bool , !torch.none -> !torch.vtensor <[1 ,0 ,256 ],f32 >
4627+ %out = torch.aten.empty.memory_format %2452 , %none , %none , %cpu , %false , %none : !torch.list <int >, !torch.none , !torch.none , !torch.Device , !torch.bool , !torch.none -> !torch.vtensor <[1 ,0 ,256 ],f32 > // expected-error {{failed to legalize operation 'torch.aten.empty.memory_format' that was explicitly marked illegal}}
45574628 return %out : !torch.vtensor <[1 ,0 ,256 ],f32 >
45584629}
45594630
0 commit comments