Skip to content

Commit 084bc47

Browse files
authored
[BACKEND] remove workaround in fp4padded alloc size calculation (#6739)
1 parent b2d9ec4 commit 084bc47

2 files changed

Lines changed: 5 additions & 6 deletions

File tree

lib/Dialect/TritonGPU/IR/Dialect.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -316,12 +316,7 @@ SmallVector<int64_t> getAllocationShapePerCTA(Attribute layout,
316316
if (auto sharedMMALayout = mlir::dyn_cast<NVMMASharedEncodingAttr>(layout)) {
317317
if (sharedMMALayout.getFp4Padded()) {
318318
auto packedAxis = getOrder(sharedMMALayout, shapeLogical)[0];
319-
if (shape.size() == 3) {
320-
// Take into account multi buffering
321-
shape[1 + packedAxis] *= 2;
322-
} else {
323-
shape[packedAxis] *= 2;
324-
}
319+
shape[packedAxis] *= 2;
325320
}
326321
}
327322
return getShapePerCTA(layout, shape);

test/Analysis/test-allocation.mlir

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,8 @@
2020
#NVMMA_SHARED_32 = #ttg.nvmma_shared<{swizzlingByteWidth = 32, transposed = false, elementBitWidth = 16}>
2121
#NVMMA_SHARED_64 = #ttg.nvmma_shared<{swizzlingByteWidth = 64, transposed = false, elementBitWidth = 16}>
2222
#NVMMA_SHARED_128 = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 16}>
23+
#NVMMA_SHARED_FP4PADDED = #ttg.nvmma_shared<{swizzlingByteWidth = 128, transposed = false, elementBitWidth = 8, fp4Padded = true}>
24+
2325
#smem = #ttg.shared_memory
2426

2527
module attributes {"ttg.num-warps" = 4 : i32, "ttg.num-ctas" = 1 : i32} {
@@ -917,6 +919,8 @@ tt.func @tightly_packed_captures(%arg0: i8, %arg1: i64) {
917919
// expected-remark @below {{nvmma_alignment}}
918920
// expected-remark @below {{size = 1088}}
919921
tt.func @nvmma_alignment(%lb : index, %ub : index, %step : index, %A : !tt.ptr<f16>, %B : !tt.ptr<f16>) {
922+
// expected-remark @below {{offset = 0, size = 128}}
923+
%fp4 = ttg.local_alloc : () -> !ttg.memdesc<8x8xi8, #NVMMA_SHARED_FP4PADDED, #ttg.shared_memory, mutable>
920924
// expected-remark @below {{offset = 0, size = 64}}
921925
%a = ttg.local_alloc : () -> !ttg.memdesc<32xf16, #A_SHARED, #ttg.shared_memory, mutable>
922926
// expected-remark @below {{offset = 128, size = 64}}

0 commit comments

Comments
 (0)