Skip to content

Commit c34d297

Browse files
committed
Add tests, cleanups
1 parent 7a935f5 commit c34d297

File tree

4 files changed

+68
-25
lines changed

4 files changed

+68
-25
lines changed

src/prefect/core/task.py

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -32,23 +32,6 @@
3232
from prefect.core import Edge # pylint: disable=W0611
3333

3434
VAR_KEYWORD = inspect.Parameter.VAR_KEYWORD
35-
EXTRA_CALL_PARAMETERS = [
36-
inspect.Parameter(
37-
"mapped", inspect.Parameter.KEYWORD_ONLY, default=False, annotation=bool
38-
),
39-
inspect.Parameter(
40-
"task_args", inspect.Parameter.KEYWORD_ONLY, default=None, annotation=dict
41-
),
42-
inspect.Parameter(
43-
"upstream_tasks",
44-
inspect.Parameter.KEYWORD_ONLY,
45-
default=None,
46-
annotation=Iterable[Any],
47-
),
48-
inspect.Parameter(
49-
"flow", inspect.Parameter.KEYWORD_ONLY, default=None, annotation="Flow",
50-
),
51-
]
5235

5336

5437
def _validate_run_signature(run: Callable) -> None:
@@ -433,7 +416,7 @@ def copy(self, **task_args: Any) -> "Task":
433416
return new
434417

435418
@property
436-
def __signature__(self):
419+
def __signature__(self) -> inspect.Signature:
437420
"""Dynamically generate the signature, replacing ``*args``/``**kwargs``
438421
with parameters from ``run``"""
439422
if not hasattr(self, "_cached_signature"):
@@ -1200,3 +1183,12 @@ def serialize(self) -> Dict[str, Any]:
12001183
- dict representing this parameter
12011184
"""
12021185
return prefect.serialization.task.ParameterSchema().dump(self)
1186+
1187+
1188+
# All keyword-only arguments to Task.__call__, used for dynamically generating
1189+
# Signature objects for Task objects
1190+
EXTRA_CALL_PARAMETERS = [
1191+
p
1192+
for p in inspect.Signature.from_callable(Task.__call__).parameters.values()
1193+
if p.kind == inspect.Parameter.KEYWORD_ONLY
1194+
]

src/prefect/tasks/core/function.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
"""
22
The tasks in this module can be used to represent arbitrary functions.
33
4-
In general, users will not instantiate these tasks by hand; they will automatically be
5-
applied when users apply the `@task` decorator.
4+
In general, users will not instantiate these tasks by hand; they will
5+
automatically be applied when users apply the `@task` decorator.
66
"""
77

88
from typing import Any, Callable
@@ -11,19 +11,23 @@
1111

1212

1313
class _DocProxy(object):
14-
def __init__(self, source):
15-
self._source = source
14+
"""A descriptor that proxies through the docstring for the wrapped task as
15+
the docstring for a `FunctionTask` instance."""
16+
17+
def __init__(self, cls_doc):
18+
self._cls_doc = cls_doc
1619

1720
def __get__(self, obj, cls):
1821
if obj is None:
19-
return self._source
22+
return self._cls_doc
2023
else:
21-
return obj.run.__doc__
24+
return getattr(obj.run, "__doc__", None) or self._cls_doc
2225

2326

2427
class FunctionTask(prefect.Task):
2528
__doc__ = _DocProxy(
26-
"""A convenience Task for functionally creating Task instances with
29+
"""
30+
A convenience Task for functionally creating Task instances with
2731
arbitrary callable `run` methods.
2832
2933
Args:
@@ -58,3 +62,8 @@ def __init__(self, fn: Callable, name: str = None, **kwargs: Any):
5862
self.run = fn
5963

6064
super().__init__(name=name, **kwargs)
65+
66+
@property
67+
def __wrapped__(self):
68+
"""Propogates information about the wrapped function"""
69+
return self.run

tests/core/test_task.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import inspect
12
import logging
23
import uuid
34
from datetime import timedelta
@@ -232,6 +233,23 @@ def test_class_instantiation_raises_helpful_warning_for_unsupported_callables(se
232233
with pytest.raises(ValueError, match="This function can not be inspected"):
233234
task(zip)
234235

236+
def test_task_signature_generation(self):
237+
class Test(Task):
238+
def run(self, x: int, y: bool, z: int = 1):
239+
pass
240+
241+
t = Test()
242+
243+
sig = inspect.signature(t)
244+
# signature is a superset of the `run` method
245+
for k, p in inspect.signature(t.run).parameters.items():
246+
assert sig.parameters[k] == p
247+
# extra kwonly args to __call__ also in sig
248+
assert set(sig.parameters).issuperset(
249+
{"mapped", "task_args", "upstream_tasks", "flow"}
250+
)
251+
assert sig.return_annotation == "Task"
252+
235253
def test_create_task_with_and_without_cache_for(self):
236254
t1 = Task()
237255
assert t1.cache_validator is never_use

tests/tasks/test_core.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,30 @@ def my_fn():
4646
f = FunctionTask(fn=my_fn, name="test")
4747
assert f.name == "test"
4848

49+
def test_function_task_docstring(self):
50+
def my_fn():
51+
"""An example docstring."""
52+
pass
53+
54+
# Original docstring available on class
55+
assert "FunctionTask" in FunctionTask.__doc__
56+
57+
# Wrapped function is docstring on instance
58+
f = FunctionTask(fn=my_fn)
59+
assert f.__doc__ == my_fn.__doc__
60+
61+
# Except when no docstring on wrapped function
62+
f = FunctionTask(fn=lambda x: x + 1)
63+
assert "FunctionTask" in f.__doc__
64+
65+
def test_function_task_sets__wrapped__(self):
66+
def my_fn():
67+
"""An example function"""
68+
pass
69+
70+
t = FunctionTask(fn=my_fn)
71+
assert t.__wrapped__ == my_fn
72+
4973

5074
class TestCollections:
5175
def test_list_returns_a_list(self):

0 commit comments

Comments
 (0)