Skip to content

Commit ca60cab

Browse files
Fixes: - Update reduction test after llvm/llvm-project@5a221c3. Reverts: - Dropped local revert of llvm/llvm-project#169614 due to #22649. --------- Signed-off-by: Muzammiluddin Syed <muzasyed@amd.com>
1 parent 75e6a7a commit ca60cab

2 files changed

Lines changed: 13 additions & 6 deletions

File tree

compiler/src/iree/compiler/Codegen/LLVMGPU/test/reduction_pipeline_cuda.mlir

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,8 @@ hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) {
4141
// CHECK-DAG: %[[CST_ACC:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x1xf32>
4242
// CHECK-DAG: gpu.thread_id x
4343
// CHECK: %[[R0:.+]] = scf.for %{{.*}} = %c0 to %c2560 step %c256 iter_args(%[[A0:.+]] = %[[CST_ACC]]) -> (vector<1x1x1xf32>) {
44-
// CHECK: %[[V:.+]] = vector.transfer_read {{.*}} : memref<512x10240xf32, {{.*}}>, vector<1x4xf32>
44+
// CHECK: memref.expand_shape {{.*}} : memref<1x1024xf32, {{.*}}> into memref<1x256x4xf32, {{.*}}>
45+
// CHECK: %[[V:.+]] = vector.transfer_read {{.*}} : memref<1x256x4xf32, {{.*}}>, vector<1x4xf32>
4546
// CHECK: %[[STRIDED:.+]] = vector.insert_strided_slice %[[V]], {{.*}} : vector<1x4xf32> into vector<1x1x1x1x1x4xf32>
4647
// CHECK: %[[REDUCE:.+]] = vector.multi_reduction <add>, %[[STRIDED]], %[[CST_ACC]] [1, 3, 5] : vector<1x1x1x1x1x4xf32> to vector<1x1x1xf32>
4748
// CHECK: %[[ADD:.+]] = arith.addf %[[REDUCE]], %[[A0]] : vector<1x1x1xf32>
@@ -104,7 +105,8 @@ hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) {
104105
// CHECK: func.func @warp_reduction_broadcast_dispatch()
105106
// CHECK-SAME: translation_info = #[[TRANSLATION_INFO]]
106107
// CHECK: scf.for {{.*}} -> (vector<1x1x1xf32>) {
107-
// CHECK: vector.transfer_read {{.*}} : memref<512x10240xf32, {{.*}}>, vector<1x4xf32>
108+
// CHECK: memref.expand_shape {{.*}} : memref<1x1024xf32, {{.*}}> into memref<1x256x4xf32, {{.*}}>
109+
// CHECK: vector.transfer_read {{.*}} : memref<1x256x4xf32, {{.*}}>, vector<1x4xf32>
108110
// CHECK: vector.multi_reduction <add>, {{.*}} [1, 3, 5] : vector<1x1x1x1x1x4xf32> to vector<1x1x1xf32>
109111
// CHECK: arith.addf {{.*}} : vector<1x1x1xf32>
110112
// CHECK: scf.yield
@@ -144,7 +146,8 @@ hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) {
144146
// CHECK: func.func @softmax()
145147
// CHECK-SAME: translation_info = #[[TRANSLATION_INFO]]
146148
// CHECK: scf.for {{.*}} -> (vector<1x1x1xf32>) {
147-
// CHECK: vector.transfer_read {{.*}} : memref<12x128x40960xf32, {{.*}}>, vector<1x4xf32>
149+
// CHECK: memref.expand_shape {{.*}} : memref<1x1x4096xf32, {{.*}}> into memref<1x1x1024x4xf32, {{.*}}>
150+
// CHECK: vector.transfer_read {{.*}} : memref<1x1x1024x4xf32, {{.*}}>, vector<1x4xf32>
148151
// CHECK: vector.multi_reduction <maxnumf>, {{.*}} {{.*}} : vector<1x1x1x1x1x4xf32> to vector<1x1x1xf32>
149152
// CHECK: arith.maxnumf {{.*}} : vector<1x1x1xf32>
150153
// CHECK: scf.yield
@@ -201,7 +204,8 @@ hal.executable.variant @cuda target(<"cuda", "cuda-nvptx-fb">) {
201204
// CHECK: func.func @softmax_singlesubgroup()
202205
// CHECK-SAME: translation_info = #[[TRANSLATION_INFO]]
203206
// CHECK: scf.for {{.*}} -> (vector<1x1x1xf32>) {
204-
// CHECK: vector.transfer_read {{.*}} : memref<12x256x40960xf32, {{.*}}>, vector<1x4xf32>
207+
// CHECK: memref.expand_shape {{.*}} : memref<1x1x128xf32, {{.*}}> into memref<1x1x32x4xf32, {{.*}}>
208+
// CHECK: vector.transfer_read {{.*}} : memref<1x1x32x4xf32, {{.*}}>, vector<1x4xf32>
205209
// CHECK: vector.multi_reduction <maxnumf>, {{.*}} {{.*}} : vector<1x1x1x1x1x4xf32> to vector<1x1x1xf32>
206210
// CHECK: arith.maxnumf {{.*}} : vector<1x1x1xf32>
207211
// CHECK: scf.yield
@@ -518,7 +522,10 @@ hal.executable private @i4_dequant_matvec {
518522
// CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index
519523
// CHECK-DAG: %[[CST:.+]] = arith.constant dense<0.000000e+00> : vector<1x1x1xf16>
520524
// CHECK: scf.for %{{.+}} = %[[C0]] to %[[C32]] step %[[C1]] iter_args(%{{.*}} = %[[CST]]) -> (vector<1x1x1xf16>)
521-
// CHECK: vector.transfer_read {{.*}} : memref<4096x32x128xi4, {{.*}}>, vector<1x4xi4>
525+
// CHECK: memref.expand_shape {{.*}} : memref<1x128xf16, {{.*}}> into memref<1x32x4xf16, {{.*}}>
526+
// CHECK: memref.expand_shape {{.*}} : memref<1x1x128xi4, {{.*}}> into memref<1x1x32x4xi4, {{.*}}>
527+
// CHECK: vector.transfer_read {{.*}} : memref<1x32x4xf16, {{.*}}>, vector<1x4xf16>
528+
// CHECK: vector.transfer_read {{.*}} : memref<1x1x32x4xi4, {{.*}}>, vector<1x4xi4>
522529
// CHECK: arith.extui %{{.*}} : vector<1x1x1x1x1x4xi4> to vector<1x1x1x1x1x4xi32>
523530
// CHECK: arith.uitofp %{{.*}} : vector<1x1x1x1x1x4xi32> to vector<1x1x1x1x1x4xf16>
524531
// CHECK: arith.subf %{{.*}}, %{{.*}} : vector<1x1x1x1x1x4xf16>

third_party/llvm-project

Submodule llvm-project updated 1430 files

0 commit comments

Comments
 (0)