Skip to content

Commit ba89ffe

Browse files
committed
Add throw error on norm value for fft routines
1 parent caf365c commit ba89ffe

2 files changed

Lines changed: 17 additions & 5 deletions

File tree

pythran/pythonic/numpy/fft/irfft.hpp

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,13 @@ namespace numpy
8888

8989
// These functions help handle None inputs for default values without relying on the C++ default mechanism.
9090
bool testNorm(types::none_type param) { return false; }
91-
bool testNorm(types::str param) { return param == "ortho"; }
91+
bool testNorm(types::str param) {
92+
if (param == "ortho") return 1;
93+
else {
94+
throw types::ValueError("norm should be None or \"ortho\"");
95+
return 0;
96+
}
97+
}
9298

9399
long testLong(types::none_type param, long def_val) { return def_val;}
94100
long testLong(long N, long def_val) { return N;}
@@ -106,10 +112,10 @@ namespace numpy
106112
if (axis >= LN) throw types::ValueError("axis out of bounds1");
107113
if (axis <= -LN-1) throw types::ValueError("axis out of bounds");
108114
// Handle None for NFFT. Map -1 -> N-1 etc...
109-
long idx = (axis+N)%N;
110-
long def_val = 2*(sutils::array(in_array.shape())[idx] - 1);
115+
axis = (axis+N)%N;
116+
long def_val = 2*(sutils::array(in_array.shape())[axis] - 1);
111117
long NFFT = testLong(_NFFT,def_val);
112-
if (axis != -1 && axis != N - 1) {
118+
if (axis != N - 1) {
113119
// Swap axis if the FFT must be computed on an axis that's not the last
114120
// one.
115121
auto swapped_array = swapaxes(in_array, axis, N - 1);

pythran/pythonic/numpy/fft/rfft.hpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,13 @@ namespace numpy
9292

9393
// These functions help handle None inputs for default values without relying on the C++ default mechanism.
9494
bool testNorm(types::none_type param) { return false; }
95-
bool testNorm(types::str param) { return param == "ortho"; }
95+
bool testNorm(types::str param) {
96+
if (param == "ortho") return 1;
97+
else {
98+
throw types::ValueError("norm should be None or \"ortho\"");
99+
return 0;
100+
}
101+
}
96102

97103
long testLong(types::none_type param, long def_val) { return def_val;}
98104
long testLong(long N, long def_val) { return N;}

0 commit comments

Comments
 (0)