@@ -142,6 +142,52 @@ def __sortNodesByExecutionTime(self, nodes: list):
142142 # return sorted(nodes, key=lambda n: self.__getNumberOfSlowRanksOnNode(n))
143143 return sorted (node_times , key = lambda t : node_times [t ])
144144
145+ def __sortNodesByMaxRankExecutionTime (self , nodes : list ):
146+ """
147+ Takes in a list of node names and sorts them based on total execution time.
148+ The fastest nodes will be first, and the slowest will be last.
149+ """
150+ node_times = {}
151+ for r , n in self .__rank_to_node_map .items ():
152+ if n in nodes :
153+ if n not in node_times :
154+ node_times [n ] = 0.0
155+ if self .__rank_times [r ] > node_times [n ]:
156+ node_times [n ] = self .__rank_times [r ]
157+ # Alternative:
158+ # return sorted(nodes, key=lambda n: self.__getNumberOfSlowRanksOnNode(n))
159+ return sorted (node_times , key = lambda t : node_times [t ])
160+
161+ def __sortNodesByNodeDevFromAvgExecutionTime (self , nodes : list ):
162+ """
163+ Takes in a list of node names and sorts them based on how much they deviate
164+ from the average total execution time.
165+ """
166+ node_times = {}
167+ for r , n in self .__rank_to_node_map .items ():
168+ if n in nodes :
169+ if n not in node_times :
170+ node_times [n ] = 0.0
171+ node_times [n ] += self .__rank_times [r ]
172+ avg = np .mean (list (node_times .values ()))
173+ return sorted (node_times , key = lambda t : abs (node_times [t ]- avg ))
174+
175+ def __sortNodesByRankDevFromAvgExecutionTime (self , nodes : list ):
176+ """
177+ Takes in a list of node names and sorts them based on how much they deviate
178+ from the average total execution time.
179+ """
180+ avg = np .mean (list (self .__rank_times .values ()))
181+ node_dev_times = {}
182+ for r , n in self .__rank_to_node_map .items ():
183+ if n in nodes :
184+ if n not in node_dev_times :
185+ node_dev_times [n ] = 0.0
186+ this_dev_time = abs (self .__rank_times [r ]- avg )
187+ if this_dev_time > node_dev_times [n ]:
188+ node_dev_times [n ] = this_dev_time
189+ return sorted (node_dev_times , key = lambda t : node_dev_times [t ])
190+
145191 def __findHighOutliers (self , data ):
146192 """
147193 Finds data points that are some percentage (given by self.__threshold_pct)
@@ -386,7 +432,8 @@ def createHostfile(self):
386432 elif num_good_nodes > self .__num_nodes :
387433 n_nodes_to_drop = num_good_nodes - self .__num_nodes
388434 assert n_nodes_to_drop > 0 , f"Cannot drop { n_nodes_to_drop } "
389- sorted_nodes = self .__sortNodesByExecutionTime (good_node_names )
435+ #sorted_nodes = self.__sortNodesByExecutionTime(good_node_names)
436+ sorted_nodes = self .__sortNodesByMaxRankExecutionTime (good_node_names )
390437 print (
391438 f"Since the SlowNodeDetector originally found { num_good_nodes } good node{ s } , "
392439 f"but only { self .__num_nodes } are needed, the following nodes will also be "
0 commit comments