@@ -35,7 +35,7 @@ class TunedConfigBySize:
3535
3636
3737class TunedConfigStore :
38- def __init__ (self , profiles : dict [HardwareProfile | None , dict [str , list [TunedConfigBySize ]]]) -> None :
38+ def __init__ (self , profiles : dict [HardwareProfile , dict [str , list [TunedConfigBySize ]]]) -> None :
3939 self ._profiles = profiles
4040
4141 @classmethod
@@ -49,24 +49,15 @@ def load_path(cls, path: str | Path) -> "TunedConfigStore":
4949
5050 @classmethod
5151 def from_payload (cls , payload : Any ) -> "TunedConfigStore" :
52- profiles : dict [HardwareProfile | None , dict [str , list [TunedConfigBySize ]]] = {}
53- if isinstance (payload , list ):
54- profiles [None ] = _configs_by_collective_from_payload ({"allreduce" : payload })
55- return cls (profiles )
56-
5752 if not isinstance (payload , dict ):
58- raise ValueError ("MSCCL++ tuned config must be a JSON object or list" )
59-
60- if "profiles" in payload :
61- raw_profiles = payload ["profiles" ]
62- if not isinstance (raw_profiles , list ):
63- raise ValueError ("MSCCL++ tuned config field 'profiles' must be a list" )
64- for raw_profile in raw_profiles :
65- profile = _profile_from_payload (raw_profile )
66- profiles [profile ] = _configs_by_collective_from_payload (raw_profile .get ("collectives" , {}))
67- return cls (profiles )
68-
69- profiles [None ] = _configs_by_collective_from_payload (payload .get ("collectives" , payload ))
53+ raise ValueError ("MSCCL++ tuned config must be a JSON object" )
54+ raw_profiles = payload .get ("profiles" )
55+ if not isinstance (raw_profiles , list ):
56+ raise ValueError ("MSCCL++ tuned config must contain a 'profiles' list" )
57+ profiles : dict [HardwareProfile , dict [str , list [TunedConfigBySize ]]] = {}
58+ for raw_profile in raw_profiles :
59+ profile = _profile_from_payload (raw_profile )
60+ profiles [profile ] = _configs_by_collective_from_payload (raw_profile .get ("collectives" , {}))
7061 return cls (profiles )
7162
7263 def select (self , profile : HardwareProfile , collective : str , message_size : int ) -> TunedConfig | None :
@@ -89,7 +80,7 @@ def upsert(self, profile: HardwareProfile, collective: str, message_size: int, c
8980 def write_path (self , path : str | Path ) -> None :
9081 profiles_payload : list [dict [str , Any ]] = []
9182 for profile , configs_by_collective in sorted (
92- (( profile , configs ) for profile , configs in self ._profiles .items () if profile is not None ),
83+ self ._profiles .items (),
9384 key = lambda item : (item [0 ].sku is None , item [0 ].sku or "" , item [0 ].scale is None , item [0 ].scale or 0 ),
9485 ):
9586 collectives : dict [str , list [dict [str , Any ]]] = {}
@@ -127,7 +118,7 @@ def _profile_from_payload(raw_profile: Any) -> HardwareProfile:
127118
128119
129120def _matching_profiles (
130- profiles : dict [HardwareProfile | None , dict [str , list [TunedConfigBySize ]]],
121+ profiles : dict [HardwareProfile , dict [str , list [TunedConfigBySize ]]],
131122 runtime_profile : HardwareProfile ,
132123) -> list [tuple [int , dict [str , list [TunedConfigBySize ]]]]:
133124 matches : list [tuple [int , dict [str , list [TunedConfigBySize ]]]] = []
@@ -138,9 +129,7 @@ def _matching_profiles(
138129 return sorted (matches , key = lambda item : item [0 ], reverse = True )
139130
140131
141- def _profile_match_specificity (profile : HardwareProfile | None , runtime_profile : HardwareProfile ) -> int | None :
142- if profile is None :
143- return - 1
132+ def _profile_match_specificity (profile : HardwareProfile , runtime_profile : HardwareProfile ) -> int | None :
144133 specificity = 0
145134 if profile .sku is not None :
146135 if profile .sku != runtime_profile .sku :
0 commit comments