Skip to content

Commit 676b861

Browse files
[CIR][ABI][AArch64][Lowering] Fix the callsite for nested unions (#1169)
For example, the following reaches ["NYI"](https://github.com/llvm/clangir/blob/c8b626d49e7f306052b2e6d3ce60b1f689d37cb5/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp#L348) when lowering to AArch64: ``` typedef struct { union { struct { char a, b; }; char c; }; } A; void foo(A a) {} void bar() { A a; foo(a); } ``` Currently, the value of the struct becomes a bitcast operation, so we can simply extend `findAlloca` to be able to trace the source alloca properly, then use that for the [coercion](https://github.com/llvm/clangir/blob/c8b626d49e7f306052b2e6d3ce60b1f689d37cb5/clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp#L341) through memory. I have also added a test for this case.
1 parent 2b03d94 commit 676b861

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,8 @@ cir::AllocaOp findAlloca(mlir::Operation *op) {
225225
return findAlloca(vals[0].getDefiningOp());
226226
} else if (auto load = mlir::dyn_cast<cir::LoadOp>(op)) {
227227
return findAlloca(load.getAddr().getDefiningOp());
228+
} else if (auto cast = mlir::dyn_cast<cir::CastOp>(op)) {
229+
return findAlloca(cast.getSrc().getDefiningOp());
228230
}
229231

230232
return {};

clang/test/CIR/CallConvLowering/AArch64/aarch64-cc-structs.c

+54
Original file line numberDiff line numberDiff line change
@@ -276,3 +276,57 @@ typedef struct {
276276
// LLVM: %[[#V3:]] = load ptr, ptr %[[#V2]], align 8
277277
// LLVM: ret void
278278
void pass_cat(CAT a) {}
279+
280+
typedef struct {
281+
union {
282+
struct {
283+
char a, b;
284+
};
285+
char c;
286+
};
287+
} NESTED_U;
288+
289+
// CHECK: cir.func @pass_nested_u(%arg0: !u64i
290+
// CHECK: %[[#V0:]] = cir.alloca !ty_NESTED_U, !cir.ptr<!ty_NESTED_U>, [""] {alignment = 4 : i64}
291+
// CHECK: %[[#V1:]] = cir.cast(integral, %arg0 : !u64i), !u16i
292+
// CHECK: %[[#V2:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_NESTED_U>
293+
// CHECK: cir.store %[[#V1]], %[[#V2]] : !u16i
294+
// CHECK: cir.return
295+
296+
// LLVM: @pass_nested_u(i64 %[[#V0:]]
297+
// LLVM: %[[#V2:]] = alloca %struct.NESTED_U, i64 1, align 4
298+
// LLVM: %[[#V3:]] = trunc i64 %[[#V0]] to i16
299+
// LLVM: store i16 %[[#V3]], ptr %[[#V2]], align 2
300+
// LLVM: ret void
301+
void pass_nested_u(NESTED_U a) {}
302+
303+
// CHECK: cir.func no_proto @call_nested_u()
304+
// CHECK: %[[#V0:]] = cir.alloca !ty_NESTED_U, !cir.ptr<!ty_NESTED_U>
305+
// CHECK: %[[#V1:]] = cir.alloca !u64i, !cir.ptr<!u64i>, ["tmp"] {alignment = 8 : i64}
306+
// CHECK: %[[#V2:]] = cir.load %[[#V0]] : !cir.ptr<!ty_NESTED_U>, !ty_NESTED_U
307+
// CHECK: %[[#V3:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_NESTED_U>)
308+
// CHECK: %[[#V4:]] = cir.load %[[#V3]]
309+
// CHECK: %[[#V5:]] = cir.cast(bitcast, %[[#V3]]
310+
// CHECK: %[[#V6:]] = cir.load %[[#V5]]
311+
// CHECK: %[[#V7:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_NESTED_U>), !cir.ptr<!void>
312+
// CHECK: %[[#V8:]] = cir.cast(bitcast, %[[#V1]] : !cir.ptr<!u64i>), !cir.ptr<!void>
313+
// CHECK: %[[#V9:]] = cir.const #cir.int<2> : !u64i
314+
// CHECK: cir.libc.memcpy %[[#V9]] bytes from %[[#V7]] to %[[#V8]] : !u64i, !cir.ptr<!void> -> !cir.ptr<!void>
315+
// CHECK: %[[#V10:]] = cir.load %[[#V1]] : !cir.ptr<!u64i>, !u64i
316+
// CHECK: cir.call @pass_nested_u(%[[#V10]]) : (!u64i) -> ()
317+
// CHECK: cir.return
318+
319+
// LLVM: void @call_nested_u()
320+
// LLVM: %[[#V1:]] = alloca %struct.NESTED_U, i64 1, align 1
321+
// LLVM: %[[#V2:]] = alloca i64, i64 1, align 8
322+
// LLVM: %[[#V3:]] = load %struct.NESTED_U, ptr %[[#V1]], align 1
323+
// LLVM: %[[#V4:]] = load %union.anon.0, ptr %[[#V1]], align 1
324+
// LLVM: %[[#V5:]] = load %struct.anon.1, ptr %[[#V1]], align 1
325+
// LLVM: call void @llvm.memcpy.p0.p0.i64(ptr %[[#V2]], ptr %[[#V1]], i64 2, i1 false)
326+
// LLVM: %[[#V6:]] = load i64, ptr %[[#V2]], align 8
327+
// LLVM: call void @pass_nested_u(i64 %[[#V6]])
328+
// LLVM: ret void
329+
void call_nested_u() {
330+
NESTED_U a;
331+
pass_nested_u(a);
332+
}

0 commit comments

Comments
 (0)