Skip to content

Commit 718949b

Browse files
committed
Refactor JobStatus
1 parent fec7bd7 commit 718949b

3 files changed

Lines changed: 10 additions & 11 deletions

File tree

src/swiss_ai_model_launch/launchers/firecrest_launcher.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,7 @@ async def get_job_status(self, job_id: int) -> JobStatus:
170170
jobid=str(job_id),
171171
# account=self.account, # TODO
172172
)
173-
return JobStatus(str(job_info[0]["status"]["state"]))
173+
return JobStatus.from_str(str(job_info[0]["status"]["state"]))
174174

175175
async def get_job_logs(self, job_id: int) -> tuple[str, str]:
176176
log_dir = Path(self._get_working_dir()) / "logs" / str(job_id)

src/swiss_ai_model_launch/launchers/launcher.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,13 @@ class JobStatus(Enum):
1111
TIMEOUT = "TIMEOUT"
1212
UNKNOWN = "UNKNOWN"
1313

14+
@classmethod
15+
def from_str(cls, state: str) -> "JobStatus":
16+
try:
17+
return cls(state)
18+
except ValueError:
19+
return cls.UNKNOWN
20+
1421

1522
class Launcher(ABC):
1623
def __init__(

src/swiss_ai_model_launch/launchers/slurm_launcher.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,6 @@
1919

2020
_APP_WORKING_DIRECTORY = ".sml"
2121

22-
_SLURM_STATE_MAP: dict[str, JobStatus] = {
23-
"PENDING": JobStatus.PENDING,
24-
"CONFIGURING": JobStatus.PENDING,
25-
"RUNNING": JobStatus.RUNNING,
26-
"COMPLETING": JobStatus.RUNNING,
27-
"TIMEOUT": JobStatus.TIMEOUT,
28-
}
29-
3022

3123
class SlurmLauncher(Launcher):
3224
def __init__(
@@ -177,7 +169,7 @@ async def get_job_status(self, job_id: int) -> JobStatus:
177169
state = stdout.decode().strip()
178170

179171
if state:
180-
return _SLURM_STATE_MAP.get(state, JobStatus.UNKNOWN)
172+
return JobStatus.from_str(state)
181173

182174
# Job not in squeue — check sacct for terminal state
183175
proc = await asyncio.create_subprocess_exec(
@@ -194,7 +186,7 @@ async def get_job_status(self, job_id: int) -> JobStatus:
194186
stdout, _ = await proc.communicate()
195187
lines = [line.strip() for line in stdout.decode().splitlines() if line.strip()]
196188
if lines:
197-
return _SLURM_STATE_MAP.get(lines[0].split()[0], JobStatus.UNKNOWN)
189+
return JobStatus.from_str(lines[0].split()[0])
198190

199191
return JobStatus.UNKNOWN
200192

0 commit comments

Comments
 (0)