Skip to content

Commit a48918f

Browse files
committed
feat: add subregisters config and allow mixed-precision fpu computations
1 parent 0449df8 commit a48918f

8 files changed

Lines changed: 194 additions & 48 deletions

File tree

include/bleach/lifter/instr-impl.hpp

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class instr_impl final : private std::vector<instruction> {
5252
std::string stack_pointer;
5353
std::vector<constant_reg> const_regs;
5454
std::vector<regclass> regclasses;
55+
std::unordered_map<std::string, std::string> subregisters;
5556

5657
public:
5758
instr_impl() = default;
@@ -64,13 +65,16 @@ class instr_impl final : private std::vector<instruction> {
6465
using vector::emplace_back;
6566
using vector::push_back;
6667

67-
StringRef get_stack_pointer() const { return stack_pointer; }
68+
StringRef get_stack_pointer() const & { return stack_pointer; }
6869

69-
auto &get_const_regs() { return const_regs; }
70-
auto &get_const_regs() const { return const_regs; }
70+
auto &get_const_regs() & { return const_regs; }
71+
auto &get_const_regs() const & { return const_regs; }
7172

72-
auto &get_regclasses() const { return regclasses; }
73-
auto &get_regclasses() { return regclasses; }
73+
auto &get_regclasses() const & { return regclasses; }
74+
auto &get_regclasses() & { return regclasses; }
75+
76+
auto &get_subregs() const & { return subregisters; }
77+
auto &get_subregs() & { return subregisters; }
7478

7579
void set_stack_pointer(std::string sp) { stack_pointer = std::move(sp); }
7680

include/bleach/lifter/lifter.hpp

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,9 @@
99
#include <llvm/Support/Regex.h>
1010
#include <llvm/Target/TargetMachine.h>
1111

12+
#include <iostream>
1213
#include <map>
14+
#include <set>
1315
#include <unordered_set>
1416

1517
namespace llvm {
@@ -56,24 +58,25 @@ class mbb2bb final
5658
}
5759
};
5860

59-
class register_class final : private std::unordered_set<unsigned> {
61+
class register_class final : private std::map<unsigned, std::vector<unsigned>> {
6062
std::string name;
6163
llvm::Regex regex;
6264
const TargetRegisterClass *rclass = nullptr;
6365
unsigned reg_bitsize = 0;
6466

6567
public:
66-
using unordered_set::begin;
67-
using unordered_set::contains;
68-
using unordered_set::empty;
69-
using unordered_set::end;
70-
using unordered_set::insert;
71-
using unordered_set::size;
68+
using map::at;
69+
using map::begin;
70+
using map::contains;
71+
using map::empty;
72+
using map::end;
73+
using map::insert;
74+
using map::size;
7275

7376
register_class(std::string &&rcname, llvm::Regex &&rx)
7477
: name(std::move(rcname)), regex(std::move(rx)) {}
7578

76-
void add_reg(unsigned reg) { insert(reg); }
79+
void add_reg(unsigned reg) { try_emplace(reg); }
7780

7881
auto &get_regex() const & { return regex; }
7982

@@ -104,8 +107,14 @@ class register_stats final : std::vector<register_class> {
104107
}
105108

106109
auto &get_register_class_for(unsigned reg) const {
107-
auto found =
108-
find_if(*this, [reg](auto &rclass) { return rclass.contains(reg); });
110+
auto found = ranges::find_if(*this, [reg](auto &rclass) {
111+
if (rclass.contains(reg))
112+
return true;
113+
auto subreg = ranges::find_if(rclass, [reg](auto &entry) {
114+
return is_contained(entry.second, reg);
115+
});
116+
return subreg != rclass.end();
117+
});
109118
if (found == end()) {
110119
throw std::invalid_argument("Unknown register encountered");
111120
}

lib/lifter/instr-impl.cpp

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <llvm/Support/SourceMgr.h>
88
#include <llvm/Support/raw_ostream.h>
99

10+
#include <stdexcept>
1011
#include <yaml-cpp/emitter.h>
1112
#include <yaml-cpp/emittermanip.h>
1213
#include <yaml-cpp/yaml.h>
@@ -134,6 +135,12 @@ instr_impl load_from_yaml(std::string yaml, llvm::LLVMContext &ctx) {
134135
for (auto &&node : instrs_conf["register-classes"])
135136
instrs.get_regclasses().push_back(
136137
{node.first.as<std::string>(), node.second.as<std::string>()});
138+
for (auto &&node : instrs_conf["subregisters"]) {
139+
auto [it, inserted] = instrs.get_subregs().try_emplace(
140+
node.first.as<std::string>(), node.second.as<std::string>());
141+
if (!inserted)
142+
throw std::invalid_argument("duplicated subregister class encountered");
143+
}
137144
for (auto &&node : instrs_conf["instructions"]) {
138145
auto norm = node.as<normalized_instruction>();
139146
instrs.push_back(denormalize_instruction(norm, ctx));

lib/lifter/lifter.cpp

Lines changed: 42 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "symtab-ir.h"
66

7+
#include <algorithm>
78
#include <iterator>
89
#include <llvm/CodeGen/MachineFunction.h>
910
#include <llvm/CodeGen/MachineModuleInfo.h>
@@ -173,7 +174,8 @@ void materialize_registers(MachineFunction &mf, Function &func, reg2vals &rmap,
173174
ArrayRef<Value *>{const_zero, arr_idx},
174175
rclass.get_name());
175176
auto sorted_regs = std::set<unsigned>();
176-
ranges::copy(rclass, std::inserter(sorted_regs, sorted_regs.end()));
177+
ranges::transform(rclass, std::inserter(sorted_regs, sorted_regs.end()),
178+
[](auto &&entry) { return entry.first; });
177179
for (auto &&[reg_idx, reg] : sorted_regs | views::enumerate) {
178180
auto *array_reg_idx = ConstantInt::get(ctx, APInt(32, reg_idx));
179181
auto *reg_src_addr = builder.CreateInBoundsGEP(
@@ -186,6 +188,10 @@ void materialize_registers(MachineFunction &mf, Function &func, reg2vals &rmap,
186188
Type::getIntNTy(ctx, rclass.get_register_size()), reg_src_addr);
187189
builder.CreateStore(reg_val, reg_dst_addr);
188190
rmap.try_emplace(reg, reg_dst_addr);
191+
auto &&subregs = rclass.at(reg);
192+
ranges::for_each(subregs, [&](std::integral auto r) {
193+
rmap.try_emplace(r, reg_dst_addr);
194+
});
189195
}
190196
}
191197
}
@@ -430,8 +436,9 @@ void assign_register_classes(const TargetSubtargetInfo &stinfo,
430436
assert(rinfo);
431437
for (auto &&[idx, reg_class] : enumerate(stats)) {
432438
auto found = find_if(rinfo->regclasses(), [&reg_class](auto &rclass) {
433-
return all_of(reg_class,
434-
[&rclass](auto reg) { return is_contained(*rclass, reg); });
439+
return all_of(reg_class, [&rclass](auto &&reg) {
440+
return is_contained(*rclass, reg.first);
441+
});
435442
});
436443
if (found == rinfo->regclass_end()) {
437444
throw std::runtime_error(
@@ -453,17 +460,48 @@ register_stats collect_register_stats(const instr_impl &instr, Module &m,
453460
throw std::runtime_error(
454461
"register-classes were not specified in input YAML");
455462
register_stats stats(rclasses.begin(), rclasses.end());
463+
const MCRegisterInfo *rinfo = nullptr;
464+
auto *stinfo = mmi.getTarget().getSubtargetImpl(*m.begin());
456465
for (auto &f : m) {
457466
auto &mf = mmi.getOrCreateMachineFunction(f);
458467
collect_register_stats_for(mf, mmi.getTarget(), stats);
468+
rinfo = stinfo->getRegisterInfo();
459469
}
460470
assert(!m.empty());
461-
auto *stinfo = mmi.getTarget().getSubtargetImpl(*m.begin());
462471
assert(stinfo);
463472
assign_register_classes(*stinfo, stats);
464473
stats.erase(std::remove_if(stats.begin(), stats.end(),
465474
[](auto &rclass) { return rclass.empty(); }),
466475
stats.end());
476+
for (auto &&[subregclass, aliasto] : instr.get_subregs()) {
477+
auto found = ranges::find_if(stats, [&subregclass](auto &rclass) {
478+
return rclass.get_name() == subregclass;
479+
});
480+
if (found != stats.end()) {
481+
auto found_aliasto = ranges::find_if(stats, [&aliasto](auto &rclass) {
482+
return rclass.get_name() == aliasto;
483+
});
484+
if (found_aliasto != stats.end()) {
485+
for (auto &&[entry, subreg] : views::zip(*found_aliasto, *found)) {
486+
auto &&[_, subregs] = entry;
487+
subregs.push_back(subreg.first);
488+
}
489+
stats.erase(found);
490+
}
491+
}
492+
}
493+
494+
for (auto &&regclass : stats) {
495+
std::cerr << "\nREGCLASS: " << regclass.get_name() << '\n';
496+
for (auto &&r : regclass) {
497+
std::cerr << rinfo->getName(r.first) << ": [ ";
498+
for (auto subreg : r.second)
499+
std::cerr << rinfo->getName(subreg) << " ";
500+
std::cerr << "]\n";
501+
}
502+
std::cerr << '\n';
503+
}
504+
467505
return stats;
468506
}
469507

test/integration/inputs/riscv-mixed-fp.c

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
#include <STATE>
2+
#include <math.h>
23
#include <stdint.h>
34
#include <stdio.h>
45
#include <string.h>
@@ -29,6 +30,14 @@ double ref_calc1(double a, double b, int x) {
2930

3031
double reference(double a, double b, int x) { return ref_calc1(a, b, x); }
3132

33+
int are_equal(double a, double b) {
34+
if (isnan(a) && isnan(b))
35+
return 1;
36+
if (isnan(a) || isnan(b))
37+
return 0;
38+
return fabs(a - b) <= 0.1;
39+
}
40+
3241
int main() {
3342
printf("Hello from main\n");
3443
// random values
@@ -43,5 +52,6 @@ int main() {
4352
double res2 = bitcast_to_double(regs.FPR[10]);
4453
printf("result: %lf\n", res2);
4554
printf("reference: %lf\n", res1);
46-
return res1 - res2 > 0.1;
55+
printf("diff: %lf\n", res1 - res2);
56+
return !are_equal(res1, res2);
4757
}
Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
#include <STATE>
2+
#include <math.h>
3+
#include <stdint.h>
4+
#include <stdio.h>
5+
#include <string.h>
6+
7+
struct register_state regs = {};
8+
9+
int64_t bitcast_to_int(double v) {
10+
int64_t res;
11+
memcpy(&res, &v, sizeof(double));
12+
return res;
13+
}
14+
15+
double bitcast_to_double(int64_t v) {
16+
double res;
17+
memcpy(&res, &v, sizeof(double));
18+
return res;
19+
}
20+
21+
// This should be in sync with test/integration/riscv-mixed-fp.c
22+
// TODO: automate this
23+
double ref_calc2(float prod, double sum, int x) {
24+
return prod + sum - (double)x;
25+
}
26+
27+
double ref_calc1(float a, float b, int x) {
28+
return ref_calc2(a * b, (a + b), x + a / b);
29+
}
30+
31+
double reference(double a, double b, int x) { return ref_calc1(a, b, x); }
32+
33+
int are_equal(double a, double b) {
34+
if (isnan(a) && isnan(b))
35+
return 1;
36+
if (isnan(a) || isnan(b))
37+
return 0;
38+
return fabs(a - b) <= 0.1;
39+
}
40+
41+
int main() {
42+
// random values that cause difference between computation in floats and
43+
// doubles
44+
double arg_1 = 4113444444434.33;
45+
double arg_2 = 0.0000003344;
46+
int arg_3 = 5;
47+
regs.FPR[10] = bitcast_to_int(arg_1);
48+
regs.FPR[11] = bitcast_to_int(arg_2);
49+
regs.GPR[10] = arg_3;
50+
top_small(&regs);
51+
double res1 = reference(arg_1, arg_2, arg_3);
52+
double res2 = bitcast_to_double(regs.FPR[10]);
53+
printf("result: %lf\n", res2);
54+
printf("reference: %lf\n", res1);
55+
printf("diff: %lf\n", res1 - res2);
56+
return !are_equal(res1, res2);
57+
}
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
// RUN: %config-gen-path/config-gen.rb --march rv64imfd \
2+
// RUN: --template-dir %config-gen-path/templates -o %t.yaml
3+
// RUN: riscv64-unknown-linux-gnu-clang %s -O2 -c -o %t.out -march=rv64imfd
4+
// RUN: %bin/llvm-bleach %t.out --instructions %t.yaml -o %t.ll \
5+
// RUN: --state-struct-file=%t.state.h
6+
// RUN: sed 's|STATE|%t.state.h|g' %S/inputs/riscv-mixed-precision-fp.c > \
7+
// RUN: %t.main.c
8+
// RUN: clang %t.main.c %t.ll -o %t.native.out -g \
9+
// RUN: -fsanitize=address,undefined
10+
// RUN: %t.native.out
11+
12+
double calc2(float prod, double sum, int x) { return prod + sum - (double)x; }
13+
14+
double calc1(float a, float b, int x) {
15+
return calc2(a * b, (a + b), x + a / b);
16+
}
17+
18+
double top_small(double a, double b, int x) { return calc1(a, b, x); }

0 commit comments

Comments
 (0)