Skip to content

Commit ee8ef34

Browse files
authored
Merge pull request #132 from saxbophone/josh/43-bit-shift-divmod-opt
Optimise multiplication and division by powers of two
2 parents 078d221 + 8d8387a commit ee8ef34

File tree

5 files changed

+115
-25
lines changed

5 files changed

+115
-25
lines changed

arby/include/arby/Nat.hpp

Lines changed: 50 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -552,6 +552,11 @@ namespace com::saxbophone::arby {
552552
_digits = product._digits;
553553
return *this; // return the result by reference
554554
}
555+
private: // private helper methods for multiplication operator
556+
constexpr bool is_power_of_2() const {
557+
return *this == Nat(1) << (bit_length() - 1);
558+
}
559+
public:
555560
/**
556561
* @brief Multiplication operator for Nat
557562
* @param lhs,rhs operands for the multiplication
@@ -562,35 +567,43 @@ namespace com::saxbophone::arby {
562567
// init product to zero
563568
Nat product;
564569
// either operand being zero always results in zero, so only run the algorithm if they're both non-zero
565-
if (not (lhs._digits.front() == 0 or rhs._digits.front() == 0)) {
566-
// multiply each digit from lhs with each digit from rhs
567-
std::size_t l = 0; // manual indices to track which digit we are on,
568-
std::size_t r = 0; // as codlili's iterators are not random-access
569-
for (auto lhs_digit : lhs._digits) {
570-
// reset r index as it cycles through multiple times
571-
r = 0;
572-
for (auto rhs_digit : rhs._digits) {
573-
// cast lhs to OverflowType to make sure both operands get promoted to avoid wrap-around overflow
574-
OverflowType multiplication = (OverflowType)lhs_digit * rhs_digit;
575-
// create a new Nat with this intermediate result and add trailing places as needed
576-
Nat intermediate = multiplication;
577-
// we need to remap the indices as the digits are stored big-endian
578-
std::size_t shift_amount = (lhs._digits.size() - 1 - l) + (rhs._digits.size() - 1 - r);
579-
// add that many trailing zeroes to intermediate's digits
580-
intermediate._digits.push_back(shift_amount, 0);
581-
// finally, add it to lhs as an accumulator
582-
product += intermediate;
583-
// increment manual indices
584-
r++;
585-
}
586-
l++;
570+
if (lhs._digits.front() == 0 or rhs._digits.front() == 0) {
571+
return product;
572+
}
573+
// optimisation using bitshifting when multiplying by binary powers
574+
if (rhs.is_power_of_2()) {
575+
return lhs << (rhs.bit_length() - 1);
576+
} else if (lhs.is_power_of_2()) {
577+
return rhs * lhs;
578+
}
579+
// multiply each digit from lhs with each digit from rhs
580+
std::size_t l = 0; // manual indices to track which digit we are on,
581+
std::size_t r = 0; // as codlili's iterators are not random-access
582+
for (auto lhs_digit : lhs._digits) {
583+
// reset r index as it cycles through multiple times
584+
r = 0;
585+
for (auto rhs_digit : rhs._digits) {
586+
// cast lhs to OverflowType to make sure both operands get promoted to avoid wrap-around overflow
587+
OverflowType multiplication = (OverflowType)lhs_digit * rhs_digit;
588+
// create a new Nat with this intermediate result and add trailing places as needed
589+
Nat intermediate = multiplication;
590+
// we need to remap the indices as the digits are stored big-endian
591+
std::size_t shift_amount = (lhs._digits.size() - 1 - l) + (rhs._digits.size() - 1 - r);
592+
// add that many trailing zeroes to intermediate's digits
593+
intermediate._digits.push_back(shift_amount, 0);
594+
// finally, add it to lhs as an accumulator
595+
product += intermediate;
596+
// increment manual indices
597+
r++;
587598
}
599+
l++;
588600
}
589601
product._validate_digits();
590602
return product;
591603
}
592604
private: // private helper methods for Nat::divmod()
593605
// function that shifts up rhs to be just big enough to be smaller than lhs
606+
// TODO: rewrite this to use bit-shifting for speed
594607
static constexpr Nat get_max_shift(const Nat& lhs, const Nat& rhs) {
595608
// how many places can we shift rhs left until it's the same width as lhs?
596609
std::size_t wiggle_room = lhs._digits.size() - rhs._digits.size();
@@ -638,6 +651,18 @@ namespace com::saxbophone::arby {
638651
if (rhs._digits.front() == 0) {
639652
throw std::domain_error("division by zero");
640653
}
654+
if (lhs._digits.front() == 0) { return {lhs, lhs}; } // zero shortcut
655+
// optimisation using bitshifting when dividing by binary powers
656+
if (rhs.is_power_of_2()) {
657+
auto width = rhs.bit_length();
658+
// the remainder is the digits that are shifted out, so bitmask for them
659+
auto bitmask = (Nat(1) << (width - 1)) - 1;
660+
Nat quotient = lhs >> (width - 1);
661+
Nat remainder = lhs & bitmask;
662+
quotient._validate_digits();
663+
remainder._validate_digits();
664+
return {quotient, remainder};
665+
}
641666
// this will gradually accumulate the calculated quotient
642667
Nat quotient;
643668
// this will gradually decrement with each subtraction
@@ -912,7 +937,9 @@ namespace com::saxbophone::arby {
912937
if (_digits.empty()) {
913938
_digits = {0};
914939
}
915-
_validate_digits(); // TODO: remove when satisfied not required
940+
// needed in some cases, probably when the intial whole-digit shift leaves a small value which then turns 0
941+
_remove_leading_zeroes();
942+
_validate_digits();
916943
return *this;
917944
}
918945
/**

tests/Nat/bit_shifting.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ TEST_CASE("arby::Nat left bit-shift", "[bit-shifting]") {
1919
{0b10001011_nat, 0, 0b10001011_nat},
2020
{0b10101110001_nat, 4, 0b101011100010000_nat},
2121
{0b1_nat, 70, 0b10000000000000000000000000000000000000000000000000000000000000000000000_nat},
22+
{0xfeed_nat, 32, 0xfeed00000000_nat},
2223
}
2324
)
2425
);
@@ -38,6 +39,7 @@ TEST_CASE("arby::Nat left bit-shift assignment", "[bit-shifting]") {
3839
{0b10001011_nat, 0, 0b10001011_nat},
3940
{0b10101110001_nat, 4, 0b101011100010000_nat},
4041
{0b1_nat, 70, 0b10000000000000000000000000000000000000000000000000000000000000000000000_nat},
42+
{0xfeed_nat, 32, 0xfeed00000000_nat},
4143
}
4244
)
4345
);
@@ -56,7 +58,8 @@ TEST_CASE("arby::Nat right bit-shift", "[bit-shifting]") {
5658
{0b10000000110100000000011101101000_nat, 54, 0b0_nat},
5759
{0b10011001010_nat, 0, 0b10011001010_nat},
5860
{0b1101011000011000_nat, 8, 0b11010110_nat},
59-
{0b11111111111111111111111111111111111111111111111111111111111111111111111111111111_nat, 70, 0b1111111111_nat}
61+
{0b11111111111111111111111111111111111111111111111111111111111111111111111111111111_nat, 70, 0b1111111111_nat},
62+
{0xfeedface1_nat, 32, 0xf_nat},
6063
}
6164
)
6265
);
@@ -76,7 +79,8 @@ TEST_CASE("arby::Nat right bit-shift assignment", "[bit-shifting]") {
7679
{0b10000000110100000000011101101000_nat, 54, 0b0_nat},
7780
{0b10011001010_nat, 0, 0b10011001010_nat},
7881
{0b1101011000011000_nat, 8, 0b11010110_nat},
79-
{0b11111111111111111111111111111111111111111111111111111111111111111111111111111111_nat, 70, 0b1111111111_nat}
82+
{0b11111111111111111111111111111111111111111111111111111111111111111111111111111111_nat, 70, 0b1111111111_nat},
83+
{0xfeedface1_nat, 32, 0xf_nat},
8084
}
8185
)
8286
);

tests/Nat/divmod.cpp

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include <catch2/catch.hpp>
66

7+
#include <arby/math.hpp>
78
#include <arby/Nat.hpp>
89

910
using namespace com::saxbophone;
@@ -207,3 +208,34 @@ TEST_CASE("Failing division", "[divmod]") {
207208
CHECK(quotient == 8123889139_nat);
208209
CHECK(remainder == 1892371893_nat);
209210
}
211+
212+
// regression tests for dividing by powers of two
213+
214+
// std::pow() is not accurate for large powers and we need exactness
215+
// TODO: put this in a helper function accessible to all tests
216+
static uintmax_t integer_pow(uintmax_t base, uintmax_t exponent) {
217+
// 1 to the power of anything is always 1
218+
if (base == 1) {
219+
return 1;
220+
}
221+
uintmax_t power = 1;
222+
for (uintmax_t i = 0; i < exponent; i++) {
223+
power *= base;
224+
}
225+
return power;
226+
}
227+
228+
TEST_CASE("divmod of arby::Nat by a power of two", "[divmod]") {
229+
uintmax_t power = GENERATE(range((uintmax_t)0, (uintmax_t)std::numeric_limits<uintmax_t>::digits));
230+
uintmax_t denominator = integer_pow(2, power);
231+
uintmax_t numerator = GENERATE_COPY(take(100, random(denominator, std::numeric_limits<uintmax_t>::max())));
232+
233+
CAPTURE(numerator, denominator);
234+
235+
auto [quotient, remainder] = arby::Nat::divmod(numerator, denominator);
236+
237+
CAPTURE(numerator, denominator, quotient, remainder);
238+
239+
CHECK(quotient == numerator / denominator);
240+
CHECK(remainder == numerator % denominator);
241+
}

tests/Nat/multiplication.cpp

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,29 @@ TEST_CASE("arby::Nat multiplication by arby::Nat", "[multiplication]") {
8888

8989
CHECK((uintmax_t)(lhs * rhs) == product);
9090
}
91+
92+
// regression tests for multiplying by powers of two
93+
94+
// std::pow() is not accurate for large powers and we need exactness
95+
// TODO: put this in a helper function accessible to all tests
96+
static uintmax_t integer_pow(uintmax_t base, uintmax_t exponent) {
97+
// 1 to the power of anything is always 1
98+
if (base == 1) {
99+
return 1;
100+
}
101+
uintmax_t power = 1;
102+
for (uintmax_t i = 0; i < exponent; i++) {
103+
power *= base;
104+
}
105+
return power;
106+
}
107+
108+
TEST_CASE("multiply arby::Nat by a power of two", "[multiplication]") {
109+
uintmax_t power = GENERATE(range((uintmax_t)0, (uintmax_t)std::numeric_limits<uintmax_t>::digits / 2));
110+
uintmax_t rhs = integer_pow(2, power);
111+
uintmax_t lhs = GENERATE_COPY(take(100, random((uintmax_t)0, rhs)));
112+
113+
auto product = arby::Nat(lhs) * arby::Nat(rhs);
114+
115+
CHECK(product == lhs * rhs);
116+
}

tests/math_support/pow.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ TEST_CASE("Zero raised to the power of any non-zero arby::Nat returns 0", "[math
2323
}
2424

2525
// std::pow() is not accurate for large powers and we need exactness
26+
// TODO: put this in a helper function accessible to all tests
2627
static uintmax_t integer_pow(uintmax_t base, uintmax_t exponent) {
2728
// 1 to the power of anything is always 1
2829
if (base == 1) {

0 commit comments

Comments
 (0)