2
2
3
3
# pyre-strict
4
4
import random
5
- from typing import Any , Dict , List , Tuple , Union
5
+ from typing import Any , Dict , Generic , List , Tuple , TypeVar , Union
6
6
7
+ GenericBaselineType = TypeVar ("GenericBaselineType" )
7
8
8
- class ProductBaselines :
9
+
10
+ class ProductBaselines (Generic [GenericBaselineType ]):
9
11
"""
10
12
A Callable Baselines class that returns a sample from the Cartesian product of
11
13
the inputs' available baselines.
@@ -22,10 +24,9 @@ class ProductBaselines:
22
24
23
25
def __init__ (
24
26
self ,
25
- # pyre-fixme[2]: Parameter annotation cannot contain `Any`.
26
27
baseline_values : Union [
27
- List [List [Any ]],
28
- Dict [Union [str , Tuple [str , ...]], List [Any ]],
28
+ List [List [GenericBaselineType ]],
29
+ Dict [Union [str , Tuple [str , ...]], List [GenericBaselineType ]],
29
30
],
30
31
) -> None :
31
32
if isinstance (baseline_values , dict ):
@@ -38,9 +39,10 @@ def __init__(
38
39
self .dict_keys = dict_keys
39
40
self .baseline_values = baseline_values
40
41
41
- # pyre-fixme[3]: Return annotation cannot contain `Any`.
42
- def sample (self ) -> Union [List [Any ], Dict [str , Any ]]:
43
- baselines = [
42
+ def sample (
43
+ self ,
44
+ ) -> Union [List [GenericBaselineType ], Dict [str , GenericBaselineType ]]:
45
+ baselines : List [GenericBaselineType ] = [
44
46
random .choice (baseline_list ) for baseline_list in self .baseline_values
45
47
]
46
48
@@ -50,15 +52,18 @@ def sample(self) -> Union[List[Any], Dict[str, Any]]:
50
52
dict_baselines = {}
51
53
for key , val in zip (self .dict_keys , baselines ):
52
54
if not isinstance (key , tuple ):
53
- key , val = (key ,), (val ,)
55
+ key_tuple , val_tuple = (key ,), (val ,)
56
+ else :
57
+ key_tuple , val_tuple = key , val
54
58
55
- for k , v in zip (key , val ):
59
+ for k , v in zip (key_tuple , val_tuple ):
56
60
dict_baselines [k ] = v
57
61
58
62
return dict_baselines
59
63
60
- # pyre-fixme[3]: Return annotation cannot contain `Any`.
61
- def __call__ (self ) -> Union [List [Any ], Dict [str , Any ]]:
64
+ def __call__ (
65
+ self ,
66
+ ) -> Union [List [GenericBaselineType ], Dict [str , GenericBaselineType ]]:
62
67
"""
63
68
Returns:
64
69
0 commit comments