Skip to content

Commit dd153c9

Browse files
erwei-xilinxclaude
andcommitted
[multi-gpu] Phase 2: collapse 3 wrap_* helpers into one wrap_bytes
The previous wrap_data / wrap_flags / wrap_bases helpers each hand-built an LLVM memref descriptor struct (!llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>), hardcoding the in-flight memref-to-LLVM ABI three times. An upstream descriptor-layout change would silently break all three. Collapse to a single wrap_bytes(ptr, size_bytes) -> memref<?xi8> that builds the descriptor once. Use sites do memref.view to retype: %data_bytes = wrap_bytes(%data_ptr, %c1024_bytes) %data_m = memref.view %data_bytes[%c0][] : memref<?xi8> to memref<256xf32> %flags_bytes = wrap_bytes(%flags_ptr, %c16_bytes) %flags_m = memref.view %flags_bytes[%c0][] : memref<?xi8> to memref<4xi32> %bases_bytes = wrap_bytes(%bases_devptr, %bases_size) %bases = memref.view %bases_bytes[%c0][%world_idx] : memref<?xi8> to memref<?xindex> ; verify_buf wrapped same way at the consumer The struct-type literal now appears in exactly one place. memref.view is a standard upstream op with its own well-tested lowering. Validated on rad-mi325x-1: W=2/4/8 all PASS. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 6897ab8 commit dd153c9

1 file changed

Lines changed: 17 additions & 39 deletions

File tree

test/gpu/symmetric_heap_dma/air_sym_handwritten.mlir

Lines changed: 17 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -137,38 +137,11 @@ module attributes {gpu.container_module} {
137137
}
138138

139139
// ---- Helpers ----------------------------------------------------------
140-
// Build a static-shape memref descriptor over a raw runtime ptr.
141-
// Phase 4's AIRSymmetricAllocToMgpuPass will do this automatically.
142-
func.func private @wrap_data(%ptr : !llvm.ptr) -> memref<256xf32> {
143-
%c0_i64 = arith.constant 0 : i64
144-
%c1_i64 = arith.constant 1 : i64
145-
%c256_i64 = arith.constant 256 : i64
146-
%d0 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
147-
%d1 = llvm.insertvalue %ptr, %d0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
148-
%d2 = llvm.insertvalue %ptr, %d1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
149-
%d3 = llvm.insertvalue %c0_i64, %d2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
150-
%d4 = llvm.insertvalue %c256_i64, %d3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
151-
%d5 = llvm.insertvalue %c1_i64, %d4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
152-
%m = builtin.unrealized_conversion_cast %d5 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<256xf32>
153-
return %m : memref<256xf32>
154-
}
155-
156-
func.func private @wrap_flags(%ptr : !llvm.ptr) -> memref<4xi32> {
157-
%c0_i64 = arith.constant 0 : i64
158-
%c1_i64 = arith.constant 1 : i64
159-
%c4_i64 = arith.constant 4 : i64
160-
%d0 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
161-
%d1 = llvm.insertvalue %ptr, %d0[0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
162-
%d2 = llvm.insertvalue %ptr, %d1[1] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
163-
%d3 = llvm.insertvalue %c0_i64, %d2[2] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
164-
%d4 = llvm.insertvalue %c4_i64, %d3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
165-
%d5 = llvm.insertvalue %c1_i64, %d4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
166-
%m = builtin.unrealized_conversion_cast %d5 : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<4xi32>
167-
return %m : memref<4xi32>
168-
}
169-
170-
// Wrap a runtime ptr (heap_bases table) as memref<?xindex>.
171-
func.func private @wrap_bases(%ptr : !llvm.ptr, %size : i64) -> memref<?xindex> {
140+
// Single ABI-leaking helper: wrap a raw runtime !llvm.ptr as a 1-D byte
141+
// memref. All typed views below derive from this via memref.view, so the
142+
// hand-built LLVM-struct descriptor literal lives in exactly one place.
143+
// Phase 4's AIRSymmetricAllocToMgpuPass will replace this entirely.
144+
func.func private @wrap_bytes(%ptr : !llvm.ptr, %size : i64) -> memref<?xi8> {
172145
%c0_i64 = arith.constant 0 : i64
173146
%c1_i64 = arith.constant 1 : i64
174147
%d0 = llvm.mlir.poison : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
@@ -178,8 +151,8 @@ module attributes {gpu.container_module} {
178151
%d4 = llvm.insertvalue %size, %d3[3, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
179152
%d5 = llvm.insertvalue %c1_i64, %d4[4, 0] : !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)>
180153
%m = builtin.unrealized_conversion_cast %d5
181-
: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<?xindex>
182-
return %m : memref<?xindex>
154+
: !llvm.struct<(ptr, ptr, i64, array<1 x i64>, array<1 x i64>)> to memref<?xi8>
155+
return %m : memref<?xi8>
183156
}
184157

185158
// ---- main ------------------------------------------------------------
@@ -226,8 +199,11 @@ module attributes {gpu.container_module} {
226199

227200
func.call @mgpuBarrier() : () -> () // flags init visible to all ranks
228201

229-
%data_m = func.call @wrap_data(%data_ptr) : (!llvm.ptr) -> memref<256xf32>
230-
%flags_m = func.call @wrap_flags(%flags_ptr) : (!llvm.ptr) -> memref<4xi32>
202+
%c0_view = arith.constant 0 : index
203+
%data_bytes = func.call @wrap_bytes(%data_ptr, %c1024_bytes) : (!llvm.ptr, i64) -> memref<?xi8>
204+
%flags_bytes = func.call @wrap_bytes(%flags_ptr, %c16_bytes) : (!llvm.ptr, i64) -> memref<?xi8>
205+
%data_m = memref.view %data_bytes[%c0_view][] : memref<?xi8> to memref<256xf32>
206+
%flags_m = memref.view %flags_bytes[%c0_view][] : memref<?xi8> to memref<4xi32>
231207

232208
// mgpuGetHeapBases() returns a HOST pointer; GPU can't deref it, so
233209
// copy to device. TODO(airgpu): make heap_bases device-accessible
@@ -240,8 +216,9 @@ module attributes {gpu.container_module} {
240216
: (i64, !llvm.ptr, i1) -> !llvm.ptr
241217
func.call @mgpuMemcpy(%bases_devptr, %bases_host, %bases_size, %nullptr)
242218
: (!llvm.ptr, !llvm.ptr, i64, !llvm.ptr) -> ()
243-
%bases = func.call @wrap_bases(%bases_devptr, %world_i64)
244-
: (!llvm.ptr, i64) -> memref<?xindex>
219+
%bases_bytes = func.call @wrap_bytes(%bases_devptr, %bases_size) : (!llvm.ptr, i64) -> memref<?xi8>
220+
%world_idx = arith.index_cast %world_i64 : i64 to index
221+
%bases = memref.view %bases_bytes[%c0_view][%world_idx] : memref<?xi8> to memref<?xindex>
245222

246223
%is_solo = arith.cmpi sle, %world, %c1_i32 : i32
247224
scf.if %is_solo {
@@ -268,7 +245,8 @@ module attributes {gpu.container_module} {
268245
scf.if %is_consumer {
269246
%verify_ptr = func.call @mgpuMemAlloc(%c1024_bytes, %nullptr, %false)
270247
: (i64, !llvm.ptr, i1) -> !llvm.ptr
271-
%verify_m = func.call @wrap_data(%verify_ptr) : (!llvm.ptr) -> memref<256xf32>
248+
%verify_bytes = func.call @wrap_bytes(%verify_ptr, %c1024_bytes) : (!llvm.ptr, i64) -> memref<?xi8>
249+
%verify_m = memref.view %verify_bytes[%c0_view][] : memref<?xi8> to memref<256xf32>
272250
gpu.launch_func @sym_kernels::@consumer
273251
blocks in (%c1, %c1, %c1)
274252
threads in (%c256, %c1, %c1)

0 commit comments

Comments
 (0)