Skip to content

Commit 90c6363

Browse files
alangenfeldbrentjericho
authored andcommitted
Probably about 6 years overdue, add a decorator that inserts a wrapping function that performs `check` calls based on annotations. ## How I Tested These Changes added test
1 parent cbf5a1a commit 90c6363

File tree

2 files changed

+253
-0
lines changed

2 files changed

+253
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,102 @@
1+
import inspect
2+
from abc import ABC
3+
from functools import update_wrapper
4+
from types import MethodType
5+
6+
import dagster._check as check
7+
from dagster._check.builder import EvalContext, build_check_call_str
8+
9+
10+
class CheckedFnWrapper(ABC):
11+
"""Output of the @checked decorator, this class holds a reference to the decorated function
12+
and upon first invocation compiles a wrapping function that performs run time type checks on
13+
annotated inputs.
14+
15+
This class is not directly instantiated, but instead dynamic subclasses are created for each
16+
callsite allowing the __call__ method to be replaced[1] with the compiled function to achieve
17+
a single stack frame of over head for this decorator in best cases scenarios.
18+
19+
20+
[1] __call__ can not be replaced on instances, only classes.
21+
"""
22+
23+
def __init__(self, fn):
24+
self._target_fn = fn
25+
self._eval_ctx = EvalContext.capture_from_frame(
26+
2,
27+
add_to_local_ns={},
28+
)
29+
30+
def __get__(self, instance, _=None):
31+
"""Allow the decorated function to be bound to instances to support class methods."""
32+
if instance:
33+
return MethodType(self, instance)
34+
return self
35+
36+
def __call__(self, *args, **kwargs):
37+
signature = inspect.signature(self._target_fn)
38+
lines = []
39+
inputs = []
40+
41+
for name, param in signature.parameters.items():
42+
if param.annotation != param.empty:
43+
param_str = build_check_call_str(
44+
ttype=param.annotation,
45+
name=name,
46+
eval_ctx=self._eval_ctx,
47+
)
48+
else:
49+
param_str = param.name
50+
51+
if param.kind in (param.KEYWORD_ONLY, param.POSITIONAL_OR_KEYWORD):
52+
param_str = f"{param.name}={param_str}"
53+
inputs.append(param.name)
54+
lines.append(param_str)
55+
56+
lazy_imports_str = "\n ".join(
57+
f"from {module} import {t}" for t, module in self._eval_ctx.lazy_imports.items()
58+
)
59+
60+
param_block = ",\n ".join(lines)
61+
inputs_block = ",\n ".join(inputs)
62+
63+
checked_fn_name = f"__checked_{self._target_fn.__name__}"
64+
65+
fn_str = f"""
66+
def {checked_fn_name}(
67+
__checked_wrapper,
68+
{inputs_block}
69+
):
70+
{lazy_imports_str}
71+
return __checked_wrapper._target_fn(
72+
{param_block}
73+
)
74+
"""
75+
76+
if "check" not in self._eval_ctx.global_ns:
77+
self._eval_ctx.global_ns["check"] = check
78+
79+
call = self._eval_ctx.compile_fn(
80+
fn_str,
81+
fn_name=checked_fn_name,
82+
)
83+
84+
self.__class__.__call__ = call
85+
return call(self, *args, **kwargs)
86+
87+
88+
def checked(fn):
89+
"""Decorator for adding runtime type checking based on type annotations."""
90+
# if nothing can be checked, return the original fn
91+
annotations = getattr(fn, "__annotations__", None)
92+
if not annotations or (len(annotations) == 1 and set(annotations.keys()) == {"return"}):
93+
return fn
94+
95+
# make a dynamic subclass to be able to hot swap __call__ post compilation
96+
class _DynamicCheckedFnWrapper(CheckedFnWrapper): ...
97+
98+
checked_fn = _DynamicCheckedFnWrapper(fn)
99+
return update_wrapper(
100+
wrapper=checked_fn,
101+
wrapped=fn,
102+
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
from typing import TYPE_CHECKING, Annotated, Optional
2+
3+
import pytest
4+
from dagster._check import CheckError
5+
from dagster._check.builder import ImportFrom
6+
from dagster._check.decorator import checked
7+
8+
if TYPE_CHECKING:
9+
from dagster._core.test_utils import TestType
10+
11+
12+
def test_basic():
13+
@checked
14+
def foo(): ...
15+
16+
foo()
17+
18+
@checked
19+
def bar(i: int): ...
20+
21+
bar(1)
22+
bar(i=1)
23+
with pytest.raises(CheckError):
24+
bar("1") # type: ignore
25+
with pytest.raises(CheckError):
26+
bar(i="1") # type: ignore
27+
28+
29+
class Thing: ...
30+
31+
32+
def test_many():
33+
@checked
34+
def big(
35+
name: str,
36+
nick_names: list[str],
37+
age: int,
38+
cool: bool,
39+
thing: Optional[Thing],
40+
other_thing: Thing,
41+
percent: float,
42+
o_s: Optional[str],
43+
o_n: Optional[int],
44+
o_f: Optional[float],
45+
o_b: Optional[bool],
46+
foos: list[Annotated["TestType", ImportFrom("dagster._core.test_utils")]],
47+
):
48+
return True
49+
50+
assert big(
51+
name="dude",
52+
nick_names=[
53+
"foo",
54+
"bar",
55+
"biz",
56+
],
57+
age=42,
58+
cool=False,
59+
thing=None,
60+
other_thing=Thing(),
61+
percent=0.5,
62+
o_s="x",
63+
o_n=3,
64+
o_f=None,
65+
o_b=None,
66+
foos=[],
67+
)
68+
69+
with pytest.raises(CheckError):
70+
assert big(
71+
name="dude",
72+
nick_names=[
73+
"foo",
74+
"bar",
75+
"biz",
76+
],
77+
age=42,
78+
cool=False,
79+
thing=None,
80+
other_thing=Thing(),
81+
percent=0.5,
82+
o_s="x",
83+
o_n=3,
84+
o_f="surprise_not_float", # type: ignore
85+
o_b=None,
86+
foos=[],
87+
)
88+
89+
90+
def test_no_op():
91+
def foo(): ...
92+
93+
c_foo = checked(foo)
94+
assert c_foo is foo
95+
96+
def bar() -> None: ...
97+
98+
c_bar = checked(bar)
99+
assert c_bar is bar
100+
101+
102+
def test_star():
103+
@checked
104+
def baz(*, i: int): ...
105+
106+
baz(i=1)
107+
with pytest.raises(CheckError):
108+
baz(i="1") # type: ignore
109+
110+
111+
def test_partial():
112+
@checked
113+
def foo(a, b, c: int): ...
114+
115+
foo(1, 2, 3)
116+
117+
118+
def test_class():
119+
class Foo:
120+
@checked
121+
def me(self):
122+
return self
123+
124+
@checked
125+
def yell(self, word: str):
126+
return word
127+
128+
@staticmethod
129+
@checked
130+
def holler(word: str):
131+
return word
132+
133+
@classmethod
134+
@checked
135+
def scream(cls, word: str):
136+
return word
137+
138+
f = Foo()
139+
f.me()
140+
141+
f.yell("HI")
142+
with pytest.raises(CheckError):
143+
f.yell(3) # type: ignore
144+
145+
Foo.holler("hi")
146+
with pytest.raises(CheckError):
147+
Foo.holler(3) # type: ignore
148+
149+
Foo.scream("hi")
150+
with pytest.raises(CheckError):
151+
Foo.scream(3) # type: ignore

0 commit comments

Comments
 (0)