diff --git a/src/solvers/flattening/bv_utils.cpp b/src/solvers/flattening/bv_utils.cpp index 31dd12fea20..84451458fcd 100644 --- a/src/solvers/flattening/bv_utils.cpp +++ b/src/solvers/flattening/bv_utils.cpp @@ -8,6 +8,10 @@ Author: Daniel Kroening, kroening@kroening.com #include "bv_utils.h" +#include + +#include +#include #include #include @@ -780,8 +784,9 @@ bvt bv_utilst::dadda_tree(const std::vector &pps) // been observed to go up by 5%-10%, and on some models even by 20%. // #define WALLACE_TREE // Dadda' reduction scheme. This yields a smaller formula size than Wallace -// trees (and also the default addition scheme), but remains disabled as it -// isn't consistently more performant either. +// trees (and also the default addition scheme), but isn't consistently more +// performant with simple partial-product generation. Only when using +// higher-radix multipliers the combination appears to perform better. // #define DADDA_TREE // The following examples demonstrate the performance differences (with a @@ -916,16 +921,87 @@ bvt bv_utilst::dadda_tree(const std::vector &pps) // our multiplier that's not using a tree reduction scheme, but aren't uniformly // better either. +// Higher radix multipliers pre-compute partial products for groups of bits: +// radix-4 are groups of 2 bits, radix-8 are groups of 3 bits, and radix-16 are +// groups of 4 bits. Performance data for these variants combined with different +// (tree) reduction schemes are recorded at +// https://tinyurl.com/multiplier-comparison. The data suggests that radix-8 +// with Dadda's reduction yields the most consistent performance improvement +// while not regressing substantially in the matrix of different benchmarks and +// CaDiCaL and MiniSat2 as solvers. +// #define RADIX_MULTIPLIER 8 +// #define USE_KARATSUBA +// #define USE_TOOM_COOK +#define USE_SCHOENHAGE_STRASSEN +#ifdef RADIX_MULTIPLIER +# define DADDA_TREE +#endif + +#ifdef RADIX_MULTIPLIER +static bvt unsigned_multiply_by_3(propt &prop, const bvt &op) +{ + PRECONDITION(prop.cnf_handled_well()); + PRECONDITION(!op.empty()); + + bvt result; + result.reserve(op.size()); + + result.push_back(op[0]); + literalt prev_bit = const_literal(false); + + for(std::size_t i = 1; i < op.size(); ++i) + { + literalt sum = prop.new_variable(); + + prop.lcnf({sum, !op[i - 1], !op[i], !prev_bit}); + prop.lcnf({sum, !op[i - 1], !op[i], result.back()}); + prop.lcnf({sum, op[i - 1], op[i], !prev_bit, result.back()}); + prop.lcnf({sum, !op[i - 1], op[i], prev_bit, !result.back()}); + prop.lcnf({sum, op[i - 1], !op[i], !result.back()}); + prop.lcnf({sum, op[i - 1], !op[i], prev_bit}); + + prop.lcnf({!sum, !op[i - 1], op[i], !prev_bit}); + prop.lcnf({!sum, !op[i - 1], op[i], result.back()}); + prop.lcnf({!sum, !op[i - 1], !op[i], prev_bit, !result.back()}); + + prop.lcnf({!sum, op[i - 1], op[i], !result.back()}); + prop.lcnf({!sum, op[i - 1], op[i], prev_bit}); + prop.lcnf({!sum, op[i - 1], !op[i], !prev_bit, result.back()}); + + prop.lcnf({!sum, op[i], prev_bit, result.back()}); + prop.lcnf({!sum, op[i], !prev_bit, !result.back()}); + + result.push_back(sum); + prev_bit = op[i - 1]; + } + + return result; +} +#endif + bvt bv_utilst::unsigned_multiplier(const bvt &_op0, const bvt &_op1) { - bvt op0=_op0, op1=_op1; + PRECONDITION(!_op0.empty()); + PRECONDITION(!_op1.empty()); - if(is_constant(op1)) - std::swap(op0, op1); + if(_op1.size() == 1) + { + bvt product; + product.reserve(_op0.size()); + for(const auto &lit : _op0) + product.push_back(prop.land(lit, _op1.front())); + return product; + } - // build the usual quadratic number of partial products + // store partial products std::vector pps; - pps.reserve(op0.size()); + pps.reserve(_op0.size()); + + bvt op0 = _op0, op1 = _op1; + +#ifndef RADIX_MULTIPLIER + if(is_constant(op1)) + std::swap(op0, op1); for(std::size_t bit=0; bit times_three_opt; + auto times_three = [this, ×_three_opt, &op0]() -> const bvt & { + if(!times_three_opt.has_value()) + { +# if 1 + if(prop.cnf_handled_well()) + times_three_opt = unsigned_multiply_by_3(prop, op0); + else +# endif + times_three_opt = add(op0, shift(op0, shiftt::SHIFT_LEFT, 1)); + } + return *times_three_opt; + }; + +# if RADIX_MULTIPLIER >= 8 + optionalt times_five_opt, times_seven_opt; + auto times_five = [this, ×_five_opt, &op0]() -> const bvt & { + if(!times_five_opt.has_value()) + times_five_opt = add(op0, shift(op0, shiftt::SHIFT_LEFT, 2)); + return *times_five_opt; + }; + auto times_seven = + [this, ×_seven_opt, &op0, ×_three]() -> const bvt & { + if(!times_seven_opt.has_value()) + times_seven_opt = add(times_three(), shift(op0, shiftt::SHIFT_LEFT, 2)); + return *times_seven_opt; + }; +# endif + +# if RADIX_MULTIPLIER == 16 + optionalt times_nine_opt, times_eleven_opt, times_thirteen_opt, + times_fifteen_opt; + auto times_nine = [this, ×_nine_opt, &op0]() -> const bvt & { + if(!times_nine_opt.has_value()) + times_nine_opt = add(op0, shift(op0, shiftt::SHIFT_LEFT, 3)); + return *times_nine_opt; + }; + auto times_eleven = + [this, ×_eleven_opt, &op0, ×_three]() -> const bvt & { + if(!times_eleven_opt.has_value()) + times_eleven_opt = add(times_three(), shift(op0, shiftt::SHIFT_LEFT, 3)); + return *times_eleven_opt; + }; + auto times_thirteen = + [this, ×_thirteen_opt, &op0, ×_five]() -> const bvt & { + if(!times_thirteen_opt.has_value()) + times_thirteen_opt = add(times_five(), shift(op0, shiftt::SHIFT_LEFT, 3)); + return *times_thirteen_opt; + }; + auto times_fifteen = + [this, ×_fifteen_opt, &op0, ×_seven]() -> const bvt & { + if(!times_fifteen_opt.has_value()) + times_fifteen_opt = add(times_seven(), shift(op0, shiftt::SHIFT_LEFT, 3)); + return *times_fifteen_opt; + }; +# endif + + for(std::size_t op1_idx = 0; op1_idx + RADIX_GROUP_SIZE - 1 < op1.size(); + op1_idx += RADIX_GROUP_SIZE) + { + const literalt &bit0 = op1[op1_idx]; + const literalt &bit1 = op1[op1_idx + 1]; +# if RADIX_MULTIPLIER >= 8 + const literalt &bit2 = op1[op1_idx + 2]; +# if RADIX_MULTIPLIER == 16 + const literalt &bit3 = op1[op1_idx + 3]; +# endif +# endif + bvt partial_sum; + + if( + bit0.is_constant() && bit1.is_constant() +# if RADIX_MULTIPLIER >= 8 + && bit2.is_constant() +# if RADIX_MULTIPLIER == 16 + && bit3.is_constant() +# endif +# endif + ) + { + if(bit0.is_false()) // *0 + { + if(bit1.is_false()) // *00 + { +# if RADIX_MULTIPLIER >= 8 + if(bit2.is_false()) // *000 + { +# if RADIX_MULTIPLIER == 16 + if(bit3.is_false()) // 0000 + continue; + else // 1000 + partial_sum = shift(op0, shiftt::SHIFT_LEFT, op1_idx + 3); +# else + continue; +# endif + } + else // *100 + { +# if RADIX_MULTIPLIER == 16 + if(bit3.is_false()) // 0100 + partial_sum = shift(op0, shiftt::SHIFT_LEFT, op1_idx + 2); + else // 1100 + partial_sum = + shift(times_three(), shiftt::SHIFT_LEFT, op1_idx + 2); +# else + partial_sum = shift(op0, shiftt::SHIFT_LEFT, op1_idx + 2); +# endif + } +# else + continue; +# endif + } + else // *10 + { +# if RADIX_MULTIPLIER >= 8 + if(bit2.is_false()) // *010 + { +# if RADIX_MULTIPLIER == 16 + if(bit3.is_false()) // 0010 + partial_sum = shift(op0, shiftt::SHIFT_LEFT, op1_idx + 1); + else // 1010 + partial_sum = + shift(times_five(), shiftt::SHIFT_LEFT, op1_idx + 1); +# else + partial_sum = shift(op0, shiftt::SHIFT_LEFT, op1_idx + 1); +# endif + } + else // *110 + { +# if RADIX_MULTIPLIER == 16 + if(bit3.is_false()) // 0110 + partial_sum = + shift(times_three(), shiftt::SHIFT_LEFT, op1_idx + 1); + else // 1110 + partial_sum = + shift(times_seven(), shiftt::SHIFT_LEFT, op1_idx + 1); +# else + partial_sum = shift(times_three(), shiftt::SHIFT_LEFT, op1_idx + 1); +# endif + } +# else + partial_sum = shift(op0, shiftt::SHIFT_LEFT, op1_idx + 1); +# endif + } + } + else // *1 + { + if(bit1.is_false()) // *01 + { +# if RADIX_MULTIPLIER >= 8 + if(bit2.is_false()) // *001 + { +# if RADIX_MULTIPLIER == 16 + if(bit3.is_false()) // 0001 + partial_sum = shift(op0, shiftt::SHIFT_LEFT, op1_idx); + else // 1001 + partial_sum = shift(times_nine(), shiftt::SHIFT_LEFT, op1_idx); +# else + partial_sum = shift(op0, shiftt::SHIFT_LEFT, op1_idx); +# endif + } + else // *101 + { +# if RADIX_MULTIPLIER == 16 + if(bit3.is_false()) // 0101 + partial_sum = shift(times_five(), shiftt::SHIFT_LEFT, op1_idx); + else // 1101 + partial_sum = + shift(times_thirteen(), shiftt::SHIFT_LEFT, op1_idx); +# else + partial_sum = shift(times_five(), shiftt::SHIFT_LEFT, op1_idx); +# endif + } +# else + partial_sum = shift(op0, shiftt::SHIFT_LEFT, op1_idx); +# endif + } + else // *11 + { +# if RADIX_MULTIPLIER >= 8 + if(bit2.is_false()) // *011 + { +# if RADIX_MULTIPLIER == 16 + if(bit3.is_false()) // 0011 + partial_sum = shift(times_three(), shiftt::SHIFT_LEFT, op1_idx); + else // 1011 + partial_sum = shift(times_eleven(), shiftt::SHIFT_LEFT, op1_idx); +# else + partial_sum = shift(times_three(), shiftt::SHIFT_LEFT, op1_idx); +# endif + } + else // *111 + { +# if RADIX_MULTIPLIER == 16 + if(bit3.is_false()) // 0111 + partial_sum = shift(times_seven(), shiftt::SHIFT_LEFT, op1_idx); + else // 1111 + partial_sum = shift(times_fifteen(), shiftt::SHIFT_LEFT, op1_idx); +# else + partial_sum = shift(times_seven(), shiftt::SHIFT_LEFT, op1_idx); +# endif + } +# else + partial_sum = shift(times_three(), shiftt::SHIFT_LEFT, op1_idx); +# endif + } + } + } + else + { + partial_sum = bvt(op1_idx, const_literal(false)); + for(std::size_t op0_idx = 0; op0_idx + op1_idx < op0.size(); ++op0_idx) + { +# if RADIX_MULTIPLIER == 4 + if(prop.cnf_handled_well()) + { + literalt partial_sum_bit = prop.new_variable(); + partial_sum.push_back(partial_sum_bit); + + // 00 + prop.lcnf({bit0, bit1, !partial_sum_bit}); + // 01 -> sum = _op0 + prop.lcnf({!bit0, bit1, !partial_sum_bit, _op0[op0_idx]}); + prop.lcnf({!bit0, bit1, partial_sum_bit, !_op0[op0_idx]}); + // 10 -> sum = (_op0 << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, !partial_sum_bit}); + else + { + prop.lcnf({bit0, !bit1, !partial_sum_bit, _op0[op0_idx - 1]}); + prop.lcnf({bit0, !bit1, partial_sum_bit, !_op0[op0_idx - 1]}); + } + // 11 -> sum = times_three + prop.lcnf({!bit0, !bit1, !partial_sum_bit, times_three()[op0_idx]}); + prop.lcnf({!bit0, !bit1, partial_sum_bit, !times_three()[op0_idx]}); + } + else + { + partial_sum.push_back(prop.lselect( + !bit1, + prop.land(bit0, op0[op0_idx]), // 0x + prop.lselect( // 1x + !bit0, + op0_idx == 0 ? const_literal(false) : op0[op0_idx - 1], + times_three()[op0_idx]))); + } +# elif RADIX_MULTIPLIER == 8 + if(prop.cnf_handled_well()) + { + literalt partial_sum_bit = prop.new_variable(); + partial_sum.push_back(partial_sum_bit); + + // 000 + prop.lcnf({bit0, bit1, bit2, !partial_sum_bit}); + // 001 -> sum = _op0 + prop.lcnf({!bit0, bit1, bit2, !partial_sum_bit, _op0[op0_idx]}); + prop.lcnf({!bit0, bit1, bit2, partial_sum_bit, !_op0[op0_idx]}); + // 010 -> sum = (_op0 << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, bit2, !partial_sum_bit}); + else + { + prop.lcnf({bit0, !bit1, bit2, !partial_sum_bit, _op0[op0_idx - 1]}); + prop.lcnf({bit0, !bit1, bit2, partial_sum_bit, !_op0[op0_idx - 1]}); + } + // 011 -> sum = times_three + prop.lcnf( + {!bit0, !bit1, bit2, !partial_sum_bit, times_three()[op0_idx]}); + prop.lcnf( + {!bit0, !bit1, bit2, partial_sum_bit, !times_three()[op0_idx]}); + // 100 -> sum = (_op0 << 2) + if(op0_idx == 0 || op0_idx == 1) + prop.lcnf({bit0, bit1, !bit2, !partial_sum_bit}); + else + { + prop.lcnf({bit0, bit1, !bit2, !partial_sum_bit, _op0[op0_idx - 2]}); + prop.lcnf({bit0, bit1, !bit2, partial_sum_bit, !_op0[op0_idx - 2]}); + } + // 101 -> sum = times_five + prop.lcnf( + {!bit0, bit1, !bit2, !partial_sum_bit, times_five()[op0_idx]}); + prop.lcnf( + {!bit0, bit1, !bit2, partial_sum_bit, !times_five()[op0_idx]}); + // 110 -> sum = (times_three << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, !bit2, !partial_sum_bit}); + else + { + prop.lcnf( + {bit0, + !bit1, + !bit2, + !partial_sum_bit, + times_three()[op0_idx - 1]}); + prop.lcnf( + {bit0, + !bit1, + !bit2, + partial_sum_bit, + !times_three()[op0_idx - 1]}); + } + // 111 -> sum = times_seven + prop.lcnf( + {!bit0, !bit1, !bit2, !partial_sum_bit, times_seven()[op0_idx]}); + prop.lcnf( + {!bit0, !bit1, !bit2, partial_sum_bit, !times_seven()[op0_idx]}); + } + else + { + partial_sum.push_back(prop.lselect( + !bit2, + prop.lselect( // 0* + !bit1, + prop.land(bit0, op0[op0_idx]), // 00x + prop.lselect( // 01x + !bit0, + op0_idx == 0 ? const_literal(false) : op0[op0_idx - 1], + times_three()[op0_idx])), + prop.lselect( // 1* + !bit1, + prop.lselect( // 10x + !bit0, + op0_idx <= 1 ? const_literal(false) : op0[op0_idx - 2], + times_five()[op0_idx]), + prop.lselect( // 11x + !bit0, + op0_idx == 0 ? const_literal(false) + : times_three()[op0_idx - 1], + times_seven()[op0_idx])))); + } +# elif RADIX_MULTIPLIER == 16 + if(prop.cnf_handled_well()) + { + literalt partial_sum_bit = prop.new_variable(); + partial_sum.push_back(partial_sum_bit); + + // 0000 + prop.lcnf({bit0, bit1, bit2, bit3, !partial_sum_bit}); + // 0001 -> sum = op0 + prop.lcnf({!bit0, bit1, bit2, bit3, !partial_sum_bit, op0[op0_idx]}); + prop.lcnf({!bit0, bit1, bit2, bit3, partial_sum_bit, !op0[op0_idx]}); + // 0010 -> sum = (op0 << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, bit2, bit3, !partial_sum_bit}); + else + { + prop.lcnf( + {bit0, !bit1, bit2, bit3, !partial_sum_bit, op0[op0_idx - 1]}); + prop.lcnf( + {bit0, !bit1, bit2, bit3, partial_sum_bit, !op0[op0_idx - 1]}); + } + // 0011 -> sum = times_three + prop.lcnf( + {!bit0, + !bit1, + bit2, + bit3, + !partial_sum_bit, + times_three()[op0_idx]}); + prop.lcnf( + {!bit0, + !bit1, + bit2, + bit3, + partial_sum_bit, + !times_three()[op0_idx]}); + // 0100 -> sum = (op0 << 2) + if(op0_idx == 0 || op0_idx == 1) + prop.lcnf({bit0, bit1, !bit2, bit3, !partial_sum_bit}); + else + { + prop.lcnf( + {bit0, bit1, !bit2, bit3, !partial_sum_bit, op0[op0_idx - 2]}); + prop.lcnf( + {bit0, bit1, !bit2, bit3, partial_sum_bit, !op0[op0_idx - 2]}); + } + // 0101 -> sum = times_five + prop.lcnf( + {!bit0, + bit1, + !bit2, + bit3, + !partial_sum_bit, + times_five()[op0_idx]}); + prop.lcnf( + {!bit0, + bit1, + !bit2, + bit3, + partial_sum_bit, + !times_five()[op0_idx]}); + // 0110 -> sum = (times_three << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, !bit2, bit3, !partial_sum_bit}); + else + { + prop.lcnf( + {bit0, + !bit1, + !bit2, + bit3, + !partial_sum_bit, + times_three()[op0_idx - 1]}); + prop.lcnf( + {bit0, + !bit1, + !bit2, + bit3, + partial_sum_bit, + !times_three()[op0_idx - 1]}); + } + // 0111 -> sum = times_seven + prop.lcnf( + {!bit0, + !bit1, + !bit2, + bit3, + !partial_sum_bit, + times_seven()[op0_idx]}); + prop.lcnf( + {!bit0, + !bit1, + !bit2, + bit3, + partial_sum_bit, + !times_seven()[op0_idx]}); + + // 1000 -> sum = (op0 << 3) + if(op0_idx == 0 || op0_idx == 1 || op0_idx == 2) + prop.lcnf({bit0, bit1, bit2, !bit3, !partial_sum_bit}); + else + { + prop.lcnf( + {bit0, bit1, bit2, !bit3, !partial_sum_bit, op0[op0_idx - 3]}); + prop.lcnf( + {bit0, bit1, bit2, !bit3, partial_sum_bit, !op0[op0_idx - 3]}); + } + // 1001 -> sum = times_nine + prop.lcnf( + {!bit0, + bit1, + bit2, + !bit3, + !partial_sum_bit, + times_nine()[op0_idx]}); + prop.lcnf( + {!bit0, + bit1, + bit2, + !bit3, + partial_sum_bit, + !times_nine()[op0_idx]}); + // 1010 -> sum = (times_five << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, bit2, !bit3, !partial_sum_bit}); + else + { + prop.lcnf( + {bit0, + !bit1, + bit2, + !bit3, + !partial_sum_bit, + times_five()[op0_idx - 1]}); + prop.lcnf( + {bit0, + !bit1, + bit2, + !bit3, + partial_sum_bit, + !times_five()[op0_idx - 1]}); + } + // 1011 -> sum = times_eleven + prop.lcnf( + {!bit0, + !bit1, + bit2, + !bit3, + !partial_sum_bit, + times_eleven()[op0_idx]}); + prop.lcnf( + {!bit0, + !bit1, + bit2, + !bit3, + partial_sum_bit, + !times_eleven()[op0_idx]}); + // 1100 -> sum = (times_three << 2) + if(op0_idx == 0 || op0_idx == 1) + prop.lcnf({bit0, bit1, !bit2, !bit3, !partial_sum_bit}); + else + { + prop.lcnf( + {bit0, + bit1, + !bit2, + !bit3, + !partial_sum_bit, + times_three()[op0_idx - 2]}); + prop.lcnf( + {bit0, + bit1, + !bit2, + !bit3, + partial_sum_bit, + !times_three()[op0_idx - 2]}); + } + // 1101 -> sum = times_thirteen + prop.lcnf( + {!bit0, + bit1, + !bit2, + !bit3, + !partial_sum_bit, + times_thirteen()[op0_idx]}); + prop.lcnf( + {!bit0, + bit1, + !bit2, + !bit3, + partial_sum_bit, + !times_thirteen()[op0_idx]}); + // 1110 -> sum = (times_seven << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, !bit2, !bit3, !partial_sum_bit}); + else + { + prop.lcnf( + {bit0, + !bit1, + !bit2, + !bit3, + !partial_sum_bit, + times_seven()[op0_idx - 1]}); + prop.lcnf( + {bit0, + !bit1, + !bit2, + !bit3, + partial_sum_bit, + !times_seven()[op0_idx - 1]}); + } + // 1111 -> sum = times_fifteen + prop.lcnf( + {!bit0, + !bit1, + !bit2, + !bit3, + !partial_sum_bit, + times_fifteen()[op0_idx]}); + prop.lcnf( + {!bit0, + !bit1, + !bit2, + !bit3, + partial_sum_bit, + !times_fifteen()[op0_idx]}); + } + else + { + partial_sum.push_back(prop.lselect( + !bit3, + prop.lselect( // 0* + !bit2, + prop.lselect( // 00* + !bit1, + prop.land(bit0, op0[op0_idx]), // 000x + prop.lselect( // 001x + !bit0, + op0_idx == 0 ? const_literal(false) : op0[op0_idx - 1], + times_three()[op0_idx])), + prop.lselect( // 01* + !bit1, + prop.lselect( // 010x + !bit0, + op0_idx <= 1 ? const_literal(false) : op0[op0_idx - 2], + times_five()[op0_idx]), + prop.lselect( // 011x + !bit0, + op0_idx == 0 ? const_literal(false) + : times_three()[op0_idx - 1], + times_seven()[op0_idx]))), + prop.lselect( // 1* + !bit2, + prop.lselect( // 10* + !bit1, + prop.lselect( // 100x + !bit0, + op0_idx <= 2 ? const_literal(false) : op0[op0_idx - 3], + times_nine()[op0_idx]), + prop.lselect( // 101x + !bit0, + op0_idx == 0 ? const_literal(false) + : times_five()[op0_idx - 1], + times_eleven()[op0_idx])), + prop.lselect( // 11* + !bit1, + prop.lselect( // 110x + !bit0, + op0_idx <= 1 ? const_literal(false) + : times_three()[op0_idx - 2], + times_thirteen()[op0_idx]), + prop.lselect( // 111x + !bit0, + op0_idx == 0 ? const_literal(false) + : times_seven()[op0_idx - 1], + times_fifteen()[op0_idx]))))); + } +# else +# error Unsupported radix +# endif + } + } + + pps.push_back(std::move(partial_sum)); + } + + if(op1.size() % RADIX_GROUP_SIZE == 1) + { + if(op0.size() == op1.size()) + { + if(pps.empty()) + pps.push_back(bvt(op0.size(), const_literal(false))); + + // This is the partial product of the MSB of op1 with op0, which is all + // zeros except for (possibly) the MSB. Since we don't need to account for + // any carry out of adding this partial product, we just need to compute + // the sum the MSB of one of the partial products and this partial + // product, we is an xor of just those bits. + pps.back().back() = + prop.lxor(pps.back().back(), prop.land(op0[0], op1.back())); + } + else + { + bvt partial_sum = bvt(op1.size() - 1, const_literal(false)); + for(const auto &lit : op0) + { + partial_sum.push_back(prop.land(lit, op1.back())); + if(partial_sum.size() == op0.size()) + break; + } + pps.push_back(std::move(partial_sum)); + } + } +# if RADIX_MULTIPLIER >= 8 + else if(op1.size() % RADIX_GROUP_SIZE == 2) + { + const literalt &bit0 = op1[op1.size() - 2]; + const literalt &bit1 = op1[op1.size() - 1]; + + bvt partial_sum = bvt(op1.size() - 2, const_literal(false)); + for(std::size_t op0_idx = 0; op0_idx < 2; ++op0_idx) + { + if(prop.cnf_handled_well()) + { + literalt partial_sum_bit = prop.new_variable(); + partial_sum.push_back(partial_sum_bit); + // 00 + prop.lcnf({bit0, bit1, !partial_sum_bit}); + // 01 -> sum = op0 + prop.lcnf({!bit0, bit1, !partial_sum_bit, op0[op0_idx]}); + prop.lcnf({!bit0, bit1, partial_sum_bit, !op0[op0_idx]}); + // 10 -> sum = (op0 << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, !partial_sum_bit}); + else + { + prop.lcnf({bit0, !bit1, !partial_sum_bit, op0[op0_idx - 1]}); + prop.lcnf({bit0, !bit1, partial_sum_bit, !op0[op0_idx - 1]}); + } + // 11 -> sum = times_three + prop.lcnf({!bit0, !bit1, !partial_sum_bit, times_three()[op0_idx]}); + prop.lcnf({!bit0, !bit1, partial_sum_bit, !times_three()[op0_idx]}); + } + else + { + partial_sum.push_back(prop.lselect( + !bit1, + prop.land(bit0, op0[op0_idx]), // 0x + prop.lselect( // 1x + !bit0, + op0_idx == 0 ? const_literal(false) : op0[op0_idx - 1], + times_three()[op0_idx]))); + } + } + + pps.push_back(std::move(partial_sum)); + } +# endif +# if RADIX_MULTIPLIER == 16 + else if(op1.size() % RADIX_GROUP_SIZE == 3) + { + const literalt &bit0 = op1[op1.size() - 3]; + const literalt &bit1 = op1[op1.size() - 2]; + const literalt &bit2 = op1[op1.size() - 1]; + + bvt partial_sum = bvt(op1.size() - 3, const_literal(false)); + for(std::size_t op0_idx = 0; op0_idx < 3; ++op0_idx) + { + if(prop.cnf_handled_well()) + { + literalt partial_sum_bit = prop.new_variable(); + partial_sum.push_back(partial_sum_bit); + // 000 + prop.lcnf({bit0, bit1, bit2, !partial_sum_bit}); + // 001 -> sum = op0 + prop.lcnf({!bit0, bit1, bit2, !partial_sum_bit, op0[op0_idx]}); + prop.lcnf({!bit0, bit1, bit2, partial_sum_bit, !op0[op0_idx]}); + // 010 -> sum = (op0 << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, bit2, !partial_sum_bit}); + else + { + prop.lcnf({bit0, !bit1, bit2, !partial_sum_bit, op0[op0_idx - 1]}); + prop.lcnf({bit0, !bit1, bit2, partial_sum_bit, !op0[op0_idx - 1]}); + } + // 011 -> sum = times_three + prop.lcnf( + {!bit0, !bit1, bit2, !partial_sum_bit, times_three()[op0_idx]}); + prop.lcnf( + {!bit0, !bit1, bit2, partial_sum_bit, !times_three()[op0_idx]}); + // 100 -> sum = (op0 << 2) + if(op0_idx == 0 || op0_idx == 1) + prop.lcnf({bit0, bit1, !bit2, !partial_sum_bit}); + else + { + prop.lcnf({bit0, bit1, !bit2, !partial_sum_bit, op0[op0_idx - 2]}); + prop.lcnf({bit0, bit1, !bit2, partial_sum_bit, !op0[op0_idx - 2]}); + } + // 101 -> sum = times_five + prop.lcnf( + {!bit0, bit1, !bit2, !partial_sum_bit, times_five()[op0_idx]}); + prop.lcnf( + {!bit0, bit1, !bit2, partial_sum_bit, !times_five()[op0_idx]}); + // 110 -> sum = (times_three << 1) + if(op0_idx == 0) + prop.lcnf({bit0, !bit1, !bit2, !partial_sum_bit}); + else + { + prop.lcnf( + {bit0, !bit1, !bit2, !partial_sum_bit, times_three()[op0_idx - 1]}); + prop.lcnf( + {bit0, !bit1, !bit2, partial_sum_bit, !times_three()[op0_idx - 1]}); + } + // 111 -> sum = times_seven + prop.lcnf( + {!bit0, !bit1, !bit2, !partial_sum_bit, times_seven()[op0_idx]}); + prop.lcnf( + {!bit0, !bit1, !bit2, partial_sum_bit, !times_seven()[op0_idx]}); + } + else + { + partial_sum.push_back(prop.lselect( + !bit2, + prop.lselect( // 0* + !bit1, + prop.land(bit0, op0[op0_idx]), // 00x + prop.lselect( // 01x + !bit0, + op0_idx == 0 ? const_literal(false) : op0[op0_idx - 1], + times_three()[op0_idx])), + prop.lselect( // 1* + !bit1, + prop.lselect( // 10x + !bit0, + op0_idx <= 1 ? const_literal(false) : op0[op0_idx - 2], + times_five()[op0_idx]), + prop.lselect( // 11x + !bit0, + op0_idx == 0 ? const_literal(false) : times_three()[op0_idx - 1], + times_seven()[op0_idx])))); + } + } + + pps.push_back(std::move(partial_sum)); + } +# endif +#endif if(pps.empty()) return zeros(op0.size()); @@ -961,6 +1828,600 @@ bvt bv_utilst::unsigned_multiplier(const bvt &_op0, const bvt &_op1) } } +bvt bv_utilst::unsigned_karatsuba_multiplier(const bvt &_op0, const bvt &_op1) +{ + if(_op0.size() != _op1.size()) + return unsigned_multiplier(_op0, _op1); + + const std::size_t op_size = _op0.size(); + // only use this approach for powers of two + if(op_size == 0 || (op_size & (op_size - 1)) != 0) + return unsigned_multiplier(_op0, _op1); + + const std::size_t half_op_size = op_size >> 1; + + // The need to use a full multiplier for z_0 means that we will not actually + // accomplish a reduction in bit width. + bvt x0{_op0.begin(), _op0.begin() + half_op_size}; + x0.resize(op_size, const_literal(false)); + bvt x1{_op0.begin() + half_op_size, _op0.end()}; + // x1.resize(op_size, const_literal(false)); + bvt y0{_op1.begin(), _op1.begin() + half_op_size}; + y0.resize(op_size, const_literal(false)); + bvt y1{_op1.begin() + half_op_size, _op1.end()}; + // y1.resize(op_size, const_literal(false)); + + bvt z0 = unsigned_multiplier(x0, y0); + bvt z2 = unsigned_karatsuba_multiplier(x1, y1); + + bvt z0_half{z0.begin(), z0.begin() + half_op_size}; + bvt z2_plus_z0 = add(z2, z0_half); + z2_plus_z0.resize(half_op_size); + + bvt x0_half{x0.begin(), x0.begin() + half_op_size}; + bvt xdiff = add(x0_half, x1); + // xdiff.resize(half_op_size); + bvt y0_half{y0.begin(), y0.begin() + half_op_size}; + bvt ydiff = add(y1, y0_half); + // ydiff.resize(half_op_size); + + bvt z1 = sub(unsigned_karatsuba_multiplier(xdiff, ydiff), z2_plus_z0); + for(std::size_t i = 0; i < half_op_size; ++i) + z1.insert(z1.begin(), const_literal(false)); + // result.insert(result.end(), z1.begin(), z1.end()); + + // z1.resize(op_size); + z0.resize(op_size); + return add(z0, z1); +} + +bvt bv_utilst::unsigned_toom_cook_multiplier(const bvt &_op0, const bvt &_op1) +{ + PRECONDITION(!_op0.empty()); + PRECONDITION(!_op1.empty()); + + if(_op1.size() == 1) + return unsigned_multiplier(_op0, _op1); + + // break up _op0, _op1 in groups of at most GROUP_SIZE bits + PRECONDITION(_op0.size() == _op1.size()); +#define GROUP_SIZE 8 + const std::size_t d_bits = + 2 * GROUP_SIZE + + 2 * address_bits((_op0.size() + GROUP_SIZE - 1) / GROUP_SIZE); + std::vector a, b, c_ops, d; + for(std::size_t i = 0; i < _op0.size(); i += GROUP_SIZE) + { + std::size_t u = std::min(i + GROUP_SIZE, _op0.size()); + a.emplace_back(_op0.begin() + i, _op0.begin() + u); + b.emplace_back(_op1.begin() + i, _op1.begin() + u); + + c_ops.push_back(zeros(i)); + d.push_back(prop.new_variables(d_bits)); + c_ops.back().insert(c_ops.back().end(), d.back().begin(), d.back().end()); + c_ops.back() = zero_extension(c_ops.back(), _op0.size()); + } + for(std::size_t i = a.size(); i < 2 * a.size() - 1; ++i) + { + d.push_back(prop.new_variables(d_bits)); + } + + // r(0) + bvt r_0 = d[0]; + prop.l_set_to_true(equal( + r_0, + unsigned_multiplier( + zero_extension(a[0], r_0.size()), zero_extension(b[0], r_0.size())))); + + for(std::size_t j = 1; j < a.size(); ++j) + { + // r(2^(j-1)) + bvt r_j = zero_extension( + d[0], std::min(_op0.size(), d[0].size() + (j - 1) * (d.size() - 1))); + for(std::size_t i = 1; i < d.size(); ++i) + { + r_j = add( + r_j, + shift( + zero_extension(d[i], r_j.size()), shiftt::SHIFT_LEFT, (j - 1) * i)); + } + + bvt a_even = zero_extension(a[0], r_j.size()); + for(std::size_t i = 2; i < a.size(); i += 2) + { + a_even = add( + a_even, + shift( + zero_extension(a[i], a_even.size()), + shiftt::SHIFT_LEFT, + (j - 1) * i)); + } + bvt a_odd = zero_extension(a[1], r_j.size()); + for(std::size_t i = 3; i < a.size(); i += 2) + { + a_odd = add( + a_odd, + shift( + zero_extension(a[i], a_odd.size()), + shiftt::SHIFT_LEFT, + (j - 1) * (i - 1))); + } + bvt b_even = zero_extension(b[0], r_j.size()); + for(std::size_t i = 2; i < b.size(); i += 2) + { + b_even = add( + b_even, + shift( + zero_extension(b[i], b_even.size()), + shiftt::SHIFT_LEFT, + (j - 1) * i)); + } + bvt b_odd = zero_extension(b[1], r_j.size()); + for(std::size_t i = 3; i < b.size(); i += 2) + { + b_odd = add( + b_odd, + shift( + zero_extension(b[i], b_odd.size()), + shiftt::SHIFT_LEFT, + (j - 1) * (i - 1))); + } + + prop.l_set_to_true(equal( + r_j, + unsigned_multiplier( + add(a_even, shift(a_odd, shiftt::SHIFT_LEFT, j - 1)), + add(b_even, shift(b_odd, shiftt::SHIFT_LEFT, j - 1))))); + + // r(-2^(j-1)) + bvt r_minus_j = zero_extension( + d[0], std::min(_op0.size(), d[0].size() + (j - 1) * (d.size() - 1))); + for(std::size_t i = 1; i < d.size(); ++i) + { + if(i % 2 == 1) + { + r_minus_j = sub( + r_minus_j, + shift( + zero_extension(d[i], r_minus_j.size()), + shiftt::SHIFT_LEFT, + (j - 1) * i)); + } + else + { + r_minus_j = add( + r_minus_j, + shift( + zero_extension(d[i], r_minus_j.size()), + shiftt::SHIFT_LEFT, + (j - 1) * i)); + } + } + + prop.l_set_to_true(equal( + r_minus_j, + unsigned_multiplier( + sub(a_even, shift(a_odd, shiftt::SHIFT_LEFT, j - 1)), + sub(b_even, shift(b_odd, shiftt::SHIFT_LEFT, j - 1))))); + } + + if(c_ops.empty()) + return zeros(_op0.size()); + else + { +#ifdef WALLACE_TREE + return wallace_tree(c_ops); +#elif defined(DADDA_TREE) + return dadda_tree(c_ops); +#else + bvt product = c_ops.front(); + + for(auto it = std::next(c_ops.begin()); it != c_ops.end(); ++it) + product = add(product, *it); + + return product; +#endif + } +} + +static std::string beautify(const bvt &bv) +{ + for(const auto &v : bv) + { + if(!v.is_constant()) + { + std::ostringstream oss; + oss << bv; + return oss.str(); + } + } + + std::string result; + std::size_t number = 0; + for(std::size_t i = 0; i < bv.size(); ++i) + { + if(result.size() % 5 == 4) + result = std::string(" ") + result; + result = std::string(bv[i].is_false() ? "0" : "1") + result; + + if(bv[i].is_true()) + number += 1 << i; + } + + return result + " (" + std::to_string(number) + ")"; +} + +bvt bv_utilst::unsigned_schoenhage_strassen_multiplier( + const bvt &a, + const bvt &b) +{ + PRECONDITION(a.size() == b.size()); + + // Running examples: we want to multiple 213 by 15 as 8- or 9-bit integers. + // That is, we seek to multiply 11010101 (011010101) by 00001111 (000001111). + // ^bit 7 ^bit 0 + // The expected result is 123 as both an 8-bit and 9-bit result (001111011). + + // We compute the result modulo a Fermat number F_m = 2^2^m + 1. The maximum + // result when multiplying a by b (with their sizes being the same per the + // precondition above) is 2^2*op_size - 1. + // TODO: we don't actually need a full multiplier, a result with up to op_size + // bits is sufficient for our purposes. + // Hence we require 2^2^m >= 2^2*op_size, i.e., 2^m >= 2*op_size, or + // m >= log_2(op_size) + 1. + // For our examples m will be 4 and 5, respectively, with Fermat numbers + // 2^16 + 1 and 2^32 + 1. + const std::size_t m = address_bits(a.size()) + 1; + std::cerr << "m: " << m << std::endl; + + // Extend bit width to 2^(m + 1) = op_size (rounded to next power of 2) * 4 + // For our examples, extended bit widths will be 32 and 64. + PRECONDITION(sizeof(std::size_t) * CHAR_BIT > m + 1); + const std::size_t two_to_m_plus_1 = (std::size_t)1 << (m + 1); + std::cerr << "a: " << beautify(a) << std::endl; + std::cerr << "b: " << beautify(b) << std::endl; + bvt a_ext = zero_extension(a, two_to_m_plus_1); + bvt b_ext = zero_extension(b, two_to_m_plus_1); + + // We need to distinguish whether m is even or odd + // m = 2n - 1 for odd m and m = 2n -2 for even m + // For our 8-bit inputs we have m = 4 and, therefore, n = 3. + // For our 9-bit inputs we have m = 5 and, therefore, n = 3. + const std::size_t n = m % 2 == 1 ? (m + 1) / 2 : m / 2 + 1; + std::cerr << "n: " << n << std::endl; + + // For even m create 2^n (of 2^(n - 1) bits) chunks from a_ext, b_ext (for our + // 8-bit inputs we have chunk_size = 4 with num_chunks = 8). + // For odd m create 2^(n + 1) chunks (of 2^(n - 1) bits) from a_ext, b_ext; + // a_0 will be bit positions 0 through to 2^(n - 1) - 1, a_{2^(n + 1) - 1} + // will be bit positions up to 2^(m + 1) - 1. + // For our 9-bit inputs we have chunk_size = 4 with num_chunks = 16 + const std::size_t chunk_size = (std::size_t)1 << (n - 1); + const std::size_t num_chunks = two_to_m_plus_1 / chunk_size; + CHECK_RETURN( + num_chunks == m % 2 ? (std::size_t)1 << (n + 1) : (std::size_t)1 << n); + std::cerr << "chunk_size: " << chunk_size << std::endl; + std::cerr << "num_chunks: " << num_chunks << std::endl; + std::cerr << "address_bits(num_chunks): " << address_bits(num_chunks) + << std::endl; + + std::vector a_rho, b_sigma; + a_rho.reserve(num_chunks); + b_sigma.reserve(num_chunks); + for(std::size_t i = 0; i < num_chunks; ++i) + { + a_rho.emplace_back( + a_ext.begin() + i * chunk_size, a_ext.begin() + (i + 1) * chunk_size); + b_sigma.emplace_back( + b_ext.begin() + i * chunk_size, b_ext.begin() + (i + 1) * chunk_size); + } + // For our example we now have + // a_rho = [ 0101, 1101, 0000, ..., 0000 ] + // b_sigma = [ 1111, 0000, 0000, ..., 0000 ] + + // Compute gamma_r = \sum_{i + j = r} a_i * b_j with bit width 3n + 5 with r + // ranging from 0 to 2^(n + 2) - 1 (to 2^(n + 1) - 1 when m is even). + // For our example this will be additions/multiplications of width 14 + // (implying that school book multiplication would be cheaper, as is the case + // for all operand lengths below 32 bits). + // TODO: all subsequent steps seem to be using mod 2^(n + 2) (mod 2^(n + 1) + // when m is even), so it may be sufficient to do this over n + 2 bits instead + // of 3n + 5. + std::vector gamma_tau{num_chunks * 2, zeros(3 * n + 5)}; + for(std::size_t tau = 0; tau < num_chunks * 2; ++tau) + { + for(std::size_t rho = tau < num_chunks ? 0 : tau - num_chunks + 1; + rho < num_chunks && rho <= tau; + ++rho) + { + const std::size_t sigma = tau - rho; + gamma_tau[tau] = add( + gamma_tau[tau], + unsigned_multiplier( + zero_extension(a_rho[rho], 3 * n + 5), + zero_extension(b_sigma[sigma], 3 * n + 5))); + } + } + // For our example we obtain + // gamma_tau = [ 00 0000 0100 1011, 00 0000 1100 0011, 0.... ] + + // Compute c_tau over bit width n + 2 (n + 1 when m is even) as gamma_tau + + // gamma_{tau + 2^(n + 1)} (gamma_{tau + 2^n} when m is even). + std::vector c_tau; + c_tau.reserve(num_chunks); + for(std::size_t tau = 0; tau < num_chunks; ++tau) + { + c_tau.push_back(add(gamma_tau[tau], gamma_tau[tau + num_chunks])); + c_tau.back().resize(address_bits(num_chunks) + 1); + std::cerr << "c_tau[" << tau << "]: " << beautify(c_tau[tau]) << std::endl; + } + // For our example we obtain + // c_tau = [ 01011, 00011, 0... ] + + // Compute z_j = c_j - c_{j + 2^n} (mod 2^(n + 2)) (mod 2^(n + 1) and c_{j + + // 2^(n - 1)} when m is even) + std::vector z_j; + z_j.reserve(num_chunks / 2); + for(std::size_t j = 0; j < num_chunks / 2; ++j) + z_j.push_back(sub(c_tau[j], c_tau[j + num_chunks / 2])); + // For our example we have z_j = c_tau as all elements beyond the second one + // are zeros. + + // Compute z_j mod F_n using number-theoretic transform with omega = 2 for + // odd m and omega = 4 for even m. + // For our examples we have F_n = 2^2^n + 1 = 257 with 2 being a 2^(n + 1)-th + // root of unity, i.e., 2^16 \equiv 1 (mod 257) (with 4 being a 2^n-root of + // unity, i.e., 4^8 \equiv 1 (mod 257). The DFT table for omega = 2 would be + // 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 1 + // 1 2 4 8 16 32 64 128 -1 -2 -4 -8 -16 -32 -64 -128 + // 1 4 16 64 -1 -4 -16 -64 1 4 16 64 -1 -4 -16 -64 + // 1 8 64 -2 -16 -128 4 32 -1 -8 -64 2 16 128 -4 -32 + // 1 16 -1 -16 1 16 -1 -16 1 16 -1 -16 1 16 -1 -16 + // 1 32 -4 -128 16 -2 -64 8 -1 -32 4 128 -16 2 64 -8 + // 1 64 -16 4 -1 -64 16 -2 1 64 -16 4 -1 -64 16 -4 + // 1 128 -64 32 -16 8 -2 2 -1 -128 64 -32 16 -8 4 -2 + // 1 -1 1 -1 1 -1 1 -1 1 -1 1 -1 1 -1 1 -1 + // 1 -2 4 -8 16 -32 64 -128 -1 2 -4 8 -16 32 -64 128 + // 1 -4 16 -64 -1 4 -16 64 1 -4 16 -64 -1 4 -16 64 + // 1 -8 64 2 -16 128 4 -32 -1 8 -64 -2 16 -128 -4 32 + // 1 -16 -1 16 1 -16 -1 16 1 -16 -1 16 1 -16 -1 16 + // 1 -32 -4 128 16 2 -64 -8 -1 32 4 -128 -16 -2 64 8 + // 1 -64 -16 -4 -1 64 16 4 1 -64 -16 -4 -1 64 16 4 + // 1 -128 -64 -32 -16 -8 -4 -2 -1 128 64 32 16 8 4 2 + // For fast NTT (less than O(n^2)) use Cooley-Tukey for NTT, then perform + // element-wise multiplication, and finally apply Gentleman-Sande for + // inverse NTT. + + // Addition mod F_n with overflow + auto cyclic_add = [this](const bvt &x, const bvt &y) { + PRECONDITION(x.size() == y.size()); + + auto result_with_overflow = adder(x, y, const_literal(false)); + if(result_with_overflow.second.is_false()) + return result_with_overflow.first; + + return add( + result_with_overflow.first, + zero_extension(bvt{1, result_with_overflow.second}, x.size())); + }; + + // Compute NTT + std::vector a_j, b_j; + a_j.reserve(num_chunks); + b_j.reserve(num_chunks); + for(std::size_t j = 0; j < num_chunks; ++j) + { + // All NTT steps are mod F_n, i.e., mod 2^2^n + 1, which implies we need + // 2^(n + 1) bits to represent numbers + a_j.push_back(zero_extension(a_rho[j], (std::size_t)1 << (n + 1))); + b_j.push_back(zero_extension(b_sigma[j], (std::size_t)1 << (n + 1))); + } + // Use in-place iterative Cooley-Tukey + std::vector Aa, Ab; + Aa.reserve(num_chunks); + Ab.reserve(num_chunks); + // In the following we use k represented as bits k_{n - 1}...k_0 and + // j_0...j_{n - 1}, i.e., the most-significant bit of k is k_{n - 1} while the + // MSB for j is j_0. + for(std::size_t k = 0; k < num_chunks; ++k) + { + // reverse n (n - 1 if m is even) bits of k + std::size_t j = 0; + for(std::size_t nu = 0; nu < address_bits(num_chunks); ++nu) + { + j <<= 1; // the initial shift has no effect + j |= (k & (1 << nu)) >> nu; + } + Aa.push_back(a_j[j]); + Ab.push_back(b_j[j]); + } + for(std::size_t nu = 1; nu <= address_bits(num_chunks); ++nu) + { + const std::size_t bit_nu = (std::size_t)1 << (nu - 1); + std::size_t bits_up_to_nu = 0; + for(std::size_t i = 0; i < nu - 1; ++i) + bits_up_to_nu |= 1 << i; + + // we only need odd ones + for(std::size_t k = 1; k < num_chunks; k += 2) + { + if((k & bit_nu) == 0) + continue; + + bvt Aa_nu_bit_is_zero = Aa[k & ~bit_nu]; + bvt Ab_nu_bit_is_zero = Ab[k & ~bit_nu]; + + const std::size_t chi = (k & bits_up_to_nu) + << (address_bits(num_chunks) - 1 - (nu - 1)); + const std::size_t omega = m % 2 == 1 ? 2 : 4; + const std::size_t shift_dist = chi * omega / 2; + + if(nu > 1) // no need to update even indices + { + Aa[k & ~bit_nu] = cyclic_add( + Aa_nu_bit_is_zero, shift(Aa[k], shiftt::ROTATE_LEFT, shift_dist)); + Ab[k & ~bit_nu] = cyclic_add( + Ab_nu_bit_is_zero, shift(Ab[k], shiftt::ROTATE_LEFT, shift_dist)); + std::cerr << "Aa[" << nu << "](" << (k & ~bit_nu) + << "): " << beautify(Aa[k & ~bit_nu]) << std::endl; +#if 0 + std::cerr << "Ab[" << nu << "](" << (k & ~bit_nu) + << "): " << beautify(Ab[k & ~bit_nu]) << std::endl; +#endif + } + + // subtraction mod F_n is addition of subtrahend cyclically shifted 2^n + // positions to the left + const std::size_t shift_dist_for_sub = shift_dist + ((std::size_t)1 << n); + Aa[k] = cyclic_add( + Aa_nu_bit_is_zero, + shift(Aa[k], shiftt::ROTATE_LEFT, shift_dist_for_sub)); + Ab[k] = cyclic_add( + Ab_nu_bit_is_zero, + shift(Ab[k], shiftt::ROTATE_LEFT, shift_dist_for_sub)); + std::cerr << "Aa[" << nu << "](" << k << "): " << beautify(Aa[k]) + << std::endl; +#if 0 + std::cerr << "Ab[" << nu << "](" << k << "): " << beautify(Ab[k]) + << std::endl; +#endif + } + } + + // Either compute u - v (if u > v), else u - v + 2^2^n + 1 + auto reduce_to_mod_F_n = [this](const bvt &x) { + const std::size_t two_to_power_of_n = x.size() / 2; + // std::cerr << "two_to_power_of_n: " << two_to_power_of_n << std::endl; + const bvt u = + zero_extension(bvt{x.begin(), x.begin() + two_to_power_of_n}, x.size()); + // std::cerr << "u: " << beautify(u) << std::endl; + const bvt v = + zero_extension(bvt{x.begin() + two_to_power_of_n, x.end()}, x.size()); + // std::cerr << "v: " << beautify(v) << std::endl; + bvt two_to_power_of_two_to_power_of_n_plus_1 = build_constant(1, x.size()); + two_to_power_of_two_to_power_of_n_plus_1[two_to_power_of_n] = + const_literal(true); + const bvt u_ext = select( + unsigned_less_than(u, v), + add(u, two_to_power_of_two_to_power_of_n_plus_1), + u); + // std::cerr << "u_ext: " << beautify(u_ext) << std::endl; + return sub(u_ext, v); + }; + + std::vector a_hat_k{num_chunks, bvt{}}, b_hat_k{num_chunks, bvt{}}; + // Reduce by F_n + for(std::size_t j = 1; j < num_chunks; j += 2) + { + a_hat_k[j] = reduce_to_mod_F_n(Aa[j]); + std::cerr << "a_hat_k[" << j << "]: " << beautify(a_hat_k[j]) << std::endl; + b_hat_k[j] = reduce_to_mod_F_n(Ab[j]); + std::cerr << "b_hat_k[" << j << "]: " << beautify(b_hat_k[j]) << std::endl; + } + + // Compute point-wise multiplication + std::vector c_hat_k{num_chunks, bvt{}}; + for(std::size_t j = 1; j < num_chunks; j += 2) + { + c_hat_k[j] = unsigned_multiplier(a_hat_k[j], b_hat_k[j]); + std::cerr << "c_hat_k[" << j << "]: " << beautify(c_hat_k[j]) << std::endl; + } + + // Apply inverse NTT + for(std::size_t nu = address_bits(num_chunks) - 1; nu > 0; --nu) + { + const std::size_t bit_nu_plus_1 = (std::size_t)1 << nu; + std::size_t bits_up_to_nu_plus_1 = 0; + for(std::size_t i = 0; i < nu; ++i) + bits_up_to_nu_plus_1 |= 1 << i; + + // we only need odd ones + for(std::size_t k = 1; k < num_chunks; k += 2) + { + if((k & bit_nu_plus_1) == 0) + continue; + + bvt c_hat_k_nu_plus_1_bit_is_zero = c_hat_k[k & ~bit_nu_plus_1]; + + c_hat_k[k & ~bit_nu_plus_1] = shift( + cyclic_add(c_hat_k_nu_plus_1_bit_is_zero, c_hat_k[k]), + shiftt::ROTATE_RIGHT, + 1); + std::cerr << "c_hat_k[" << nu << "](" << (k & ~bit_nu_plus_1) + << "): " << beautify(c_hat_k[k & ~bit_nu_plus_1]) << std::endl; + + const std::size_t chi = (k & bits_up_to_nu_plus_1) + << (address_bits(num_chunks) - 1 - nu); + const std::size_t omega = m % 2 == 1 ? 2 : 4; + const std::size_t shift_dist = chi * omega / 2 + 1; + std::cerr << "SHIFT: " << shift_dist << std::endl; + + c_hat_k[k] = shift( + cyclic_add( + c_hat_k_nu_plus_1_bit_is_zero, + shift(c_hat_k[k], shiftt::ROTATE_LEFT, (std::size_t)1 << n)), + shiftt::ROTATE_RIGHT, + shift_dist); + std::cerr << "c_hat_k[" << nu << "](" << k + << "): " << beautify(c_hat_k[k]) << std::endl; + } + } + // Reduce by F_n + std::vector z_j_mod_F_n; + z_j_mod_F_n.reserve(num_chunks / 2); + for(std::size_t j = 0; j < num_chunks / 2; ++j) + { + // reverse n - 1 (n - 2 if m is even) bits of j + std::size_t k = 0; + for(std::size_t nu = 0; nu < address_bits(num_chunks) - 1; ++nu) + { + k |= (j & (1 << nu)) >> nu; + k <<= 1; + } + k |= 1; + std::cerr << "j " << j << " maps to " << k << std::endl; + z_j_mod_F_n.push_back(reduce_to_mod_F_n(c_hat_k[k])); + std::cerr << "z_j_mod_F_n[" << j << "]: " << beautify(z_j_mod_F_n[j]) + << std::endl; + } + + // Compute final coefficients as eta + delta * F_n where delta = eta - xi for + // eta z_j and xi c_hat_k. + for(std::size_t j = 0; j < num_chunks / 2; ++j) + { + bvt eta = z_j_mod_F_n[j]; + std::cerr << "eta[" << j << "]: " << beautify(eta) << std::endl; + bvt xi = z_j[j]; + std::cerr << "xi[" << j << "]: " << beautify(xi) << std::endl; + // TODO: couldn't we do this over just xi.size() bits instead? + bvt delta = sub(eta, zero_extension(xi, eta.size())); + delta.resize(xi.size()); + std::cerr << "delta[" << j << "]: " << beautify(delta) << std::endl; + z_j[j] = add( + zero_extension(eta, two_to_m_plus_1), + add( + shift( + zero_extension(delta, two_to_m_plus_1), + shiftt::SHIFT_LEFT, + (std::size_t)1 << n), + zero_extension(delta, two_to_m_plus_1))); + std::cerr << "z_j[" << j << "]: " << beautify(z_j[j]) << std::endl; + } + + bvt result = zeros(two_to_m_plus_1); + for(std::size_t j = 0; j < num_chunks / 2; ++j) + { + if(chunk_size * j >= a.size()) + break; + result = add(result, shift(z_j[j], shiftt::SHIFT_LEFT, chunk_size * j)); + } + std::cerr << "result: " << beautify(result) << std::endl; + result.resize(a.size()); + std::cerr << "result resized: " << beautify(result) << std::endl; + + return result; +} + bvt bv_utilst::unsigned_multiplier_no_overflow( const bvt &op0, const bvt &op1) @@ -1011,7 +2472,15 @@ bvt bv_utilst::signed_multiplier(const bvt &op0, const bvt &op1) bvt neg0=cond_negate(op0, sign0); bvt neg1=cond_negate(op1, sign1); +#ifdef USE_KARATSUBA + bvt result = unsigned_karatsuba_multiplier(neg0, neg1); +#elif defined(USE_TOOM_COOK) + bvt result = unsigned_toom_cook_multiplier(neg0, neg1); +#elif defined(USE_SCHOENHAGE_STRASSEN) + bvt result = unsigned_schoenhage_strassen_multiplier(neg0, neg1); +#else bvt result=unsigned_multiplier(neg0, neg1); +#endif literalt result_sign=prop.lxor(sign0, sign1); @@ -1079,7 +2548,18 @@ bvt bv_utilst::multiplier( switch(rep) { case representationt::SIGNED: return signed_multiplier(op0, op1); +#ifdef USE_KARATSUBA + case representationt::UNSIGNED: + return unsigned_karatsuba_multiplier(op0, op1); +#elif defined(USE_TOOM_COOK) + case representationt::UNSIGNED: + return unsigned_toom_cook_multiplier(op0, op1); +#elif defined(USE_SCHOENHAGE_STRASSEN) + case representationt::UNSIGNED: + return unsigned_schoenhage_strassen_multiplier(op0, op1); +#else case representationt::UNSIGNED: return unsigned_multiplier(op0, op1); +#endif } UNREACHABLE; diff --git a/src/solvers/flattening/bv_utils.h b/src/solvers/flattening/bv_utils.h index a9a195257a6..87e41856781 100644 --- a/src/solvers/flattening/bv_utils.h +++ b/src/solvers/flattening/bv_utils.h @@ -79,6 +79,9 @@ class bv_utilst bvt shift(const bvt &op, const shiftt shift, const bvt &distance); bvt unsigned_multiplier(const bvt &op0, const bvt &op1); + bvt unsigned_karatsuba_multiplier(const bvt &op0, const bvt &op1); + bvt unsigned_toom_cook_multiplier(const bvt &op0, const bvt &op1); + bvt unsigned_schoenhage_strassen_multiplier(const bvt &a, const bvt &b); bvt signed_multiplier(const bvt &op0, const bvt &op1); bvt multiplier(const bvt &op0, const bvt &op1, representationt rep); bvt multiplier_no_overflow(