Skip to content

Commit 58bee85

Browse files
authored
Fix type inference for class level Parameter setter (#1141)
1 parent 5358e32 commit 58bee85

6 files changed

Lines changed: 617 additions & 17 deletions

File tree

.github/workflows/test.yaml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,18 @@ jobs:
159159
id: pyright
160160
continue-on-error: true
161161
run: pixi run -e type type-pyright
162+
- name: Asserts
163+
id: asserts
164+
continue-on-error: true
165+
run: pixi run -e type type-asserts
162166
- name: Error if any check failed
163-
if: steps.ty.outcome != 'success' || steps.mypy.outcome != 'success' || steps.pyrefly.outcome != 'success' || steps.pyright.outcome != 'success'
167+
if: steps.ty.outcome != 'success' || steps.mypy.outcome != 'success' || steps.pyrefly.outcome != 'success' || steps.pyright.outcome != 'success' || steps.asserts.outcome != 'success'
164168
run: |
165169
echo "TY: ${{ steps.ty.outcome }}"
166170
echo "mypy: ${{ steps.mypy.outcome }}"
167171
echo "Pyrefly: ${{ steps.pyrefly.outcome }}"
168172
echo "Pyright: ${{ steps.pyright.outcome }}"
173+
echo "Asserts: ${{ steps.asserts.outcome }}"
169174
exit 1
170175
171176
pypy_test_suite:

param/parameterized.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1944,7 +1944,7 @@ def __get__(self, obj: Parameterized | None, objtype: type[Parameterized] | None
19441944
return result
19451945

19461946
@instance_descriptor
1947-
def __set__(self, obj: Parameterized, val: _T):
1947+
def __set__(self, obj: Parameterized | None, val: _T):
19481948
"""
19491949
Set the value for this Parameter.
19501950
@@ -2731,7 +2731,7 @@ def _resolve_ref(self_, pobj: Parameter, value: t.Any):
27312731
except Skip:
27322732
value = Undefined
27332733
if is_async and pobj.name:
2734-
async_executor(partial(self_._async_ref, pobj.name, value))
2734+
async_executor(partial(self_._async_ref, pobj.name, t.cast("t.Awaitable[t.Any]", value)))
27352735
value = None
27362736
return ref, deps, value, is_async
27372737

param/parameters.py

Lines changed: 179 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -612,11 +612,11 @@ def __init__(
612612
self._set_instantiate(True)
613613
self._initialize_generator(self.default)
614614

615-
def _initialize_generator(self, gen, obj=None):
615+
def _initialize_generator(self, gen, obj: Parameterized | None = None):
616616
"""Add 'last time' and 'last value' attributes to the generator."""
617617
# Could use a dictionary to hold these things.
618618
if obj is not None and hasattr(obj, "_Dynamic_time_fn"):
619-
gen._Dynamic_time_fn = obj._Dynamic_time_fn
619+
gen._Dynamic_time_fn = obj._Dynamic_time_fn # type: ignore[attribute-access]
620620

621621
gen._Dynamic_last = None
622622
# Would have usede None for this, but can't compare a fixedpoint
@@ -642,7 +642,7 @@ def __get__(
642642
return t.cast("_T", self._produce_value(gen))
643643

644644
@instance_descriptor
645-
def __set__(self, obj: Parameterized, val: _T):
645+
def __set__(self, obj: Parameterized | None, val: _T):
646646
"""
647647
Call the superclass's set and keep this parameter's
648648
instantiate value up to date (dynamic parameters
@@ -653,8 +653,10 @@ def __set__(self, obj: Parameterized, val: _T):
653653
super().__set__(obj, val)
654654

655655
dynamic = callable(val)
656-
if dynamic: self._initialize_generator(val,obj)
657-
if obj is None: self._set_instantiate(dynamic)
656+
if dynamic:
657+
self._initialize_generator(val, obj)
658+
if obj is None:
659+
self._set_instantiate(dynamic)
658660

659661
def _produce_value(self, gen, force: bool = False):
660662
"""
@@ -2218,11 +2220,54 @@ def _validate(self, val):
22182220
self._validate_value(val, self.allow_None)
22192221

22202222

2221-
class Action(Callable):
2223+
class Action(Callable[_T]):
22222224
"""
22232225
A user-provided function that can be invoked like a class or object method using ().
22242226
In a GUI, this might be mapped to a button, but it can be invoked directly as well.
22252227
"""
2228+
2229+
if t.TYPE_CHECKING:
2230+
2231+
@t.overload
2232+
def __init__(
2233+
self: Action[t.Callable[[], t.Any]],
2234+
default: t.Callable[[], t.Any] = lambda: None,
2235+
*,
2236+
allow_None: t.Literal[False] = False,
2237+
doc: str | None = None,
2238+
label: str | None = None,
2239+
precedence: float | None = None,
2240+
instantiate: bool = False,
2241+
constant: bool = False,
2242+
readonly: bool = False,
2243+
pickle_default_value: bool = True,
2244+
per_instance: bool = True,
2245+
allow_refs: bool = False,
2246+
nested_refs: bool = False,
2247+
default_factory: t.Callable[[], t.Any] | None = None,
2248+
metadata: dict[str, t.Any] | None = None,
2249+
) -> None:
2250+
...
2251+
2252+
@t.overload
2253+
def __init__(
2254+
self: Action[t.Callable[[], t.Any] | None],
2255+
default: None = None,
2256+
*,
2257+
allow_None: t.Literal[True] = True,
2258+
**params: Unpack[_ParameterKwargs]
2259+
) -> None:
2260+
...
2261+
2262+
def __init__(self,
2263+
default: t.Callable[[], t.Any] | None = t.cast("t.Callable[[], t.Any] | None", Undefined), # pyrefly: ignore[bad-argument-type]
2264+
*,
2265+
allow_None: bool = t.cast("bool", Undefined), # pyrefly: ignore[bad-argument-type]
2266+
**params: Unpack[_ParameterKwargs]
2267+
) -> None:
2268+
super().__init__(default=default, **params) # type: ignore[misc] # pyrefly: ignore[bad-argument-type]
2269+
self._validate(self.default)
2270+
22262271
# Currently same implementation as Callable, but kept separate to allow different handling in GUIs
22272272

22282273
#-----------------------------------------------------------------------------
@@ -3327,6 +3372,16 @@ def __init__(
33273372
) -> None:
33283373
...
33293374

3375+
@t.overload
3376+
def __init__(
3377+
self: DataFrame[pd.DataFrame | None],
3378+
default: None = None,
3379+
*,
3380+
allow_None: t.Literal[False] = False,
3381+
**kwargs: Unpack[_DataFrameInitKwargs]
3382+
) -> None:
3383+
...
3384+
33303385
def __init__(
33313386
self,
33323387
default: pd.DataFrame | None = t.cast("pd.DataFrame | None", Undefined), # pyrefly: ignore[bad-argument-type]
@@ -3798,7 +3853,7 @@ class Path(Parameter[_T]):
37983853
__slots__ = ['search_paths', 'check_exists']
37993854

38003855
_slot_defaults = dict(
3801-
Parameter._slot_defaults, check_exists=True,
3856+
Parameter._slot_defaults, check_exists=True, search_paths=None
38023857
)
38033858

38043859
search_paths: list[str | PathLike] | None
@@ -3811,9 +3866,9 @@ def __init__(
38113866
self: Path[PathLike | str],
38123867
default: PathLike | str = pathlib.Path(""),
38133868
*,
3814-
allow_None: t.Literal[False] = False,
38153869
search_paths: list[str | PathLike] | None = None,
38163870
check_exists: bool = True,
3871+
allow_None: t.Literal[False] = False,
38173872
doc: str | None = None,
38183873
label: str | None = None,
38193874
precedence: float | None = None,
@@ -3832,7 +3887,7 @@ def __init__(
38323887
@t.overload
38333888
def __init__(
38343889
self: Path[PathLike | str | None],
3835-
default: None = None,
3890+
default: PathLike | str = pathlib.Path(""),
38363891
*,
38373892
allow_None: t.Literal[True] = True,
38383893
**kwargs: Unpack[_PathInitKwargs]
@@ -3842,7 +3897,7 @@ def __init__(
38423897
@t.overload
38433898
def __init__(
38443899
self: Path[PathLike | str | None],
3845-
default: PathLike | str = pathlib.Path(""),
3900+
default: PathLike | str | None = None,
38463901
*,
38473902
allow_None: t.Literal[True] = True,
38483903
**kwargs: Unpack[_PathInitKwargs]
@@ -3910,8 +3965,7 @@ def __getstate__(self):
39103965
return state
39113966

39123967

3913-
3914-
class Filename(Path):
3968+
class Filename(Path[_T]):
39153969
"""
39163970
Parameter that can be set to a string specifying the path of a file.
39173971
@@ -3926,11 +3980,67 @@ class Filename(Path):
39263980
is ``None``).
39273981
"""
39283982

3983+
if t.TYPE_CHECKING:
3984+
3985+
@t.overload
3986+
def __init__(
3987+
self: Filename[PathLike | str],
3988+
default: PathLike | str = pathlib.Path(""),
3989+
*,
3990+
allow_None: t.Literal[False] = False,
3991+
search_paths: list[str | PathLike] | None = None,
3992+
check_exists: bool = True,
3993+
doc: str | None = None,
3994+
label: str | None = None,
3995+
precedence: float | None = None,
3996+
instantiate: bool = False,
3997+
constant: bool = False,
3998+
readonly: bool = False,
3999+
pickle_default_value: bool = True,
4000+
per_instance: bool = True,
4001+
allow_refs: bool = False,
4002+
nested_refs: bool = False,
4003+
default_factory: t.Callable[[], t.Any] | None = None,
4004+
metadata: dict[str, t.Any] | None = None,
4005+
) -> None:
4006+
...
4007+
4008+
@t.overload
4009+
def __init__(
4010+
self: Filename[PathLike | str | None],
4011+
default: PathLike | str = pathlib.Path(""),
4012+
*,
4013+
allow_None: t.Literal[True] = True,
4014+
**kwargs: Unpack[_PathInitKwargs]
4015+
) -> None:
4016+
...
4017+
4018+
@t.overload
4019+
def __init__(
4020+
self: Filename[PathLike | str | None],
4021+
default: PathLike | str | None = None,
4022+
*,
4023+
allow_None: t.Literal[True] = True,
4024+
**kwargs: Unpack[_PathInitKwargs]
4025+
) -> None:
4026+
...
4027+
4028+
def __init__(
4029+
self,
4030+
default: str | PathLike | None = t.cast("str | PathLike | None", Undefined), # pyrefly: ignore[bad-argument-type]
4031+
*,
4032+
allow_None: bool = t.cast("bool", Undefined), # pyrefly: ignore[bad-argument-type]
4033+
**kwargs: Unpack[_PathInitKwargs]
4034+
) -> None:
4035+
super().__init__( # type: ignore[misc, call-overload] # ty: ignore[no-matching-overload]
4036+
default=default, allow_None=allow_None, **kwargs # type: ignore[arg-type]
4037+
)
4038+
39294039
def _resolve(self, path):
39304040
return resolve_path(path=path, path_to_file=True, search_paths=self.search_paths)
39314041

39324042

3933-
class Foldername(Path):
4043+
class Foldername(Path[_T]):
39344044
"""
39354045
Parameter that can be set to a string specifying the path of a folder.
39364046
@@ -3945,6 +4055,62 @@ class Foldername(Path):
39454055
is ``None``).
39464056
"""
39474057

4058+
if t.TYPE_CHECKING:
4059+
4060+
@t.overload
4061+
def __init__(
4062+
self: Foldername[PathLike | str],
4063+
default: PathLike | str = pathlib.Path(""),
4064+
*,
4065+
allow_None: t.Literal[False] = False,
4066+
search_paths: list[str | PathLike] | None = None,
4067+
check_exists: bool = True,
4068+
doc: str | None = None,
4069+
label: str | None = None,
4070+
precedence: float | None = None,
4071+
instantiate: bool = False,
4072+
constant: bool = False,
4073+
readonly: bool = False,
4074+
pickle_default_value: bool = True,
4075+
per_instance: bool = True,
4076+
allow_refs: bool = False,
4077+
nested_refs: bool = False,
4078+
default_factory: t.Callable[[], t.Any] | None = None,
4079+
metadata: dict[str, t.Any] | None = None,
4080+
) -> None:
4081+
...
4082+
4083+
@t.overload
4084+
def __init__(
4085+
self: Foldername[PathLike | str | None],
4086+
default: PathLike | str = pathlib.Path(""),
4087+
*,
4088+
allow_None: t.Literal[True] = True,
4089+
**kwargs: Unpack[_PathInitKwargs]
4090+
) -> None:
4091+
...
4092+
4093+
@t.overload
4094+
def __init__(
4095+
self: Foldername[PathLike | str | None],
4096+
default: PathLike | str | None = None,
4097+
*,
4098+
allow_None: t.Literal[True] = True,
4099+
**kwargs: Unpack[_PathInitKwargs]
4100+
) -> None:
4101+
...
4102+
4103+
def __init__(
4104+
self,
4105+
default: str | PathLike | None = t.cast("str | PathLike | None", Undefined), # pyrefly: ignore[bad-argument-type]
4106+
*,
4107+
allow_None: bool = t.cast("bool", Undefined), # pyrefly: ignore[bad-argument-type]
4108+
**kwargs: Unpack[_PathInitKwargs]
4109+
) -> None:
4110+
super().__init__( # type: ignore[misc, call-overload] # ty: ignore[no-matching-overload]
4111+
default=default, allow_None=allow_None, **kwargs # type: ignore[arg-type]
4112+
)
4113+
39484114
def _resolve(self, path):
39494115
return resolve_path(path=path, path_to_file=False, search_paths=self.search_paths)
39504116

pixi.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,7 @@ lint-install = 'pre-commit install'
205205
# =============== TYPECHECK ===================
206206
# =============================================
207207
[feature.type.dependencies]
208-
ty = "*"
208+
ty = "==0.0.34"
209209
mypy = "*"
210210
numpy = "*"
211211
pandas = "*"
@@ -222,11 +222,13 @@ type-pyrefly = { cmd = 'pyrefly check --output-format min-text --count-errors=0
222222
{ arg = "file", default = "" },
223223
] }
224224
type-pyright = { cmd = 'pyright {{ file }}', args = [{ arg = "file", default = "" }] }
225+
type-asserts = { depends-on = ["install"], cmd = 'pyright --project tests/pyrightconfig-bare.json tests/assert_types.py' }
225226
type-all = { depends-on = [
226227
{ task = "type-ty", args = [ { file = "{{ file }}" } ] },
227228
{ task = "type-mypy", args = [ { file = "{{ file }}" } ] },
228229
{ task = "type-pyrefly", args = [ { file = "{{ file }}" } ] },
229230
{ task = "type-pyright", args = [ { file = "{{ file }}" } ] },
231+
{ task = "type-asserts" },
230232
], args = [
231233
{ arg = "file", default = "" },
232234
] }

0 commit comments

Comments
 (0)