11import math
22import uuid
3+ from logging import Logger
4+ from typing import Any
35
46from pymongo import MongoClient
7+ from pymongo .collection import Collection
58
69
710class ASHA :
@@ -12,9 +15,9 @@ def __init__(
1215 min_r : int ,
1316 max_r : int ,
1417 metric_increases : bool ,
15- mongodb_configurations ,
16- _log ,
17- ):
18+ mongodb_configurations : dict [ str , Any ] ,
19+ _log : Logger ,
20+ ) -> None :
1821 """Doc string pretty please ^^
1922
2023 Args:
@@ -46,7 +49,9 @@ def __init__(
4649 self .mongodb_configurations , self .asha_collection_name
4750 )
4851
49- def _get_mongo_collection (self , mongodb_configurations , experiment_name ):
52+ def _get_mongo_collection (
53+ self , mongodb_configurations : dict [str , Any ], experiment_name : str
54+ ) -> Collection :
5055 """
5156 Connecting to the MongoDB, credentials from SEML config
5257 returns connection
@@ -72,7 +77,9 @@ def _get_mongo_collection(self, mongodb_configurations, experiment_name):
7277 )
7378 return collection
7479
75- def save_metric_to_db (self , collection , job_id , stage , metric ):
80+ def save_metric_to_db (
81+ self , collection : Collection , job_id : str , stage : int , metric : float
82+ ) -> None :
7683 """
7784 Insert or update metric for the given job_id and stage in the MongoDB collection.
7885 """
@@ -82,7 +89,7 @@ def save_metric_to_db(self, collection, job_id, stage, metric):
8289 upsert = True ,
8390 )
8491
85- def store_stage_metric (self , stage : int , metric : float ):
92+ def store_stage_metric (self , stage : int , metric : float ) -> bool :
8693 """
8794 Accuracy added and other metrics compaired,
8895 probably should break this into different functions,
@@ -126,7 +133,7 @@ def store_stage_metric(self, stage: int, metric: float):
126133 self .set_status_db ('Completed' )
127134 return should_terminate
128135
129- def metric_in_rungs (self , stage ) :
136+ def metric_in_rungs (self , stage : int ) -> bool :
130137 """
131138 if user wants to check if their stage/resource is in a rung
132139 """
@@ -135,7 +142,9 @@ def metric_in_rungs(self, stage):
135142 else :
136143 return False
137144
138- def get_metric_at_stage_db (self , collection , stage , current_job_id = None ):
145+ def get_metric_at_stage_db (
146+ self , collection : Collection , stage : int , current_job_id : str | None = None
147+ ) -> dict [str , float ]:
139148 """
140149 Retrieve metrics of all jobs at the specified stage from the MongoDB collection.
141150 Returns a dict: {job_id: metric}, excluding current_job_id if provided.
@@ -149,12 +158,16 @@ def get_metric_at_stage_db(self, collection, stage, current_job_id=None):
149158 metrics [job_id ] = metric
150159 return metrics
151160
152- def _print_stage_info (self , stage , metric , other_job_metrics ):
161+ def _print_stage_info (
162+ self , stage : int , metric : float , other_job_metrics : dict [str , float ]
163+ ) -> None :
153164 self ._log .info (f'[Epoch { stage } ] Own metric: { metric } ' )
154165 self ._log .info (f"[Epoch { stage } ] Other jobs' metrics: { other_job_metrics } " )
155166 pass
156167
157- def _job_promotion (self , metric , other_job_metrics , eta ):
168+ def _job_promotion (
169+ self , metric : float , other_job_metrics : dict [str , float ], eta : int | float
170+ ) -> bool :
158171 """
159172 returns cutoff metric at which jobs should be promoted
160173 """
@@ -191,7 +204,7 @@ def _job_promotion(self, metric, other_job_metrics, eta):
191204
192205 return promotion
193206
194- def generate_rungs (self , min_r , eta , max_r ) :
207+ def generate_rungs (self , min_r : int , eta : int | float , max_r : int ) -> list [ int ] :
195208 """
196209 generates rungs at which promotion will be checked
197210 """
@@ -216,7 +229,7 @@ def generate_rungs(self, min_r, eta, max_r):
216229 )
217230 return rungs
218231
219- def set_status_db (self , status ) :
232+ def set_status_db (self , status : str ) -> None :
220233 """
221234 set status in mongodb collection to mark if process is still running
222235 """
@@ -225,7 +238,7 @@ def set_status_db(self, status):
225238 )
226239
227240 # def isbest(self,metric,other_job_metrics):
228- def isbest (self ):
241+ def isbest (self ) -> None :
229242 """
230243 this function doesn't is incorrect,
231244 working on it to use the status to see if all jobs are completed
0 commit comments