Skip to content

Commit

Permalink
Add tests, cleanups
Browse files Browse the repository at this point in the history
  • Loading branch information
jcrist committed Jun 1, 2020
1 parent 7a935f5 commit c34d297
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 25 deletions.
28 changes: 10 additions & 18 deletions src/prefect/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,23 +32,6 @@
from prefect.core import Edge # pylint: disable=W0611

VAR_KEYWORD = inspect.Parameter.VAR_KEYWORD
EXTRA_CALL_PARAMETERS = [
inspect.Parameter(
"mapped", inspect.Parameter.KEYWORD_ONLY, default=False, annotation=bool
),
inspect.Parameter(
"task_args", inspect.Parameter.KEYWORD_ONLY, default=None, annotation=dict
),
inspect.Parameter(
"upstream_tasks",
inspect.Parameter.KEYWORD_ONLY,
default=None,
annotation=Iterable[Any],
),
inspect.Parameter(
"flow", inspect.Parameter.KEYWORD_ONLY, default=None, annotation="Flow",
),
]


def _validate_run_signature(run: Callable) -> None:
Expand Down Expand Up @@ -433,7 +416,7 @@ def copy(self, **task_args: Any) -> "Task":
return new

@property
def __signature__(self):
def __signature__(self) -> inspect.Signature:
"""Dynamically generate the signature, replacing ``*args``/``**kwargs``
with parameters from ``run``"""
if not hasattr(self, "_cached_signature"):
Expand Down Expand Up @@ -1200,3 +1183,12 @@ def serialize(self) -> Dict[str, Any]:
- dict representing this parameter
"""
return prefect.serialization.task.ParameterSchema().dump(self)


# All keyword-only arguments to Task.__call__, used for dynamically generating
# Signature objects for Task objects
EXTRA_CALL_PARAMETERS = [
p
for p in inspect.Signature.from_callable(Task.__call__).parameters.values()
if p.kind == inspect.Parameter.KEYWORD_ONLY
]
23 changes: 16 additions & 7 deletions src/prefect/tasks/core/function.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
"""
The tasks in this module can be used to represent arbitrary functions.
In general, users will not instantiate these tasks by hand; they will automatically be
applied when users apply the `@task` decorator.
In general, users will not instantiate these tasks by hand; they will
automatically be applied when users apply the `@task` decorator.
"""

from typing import Any, Callable
Expand All @@ -11,19 +11,23 @@


class _DocProxy(object):
def __init__(self, source):
self._source = source
"""A descriptor that proxies through the docstring for the wrapped task as
the docstring for a `FunctionTask` instance."""

def __init__(self, cls_doc):
self._cls_doc = cls_doc

def __get__(self, obj, cls):
if obj is None:
return self._source
return self._cls_doc
else:
return obj.run.__doc__
return getattr(obj.run, "__doc__", None) or self._cls_doc


class FunctionTask(prefect.Task):
__doc__ = _DocProxy(
"""A convenience Task for functionally creating Task instances with
"""
A convenience Task for functionally creating Task instances with
arbitrary callable `run` methods.
Args:
Expand Down Expand Up @@ -58,3 +62,8 @@ def __init__(self, fn: Callable, name: str = None, **kwargs: Any):
self.run = fn

super().__init__(name=name, **kwargs)

@property
def __wrapped__(self):
"""Propogates information about the wrapped function"""
return self.run
18 changes: 18 additions & 0 deletions tests/core/test_task.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import inspect
import logging
import uuid
from datetime import timedelta
Expand Down Expand Up @@ -232,6 +233,23 @@ def test_class_instantiation_raises_helpful_warning_for_unsupported_callables(se
with pytest.raises(ValueError, match="This function can not be inspected"):
task(zip)

def test_task_signature_generation(self):
class Test(Task):
def run(self, x: int, y: bool, z: int = 1):
pass

t = Test()

sig = inspect.signature(t)
# signature is a superset of the `run` method
for k, p in inspect.signature(t.run).parameters.items():
assert sig.parameters[k] == p
# extra kwonly args to __call__ also in sig
assert set(sig.parameters).issuperset(
{"mapped", "task_args", "upstream_tasks", "flow"}
)
assert sig.return_annotation == "Task"

def test_create_task_with_and_without_cache_for(self):
t1 = Task()
assert t1.cache_validator is never_use
Expand Down
24 changes: 24 additions & 0 deletions tests/tasks/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,30 @@ def my_fn():
f = FunctionTask(fn=my_fn, name="test")
assert f.name == "test"

def test_function_task_docstring(self):
def my_fn():
"""An example docstring."""
pass

# Original docstring available on class
assert "FunctionTask" in FunctionTask.__doc__

# Wrapped function is docstring on instance
f = FunctionTask(fn=my_fn)
assert f.__doc__ == my_fn.__doc__

# Except when no docstring on wrapped function
f = FunctionTask(fn=lambda x: x + 1)
assert "FunctionTask" in f.__doc__

def test_function_task_sets__wrapped__(self):
def my_fn():
"""An example function"""
pass

t = FunctionTask(fn=my_fn)
assert t.__wrapped__ == my_fn


class TestCollections:
def test_list_returns_a_list(self):
Expand Down

0 comments on commit c34d297

Please sign in to comment.