Skip to content

Commit 9d1f443

Browse files
committed
Implement defaults
1 parent 7a94ea3 commit 9d1f443

6 files changed

Lines changed: 121 additions & 77 deletions

File tree

pythran/pythonic/include/numpy/fft/irfft.hpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,9 @@ namespace numpy
1111
namespace fft
1212
{
1313

14-
template <class T, class pS>
14+
template <class T, class pS, typename U, typename V, typename W>
1515
types::ndarray<double, types::array<long, std::tuple_size<pS>::value>>
16-
irfft(types::ndarray<T, pS> const &, long NFFT = -1, long axis = -1,
17-
types::str renorm = "");
16+
irfft(types::ndarray<T, pS> const &, U NFFT, V axis, W renorm);
1817

1918
NUMPY_EXPR_TO_NDARRAY0_DECL(irfft);
2019
DEFINE_FUNCTOR(pythonic::numpy::fft, irfft);

pythran/pythonic/include/numpy/fft/rfft.hpp

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,10 @@ namespace numpy
1212
{
1313

1414
// I'm sure there's a better way to do this.
15-
template <class T, class pS, typename U>
15+
template <class T, class pS, typename U, typename V, typename W>
1616
types::ndarray<std::complex<typename std::common_type<T, double>::type>,
1717
types::array<long, std::tuple_size<pS>::value>>
18-
rfft(types::ndarray<T, pS> const &input, long NFFT = -1, long axis = -1,
19-
U renorm = types::str(""));
20-
21-
template <class T, class pS>
22-
types::ndarray<std::complex<typename std::common_type<T, double>::type>,
23-
types::array<long, std::tuple_size<pS>::value>>
24-
rfft(types::ndarray<T, pS> const &input, long NFFT = -1, long axis = -1);
18+
rfft(types::ndarray<T, pS> const &input, U NFFT, V axis, W renorm);
2519

2620
NUMPY_EXPR_TO_NDARRAY0_DECL(rfft);
2721
DEFINE_FUNCTOR(pythonic::numpy::fft, rfft);

pythran/pythonic/numpy/fft/irfft.hpp

Lines changed: 40 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -86,16 +86,49 @@ namespace numpy
8686
return out_array;
8787
}
8888

89-
template <class T, class pS>
89+
// These functions help handle None inputs for default values without
90+
// relying on the C++ default mechanism.
91+
bool testNorm(types::none_type param)
92+
{
93+
return false;
94+
}
95+
bool testNorm(types::str param)
96+
{
97+
if (param == "ortho")
98+
return 1;
99+
else {
100+
throw types::ValueError("norm should be None or \"ortho\"");
101+
return 0;
102+
}
103+
}
104+
105+
long testLong(types::none_type param, long def_val)
106+
{
107+
return def_val;
108+
}
109+
long testLong(long N, long def_val)
110+
{
111+
return N;
112+
}
113+
114+
template <class T, class pS, typename U, typename V, typename W>
90115
types::ndarray<double, types::array<long, std::tuple_size<pS>::value>>
91-
irfft(types::ndarray<T, pS> const &in_array, long NFFT, long axis,
92-
types::str normalize)
116+
irfft(types::ndarray<T, pS> const &in_array, U _NFFT, V _axis, W renorm)
93117
{
94118
auto constexpr N = std::tuple_size<pS>::value;
95-
bool norm = (normalize == "ortho");
96-
if (NFFT == -1)
97-
NFFT = 2 * (std::get<N - 1>(in_array.shape()) - 1);
98-
if (axis != -1 && axis != N - 1) {
119+
bool norm = testNorm(renorm);
120+
// Handle None for axis input.
121+
long axis = testLong(_axis, -1);
122+
long LN = (long)N;
123+
if (axis >= LN)
124+
throw types::ValueError("axis out of bounds1");
125+
if (axis <= -LN - 1)
126+
throw types::ValueError("axis out of bounds");
127+
// Handle None for NFFT. Map -1 -> N-1 etc...
128+
axis = (axis + N) % N;
129+
long def_val = 2 * (sutils::array(in_array.shape())[axis] - 1);
130+
long NFFT = testLong(_NFFT, def_val);
131+
if (axis != N - 1) {
99132
// Swap axis if the FFT must be computed on an axis that's not the last
100133
// one.
101134
auto swapped_array = swapaxes(in_array, axis, N - 1);

pythran/pythonic/numpy/fft/rfft.hpp

Lines changed: 32 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -90,36 +90,50 @@ namespace numpy
9090
return out_array;
9191
}
9292

93-
template <class T, class pS>
94-
types::ndarray<std::complex<typename std::common_type<T, double>::type>,
95-
types::array<long, std::tuple_size<pS>::value>>
96-
rfft(types::ndarray<T, pS> const &in_array, long NFFT, long axis)
93+
// These functions help handle None inputs for default values without
94+
// relying on the C++ default mechanism.
95+
bool testNorm(types::none_type param)
9796
{
98-
return rfft(in_array, NFFT, axis, "");
97+
return false;
98+
}
99+
bool testNorm(types::str param)
100+
{
101+
if (param == "ortho")
102+
return 1;
103+
else {
104+
throw types::ValueError("norm should be None or \"ortho\"");
105+
return 0;
106+
}
99107
}
100108

101-
// This is kludgy, and I'm sure there's a better way to do this. Jeanl
102-
bool testThis(types::none_type param)
109+
long testLong(types::none_type param, long def_val)
103110
{
104-
return false;
111+
return def_val;
105112
}
106-
bool testThis(types::str param)
113+
long testLong(long N, long def_val)
107114
{
108-
return param == "ortho";
115+
return N;
109116
}
110117

111-
template <class T, class pS, typename U>
118+
template <class T, class pS, typename U, typename V, typename W>
112119
types::ndarray<std::complex<typename std::common_type<T, double>::type>,
113120
types::array<long, std::tuple_size<pS>::value>>
114-
rfft(types::ndarray<T, pS> const &in_array, long NFFT, long axis,
115-
U normalize)
121+
rfft(types::ndarray<T, pS> const &in_array, U _NFFT, V _axis, W normalize)
116122
{
117-
bool norm = testThis(normalize);
123+
bool norm = testNorm(normalize);
118124
auto constexpr N = std::tuple_size<pS>::value;
119-
120-
if (NFFT == -1)
121-
NFFT = std::get<N - 1>(in_array.shape());
122-
if (axis != -1 && axis != N - 1) {
125+
// Handle None for axis input.
126+
long axis = testLong(_axis, -1);
127+
// Handle None for NFFT. Map -1 -> N-1 etc...
128+
long LN = (long)N;
129+
if (axis >= LN)
130+
throw types::ValueError("axis out of bounds1");
131+
if (axis <= -LN - 1)
132+
throw types::ValueError("axis out of bounds");
133+
axis = (axis + LN) % LN;
134+
long def_val = sutils::array(in_array.shape())[axis];
135+
long NFFT = testLong(_NFFT, def_val);
136+
if (axis != LN - 1) {
123137
// Swap axis if the FFT must be computed on an axis that's not the last
124138
// one.
125139
auto swapped_array = swapaxes(in_array, axis, N - 1);

pythran/tables.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3869,8 +3869,12 @@ def partialsum(seq):
38693869
signature=_numpy_float_unary_op_float_signature
38703870
),
38713871
"fft": {
3872-
"rfft": FunctionIntr(args=(), global_effects=True),
3873-
"irfft": FunctionIntr(args=(), global_effects=True),
3872+
"rfft": FunctionIntr(args=('input','NFFT','axis','norm'),
3873+
defaults=(None,None,None),
3874+
global_effects=True),
3875+
"irfft": FunctionIntr(args=('input','NFFT','axis','norm'),
3876+
defaults=(None,None,None),
3877+
global_effects=True),
38743878
},
38753879
"random": {
38763880
"binomial": FunctionIntr(args=('n', 'p', 'size'),

pythran/transformations/remove_named_arguments.py

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def handle_keywords(self, func, node, offset=0):
5050
defaults = func.args.defaults
5151
keywords = {func_argument_names[kw.arg]: kw.value
5252
for kw in node.keywords}
53-
node.args.extend([None] * (1 + max(keywords.keys()) - len(node.args)))
53+
node.args.extend([None] * (nargs - len(node.args)))
5454

5555
replacements = {}
5656
for index, arg in enumerate(node.args):
@@ -59,45 +59,45 @@ def handle_keywords(self, func, node, offset=0):
5959
replacements[index] = deepcopy(keywords[index])
6060
else: # must be a default value
6161
replacements[index] = deepcopy(defaults[index - nargs])
62+
6263
return replacements
6364

6465
def visit_Call(self, node):
65-
if node.keywords:
66-
self.update = True
67-
68-
aliases = self.aliases[node.func]
69-
assert aliases, "at least one alias"
70-
71-
# all aliases should have the same structural type...
72-
# call to self.handle_keywords raises an exception otherwise
73-
try:
74-
replacements = {}
75-
for func_alias in aliases:
76-
handle_special_calls(func_alias, node)
77-
78-
if func_alias is None: # aliasing computation failed
79-
pass
80-
elif isinstance(func_alias, ast.Call): # nested function
81-
# func_alias looks like functools.partial(foo, a)
82-
# so we reorder using alias for 'foo'
83-
offset = len(func_alias.args) - 1
84-
call = func_alias.args[0]
85-
for func_alias in self.aliases[call]:
86-
replacements = self.handle_keywords(func_alias,
87-
node, offset)
88-
else:
89-
replacements = self.handle_keywords(func_alias, node)
90-
91-
# if we reach this point, we should have a replacement
92-
# candidate, or nothing structural typing issues would have
93-
# raised an exception in handle_keywords
94-
if replacements:
95-
for index, value in replacements.items():
96-
node.args[index] = value
97-
node.keywords = []
98-
99-
except KeyError as ve:
100-
err = ("function uses an unknown (or unsupported) keyword "
101-
"argument `{}`".format(ve.args[0]))
102-
raise PythranSyntaxError(err, node)
66+
self.update = True
67+
68+
aliases = self.aliases[node.func]
69+
assert aliases, "at least one alias"
70+
71+
# all aliases should have the same structural type...
72+
# call to self.handle_keywords raises an exception otherwise
73+
try:
74+
replacements = {}
75+
for func_alias in aliases:
76+
handle_special_calls(func_alias, node)
77+
78+
if func_alias is None: # aliasing computation failed
79+
pass
80+
elif isinstance(func_alias, ast.Call): # nested function
81+
# func_alias looks like functools.partial(foo, a)
82+
# so we reorder using alias for 'foo'
83+
offset = len(func_alias.args) - 1
84+
call = func_alias.args[0]
85+
for func_alias in self.aliases[call]:
86+
replacements = self.handle_keywords(func_alias,
87+
node, offset)
88+
else:
89+
replacements = self.handle_keywords(func_alias, node)
90+
91+
# if we reach this point, we should have a replacement
92+
# candidate, or nothing structural typing issues would have
93+
# raised an exception in handle_keywords
94+
if replacements:
95+
for index, value in replacements.items():
96+
node.args[index] = value
97+
node.keywords = []
98+
99+
except KeyError as ve:
100+
err = ("function uses an unknown (or unsupported) keyword "
101+
"argument `{}`".format(ve.args[0]))
102+
raise PythranSyntaxError(err, node)
103103
return self.generic_visit(node)

0 commit comments

Comments
 (0)