Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 57 additions & 22 deletions compiler/src/iree/compiler/Codegen/SPIRV/KernelConfig.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1805,6 +1805,55 @@ static LogicalResult setSPIRVOpConfig(IREE::GPU::TargetAttr target,
// Entry Point
//===----------------------------------------------------------------------===//

/// Find the root operation for the dispatch. The root is the op that will be
/// tiled and distributed to workgroups; all other ops fuse with it as producers
/// or consumers.
///
/// Priority (all passes iterate in reverse to prefer later ops):
/// 1. Named ops (matmul, conv) or generics with reduction iterators.
/// 2. Any generic op (elementwise).
/// 3. Fill ops.
static Operation *getRootOperation(ArrayRef<Operation *> computeOps) {
Operation *rootOperation = nullptr;

// Pass 1: named ops or generics with reductions.
for (Operation *op : llvm::reverse(computeOps)) {
if (auto genericOp = dyn_cast<linalg::GenericOp>(op)) {
if (genericOp.getNumLoops() != genericOp.getNumParallelLoops()) {
rootOperation = op;
break;
}
continue;
}
if (!isa<linalg::FillOp>(op) && isa<TilingInterface>(op)) {
rootOperation = op;
break;
}
}

// Pass 2: any generic op (elementwise).
if (!rootOperation) {
for (Operation *op : llvm::reverse(computeOps)) {
if (isa<linalg::GenericOp>(op)) {
rootOperation = op;
break;
}
}
}

// Pass 3: fill ops.
if (!rootOperation) {
for (Operation *op : llvm::reverse(computeOps)) {
if (isa<linalg::FillOp>(op)) {
rootOperation = op;
break;
}
}
}

return rootOperation;
}

static LogicalResult setConfigForKernel(IREE::GPU::TargetAttr target,
mlir::FunctionOpInterface funcOp) {
SmallVector<Operation *> computeOps = getComputeOps(funcOp);
Expand All @@ -1813,35 +1862,21 @@ static LogicalResult setConfigForKernel(IREE::GPU::TargetAttr target,
return success();
}

// Try to find a configuration according to a matmul/convolution op, which as
// at least one reduction dimension, and use it as the root op. So, skip all
// fused parallel producer ops.
ArrayRef roots(computeOps);
while (roots.size() > 1) {
auto linalgOp = dyn_cast<linalg::LinalgOp>(roots.front());
if (!linalgOp) {
break;
}
if (linalgOp.getNumParallelLoops() != linalgOp.getNumLoops()) {
break;
}
roots = roots.drop_front();
Operation *rootOp = getRootOperation(computeOps);
if (!rootOp) {
return computeOps.back()->emitOpError(
"unable to find root operation in dispatch");
}

for (Operation *computeOp : roots) {
if (succeeded(setSPIRVOpConfig(target, funcOp, computeOp))) {
return success();
}
if (succeeded(setSPIRVOpConfig(target, funcOp, rootOp))) {
return success();
}

Operation *computeOp = roots.back();
// If there are still no root op, check for any linalg.generic op.
if (succeeded(setDefaultOpConfig(target, computeOp))) {
if (succeeded(setDefaultOpConfig(target, rootOp))) {
return success();
}

// Check if the op configuration was set.
return computeOp->emitOpError(
return rootOp->emitOpError(
"without known roots, the last compute operation in the tiled "
"loop body is expected to be set as root");
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
// RUN: iree-opt --split-input-file \
// RUN: --pass-pipeline='builtin.module(func.func(iree-codegen-decompose-softmax), iree-spirv-select-lowering-strategy-pass)' %s | \
// RUN: FileCheck %s

// Verifies that for decomposed softmax (max-reduce, exp-sum-reduce, div), the
// lowering config is placed on the last reduction (exp-sum) rather than the
// first (max).

#executable_target_vulkan_spirv_fb = #hal.executable.target<"vulkan-spirv", "vulkan-spirv-fb", {
iree_codegen.target_info = #iree_gpu.target<arch = "", features = "spirv:v1.6,cap:Shader", wgp = <
compute = fp32|int32, storage = b32, subgroup = none,
subgroup_size_choices = [64], max_workgroup_sizes = [512, 512, 512],
max_thread_count_per_workgroup = 512, max_workgroup_memory_bytes = 16384,
max_workgroup_counts = [65535, 65535, 65535]>>
}>
func.func @softmax(%arg0: tensor<10x256x256xf32>) -> tensor<10x256x256xf32>
attributes {hal.executable.target = #executable_target_vulkan_spirv_fb} {
%0 = tensor.empty() : tensor<10x256x256xf32>
%1 = linalg.softmax dimension(2)
ins(%arg0 : tensor<10x256x256xf32>)
outs(%0 : tensor<10x256x256xf32>) -> tensor<10x256x256xf32>
return %1 : tensor<10x256x256xf32>
}

// The lowering_config should be on the exp-sum reduction (the second generic
// with a reduction iterator), not on the max reduction (the first).

// CHECK-LABEL: func.func @softmax
// Max reduction: no lowering_config.
// CHECK: linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK-NOT: lowering_config
// Exp-sum reduction: has lowering_config.
// CHECK: linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "reduction"]
// CHECK-SAME: lowering_config
// Div elementwise: no lowering_config.
// CHECK: linalg.generic
// CHECK-SAME: iterator_types = ["parallel", "parallel", "parallel"]
// CHECK-NOT: lowering_config
Loading