Skip to content

Commit c437575

Browse files
committed
avoid using return cast
return cast is equivalent to `type: ignore` on the result of the expression, but more expensive and can cause errors e.g. preventing access to traits during __del__ in process teardown
1 parent 808b361 commit c437575

File tree

1 file changed

+60
-54
lines changed

1 file changed

+60
-54
lines changed

traitlets/traitlets.py

+60-54
Original file line numberDiff line numberDiff line change
@@ -169,8 +169,8 @@ class TraitError(Exception):
169169
# -----------------------------------------------------------------------------
170170

171171

172-
def isidentifier(s: t.Any) -> bool:
173-
return t.cast(bool, s.isidentifier())
172+
def isidentifier(s: str) -> bool:
173+
return s.isidentifier()
174174

175175

176176
def _safe_literal_eval(s: str) -> t.Any:
@@ -293,13 +293,21 @@ class link:
293293

294294
updating = False
295295

296-
def __init__(self, source: t.Any, target: t.Any, transform: t.Any = None) -> None:
296+
def __init__(
297+
self, source: t.Any, target: t.Any, transform: t.Iterable[FuncT] | None = None
298+
) -> None:
297299
_validate_link(source, target)
298300
self.source, self.target = source, target
299-
self._transform, self._transform_inv = transform if transform else (lambda x: x,) * 2
300-
301+
if transform:
302+
self._transform, self._transform_inv = transform # type:ignore[method-assign]
301303
self.link()
302304

305+
def _transform(self, x: T) -> T:
306+
"""default transform: no-op"""
307+
return x
308+
309+
_transform_inv = _transform
310+
303311
def link(self) -> None:
304312
try:
305313
setattr(
@@ -597,12 +605,12 @@ def default(self, obj: t.Any = None) -> G | None:
597605
in the same way that dynamic defaults defined by ``@default`` are.
598606
"""
599607
if self.default_value is not Undefined:
600-
return t.cast(G, self.default_value)
608+
return self.default_value # type:ignore[no-any-return]
601609
elif hasattr(self, "make_dynamic_default"):
602-
return t.cast(G, self.make_dynamic_default())
610+
return self.make_dynamic_default() # type:ignore[no-any-return]
603611
else:
604612
# Undefined will raise in TraitType.get
605-
return t.cast(G, self.default_value)
613+
return self.default_value # type:ignore[no-any-return]
606614

607615
def get_default_value(self) -> G | None:
608616
"""DEPRECATED: Retrieve the static default value for this trait.
@@ -613,7 +621,7 @@ def get_default_value(self) -> G | None:
613621
DeprecationWarning,
614622
stacklevel=2,
615623
)
616-
return t.cast(G, self.default_value)
624+
return self.default_value # type:ignore[no-any-return]
617625

618626
def init_default_value(self, obj: t.Any) -> G | None:
619627
"""DEPRECATED: Set the static default value for the trait type."""
@@ -658,12 +666,12 @@ def get(self, obj: HasTraits, cls: type[t.Any] | None = None) -> G | None:
658666
type="default",
659667
)
660668
)
661-
return t.cast(G, value)
669+
return value # type:ignore[no-any-return]
662670
except Exception as e:
663671
# This should never be reached.
664672
raise TraitError("Unexpected error in TraitType: default value not set properly") from e
665673
else:
666-
return t.cast(G, value)
674+
return value # type:ignore[no-any-return]
667675

668676
@t.overload
669677
def __get__(self, obj: None, cls: type[t.Any]) -> Self:
@@ -684,7 +692,7 @@ def __get__(self, obj: HasTraits | None, cls: type[t.Any]) -> Self | G:
684692
if obj is None:
685693
return self
686694
else:
687-
return t.cast(G, self.get(obj, cls)) # the G should encode the Optional
695+
return self.get(obj, cls) # type:ignore[return-value]
688696

689697
def set(self, obj: HasTraits, value: S) -> None:
690698
new_value = self._validate(obj, value)
@@ -722,7 +730,7 @@ def _validate(self, obj: t.Any, value: t.Any) -> G | None:
722730
value = self.validate(obj, value)
723731
if obj._cross_validation_lock is False:
724732
value = self._cross_validate(obj, value)
725-
return t.cast(G, value)
733+
return value # type:ignore[no-any-return]
726734

727735
def _cross_validate(self, obj: t.Any, value: t.Any) -> G | None:
728736
if self.name in obj._trait_validators:
@@ -738,7 +746,7 @@ def _cross_validate(self, obj: t.Any, value: t.Any) -> G | None:
738746
"use @validate decorator instead.",
739747
)
740748
value = cross_validate(value, self)
741-
return t.cast(G, value)
749+
return value # type:ignore[no-any-return]
742750

743751
def __or__(self, other: TraitType[t.Any, t.Any]) -> Union:
744752
if isinstance(other, Union):
@@ -1142,7 +1150,7 @@ def compatible_observer(
11421150
)
11431151
return func(self, change)
11441152

1145-
return t.cast(FuncT, compatible_observer)
1153+
return compatible_observer # type:ignore[return-value]
11461154

11471155

11481156
def validate(*names: Sentinel | str) -> ValidateHandler:
@@ -1894,7 +1902,7 @@ def trait_defaults(self, *names: str, **metadata: t.Any) -> dict[str, t.Any] | S
18941902
raise TraitError(f"'{n}' is not a trait of '{type(self).__name__}' instances")
18951903

18961904
if len(names) == 1 and len(metadata) == 0:
1897-
return t.cast(Sentinel, self._get_trait_default_generator(names[0])(self))
1905+
return self._get_trait_default_generator(names[0])(self) # type:ignore[no-any-return]
18981906

18991907
trait_names = self.trait_names(**metadata)
19001908
trait_names.extend(names)
@@ -2144,7 +2152,7 @@ def validate(self, obj: t.Any, value: t.Any) -> G:
21442152
) from e
21452153
try:
21462154
if issubclass(value, self.klass): # type:ignore[arg-type]
2147-
return t.cast(G, value)
2155+
return value # type:ignore[no-any-return]
21482156
except Exception:
21492157
pass
21502158

@@ -2306,7 +2314,7 @@ def validate(self, obj: t.Any, value: t.Any) -> T | None:
23062314
if self.allow_none and value is None:
23072315
return value
23082316
if isinstance(value, self.klass): # type:ignore[arg-type]
2309-
return t.cast(T, value)
2317+
return value # type:ignore[no-any-return]
23102318
else:
23112319
self.error(obj, value)
23122320

@@ -2338,7 +2346,7 @@ def default_value_repr(self) -> str:
23382346
return repr(self.make_dynamic_default())
23392347

23402348
def from_string(self, s: str) -> T | None:
2341-
return t.cast(T, _safe_literal_eval(s))
2349+
return _safe_literal_eval(s) # type:ignore[no-any-return]
23422350

23432351

23442352
class ForwardDeclaredMixin:
@@ -2635,12 +2643,12 @@ def __init__(
26352643
def validate(self, obj: t.Any, value: t.Any) -> G:
26362644
if not isinstance(value, int):
26372645
self.error(obj, value)
2638-
return t.cast(G, _validate_bounds(self, obj, value))
2646+
return _validate_bounds(self, obj, value) # type:ignore[no-any-return]
26392647

26402648
def from_string(self, s: str) -> G:
26412649
if self.allow_none and s == "None":
2642-
return t.cast(G, None)
2643-
return t.cast(G, int(s))
2650+
return None # type:ignore[return-value]
2651+
return int(s) # type:ignore[return-value]
26442652

26452653
def subclass_init(self, cls: type[t.Any]) -> None:
26462654
pass # fully opt out of instance_init
@@ -2691,7 +2699,7 @@ def validate(self, obj: t.Any, value: t.Any) -> G:
26912699
value = int(value)
26922700
except Exception:
26932701
self.error(obj, value)
2694-
return t.cast(G, _validate_bounds(self, obj, value))
2702+
return _validate_bounds(self, obj, value) # type:ignore[no-any-return]
26952703

26962704

26972705
Long, CLong = Int, CInt
@@ -2753,12 +2761,12 @@ def validate(self, obj: t.Any, value: t.Any) -> G:
27532761
value = float(value)
27542762
if not isinstance(value, float):
27552763
self.error(obj, value)
2756-
return t.cast(G, _validate_bounds(self, obj, value))
2764+
return _validate_bounds(self, obj, value) # type:ignore[return-value]
27572765

27582766
def from_string(self, s: str) -> G:
27592767
if self.allow_none and s == "None":
2760-
return t.cast(G, None)
2761-
return t.cast(G, float(s))
2768+
return None # type:ignore[return-value]
2769+
return float(s) # type:ignore[return-value]
27622770

27632771
def subclass_init(self, cls: type[t.Any]) -> None:
27642772
pass # fully opt out of instance_init
@@ -2809,7 +2817,7 @@ def validate(self, obj: t.Any, value: t.Any) -> G:
28092817
value = float(value)
28102818
except Exception:
28112819
self.error(obj, value)
2812-
return t.cast(G, _validate_bounds(self, obj, value))
2820+
return _validate_bounds(self, obj, value) # type:ignore[no-any-return]
28132821

28142822

28152823
class Complex(TraitType[complex, t.Union[complex, float, int]]):
@@ -2935,18 +2943,18 @@ def __init__(
29352943

29362944
def validate(self, obj: t.Any, value: t.Any) -> G:
29372945
if isinstance(value, str):
2938-
return t.cast(G, value)
2946+
return value # type:ignore[return-value]
29392947
if isinstance(value, bytes):
29402948
try:
2941-
return t.cast(G, value.decode("ascii", "strict"))
2949+
return value.decode("ascii", "strict") # type:ignore[return-value]
29422950
except UnicodeDecodeError as e:
29432951
msg = "Could not decode {!r} for unicode trait '{}' of {} instance."
29442952
raise TraitError(msg.format(value, self.name, class_of(obj))) from e
29452953
self.error(obj, value)
29462954

29472955
def from_string(self, s: str) -> G:
29482956
if self.allow_none and s == "None":
2949-
return t.cast(G, None)
2957+
return None # type:ignore[return-value]
29502958
s = os.path.expanduser(s)
29512959
if len(s) >= 2:
29522960
# handle deprecated "1"
@@ -2960,7 +2968,7 @@ def from_string(self, s: str) -> G:
29602968
DeprecationWarning,
29612969
stacklevel=2,
29622970
)
2963-
return t.cast(G, s)
2971+
return s # type:ignore[return-value]
29642972

29652973
def subclass_init(self, cls: type[t.Any]) -> None:
29662974
pass # fully opt out of instance_init
@@ -3008,7 +3016,7 @@ def __init__(
30083016

30093017
def validate(self, obj: t.Any, value: t.Any) -> G:
30103018
try:
3011-
return t.cast(G, str(value))
3019+
return str(value) # type:ignore[return-value]
30123020
except Exception:
30133021
self.error(obj, value)
30143022

@@ -3091,22 +3099,22 @@ def __init__(
30913099

30923100
def validate(self, obj: t.Any, value: t.Any) -> G:
30933101
if isinstance(value, bool):
3094-
return t.cast(G, value)
3102+
return value # type:ignore[return-value]
30953103
elif isinstance(value, int):
30963104
if value == 1:
3097-
return t.cast(G, True)
3105+
return True # type:ignore[return-value]
30983106
elif value == 0:
3099-
return t.cast(G, False)
3107+
return False # type:ignore[return-value]
31003108
self.error(obj, value)
31013109

31023110
def from_string(self, s: str) -> G:
31033111
if self.allow_none and s == "None":
3104-
return t.cast(G, None)
3112+
return None # type:ignore[return-value]
31053113
s = s.lower()
31063114
if s in {"true", "1"}:
3107-
return t.cast(G, True)
3115+
return True # type:ignore[return-value]
31083116
elif s in {"false", "0"}:
3109-
return t.cast(G, False)
3117+
return False # type:ignore[return-value]
31103118
else:
31113119
raise ValueError("%r is not 1, 0, true, or false")
31123120

@@ -3163,7 +3171,7 @@ def __init__(
31633171

31643172
def validate(self, obj: t.Any, value: t.Any) -> G:
31653173
try:
3166-
return t.cast(G, bool(value))
3174+
return bool(value) # type:ignore[return-value]
31673175
except Exception:
31683176
self.error(obj, value)
31693177

@@ -3220,7 +3228,7 @@ def __init__(
32203228

32213229
def validate(self, obj: t.Any, value: t.Any) -> G:
32223230
if self.values and value in self.values:
3223-
return t.cast(G, value)
3231+
return value # type:ignore[no-any-return]
32243232
self.error(obj, value)
32253233

32263234
def _choices_str(self, as_rst: bool = False) -> str:
@@ -3247,7 +3255,7 @@ def from_string(self, s: str) -> G:
32473255
try:
32483256
return self.validate(None, s)
32493257
except TraitError:
3250-
return t.cast(G, _safe_literal_eval(s))
3258+
return _safe_literal_eval(s) # type:ignore[no-any-return]
32513259

32523260
def subclass_init(self, cls: type[t.Any]) -> None:
32533261
pass # fully opt out of instance_init
@@ -3275,7 +3283,7 @@ def validate(self, obj: t.Any, value: t.Any) -> G:
32753283
for v in self.values or []:
32763284
assert isinstance(v, str)
32773285
if v.lower() == value.lower():
3278-
return t.cast(G, v)
3286+
return v # type:ignore[return-value]
32793287
self.error(obj, value)
32803288

32813289
def _info(self, as_rst: bool = False) -> str:
@@ -3479,14 +3487,12 @@ def validate(self, obj: t.Any, value: t.Any) -> T | None:
34793487
if value is None:
34803488
return value
34813489

3482-
value = self.validate_elements(obj, value)
3483-
3484-
return t.cast(T, value)
3490+
return self.validate_elements(obj, value) # type:ignore[no-any-return]
34853491

34863492
def validate_elements(self, obj: t.Any, value: t.Any) -> T | None:
34873493
validated = []
34883494
if self._trait is None or isinstance(self._trait, Any):
3489-
return t.cast(T, value)
3495+
return value # type:ignore[no-any-return]
34903496
for v in value:
34913497
try:
34923498
v = self._trait._validate(obj, v)
@@ -3553,7 +3559,7 @@ def from_string_list(self, s_list: list[str]) -> T | None:
35533559
else:
35543560
# backward-compat: allow item_from_string to ignore index arg
35553561
def item_from_string(s: str, index: int | None = None) -> T | str:
3556-
return t.cast(T, self.item_from_string(s))
3562+
return self.item_from_string(s)
35573563

35583564
return self.klass( # type:ignore[call-arg]
35593565
[item_from_string(s, index=idx) for idx, s in enumerate(s_list)]
@@ -3565,15 +3571,15 @@ def item_from_string(self, s: str, index: int | None = None) -> T | str:
35653571
Evaluated when parsing CLI configuration from a string
35663572
"""
35673573
if self._trait:
3568-
return t.cast(T, self._trait.from_string(s))
3574+
return self._trait.from_string(s) # type:ignore[no-any-return]
35693575
else:
35703576
return s
35713577

35723578

35733579
class List(Container[t.List[T]]):
35743580
"""An instance of a Python list."""
35753581

3576-
klass = list # type:ignore[assignment]
3582+
klass = list
35773583
_cast_types: t.Any = (tuple,)
35783584

35793585
def __init__(
@@ -4051,7 +4057,7 @@ def from_string(self, s: str) -> dict[K, V] | None:
40514057
if not isinstance(s, str):
40524058
raise TypeError(f"from_string expects a string, got {s!r} of type {type(s)}")
40534059
try:
4054-
return t.cast("dict[K, V]", self.from_string_list([s]))
4060+
return self.from_string_list([s]) # type:ignore[no-any-return]
40554061
except Exception:
40564062
test = _safe_literal_eval(s)
40574063
if isinstance(test, dict):
@@ -4109,7 +4115,7 @@ def item_from_string(self, s: str) -> dict[K, V]:
41094115
value_trait = (self._per_key_traits or {}).get(key, self._value_trait)
41104116
if value_trait:
41114117
value = value_trait.from_string(value)
4112-
return t.cast("dict[K, V]", {key: value})
4118+
return {key: value} # type:ignore[dict-item]
41134119

41144120

41154121
class TCPAddress(TraitType[G, S]):
@@ -4165,17 +4171,17 @@ def validate(self, obj: t.Any, value: t.Any) -> G:
41654171
if isinstance(value[0], str) and isinstance(value[1], int):
41664172
port = value[1]
41674173
if port >= 0 and port <= 65535:
4168-
return t.cast(G, value)
4174+
return value # type:ignore[return-value]
41694175
self.error(obj, value)
41704176

41714177
def from_string(self, s: str) -> G:
41724178
if self.allow_none and s == "None":
4173-
return t.cast(G, None)
4179+
return None # type:ignore[return-value]
41744180
if ":" not in s:
41754181
raise ValueError("Require `ip:port`, got %r" % s)
41764182
ip, port_str = s.split(":", 1)
41774183
port = int(port_str)
4178-
return t.cast(G, (ip, port))
4184+
return (ip, port) # type:ignore[return-value]
41794185

41804186

41814187
class CRegExp(TraitType["re.Pattern[t.Any]", t.Union["re.Pattern[t.Any]", str]]):

0 commit comments

Comments
 (0)