Skip to content

Commit 6615e31

Browse files
author
Sylvain MARIE
committed
Fixed bug when the signature of the function to create contains non-locally available type hints. Fixes #32
1 parent 17fd618 commit 6615e31

File tree

3 files changed

+84
-17
lines changed

3 files changed

+84
-17
lines changed

makefun/main.py

Lines changed: 57 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,12 @@ def _is_generator_func(func_impl):
246246
return isgeneratorfunction(func_impl)
247247

248248

249-
class DefaultHolder:
249+
class _SymbolRef:
250+
"""
251+
A class used to protect signature default values and type hints when the local context would not be able
252+
to evaluate them properly when the new function is created. In this case we store them under a known name,
253+
we add that name to the locals(), and we use this symbol that has a repr() equal to the name.
254+
"""
250255
__slots__ = 'varname'
251256

252257
def __init__(self, varname):
@@ -267,22 +272,16 @@ def get_signature_string(func_name, func_signature, evaldict):
267272
# protect the parameters if needed
268273
new_params = []
269274
for p_name, p in func_signature.parameters.items():
270-
if p.default is not Parameter.empty and not isinstance(p.default, (int, str, float, bool)):
271-
# check if the repr() of the default value is equal to itself.
272-
needs_protection = True
273-
try:
274-
deflt = eval(repr(p.default))
275-
needs_protection = deflt != p.default
276-
except SyntaxError:
277-
pass
278-
279-
# if we have any problem, we need to protect the default value
280-
if needs_protection:
281-
# store the object in the evaldict and insert name
282-
varname = "DEFAULT_%s" % p_name
283-
evaldict[varname] = p.default
284-
p = Parameter(p.name, kind=p.kind, default=DefaultHolder(varname), annotation=p.annotation)
275+
# if default value can not be evaluated, protect it
276+
default_needs_protection = _signature_symbol_needs_protection(p.default, evaldict)
277+
new_default = _protect_signature_symbol(p.default, default_needs_protection, "DEFAULT_%s" % p_name, evaldict)
285278

279+
# if type hint can not be evaluated, protect it
280+
annotation_needs_protection = _signature_symbol_needs_protection(p.annotation, evaldict)
281+
new_annotation = _protect_signature_symbol(p.annotation, annotation_needs_protection, "HINT_%s" % p_name, evaldict)
282+
283+
# replace the parameter with the possibly new default and hint
284+
p = Parameter(p.name, kind=p.kind, default=new_default, annotation=new_annotation)
286285
new_params.append(p)
287286

288287
# copy signature object
@@ -292,6 +291,48 @@ def get_signature_string(func_name, func_signature, evaldict):
292291
return "%s%s:" % (func_name, s)
293292

294293

294+
def _signature_symbol_needs_protection(symbol, evaldict):
295+
"""
296+
Helper method for signature symbols (defaults, type hints) protection.
297+
298+
Returns True if the given symbol needs to be protected - that is, if its repr() can not be correctly evaluated with current evaldict.
299+
:param symbol:
300+
:return:
301+
"""
302+
if symbol is not None and symbol is not Parameter.empty and not isinstance(symbol, (int, str, float, bool)):
303+
# check if the repr() of the default value is equal to itself.
304+
try:
305+
deflt = eval(repr(symbol), evaldict)
306+
needs_protection = deflt != symbol
307+
except SyntaxError:
308+
needs_protection = True
309+
else:
310+
needs_protection = False
311+
312+
return needs_protection
313+
314+
315+
def _protect_signature_symbol(val, needs_protection, varname, evaldict):
316+
"""
317+
Helper method for signature symbols (defaults, type hints) protection.
318+
319+
Returns either `val`, or a protection symbol. In that case the protection symbol
320+
is created with name `varname` and inserted into `evaldict`
321+
322+
:param val:
323+
:param needs_protection:
324+
:param varname:
325+
:param evaldict:
326+
:return:
327+
"""
328+
if needs_protection:
329+
# store the object in the evaldict and insert name
330+
evaldict[varname] = val
331+
return _SymbolRef(varname)
332+
else:
333+
return val
334+
335+
295336
def get_signature_from_string(func_sig_str, evaldict):
296337
"""
297338
Creates a `Signature` object from the given function signature string.

makefun/tests/_test_py35.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,16 @@ async def my_native_coroutine_handler(sleep_time):
99
return sleep_time
1010

1111
return my_native_coroutine_handler
12+
13+
14+
def make_ref_function():
15+
"""Returns a function with a type hint that is locally defined """
16+
17+
# the symbol is defined here, so it is not seen outside
18+
class A:
19+
pass
20+
21+
def ref(a: A):
22+
pass
23+
24+
return ref

makefun/tests/test_advanced.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import logging
22
import sys
3-
from copy import copy, deepcopy
43

54
import pytest
65

@@ -128,3 +127,17 @@ def g(self):
128127
# our mod
129128
assert C.D.g.__qualname__ == 'test_qualname_when_nested.<locals>.C.D.g'
130129
assert str(signature(C.D.g)) == "(self, a)"
130+
131+
132+
@pytest.mark.skipif(sys.version_info < (3, 5), reason="requires python 3.5 or higher (non-comment type hints)")
133+
def test_type_hint_error():
134+
""" Test for https://github.com/smarie/python-makefun/issues/32 """
135+
136+
from makefun.tests._test_py35 import make_ref_function
137+
ref_f = make_ref_function()
138+
139+
@wraps(ref_f)
140+
def foo(a):
141+
return a
142+
143+
assert foo(10) == 10

0 commit comments

Comments
 (0)