@@ -41,7 +41,8 @@ hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) {
4141// CHECK-DAG: %[[CST_ACC:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x1xf32>
4242// CHECK-DAG: gpu.thread_id x
4343// CHECK: %[[R0:.+]] = scf.for %{{.*}} = %c0 to %c2560 step %c256 iter_args(%[[A0:.+]] = %[[CST_ACC]]) -> (vector<1x1x1xf32>) {
44- // CHECK: %[[V:.+]] = vector.transfer_read {{.*}} : memref<512x10240xf32, {{.*}}>, vector<1x4xf32>
44+ // CHECK: memref.expand_shape {{.*}} : memref<1x1024xf32, {{.*}}> into memref<1x256x4xf32, {{.*}}>
45+ // CHECK: %[[V:.+]] = vector.transfer_read {{.*}} : memref<1x256x4xf32, {{.*}}>, vector<1x4xf32>
4546// CHECK: %[[STRIDED:.+]] = vector.insert_strided_slice %[[V]], {{.*}} : vector<1x4xf32> into vector<1x1x1x1x1x4xf32>
4647// CHECK: %[[REDUCE:.+]] = vector.multi_reduction <add>, %[[STRIDED]], %[[CST_ACC]] [1, 3, 5] : vector<1x1x1x1x1x4xf32> to vector<1x1x1xf32>
4748// CHECK: %[[ADD:.+]] = arith.addf %[[REDUCE]], %[[A0]] : vector<1x1x1xf32>
@@ -104,7 +105,8 @@ hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) {
104105// CHECK: func.func @warp_reduction_broadcast_dispatch()
105106// CHECK-SAME: translation_info = #[[TRANSLATION_INFO]]
106107// CHECK: scf.for {{.*}} -> (vector<1x1x1xf32>) {
107- // CHECK: vector.transfer_read {{.*}} : memref<512x10240xf32, {{.*}}>, vector<1x4xf32>
108+ // CHECK: memref.expand_shape {{.*}} : memref<1x1024xf32, {{.*}}> into memref<1x256x4xf32, {{.*}}>
109+ // CHECK: vector.transfer_read {{.*}} : memref<1x256x4xf32, {{.*}}>, vector<1x4xf32>
108110// CHECK: vector.multi_reduction <add>, {{.*}} [1, 3, 5] : vector<1x1x1x1x1x4xf32> to vector<1x1x1xf32>
109111// CHECK: arith.addf {{.*}} : vector<1x1x1xf32>
110112// CHECK: scf.yield
@@ -144,7 +146,8 @@ hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) {
144146// CHECK: func.func @softmax()
145147// CHECK-SAME: translation_info = #[[TRANSLATION_INFO]]
146148// CHECK: scf.for {{.*}} -> (vector<1x1x1xf32>) {
147- // CHECK: vector.transfer_read {{.*}} : memref<12x128x40960xf32, {{.*}}>, vector<1x4xf32>
149+ // CHECK: memref.expand_shape {{.*}} : memref<1x1x4096xf32, {{.*}}> into memref<1x1x1024x4xf32, {{.*}}>
150+ // CHECK: vector.transfer_read {{.*}} : memref<1x1x1024x4xf32, {{.*}}>, vector<1x4xf32>
148151// CHECK: vector.multi_reduction <maxnumf>, {{.*}} {{.*}} : vector<1x1x1x1x1x4xf32> to vector<1x1x1xf32>
149152// CHECK: arith.maxnumf {{.*}} : vector<1x1x1xf32>
150153// CHECK: scf.yield
@@ -201,7 +204,8 @@ hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) {
201204// CHECK: func.func @softmax_singlesubgroup()
202205// CHECK-SAME: translation_info = #[[TRANSLATION_INFO]]
203206// CHECK: scf.for {{.*}} -> (vector<1x1x1xf32>) {
204- // CHECK: vector.transfer_read {{.*}} : memref<12x256x40960xf32, {{.*}}>, vector<1x4xf32>
207+ // CHECK: memref.expand_shape {{.*}} : memref<1x1x128xf32, {{.*}}> into memref<1x1x32x4xf32, {{.*}}>
208+ // CHECK: vector.transfer_read {{.*}} : memref<1x1x32x4xf32, {{.*}}>, vector<1x4xf32>
205209// CHECK: vector.multi_reduction <maxnumf>, {{.*}} {{.*}} : vector<1x1x1x1x1x4xf32> to vector<1x1x1xf32>
206210// CHECK: arith.maxnumf {{.*}} : vector<1x1x1xf32>
207211// CHECK: scf.yield
@@ -518,7 +522,10 @@ hal.executable private @i4_dequant_matvec {
518522// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
519523// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x1xf16>
520524// CHECK: scf.for %{{.+}} = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%{{.*}} = %[[CST]]) -> (vector<1x1x1xf16>)
521- // CHECK: vector.transfer_read {{.*}} : memref<4096x32x128xi4, {{.*}}>, vector<1x4xi4>
525+ // CHECK: memref.expand_shape {{.*}} : memref<1x128xf16, {{.*}}> into memref<1x32x4xf16, {{.*}}>
526+ // CHECK: memref.expand_shape {{.*}} : memref<1x1x128xi4, {{.*}}> into memref<1x1x32x4xi4, {{.*}}>
527+ // CHECK: vector.transfer_read {{.*}} : memref<1x32x4xf16, {{.*}}>, vector<1x4xf16>
528+ // CHECK: vector.transfer_read {{.*}} : memref<1x1x32x4xi4, {{.*}}>, vector<1x4xi4>
522529// CHECK: arith.extui %{{.*}} : vector<1x1x1x1x1x4xi4> to vector<1x1x1x1x1x4xi32>
523530// CHECK: arith.uitofp %{{.*}} : vector<1x1x1x1x1x4xi32> to vector<1x1x1x1x1x4xf16>
524531// CHECK: arith.subf %{{.*}}, %{{.*}} : vector<1x1x1x1x1x4xf16>
0 commit comments