Skip to content

Commit

Permalink
Cleanup in get_enum_from_fn (#8852)
Browse files Browse the repository at this point in the history
Co-authored-by: Nicolas Hug <[email protected]>
  • Loading branch information
GdoongMathew and NicolasHug authored Feb 19, 2025
1 parent 8bed9d8 commit b5c7443
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions torchvision/models/_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from functools import partial
from inspect import signature
from types import ModuleType
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union
from typing import Any, Callable, Dict, get_args, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union

from torch import nn

Expand Down Expand Up @@ -168,14 +168,13 @@ def _get_enum_from_fn(fn: Callable) -> Type[WeightsEnum]:
if "weights" not in sig.parameters:
raise ValueError("The method is missing the 'weights' argument.")

ann = signature(fn).parameters["weights"].annotation
ann = sig.parameters["weights"].annotation
weights_enum = None
if isinstance(ann, type) and issubclass(ann, WeightsEnum):
weights_enum = ann
else:
# handle cases like Union[Optional, T]
# TODO: Replace ann.__args__ with typing.get_args(ann) after python >= 3.8
for t in ann.__args__: # type: ignore[union-attr]
for t in get_args(ann): # type: ignore[union-attr]
if isinstance(t, type) and issubclass(t, WeightsEnum):
weights_enum = t
break
Expand Down

0 comments on commit b5c7443

Please sign in to comment.