forked from pytorch/captum
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtyping.py
108 lines (90 loc) · 3 KB
/
typing.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
#!/usr/bin/env python3
# pyre-strict
from collections import UserDict
from typing import (
List,
Literal,
Optional,
overload,
Protocol,
Tuple,
TYPE_CHECKING,
TypeVar,
Union,
)
from torch import Tensor
from torch.nn import Module
TensorOrTupleOfTensorsGeneric = TypeVar(
"TensorOrTupleOfTensorsGeneric", Tensor, Tuple[Tensor, ...]
)
# pyre-fixme[24]: Generic type `tuple` expects at least 1 type parameter.
TupleOrTensorOrBoolGeneric = TypeVar("TupleOrTensorOrBoolGeneric", Tuple, Tensor, bool)
ModuleOrModuleList = TypeVar("ModuleOrModuleList", Module, List[Module])
TargetType = Union[None, int, Tuple[int, ...], Tensor, List[Tuple[int, ...]], List[int]]
BaselineTupleType = Union[None, Tuple[Union[Tensor, int, float], ...]]
BaselineType = Union[None, Tensor, int, float, BaselineTupleType]
TensorLikeList1D = List[float]
TensorLikeList2D = List[TensorLikeList1D]
TensorLikeList3D = List[TensorLikeList2D]
TensorLikeList4D = List[TensorLikeList3D]
TensorLikeList5D = List[TensorLikeList4D]
TensorLikeList = Union[
TensorLikeList1D,
TensorLikeList2D,
TensorLikeList3D,
TensorLikeList4D,
TensorLikeList5D,
]
try:
# Subscripted slice syntax is not supported in previous Python versions,
# falling back to slice type.
SliceIntType = slice[int, int, int]
except TypeError:
# pyre-fixme[24]: Generic type `slice` expects 3 type parameters.
SliceIntType = slice # type: ignore
# Necessary for Python >=3.7 and <3.9!
if TYPE_CHECKING:
BatchEncodingType = UserDict[Union[int, str], object]
else:
BatchEncodingType = UserDict
class TokenizerLike(Protocol):
"""A protocol for tokenizer-like objects that can be used with Captum
LLM attribution methods."""
@overload
def encode(
self, text: str, add_special_tokens: bool = ..., return_tensors: None = ...
) -> List[int]: ...
@overload
def encode(
self,
text: str,
add_special_tokens: bool = ...,
return_tensors: Literal["pt"] = ...,
) -> Tensor: ...
def encode(
self,
text: str,
add_special_tokens: bool = True,
return_tensors: Optional[str] = None,
) -> Union[List[int], Tensor]: ...
def decode(self, token_ids: Tensor) -> str: ...
@overload
def convert_ids_to_tokens(self, token_ids: List[int]) -> List[str]: ...
@overload
def convert_ids_to_tokens(self, token_ids: int) -> str: ...
def convert_ids_to_tokens(
self, token_ids: Union[List[int], int]
) -> Union[List[str], str]: ...
@overload
def convert_tokens_to_ids(self, tokens: str) -> int: ...
@overload
def convert_tokens_to_ids(self, tokens: List[str]) -> List[int]: ...
def convert_tokens_to_ids(
self, tokens: Union[List[str], str]
) -> Union[List[int], int]: ...
def __call__(
self,
text: Optional[Union[str, List[str], List[List[str]]]] = None,
add_special_tokens: bool = True,
return_offsets_mapping: bool = False,
) -> BatchEncodingType: ...