Skip to content

Commit 1ec3fbc

Browse files
committed
[LLVMCPU] Lower explicit workgroup-local allocs through dispatch ABI
Add #iree_codegen.workgroup_local as a codegen memory-space attribute for authored LLVMCPU dispatch allocations that should use HAL workgroup local memory instead of the thread stack. This deliberately matches the GPU-side semantic split: workgroup/shared memory is represented as memref.alloc in a workgroup memory space, while memref.alloca remains private stack scratch. LLVMCPU uses an IREE-specific memory-space attribute instead of #gpu.address_space<workgroup>, but keeps the same alloc/dealloc ownership model. The assignment pass is intentionally narrow: it only handles entry-block memref.alloc operations in HAL executable exports, computes byte ranges with the target data layout and ABI alignment, rejects dynamic or unsupported layouts, and refuses to overwrite existing local-memory assignments or predeclared export requirements. ConvertToLLVM consumes the assigned range by building memref descriptors from the HAL workgroup local-memory pointer. Matching memref.dealloc operations are erased because the storage is owned by the dispatch frame.
1 parent 5bfd638 commit 1ec3fbc

15 files changed

Lines changed: 794 additions & 2 deletions

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,8 @@ constexpr StringLiteral kUkernelAttrName = "iree_codegen.ukernel";
7373
constexpr StringLiteral kUKernelProviderName = "iree_codegen.ukernel_provider";
7474
constexpr StringLiteral kVectorTileSizesAttrName =
7575
"iree_codegen.vector_tile_sizes";
76+
constexpr StringLiteral kWorkgroupLocalMemoryRangeAttrName =
77+
"iree_codegen.local_memory_range";
7678

7779
//===----------------------------------------------------------------------===//
7880
// Helpers for getting/setting iree_codegen.translation_info attribute on a

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.td

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,35 @@ def LocalMappingAttr :
160160
}];
161161
}
162162

163+
//===---------------------------------------------------------------------===//
164+
// iree_codegen.workgroup_local memory space attribute
165+
//===---------------------------------------------------------------------===//
166+
167+
def WorkgroupLocalMemoryAttr :
168+
AttrDef<IREECodegen_Dialect, "WorkgroupLocalMemory"> {
169+
let mnemonic = "workgroup_local";
170+
let parameters = (ins);
171+
let assemblyFormat = "";
172+
173+
let description = [{
174+
Memory space attribute indicating that an allocation uses per-workgroup
175+
local memory supplied by the target runtime instead of the thread stack.
176+
177+
LLVMCPU dispatch lowering rewrites `memref.alloc` operations marked with
178+
this memory space to use the HAL dispatch ABI local-memory pointer. Matching
179+
`memref.dealloc` operations are erased because the storage is owned by the
180+
dispatch frame. The allocation layout is assigned before final LLVM
181+
conversion and exported as the dispatch `workgroup_local_memory`
182+
requirement.
183+
184+
Example:
185+
```mlir
186+
%scratch = memref.alloc()
187+
: memref<64x1536xf32, #iree_codegen.workgroup_local>
188+
```
189+
}];
190+
}
191+
163192
//===---------------------------------------------------------------------===//
164193
// iree_codegen.simple_target
165194
//===---------------------------------------------------------------------===//

compiler/src/iree/compiler/Codegen/Dialect/Codegen/IR/test/roundtrip.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,17 @@ func.func private @workgroup_scope_attr_linearize() attributes {
114114

115115
// -----
116116

117+
func.func @workgroup_local_memory_space(
118+
%arg0: memref<64x64xf32, #iree_codegen.workgroup_local>)
119+
-> memref<64x64xf32, #iree_codegen.workgroup_local> {
120+
return %arg0 : memref<64x64xf32, #iree_codegen.workgroup_local>
121+
}
122+
// CHECK-LABEL: func.func @workgroup_local_memory_space(
123+
// CHECK-SAME: memref<64x64xf32, #iree_codegen.workgroup_local>
124+
// CHECK-SAME: -> memref<64x64xf32, #iree_codegen.workgroup_local>
125+
126+
// -----
127+
117128
// Test constraints op with knobs and dims.
118129
func.func @constraints_op(%arg0: index, %arg1: index) {
119130
iree_codegen.smt.constraints target = <set = 0>, pipeline = #iree_gpu.pipeline<VectorDistribute>,

compiler/src/iree/compiler/Codegen/LLVMCPU/BUILD.bazel

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ iree_compiler_cc_library(
6060
"LLVMCPU2DScalableTo1DScalable.cpp",
6161
"LLVMCPUAssignConstantOrdinals.cpp",
6262
"LLVMCPUAssignImportOrdinals.cpp",
63+
"LLVMCPUAssignWorkgroupLocalMemory.cpp",
6364
"LLVMCPUCheckIRBeforeLLVMConversion.cpp",
6465
"LLVMCPUEmitVectorizationRemarks.cpp",
6566
"LLVMCPULinkExecutables.cpp",

compiler/src/iree/compiler/Codegen/LLVMCPU/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ iree_cc_library(
5656
"LLVMCPU2DScalableTo1DScalable.cpp"
5757
"LLVMCPUAssignConstantOrdinals.cpp"
5858
"LLVMCPUAssignImportOrdinals.cpp"
59+
"LLVMCPUAssignWorkgroupLocalMemory.cpp"
5960
"LLVMCPUCheckIRBeforeLLVMConversion.cpp"
6061
"LLVMCPUEmitVectorizationRemarks.cpp"
6162
"LLVMCPULinkExecutables.cpp"

compiler/src/iree/compiler/Codegen/LLVMCPU/ConvertToLLVM.cpp

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "iree/compiler/Codegen/Common/PassUtils.h"
88
#include "iree/compiler/Codegen/Common/Transforms.h"
9+
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenAttrs.h"
910
#include "iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
1011
#include "iree/compiler/Codegen/LLVMCPU/DispatchABI.h"
1112
#include "iree/compiler/Codegen/LLVMCPU/Passes.h"
@@ -332,6 +333,83 @@ struct ConvertHALInterfaceBindingSubspanOp
332333
}
333334
};
334335

336+
/// Rewrites memref.alloc with #iree_codegen.workgroup_local memory space to a
337+
/// memref descriptor backed by HAL dispatch workgroup local memory.
338+
struct ConvertWorkgroupLocalAllocOp
339+
: ConvertOpToLLVMWithABIPattern<memref::AllocOp> {
340+
ConvertWorkgroupLocalAllocOp(HALDispatchABI &abi,
341+
LLVMTypeConverter &typeConverter)
342+
: ConvertOpToLLVMWithABIPattern(abi, typeConverter, /*benefit=*/2) {}
343+
344+
LogicalResult
345+
matchAndRewrite(memref::AllocOp allocOp, memref::AllocOpAdaptor,
346+
ConversionPatternRewriter &rewriter) const override {
347+
MemRefType memRefType = allocOp.getType();
348+
if (!isa_and_nonnull<IREE::Codegen::WorkgroupLocalMemoryAttr>(
349+
memRefType.getMemorySpace())) {
350+
return failure();
351+
}
352+
if (!memRefType.hasStaticShape()) {
353+
return allocOp.emitOpError(
354+
"workgroup local memory allocations must have static shape");
355+
}
356+
SmallVector<int64_t> strides;
357+
int64_t offset = 0;
358+
if (failed(memRefType.getStridesAndOffset(strides, offset)) ||
359+
ShapedType::isDynamic(offset) ||
360+
llvm::any_of(strides, ShapedType::isDynamic)) {
361+
return allocOp.emitOpError(
362+
"workgroup local memory allocations must have static layout");
363+
}
364+
365+
auto rangeAttr = allocOp->getAttrOfType<DenseI64ArrayAttr>(
366+
kWorkgroupLocalMemoryRangeAttrName);
367+
if (!rangeAttr || rangeAttr.size() != 2 || rangeAttr[0] < 0 ||
368+
rangeAttr[1] < 0) {
369+
return allocOp.emitOpError(
370+
"missing valid iree_codegen.local_memory_range annotation");
371+
}
372+
373+
Location loc = allocOp.getLoc();
374+
Value basePtr = abi.loadWorkgroupLocalMemoryPtr(allocOp, rewriter);
375+
Value byteOffset = LLVM::ConstantOp::create(
376+
rewriter, loc, rewriter.getI64Type(), rangeAttr[0]);
377+
Value offsetPtr = LLVM::GEPOp::create(
378+
rewriter, loc, basePtr.getType(), rewriter.getI8Type(), basePtr,
379+
byteOffset, LLVM::GEPNoWrapFlags::inbounds);
380+
381+
MemRefType strippedType =
382+
MemRefType::get(memRefType.getShape(), memRefType.getElementType(),
383+
memRefType.getLayout());
384+
auto desc = MemRefDescriptor::fromStaticShape(
385+
rewriter, loc, *getTypeConverter(), strippedType, offsetPtr);
386+
rewriter.replaceOp(allocOp, {desc});
387+
return success();
388+
}
389+
};
390+
391+
/// Erases deallocations for #iree_codegen.workgroup_local memory. The storage
392+
/// is owned by the HAL dispatch frame and released when the dispatch returns.
393+
struct ConvertWorkgroupLocalDeallocOp
394+
: ConvertOpToLLVMWithABIPattern<memref::DeallocOp> {
395+
ConvertWorkgroupLocalDeallocOp(HALDispatchABI &abi,
396+
LLVMTypeConverter &typeConverter)
397+
: ConvertOpToLLVMWithABIPattern(abi, typeConverter, /*benefit=*/2) {}
398+
399+
LogicalResult
400+
matchAndRewrite(memref::DeallocOp deallocOp, memref::DeallocOpAdaptor,
401+
ConversionPatternRewriter &rewriter) const override {
402+
auto memRefType = dyn_cast<MemRefType>(deallocOp.getMemref().getType());
403+
if (!memRefType ||
404+
!isa_and_nonnull<IREE::Codegen::WorkgroupLocalMemoryAttr>(
405+
memRefType.getMemorySpace())) {
406+
return failure();
407+
}
408+
rewriter.eraseOp(deallocOp);
409+
return success();
410+
}
411+
};
412+
335413
struct InstrumentationEntry {
336414
// !llvm.ptr<i8> pointing at the base of the ringbuffer.
337415
Value basePtr;
@@ -1052,6 +1130,11 @@ void ConvertToLLVMPass::runOnOperation() {
10521130
options.dataLayout = llvm::DataLayout(dataLayoutStr);
10531131
options.overrideIndexBitwidth(options.dataLayout.getPointerSizeInBits());
10541132
LLVMTypeConverter typeConverter(&getContext(), options, &dataLayoutAnalysis);
1133+
typeConverter.addTypeAttributeConversion(
1134+
[](BaseMemRefType, IREE::Codegen::WorkgroupLocalMemoryAttr attr)
1135+
-> TypeConverter::AttributeConversionResult {
1136+
return IntegerAttr::get(IntegerType::get(attr.getContext(), 64), 0);
1137+
});
10551138

10561139
RewritePatternSet patterns(&getContext());
10571140

@@ -1123,6 +1206,8 @@ void ConvertToLLVMPass::runOnOperation() {
11231206
ConvertHALInterfaceWorkgroupCountOp,
11241207
ConvertHALInterfaceConstantLoadOp,
11251208
ConvertHALInterfaceBindingSubspanOp,
1209+
ConvertWorkgroupLocalAllocOp,
1210+
ConvertWorkgroupLocalDeallocOp,
11261211
ConvertHALInstrumentWorkgroupOp,
11271212
ConvertHALInstrumentValueOp,
11281213
ConvertHALInstrumentMemoryLoadOp,

0 commit comments

Comments
 (0)