Skip to content

Commit 620cef1

Browse files
Merge pull request #4 from clearmatics/exceptions-fix
Exceptions fix - merge scipr-lab/libfqfft#12 (depends on #6)
2 parents 249c88a + 4877aed commit 620cef1

12 files changed

+140
-16
lines changed

libfqfft/evaluation_domain/domains/arithmetic_sequence_domain.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ namespace libfqfft {
2828
FieldT arithmetic_generator;
2929
void do_precomputation();
3030

31+
static bool valid_for_size(const size_t m);
32+
3133
arithmetic_sequence_domain(const size_t m);
3234

3335
void FFT(std::vector<FieldT> &a);

libfqfft/evaluation_domain/domains/arithmetic_sequence_domain.tcc

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,20 @@
2323

2424
namespace libfqfft {
2525

26+
template<typename FieldT>
27+
bool arithmetic_sequence_domain<FieldT>::valid_for_size(const size_t m)
28+
{
29+
if (m <= 1) {
30+
return false;
31+
}
32+
33+
if (FieldT::arithmetic_generator() == FieldT::zero()) {
34+
return false;
35+
}
36+
37+
return true;
38+
}
39+
2640
template<typename FieldT>
2741
arithmetic_sequence_domain<FieldT>::arithmetic_sequence_domain(const size_t m) : evaluation_domain<FieldT>(m)
2842
{
@@ -42,7 +56,7 @@ void arithmetic_sequence_domain<FieldT>::FFT(std::vector<FieldT> &a)
4256

4357
/* Monomial to Newton */
4458
monomial_to_newton_basis(a, this->subproduct_tree, this->m);
45-
59+
4660
/* Newton to Evaluation */
4761
std::vector<FieldT> S(this->m); /* i! * arithmetic_generator */
4862
S[0] = FieldT::one();
@@ -70,7 +84,7 @@ template<typename FieldT>
7084
void arithmetic_sequence_domain<FieldT>::iFFT(std::vector<FieldT> &a)
7185
{
7286
if (a.size() != this->m) throw DomainSizeException("arithmetic: expected a.size() == this->m");
73-
87+
7488
if (!this->precomputation_sentinel) do_precomputation();
7589

7690
/* Interpolation to Newton */
@@ -152,7 +166,7 @@ std::vector<FieldT> arithmetic_sequence_domain<FieldT>::evaluate_all_lagrange_po
152166

153167
std::vector<FieldT> w(this->m);
154168
w[0] = g_vanish.inverse() * (this->arithmetic_generator^(this->m-1));
155-
169+
156170
l[0] = l_vanish * l[0].inverse() * w[0];
157171
for (size_t i = 1; i < this->m; i++)
158172
{

libfqfft/evaluation_domain/domains/basic_radix2_domain.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ class basic_radix2_domain : public evaluation_domain<FieldT> {
2626

2727
FieldT omega;
2828

29+
static bool valid_for_size(const size_t m);
30+
2931
basic_radix2_domain(const size_t m);
3032

3133
void FFT(std::vector<FieldT> &a);

libfqfft/evaluation_domain/domains/basic_radix2_domain.tcc

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,20 @@
2222

2323
namespace libfqfft {
2424

25+
template<typename FieldT>
26+
bool basic_radix2_domain<FieldT>::valid_for_size(const size_t m)
27+
{
28+
if (m <= 1) {
29+
return false;
30+
}
31+
32+
if (!libff::has_root_of_unity<FieldT>(m)) {
33+
return false;
34+
}
35+
36+
return true;
37+
}
38+
2539
template<typename FieldT>
2640
basic_radix2_domain<FieldT>::basic_radix2_domain(const size_t m) : evaluation_domain<FieldT>(m)
2741
{

libfqfft/evaluation_domain/domains/extended_radix2_domain.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ class extended_radix2_domain : public evaluation_domain<FieldT> {
2727
FieldT omega;
2828
FieldT shift;
2929

30+
static bool valid_for_size(const size_t m);
31+
3032
extended_radix2_domain(const size_t m);
3133

3234
void FFT(std::vector<FieldT> &a);

libfqfft/evaluation_domain/domains/extended_radix2_domain.tcc

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,32 @@
1717

1818
namespace libfqfft {
1919

20+
template<typename FieldT>
21+
bool extended_radix2_domain<FieldT>::valid_for_size(const size_t m)
22+
{
23+
if (m <= 1) {
24+
return false;
25+
}
26+
27+
// Will `get_root_of_unity` throw?
28+
if (!std::is_same<FieldT, libff::Double>::value)
29+
{
30+
const size_t logm = libff::log2(m);
31+
32+
if (logm != (FieldT::s + 1)) {
33+
return false;
34+
}
35+
}
36+
37+
size_t small_m = m / 2;
38+
39+
if (!libff::has_root_of_unity<FieldT>(small_m)) {
40+
return false;
41+
}
42+
43+
return true;
44+
}
45+
2046
template<typename FieldT>
2147
extended_radix2_domain<FieldT>::extended_radix2_domain(const size_t m) : evaluation_domain<FieldT>(m)
2248
{

libfqfft/evaluation_domain/domains/geometric_sequence_domain.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ namespace libfqfft {
2727
std::vector<FieldT> geometric_triangular_sequence;
2828
void do_precomputation();
2929

30+
static bool valid_for_size(const size_t m);
31+
3032
geometric_sequence_domain(const size_t m);
3133

3234
void FFT(std::vector<FieldT> &a);

libfqfft/evaluation_domain/domains/geometric_sequence_domain.tcc

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,19 +23,33 @@
2323

2424
namespace libfqfft {
2525

26+
template<typename FieldT>
27+
bool geometric_sequence_domain<FieldT>::valid_for_size(const size_t m)
28+
{
29+
if (m <= 1) {
30+
return false;
31+
}
32+
33+
if (FieldT::geometric_generator() == FieldT::zero()) {
34+
return false;
35+
}
36+
37+
return true;
38+
}
39+
2640
template<typename FieldT>
2741
geometric_sequence_domain<FieldT>::geometric_sequence_domain(const size_t m) : evaluation_domain<FieldT>(m)
2842
{
2943
if (m <= 1) throw InvalidSizeException("geometric(): expected m > 1");
3044
if (FieldT::geometric_generator() == FieldT::zero())
3145
throw InvalidSizeException("geometric(): expected FieldT::geometric_generator() != FieldT::zero()");
32-
46+
3347
precomputation_sentinel = 0;
3448
}
3549

3650
template<typename FieldT>
3751
void geometric_sequence_domain<FieldT>::FFT(std::vector<FieldT> &a)
38-
{
52+
{
3953
if (a.size() != this->m) throw DomainSizeException("geometric: expected a.size() == this->m");
4054

4155
if (!this->precomputation_sentinel) do_precomputation();
@@ -71,7 +85,7 @@ template<typename FieldT>
7185
void geometric_sequence_domain<FieldT>::iFFT(std::vector<FieldT> &a)
7286
{
7387
if (a.size() != this->m) throw DomainSizeException("geometric: expected a.size() == this->m");
74-
88+
7589
if (!this->precomputation_sentinel) do_precomputation();
7690

7791
/* Interpolation to Newton */

libfqfft/evaluation_domain/domains/step_radix2_domain.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class step_radix2_domain : public evaluation_domain<FieldT> {
2929
FieldT big_omega;
3030
FieldT small_omega;
3131

32+
static bool valid_for_size(const size_t m);
33+
3234
step_radix2_domain(const size_t m);
3335

3436
void FFT(std::vector<FieldT> &a);

libfqfft/evaluation_domain/domains/step_radix2_domain.tcc

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,33 @@
1717

1818
namespace libfqfft {
1919

20+
template<typename FieldT>
21+
bool step_radix2_domain<FieldT>::valid_for_size(const size_t m)
22+
{
23+
if (m <= 1) {
24+
return false;
25+
}
26+
27+
const size_t big_m = 1ul<<(libff::log2(m)-1);
28+
const size_t small_m = m - big_m;
29+
30+
if (small_m != 1ul<<libff::log2(small_m)) {
31+
return false;
32+
}
33+
34+
// omega
35+
if (!libff::has_root_of_unity<FieldT>(1ul<<libff::log2(m))) {
36+
return false;
37+
}
38+
39+
// small_omega
40+
if (!libff::has_root_of_unity<FieldT>(1ul<<libff::log2(small_m))) {
41+
return false;
42+
}
43+
44+
return true;
45+
}
46+
2047
template<typename FieldT>
2148
step_radix2_domain<FieldT>::step_radix2_domain(const size_t m) : evaluation_domain<FieldT>(m)
2249
{
@@ -30,7 +57,7 @@ step_radix2_domain<FieldT>::step_radix2_domain(const size_t m) : evaluation_doma
3057

3158
try { omega = libff::get_root_of_unity<FieldT>(1ul<<libff::log2(m)); }
3259
catch (const std::invalid_argument& e) { throw DomainSizeException(e.what()); }
33-
60+
3461
big_omega = omega.squared();
3562
small_omega = libff::get_root_of_unity<FieldT>(small_m);
3663
}

libfqfft/evaluation_domain/evaluation_domain.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#define EVALUATION_DOMAIN_HPP_
2828

2929
#include <vector>
30+
#include <libff/common/double.hpp>
3031

3132
namespace libfqfft {
3233

libfqfft/evaluation_domain/get_evaluation_domain.tcc

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,15 +38,33 @@ std::shared_ptr<evaluation_domain<FieldT> > get_evaluation_domain(const size_t m
3838
const size_t small = min_size - big;
3939
const size_t rounded_small = (1ul<<libff::log2(small));
4040

41-
try { result.reset(new basic_radix2_domain<FieldT>(min_size)); }
42-
catch(...) { try { result.reset(new extended_radix2_domain<FieldT>(min_size)); }
43-
catch(...) { try { result.reset(new step_radix2_domain<FieldT>(min_size)); }
44-
catch(...) { try { result.reset(new basic_radix2_domain<FieldT>(big + rounded_small)); }
45-
catch(...) { try { result.reset(new extended_radix2_domain<FieldT>(big + rounded_small)); }
46-
catch(...) { try { result.reset(new step_radix2_domain<FieldT>(big + rounded_small)); }
47-
catch(...) { try { result.reset(new geometric_sequence_domain<FieldT>(min_size)); }
48-
catch(...) { try { result.reset(new arithmetic_sequence_domain<FieldT>(min_size)); }
49-
catch(...) { throw DomainSizeException("get_evaluation_domain: no matching domain"); }}}}}}}}
41+
if (basic_radix2_domain<FieldT>::valid_for_size(min_size)) {
42+
result.reset(new basic_radix2_domain<FieldT>(min_size));
43+
}
44+
else if (extended_radix2_domain<FieldT>::valid_for_size(min_size)) {
45+
result.reset(new extended_radix2_domain<FieldT>(min_size));
46+
}
47+
else if (step_radix2_domain<FieldT>::valid_for_size(min_size)) {
48+
result.reset(new step_radix2_domain<FieldT>(min_size));
49+
}
50+
else if (basic_radix2_domain<FieldT>::valid_for_size(big + rounded_small)) {
51+
result.reset(new basic_radix2_domain<FieldT>(big + rounded_small));
52+
}
53+
else if (extended_radix2_domain<FieldT>::valid_for_size(big + rounded_small)) {
54+
result.reset(new extended_radix2_domain<FieldT>(big + rounded_small));
55+
}
56+
else if (step_radix2_domain<FieldT>::valid_for_size(big + rounded_small)) {
57+
result.reset(new step_radix2_domain<FieldT>(big + rounded_small));
58+
}
59+
else if (geometric_sequence_domain<FieldT>::valid_for_size(min_size)) {
60+
result.reset(new geometric_sequence_domain<FieldT>(min_size));
61+
}
62+
else if (arithmetic_sequence_domain<FieldT>::valid_for_size(min_size)) {
63+
result.reset(new arithmetic_sequence_domain<FieldT>(min_size));
64+
}
65+
else {
66+
throw DomainSizeException("get_evaluation_domain: no matching domain");
67+
}
5068

5169
return result;
5270
}

0 commit comments

Comments
 (0)