Skip to content

Commit ae55fcf

Browse files
committed
Add type annotations to functions
1 parent 705e543 commit ae55fcf

File tree

1 file changed

+26
-13
lines changed

1 file changed

+26
-13
lines changed

src/seml/utils/asha.py

Lines changed: 26 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
import math
22
import uuid
3+
from logging import Logger
4+
from typing import Any
35

46
from pymongo import MongoClient
7+
from pymongo.collection import Collection
58

69

710
class ASHA:
@@ -12,9 +15,9 @@ def __init__(
1215
min_r: int,
1316
max_r: int,
1417
metric_increases: bool,
15-
mongodb_configurations,
16-
_log,
17-
):
18+
mongodb_configurations: dict[str, Any],
19+
_log: Logger,
20+
) -> None:
1821
"""Doc string pretty please ^^
1922
2023
Args:
@@ -46,7 +49,9 @@ def __init__(
4649
self.mongodb_configurations, self.asha_collection_name
4750
)
4851

49-
def _get_mongo_collection(self, mongodb_configurations, experiment_name):
52+
def _get_mongo_collection(
53+
self, mongodb_configurations: dict[str, Any], experiment_name: str
54+
) -> Collection:
5055
"""
5156
Connecting to the MongoDB, credentials from SEML config
5257
returns connection
@@ -72,7 +77,9 @@ def _get_mongo_collection(self, mongodb_configurations, experiment_name):
7277
)
7378
return collection
7479

75-
def save_metric_to_db(self, collection, job_id, stage, metric):
80+
def save_metric_to_db(
81+
self, collection: Collection, job_id: str, stage: int, metric: float
82+
) -> None:
7683
"""
7784
Insert or update metric for the given job_id and stage in the MongoDB collection.
7885
"""
@@ -82,7 +89,7 @@ def save_metric_to_db(self, collection, job_id, stage, metric):
8289
upsert=True,
8390
)
8491

85-
def store_stage_metric(self, stage: int, metric: float):
92+
def store_stage_metric(self, stage: int, metric: float) -> bool:
8693
"""
8794
Accuracy added and other metrics compaired,
8895
probably should break this into different functions,
@@ -126,7 +133,7 @@ def store_stage_metric(self, stage: int, metric: float):
126133
self.set_status_db('Completed')
127134
return should_terminate
128135

129-
def metric_in_rungs(self, stage):
136+
def metric_in_rungs(self, stage: int) -> bool:
130137
"""
131138
if user wants to check if their stage/resource is in a rung
132139
"""
@@ -135,7 +142,9 @@ def metric_in_rungs(self, stage):
135142
else:
136143
return False
137144

138-
def get_metric_at_stage_db(self, collection, stage, current_job_id=None):
145+
def get_metric_at_stage_db(
146+
self, collection: Collection, stage: int, current_job_id: str | None = None
147+
) -> dict[str, float]:
139148
"""
140149
Retrieve metrics of all jobs at the specified stage from the MongoDB collection.
141150
Returns a dict: {job_id: metric}, excluding current_job_id if provided.
@@ -149,12 +158,16 @@ def get_metric_at_stage_db(self, collection, stage, current_job_id=None):
149158
metrics[job_id] = metric
150159
return metrics
151160

152-
def _print_stage_info(self, stage, metric, other_job_metrics):
161+
def _print_stage_info(
162+
self, stage: int, metric: float, other_job_metrics: dict[str, float]
163+
) -> None:
153164
self._log.info(f'[Epoch {stage}] Own metric: {metric}')
154165
self._log.info(f"[Epoch {stage}] Other jobs' metrics: {other_job_metrics}")
155166
pass
156167

157-
def _job_promotion(self, metric, other_job_metrics, eta):
168+
def _job_promotion(
169+
self, metric: float, other_job_metrics: dict[str, float], eta: int | float
170+
) -> bool:
158171
"""
159172
returns cutoff metric at which jobs should be promoted
160173
"""
@@ -191,7 +204,7 @@ def _job_promotion(self, metric, other_job_metrics, eta):
191204

192205
return promotion
193206

194-
def generate_rungs(self, min_r, eta, max_r):
207+
def generate_rungs(self, min_r: int, eta: int | float, max_r: int) -> list[int]:
195208
"""
196209
generates rungs at which promotion will be checked
197210
"""
@@ -216,7 +229,7 @@ def generate_rungs(self, min_r, eta, max_r):
216229
)
217230
return rungs
218231

219-
def set_status_db(self, status):
232+
def set_status_db(self, status: str) -> None:
220233
"""
221234
set status in mongodb collection to mark if process is still running
222235
"""
@@ -225,7 +238,7 @@ def set_status_db(self, status):
225238
)
226239

227240
# def isbest(self,metric,other_job_metrics):
228-
def isbest(self):
241+
def isbest(self) -> None:
229242
"""
230243
this function doesn't is incorrect,
231244
working on it to use the status to see if all jobs are completed

0 commit comments

Comments
 (0)