diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp index 1ccac17b77f6..45d54f8fc6a3 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp +++ b/compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp @@ -1805,6 +1805,55 @@ static LogicalResult setSPIRVOpConfig(IREE::GPU::TargetAttr target, // Entry Point //===----------------------------------------------------------------------===// +/// Find the root operation for the dispatch. The root is the op that will be +/// tiled and distributed to workgroups; all other ops fuse with it as producers +/// or consumers. +/// +/// Priority (all passes iterate in reverse to prefer later ops): +/// 1. Named ops (matmul, conv) or generics with reduction iterators. +/// 2. Any generic op (elementwise). +/// 3. Fill ops. +static Operation *getRootOperation(ArrayRef computeOps) { + Operation *rootOperation = nullptr; + + // Pass 1: named ops or generics with reductions. + for (Operation *op : llvm::reverse(computeOps)) { + if (auto genericOp = dyn_cast(op)) { + if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) { + rootOperation = op; + break; + } + continue; + } + if (!isa(op) && isa(op)) { + rootOperation = op; + break; + } + } + + // Pass 2: any generic op (elementwise). + if (!rootOperation) { + for (Operation *op : llvm::reverse(computeOps)) { + if (isa(op)) { + rootOperation = op; + break; + } + } + } + + // Pass 3: fill ops. + if (!rootOperation) { + for (Operation *op : llvm::reverse(computeOps)) { + if (isa(op)) { + rootOperation = op; + break; + } + } + } + + return rootOperation; +} + static LogicalResult setConfigForKernel(IREE::GPU::TargetAttr target, mlir::FunctionOpInterface funcOp) { SmallVector computeOps = getComputeOps(funcOp); @@ -1813,35 +1862,21 @@ static LogicalResult setConfigForKernel(IREE::GPU::TargetAttr target, return success(); } - // Try to find a configuration according to a matmul/convolution op, which as - // at least one reduction dimension, and use it as the root op. So, skip all - // fused parallel producer ops. - ArrayRef roots(computeOps); - while (roots.size() > 1) { - auto linalgOp = dyn_cast(roots.front()); - if (!linalgOp) { - break; - } - if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) { - break; - } - roots = roots.drop_front(); + Operation *rootOp = getRootOperation(computeOps); + if (!rootOp) { + return computeOps.back()->emitOpError( + "unable to find root operation in dispatch"); } - for (Operation *computeOp : roots) { - if (succeeded(setSPIRVOpConfig(target, funcOp, computeOp))) { - return success(); - } + if (succeeded(setSPIRVOpConfig(target, funcOp, rootOp))) { + return success(); } - Operation *computeOp = roots.back(); - // If there are still no root op, check for any linalg.generic op. - if (succeeded(setDefaultOpConfig(target, computeOp))) { + if (succeeded(setDefaultOpConfig(target, rootOp))) { return success(); } - // Check if the op configuration was set. - return computeOp->emitOpError( + return rootOp->emitOpError( "without known roots, the last compute operation in the tiled " "loop body is expected to be set as root"); } diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel index 47d4524d055d..411adde67bac 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/BUILD.bazel @@ -33,6 +33,7 @@ iree_lit_test_suite( "config_default_matmul.mlir", "config_default_misc.mlir", "config_default_reduction.mlir", + "config_default_softmax.mlir", "config_default_sub_byte_types.mlir", "config_mali_conv.mlir", "config_mali_matmul.mlir", diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt index 6ce2eccaf525..aa163cd3959a 100644 --- a/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/CMakeLists.txt @@ -28,6 +28,7 @@ iree_lit_test_suite( "config_default_matmul.mlir" "config_default_misc.mlir" "config_default_reduction.mlir" + "config_default_softmax.mlir" "config_default_sub_byte_types.mlir" "config_mali_conv.mlir" "config_mali_matmul.mlir" diff --git a/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_softmax.mlir b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_softmax.mlir new file mode 100644 index 000000000000..866bf9635e40 --- /dev/null +++ b/compiler/src/iree/compiler/Codegen/SPIRV/test/config_default_softmax.mlir @@ -0,0 +1,40 @@ +// RUN: iree-opt --split-input-file \ +// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-decompose-softmax), iree-spirv-select-lowering-strategy-pass)' %s | \ +// RUN: FileCheck %s + +// Verifies that for decomposed softmax (max-reduce, exp-sum-reduce, div), the +// lowering config is placed on the last reduction (exp-sum) rather than the +// first (max). + +#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", { + iree_codegen.target_info = #iree_gpu.target> +}> +func.func @softmax(%arg0: tensor<10x256x256xf32>) -> tensor<10x256x256xf32> + attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} { + %0 = tensor.empty() : tensor<10x256x256xf32> + %1 = linalg.softmax dimension(2) + ins(%arg0 : tensor<10x256x256xf32>) + outs(%0 : tensor<10x256x256xf32>) -> tensor<10x256x256xf32> + return %1 : tensor<10x256x256xf32> +} + +// The lowering_config should be on the exp-sum reduction (the second generic +// with a reduction iterator), not on the max reduction (the first). + +// CHECK-LABEL: func.func @softmax +// Max reduction: no lowering_config. +// CHECK: linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK-NOT: lowering_config +// Exp-sum reduction: has lowering_config. +// CHECK: linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"] +// CHECK-SAME: lowering_config +// Div elementwise: no lowering_config. +// CHECK: linalg.generic +// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"] +// CHECK-NOT: lowering_config