Skip to content

Commit cff90e3

Browse files
committed
x86_64: implement select of register mask
1 parent 0ef3250 commit cff90e3

File tree

2 files changed

+167
-23
lines changed

2 files changed

+167
-23
lines changed

src/arch/x86_64/CodeGen.zig

+147-23
Original file line numberDiff line numberDiff line change
@@ -97980,16 +97980,150 @@ fn airSelect(self: *CodeGen, inst: Air.Inst.Index) !void {
9798097980
switch (pred_mcv) {
9798197981
.register => |pred_reg| switch (pred_reg.class()) {
9798297982
.general_purpose => {},
97983-
.sse => if (need_xmm0 and pred_reg.id() != comptime Register.xmm0.id()) {
97984-
try self.register_manager.getKnownReg(.xmm0, null);
97985-
try self.genSetReg(.xmm0, pred_ty, pred_mcv, .{});
97986-
break :mask .xmm0;
97987-
} else break :mask if (has_blend)
97988-
pred_reg
97983+
.sse => if (elem_ty.toIntern() == .bool_type)
97984+
if (need_xmm0 and pred_reg.id() != comptime Register.xmm0.id()) {
97985+
try self.register_manager.getKnownReg(.xmm0, null);
97986+
try self.genSetReg(.xmm0, pred_ty, pred_mcv, .{});
97987+
break :mask .xmm0;
97988+
} else break :mask if (has_blend)
97989+
pred_reg
97990+
else
97991+
try self.copyToTmpRegister(pred_ty, pred_mcv)
9798997992
else
97990-
try self.copyToTmpRegister(pred_ty, pred_mcv),
97993+
return self.fail("TODO implement airSelect for {}", .{ty.fmt(pt)}),
9799197994
else => unreachable,
9799297995
},
97996+
.register_mask => |pred_reg_mask| {
97997+
if (pred_reg_mask.info.scalar.bitSize(self.target) != 8 * elem_abi_size)
97998+
return self.fail("TODO implement airSelect for {}", .{ty.fmt(pt)});
97999+
98000+
const mask_reg: Register = if (need_xmm0 and pred_reg_mask.reg.id() != comptime Register.xmm0.id()) mask_reg: {
98001+
try self.register_manager.getKnownReg(.xmm0, null);
98002+
try self.genSetReg(.xmm0, ty, .{ .register = pred_reg_mask.reg }, .{});
98003+
break :mask_reg .xmm0;
98004+
} else pred_reg_mask.reg;
98005+
const mask_alias = registerAlias(mask_reg, abi_size);
98006+
const mask_lock = self.register_manager.lockRegAssumeUnused(mask_reg);
98007+
defer self.register_manager.unlockReg(mask_lock);
98008+
98009+
const lhs_mcv = try self.resolveInst(extra.lhs);
98010+
const lhs_lock = switch (lhs_mcv) {
98011+
.register => |lhs_reg| self.register_manager.lockRegAssumeUnused(lhs_reg),
98012+
else => null,
98013+
};
98014+
defer if (lhs_lock) |lock| self.register_manager.unlockReg(lock);
98015+
98016+
const rhs_mcv = try self.resolveInst(extra.rhs);
98017+
const rhs_lock = switch (rhs_mcv) {
98018+
.register => |rhs_reg| self.register_manager.lockReg(rhs_reg),
98019+
else => null,
98020+
};
98021+
defer if (rhs_lock) |lock| self.register_manager.unlockReg(lock);
98022+
98023+
const order = has_blend != pred_reg_mask.info.inverted;
98024+
const reuse_mcv, const other_mcv = if (order)
98025+
.{ rhs_mcv, lhs_mcv }
98026+
else
98027+
.{ lhs_mcv, rhs_mcv };
98028+
const dst_mcv: MCValue = if (reuse_mcv.isRegister() and self.reuseOperand(
98029+
inst,
98030+
if (order) extra.rhs else extra.lhs,
98031+
@intFromBool(order),
98032+
reuse_mcv,
98033+
)) reuse_mcv else if (has_avx)
98034+
.{ .register = try self.register_manager.allocReg(inst, abi.RegisterClass.sse) }
98035+
else
98036+
try self.copyToRegisterWithInstTracking(inst, ty, reuse_mcv);
98037+
const dst_reg = dst_mcv.getReg().?;
98038+
const dst_alias = registerAlias(dst_reg, abi_size);
98039+
const dst_lock = self.register_manager.lockReg(dst_reg);
98040+
defer if (dst_lock) |lock| self.register_manager.unlockReg(lock);
98041+
98042+
const mir_tag = @as(?Mir.Inst.FixedTag, if ((pred_reg_mask.info.kind == .all and
98043+
elem_ty.toIntern() != .f32_type and elem_ty.toIntern() != .f64_type) or pred_reg_mask.info.scalar == .byte)
98044+
if (has_avx)
98045+
.{ .vp_b, .blendv }
98046+
else if (has_blend)
98047+
.{ .p_b, .blendv }
98048+
else if (pred_reg_mask.info.kind == .all)
98049+
.{ .p_, undefined }
98050+
else
98051+
null
98052+
else if ((pred_reg_mask.info.kind == .all and (elem_ty.toIntern() != .f64_type or !self.hasFeature(.sse2))) or
98053+
pred_reg_mask.info.scalar == .dword)
98054+
if (has_avx)
98055+
.{ .v_ps, .blendv }
98056+
else if (has_blend)
98057+
.{ ._ps, .blendv }
98058+
else if (pred_reg_mask.info.kind == .all)
98059+
.{ ._ps, undefined }
98060+
else
98061+
null
98062+
else if (pred_reg_mask.info.kind == .all or pred_reg_mask.info.scalar == .qword)
98063+
if (has_avx)
98064+
.{ .v_pd, .blendv }
98065+
else if (has_blend)
98066+
.{ ._pd, .blendv }
98067+
else if (pred_reg_mask.info.kind == .all)
98068+
.{ ._pd, undefined }
98069+
else
98070+
null
98071+
else
98072+
null) orelse return self.fail("TODO implement airSelect for {}", .{ty.fmt(pt)});
98073+
if (has_avx) {
98074+
const rhs_alias = if (reuse_mcv.isRegister())
98075+
registerAlias(reuse_mcv.getReg().?, abi_size)
98076+
else rhs: {
98077+
try self.genSetReg(dst_reg, ty, reuse_mcv, .{});
98078+
break :rhs dst_alias;
98079+
};
98080+
if (other_mcv.isBase()) try self.asmRegisterRegisterMemoryRegister(
98081+
mir_tag,
98082+
dst_alias,
98083+
rhs_alias,
98084+
try other_mcv.mem(self, .{ .size = self.memSize(ty) }),
98085+
mask_alias,
98086+
) else try self.asmRegisterRegisterRegisterRegister(
98087+
mir_tag,
98088+
dst_alias,
98089+
rhs_alias,
98090+
registerAlias(if (other_mcv.isRegister())
98091+
other_mcv.getReg().?
98092+
else
98093+
try self.copyToTmpRegister(ty, other_mcv), abi_size),
98094+
mask_alias,
98095+
);
98096+
} else if (has_blend) if (other_mcv.isBase()) try self.asmRegisterMemoryRegister(
98097+
mir_tag,
98098+
dst_alias,
98099+
try other_mcv.mem(self, .{ .size = self.memSize(ty) }),
98100+
mask_alias,
98101+
) else try self.asmRegisterRegisterRegister(
98102+
mir_tag,
98103+
dst_alias,
98104+
registerAlias(if (other_mcv.isRegister())
98105+
other_mcv.getReg().?
98106+
else
98107+
try self.copyToTmpRegister(ty, other_mcv), abi_size),
98108+
mask_alias,
98109+
) else {
98110+
try self.asmRegisterRegister(.{ mir_tag[0], .@"and" }, dst_alias, mask_alias);
98111+
if (other_mcv.isBase()) try self.asmRegisterMemory(
98112+
.{ mir_tag[0], .andn },
98113+
mask_alias,
98114+
try other_mcv.mem(self, .{ .size = .fromSize(abi_size) }),
98115+
) else try self.asmRegisterRegister(
98116+
.{ mir_tag[0], .andn },
98117+
mask_alias,
98118+
if (other_mcv.isRegister())
98119+
other_mcv.getReg().?
98120+
else
98121+
try self.copyToTmpRegister(ty, other_mcv),
98122+
);
98123+
try self.asmRegisterRegister(.{ mir_tag[0], .@"or" }, dst_alias, mask_alias);
98124+
}
98125+
break :result dst_mcv;
98126+
},
9799398127
else => {},
9799498128
}
9799598129
const mask_reg: Register = if (need_xmm0) mask_reg: {
@@ -98192,7 +98326,7 @@ fn airSelect(self: *CodeGen, inst: Air.Inst.Index) !void {
9819298326
const dst_lock = self.register_manager.lockReg(dst_reg);
9819398327
defer if (dst_lock) |lock| self.register_manager.unlockReg(lock);
9819498328

98195-
const mir_tag = @as(?Mir.Inst.FixedTag, switch (ty.childType(zcu).zigTypeTag(zcu)) {
98329+
const mir_tag = @as(?Mir.Inst.FixedTag, switch (elem_ty.zigTypeTag(zcu)) {
9819698330
else => null,
9819798331
.int => switch (abi_size) {
9819898332
0 => unreachable,
@@ -98208,7 +98342,7 @@ fn airSelect(self: *CodeGen, inst: Air.Inst.Index) !void {
9820898342
null,
9820998343
else => null,
9821098344
},
98211-
.float => switch (ty.childType(zcu).floatBits(self.target.*)) {
98345+
.float => switch (elem_ty.floatBits(self.target.*)) {
9821298346
else => unreachable,
9821398347
16, 80, 128 => null,
9821498348
32 => switch (vec_len) {
@@ -98262,30 +98396,20 @@ fn airSelect(self: *CodeGen, inst: Air.Inst.Index) !void {
9826298396
try self.copyToTmpRegister(ty, lhs_mcv), abi_size),
9826398397
mask_alias,
9826498398
) else {
98265-
const mir_fixes = @as(?Mir.Inst.Fixes, switch (elem_ty.zigTypeTag(zcu)) {
98266-
else => null,
98267-
.int => .p_,
98268-
.float => switch (elem_ty.floatBits(self.target.*)) {
98269-
32 => ._ps,
98270-
64 => ._pd,
98271-
16, 80, 128 => null,
98272-
else => unreachable,
98273-
},
98274-
}) orelse return self.fail("TODO implement airSelect for {}", .{ty.fmt(pt)});
98275-
try self.asmRegisterRegister(.{ mir_fixes, .@"and" }, dst_alias, mask_alias);
98399+
try self.asmRegisterRegister(.{ mir_tag[0], .@"and" }, dst_alias, mask_alias);
9827698400
if (rhs_mcv.isBase()) try self.asmRegisterMemory(
98277-
.{ mir_fixes, .andn },
98401+
.{ mir_tag[0], .andn },
9827898402
mask_alias,
9827998403
try rhs_mcv.mem(self, .{ .size = .fromSize(abi_size) }),
9828098404
) else try self.asmRegisterRegister(
98281-
.{ mir_fixes, .andn },
98405+
.{ mir_tag[0], .andn },
9828298406
mask_alias,
9828398407
if (rhs_mcv.isRegister())
9828498408
rhs_mcv.getReg().?
9828598409
else
9828698410
try self.copyToTmpRegister(ty, rhs_mcv),
9828798411
);
98288-
try self.asmRegisterRegister(.{ mir_fixes, .@"or" }, dst_alias, mask_alias);
98412+
try self.asmRegisterRegister(.{ mir_tag[0], .@"or" }, dst_alias, mask_alias);
9828998413
}
9829098414
break :result dst_mcv;
9829198415
};

test/behavior/select.zig

+20
Original file line numberDiff line numberDiff line change
@@ -66,3 +66,23 @@ fn selectArrays() !void {
6666
const xyz = @select(f32, x, y, z);
6767
try expect(mem.eql(f32, &@as([4]f32, xyz), &[4]f32{ 0.0, 312.1, -145.9, -3381.233 }));
6868
}
69+
70+
test "@select compare result" {
71+
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
72+
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
73+
if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest;
74+
75+
const S = struct {
76+
fn min(comptime V: type, lhs: V, rhs: V) V {
77+
return @select(@typeInfo(V).vector.child, lhs < rhs, lhs, rhs);
78+
}
79+
80+
fn doTheTest() !void {
81+
try expect(@reduce(.And, min(@Vector(4, f32), .{ -1, 2, -3, 4 }, .{ 1, -2, 3, -4 }) == @Vector(4, f32){ -1, -2, -3, -4 }));
82+
try expect(@reduce(.And, min(@Vector(2, f64), .{ -1, 2 }, .{ 1, -2 }) == @Vector(2, f64){ -1, -2 }));
83+
}
84+
};
85+
86+
try S.doTheTest();
87+
try comptime S.doTheTest();
88+
}

0 commit comments

Comments
 (0)