@@ -137,38 +137,11 @@ module attributes {gpu.container_module} {
137137 }
138138
139139 // ---- Helpers ----------------------------------------------------------
140- // Build a static-shape memref descriptor over a raw runtime ptr.
141- // Phase 4's AIRSymmetricAllocToMgpuPass will do this automatically.
142- func.func private @wrap_data (%ptr : !llvm.ptr ) -> memref <256 xf32 > {
143- %c0_i64 = arith.constant 0 : i64
144- %c1_i64 = arith.constant 1 : i64
145- %c256_i64 = arith.constant 256 : i64
146- %d0 = llvm.mlir.poison : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)>
147- %d1 = llvm.insertvalue %ptr , %d0 [0 ] : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)>
148- %d2 = llvm.insertvalue %ptr , %d1 [1 ] : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)>
149- %d3 = llvm.insertvalue %c0_i64 , %d2 [2 ] : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)>
150- %d4 = llvm.insertvalue %c256_i64 , %d3 [3 , 0 ] : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)>
151- %d5 = llvm.insertvalue %c1_i64 , %d4 [4 , 0 ] : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)>
152- %m = builtin.unrealized_conversion_cast %d5 : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)> to memref <256 xf32 >
153- return %m : memref <256 xf32 >
154- }
155-
156- func.func private @wrap_flags (%ptr : !llvm.ptr ) -> memref <4 xi32 > {
157- %c0_i64 = arith.constant 0 : i64
158- %c1_i64 = arith.constant 1 : i64
159- %c4_i64 = arith.constant 4 : i64
160- %d0 = llvm.mlir.poison : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)>
161- %d1 = llvm.insertvalue %ptr , %d0 [0 ] : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)>
162- %d2 = llvm.insertvalue %ptr , %d1 [1 ] : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)>
163- %d3 = llvm.insertvalue %c0_i64 , %d2 [2 ] : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)>
164- %d4 = llvm.insertvalue %c4_i64 , %d3 [3 , 0 ] : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)>
165- %d5 = llvm.insertvalue %c1_i64 , %d4 [4 , 0 ] : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)>
166- %m = builtin.unrealized_conversion_cast %d5 : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)> to memref <4 xi32 >
167- return %m : memref <4 xi32 >
168- }
169-
170- // Wrap a runtime ptr (heap_bases table) as memref<?xindex>.
171- func.func private @wrap_bases (%ptr : !llvm.ptr , %size : i64 ) -> memref <?xindex > {
140+ // Single ABI-leaking helper: wrap a raw runtime !llvm.ptr as a 1-D byte
141+ // memref. All typed views below derive from this via memref.view, so the
142+ // hand-built LLVM-struct descriptor literal lives in exactly one place.
143+ // Phase 4's AIRSymmetricAllocToMgpuPass will replace this entirely.
144+ func.func private @wrap_bytes (%ptr : !llvm.ptr , %size : i64 ) -> memref <?xi8 > {
172145 %c0_i64 = arith.constant 0 : i64
173146 %c1_i64 = arith.constant 1 : i64
174147 %d0 = llvm.mlir.poison : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)>
@@ -178,8 +151,8 @@ module attributes {gpu.container_module} {
178151 %d4 = llvm.insertvalue %size , %d3 [3 , 0 ] : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)>
179152 %d5 = llvm.insertvalue %c1_i64 , %d4 [4 , 0 ] : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)>
180153 %m = builtin.unrealized_conversion_cast %d5
181- : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)> to memref <?x index >
182- return %m : memref <?x index >
154+ : !llvm.struct <(ptr , ptr , i64 , array <1 x i64 >, array <1 x i64 >)> to memref <?x i8 >
155+ return %m : memref <?x i8 >
183156 }
184157
185158 // ---- main ------------------------------------------------------------
@@ -226,8 +199,11 @@ module attributes {gpu.container_module} {
226199
227200 func.call @mgpuBarrier () : () -> () // flags init visible to all ranks
228201
229- %data_m = func.call @wrap_data (%data_ptr ) : (!llvm.ptr ) -> memref <256 xf32 >
230- %flags_m = func.call @wrap_flags (%flags_ptr ) : (!llvm.ptr ) -> memref <4 xi32 >
202+ %c0_view = arith.constant 0 : index
203+ %data_bytes = func.call @wrap_bytes (%data_ptr , %c1024_bytes ) : (!llvm.ptr , i64 ) -> memref <?xi8 >
204+ %flags_bytes = func.call @wrap_bytes (%flags_ptr , %c16_bytes ) : (!llvm.ptr , i64 ) -> memref <?xi8 >
205+ %data_m = memref.view %data_bytes [%c0_view ][] : memref <?xi8 > to memref <256 xf32 >
206+ %flags_m = memref.view %flags_bytes [%c0_view ][] : memref <?xi8 > to memref <4 xi32 >
231207
232208 // mgpuGetHeapBases() returns a HOST pointer; GPU can't deref it, so
233209 // copy to device. TODO(airgpu): make heap_bases device-accessible
@@ -240,8 +216,9 @@ module attributes {gpu.container_module} {
240216 : (i64 , !llvm.ptr , i1 ) -> !llvm.ptr
241217 func.call @mgpuMemcpy (%bases_devptr , %bases_host , %bases_size , %nullptr )
242218 : (!llvm.ptr , !llvm.ptr , i64 , !llvm.ptr ) -> ()
243- %bases = func.call @wrap_bases (%bases_devptr , %world_i64 )
244- : (!llvm.ptr , i64 ) -> memref <?xindex >
219+ %bases_bytes = func.call @wrap_bytes (%bases_devptr , %bases_size ) : (!llvm.ptr , i64 ) -> memref <?xi8 >
220+ %world_idx = arith.index_cast %world_i64 : i64 to index
221+ %bases = memref.view %bases_bytes [%c0_view ][%world_idx ] : memref <?xi8 > to memref <?xindex >
245222
246223 %is_solo = arith.cmpi sle , %world , %c1_i32 : i32
247224 scf.if %is_solo {
@@ -268,7 +245,8 @@ module attributes {gpu.container_module} {
268245 scf.if %is_consumer {
269246 %verify_ptr = func.call @mgpuMemAlloc (%c1024_bytes , %nullptr , %false )
270247 : (i64 , !llvm.ptr , i1 ) -> !llvm.ptr
271- %verify_m = func.call @wrap_data (%verify_ptr ) : (!llvm.ptr ) -> memref <256 xf32 >
248+ %verify_bytes = func.call @wrap_bytes (%verify_ptr , %c1024_bytes ) : (!llvm.ptr , i64 ) -> memref <?xi8 >
249+ %verify_m = memref.view %verify_bytes [%c0_view ][] : memref <?xi8 > to memref <256 xf32 >
272250 gpu.launch_func @sym_kernels ::@consumer
273251 blocks in (%c1 , %c1 , %c1 )
274252 threads in (%c256 , %c1 , %c1 )
0 commit comments