Skip to content

Commit e5f2af2

Browse files
Vivek Miglanifacebook-github-bot
Vivek Miglani
authored andcommitted
Fix baselines utils pyre fix me issues
Differential Revision: D67706854
1 parent 2bc4321 commit e5f2af2

File tree

1 file changed

+17
-12
lines changed

1 file changed

+17
-12
lines changed

captum/attr/_utils/baselines.py

+17-12
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@
22

33
# pyre-strict
44
import random
5-
from typing import Any, Dict, List, Tuple, Union
5+
from typing import Any, Dict, Generic, List, Tuple, TypeVar, Union
66

7+
GenericBaselineType = TypeVar("GenericBaselineType")
78

8-
class ProductBaselines:
9+
10+
class ProductBaselines(Generic[GenericBaselineType]):
911
"""
1012
A Callable Baselines class that returns a sample from the Cartesian product of
1113
the inputs' available baselines.
@@ -22,10 +24,9 @@ class ProductBaselines:
2224

2325
def __init__(
2426
self,
25-
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
2627
baseline_values: Union[
27-
List[List[Any]],
28-
Dict[Union[str, Tuple[str, ...]], List[Any]],
28+
List[List[GenericBaselineType]],
29+
Dict[Union[str, Tuple[str, ...]], List[GenericBaselineType]],
2930
],
3031
) -> None:
3132
if isinstance(baseline_values, dict):
@@ -38,9 +39,10 @@ def __init__(
3839
self.dict_keys = dict_keys
3940
self.baseline_values = baseline_values
4041

41-
# pyre-fixme[3]: Return annotation cannot contain `Any`.
42-
def sample(self) -> Union[List[Any], Dict[str, Any]]:
43-
baselines = [
42+
def sample(
43+
self,
44+
) -> Union[List[GenericBaselineType], Dict[str, GenericBaselineType]]:
45+
baselines: List[GenericBaselineType] = [
4446
random.choice(baseline_list) for baseline_list in self.baseline_values
4547
]
4648

@@ -50,15 +52,18 @@ def sample(self) -> Union[List[Any], Dict[str, Any]]:
5052
dict_baselines = {}
5153
for key, val in zip(self.dict_keys, baselines):
5254
if not isinstance(key, tuple):
53-
key, val = (key,), (val,)
55+
key_tuple, val_tuple = (key,), (val,)
56+
else:
57+
key_tuple, val_tuple = key, val
5458

55-
for k, v in zip(key, val):
59+
for k, v in zip(key_tuple, val_tuple):
5660
dict_baselines[k] = v
5761

5862
return dict_baselines
5963

60-
# pyre-fixme[3]: Return annotation cannot contain `Any`.
61-
def __call__(self) -> Union[List[Any], Dict[str, Any]]:
64+
def __call__(
65+
self,
66+
) -> Union[List[GenericBaselineType], Dict[str, GenericBaselineType]]:
6267
"""
6368
Returns:
6469

0 commit comments

Comments
 (0)