Skip to content

Commit 192713f

Browse files
authored
Merge pull request #301 from NVIDIA/am/upd-jobs
Refactor Job classes
2 parents d61a046 + db69757 commit 192713f

File tree

12 files changed

+49
-151
lines changed

12 files changed

+49
-151
lines changed

src/cloudai/_core/base_job.py

Lines changed: 6 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -14,73 +14,16 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
from pathlib import Path
17+
from dataclasses import dataclass, field
1818
from typing import Union
1919

20-
from .system import System
2120
from .test_scenario import TestRun
2221

2322

23+
@dataclass
2424
class BaseJob:
25-
"""
26-
Base class for representing a job created by executing a test.
25+
"""Base class for representing a job created by executing a test."""
2726

28-
Attributes
29-
id (Union[str, int]): The unique identifier of the job.
30-
mode (str): The mode of the job (e.g., 'run', 'dry-run').
31-
system (System): The system in which the job is running.
32-
test_run (TestRun): The TestRun instance associated with this job.
33-
output_path (Path): The path where the job's output is stored.
34-
terminated_by_dependency (bool): Flag to indicate if the job was terminated due to a dependency.
35-
"""
36-
37-
def __init__(self, mode: str, system: System, test_run: TestRun):
38-
"""
39-
Initialize a BaseJob instance.
40-
41-
Args:
42-
mode (str): The mode of the job (e.g., 'run', 'dry-run').
43-
system (System): The system in which the job is running.
44-
test_run (TestRun): The TestRun instance associated with this job.
45-
"""
46-
self.id: Union[str, int] = 0
47-
self.mode: str = mode
48-
self.system: System = system
49-
self.test_run: TestRun = test_run
50-
self.output_path: Path = test_run.output_path
51-
self.terminated_by_dependency: bool = False
52-
53-
def is_running(self) -> bool:
54-
"""
55-
Check if the specified job is currently running.
56-
57-
Returns
58-
bool: True if the job is running, False otherwise.
59-
"""
60-
if self.mode == "dry-run":
61-
return True
62-
return self.system.is_job_running(self)
63-
64-
def is_completed(self) -> bool:
65-
"""
66-
Check if a job is completed.
67-
68-
Returns
69-
bool: True if the job is completed, False otherwise.
70-
"""
71-
if self.mode == "dry-run":
72-
return True
73-
return self.system.is_job_completed(self)
74-
75-
def increment_iteration(self):
76-
"""Increment the iteration count of the associated test."""
77-
self.test_run.current_iteration += 1
78-
79-
def __repr__(self) -> str:
80-
"""
81-
Return a string representation of the BaseJob instance.
82-
83-
Returns
84-
str: String representation of the job.
85-
"""
86-
return f"BaseJob(id={self.id}, mode={self.mode}, system={self.system.name}, test={self.test_run.test.name})"
27+
test_run: TestRun
28+
id: Union[str, int]
29+
terminated_by_dependency: bool = field(default=False, init=False)

src/cloudai/_core/base_runner.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ async def check_start_post_init_dependencies(self):
204204
items = list(self.testrun_to_job_map.items())
205205

206206
for tr, job in items:
207-
if job.is_running():
207+
if self.system.is_job_running(job):
208208
await self.check_and_schedule_start_post_init_dependent_tests(tr)
209209

210210
async def check_and_schedule_start_post_init_dependent_tests(self, started_test_run: TestRun):
@@ -279,7 +279,7 @@ async def monitor_jobs(self) -> int:
279279
successful_jobs_count = 0
280280

281281
for job in list(self.jobs):
282-
if job.is_completed():
282+
if self.system.is_job_completed(job):
283283
await self.job_completion_callback(job)
284284

285285
if self.mode == "dry-run":
@@ -322,7 +322,7 @@ def get_job_status(self, job: BaseJob) -> JobStatusResult:
322322
Returns:
323323
JobStatusResult: The result containing the job status and an optional error message.
324324
"""
325-
return job.test_run.test.test_template.get_job_status(job.output_path)
325+
return job.test_run.test.test_template.get_job_status(job.test_run.output_path)
326326

327327
async def handle_job_completion(self, completed_job: BaseJob):
328328
"""
@@ -335,7 +335,7 @@ async def handle_job_completion(self, completed_job: BaseJob):
335335

336336
self.jobs.remove(completed_job)
337337
del self.testrun_to_job_map[completed_job.test_run]
338-
completed_job.increment_iteration()
338+
completed_job.test_run.current_iteration += 1
339339
if not completed_job.terminated_by_dependency and completed_job.test_run.has_more_iterations():
340340
msg = f"Re-running job for iteration {completed_job.test_run.current_iteration}"
341341
logging.info(msg)

src/cloudai/runner/kubernetes/kubernetes_job.py

Lines changed: 7 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,42 +15,14 @@
1515
# limitations under the License.
1616

1717

18-
from cloudai import BaseJob, System, TestRun
18+
from dataclasses import dataclass
1919

20+
from cloudai import BaseJob
2021

21-
class KubernetesJob(BaseJob):
22-
"""
23-
A job class for execution on a Kubernetes system.
24-
25-
Attributes
26-
mode (str): The mode of the job (e.g., 'run', 'dry-run').
27-
system (System): The system in which the job is running.
28-
test_run (TestRun): The test instance associated with this job.
29-
name (str): The name of the job.
30-
kind (str): The kind of the job.
31-
"""
32-
33-
def __init__(self, mode: str, system: System, test_run: TestRun, name: str, kind: str):
34-
"""
35-
Initialize a KubernetesJob instance.
3622

37-
Args:
38-
mode (str): The mode of the job (e.g., 'run', 'dry-run').
39-
system (System): The system in which the job is running.
40-
test_run (TestRun): The test instance associated with this job.
41-
name (str): The name of the job.
42-
kind (str): The kind of the job.
43-
"""
44-
super().__init__(mode, system, test_run)
45-
self.id = name
46-
self.name = name
47-
self.kind = kind
48-
49-
def __repr__(self) -> str:
50-
"""
51-
Return a string representation of the KubernetesJob instance.
23+
@dataclass
24+
class KubernetesJob(BaseJob):
25+
"""A job class for execution on a Kubernetes system."""
5226

53-
Returns
54-
str: String representation of the job.
55-
"""
56-
return f"KubernetesJob(name={self.name}, test={self.test_run.test.name}, " f"kind={self.kind})"
27+
kind: str
28+
name: str

src/cloudai/runner/kubernetes/kubernetes_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def _submit_test(self, tr: TestRun) -> KubernetesJob:
4747
k8s_system: KubernetesSystem = cast(KubernetesSystem, self.system)
4848
job_name = k8s_system.create_job(job_spec)
4949

50-
return KubernetesJob(self.mode, self.system, tr, job_name, job_kind)
50+
return KubernetesJob(tr, id=job_name, name=job_name, kind=job_kind)
5151

5252
async def job_completion_callback(self, job: BaseJob) -> None:
5353
"""
@@ -58,7 +58,7 @@ async def job_completion_callback(self, job: BaseJob) -> None:
5858
"""
5959
k8s_system: KubernetesSystem = cast(KubernetesSystem, self.system)
6060
k_job = cast(KubernetesJob, job)
61-
k8s_system.store_logs_for_job(k_job.name, k_job.output_path)
61+
k8s_system.store_logs_for_job(k_job.name, k_job.test_run.output_path)
6262
k8s_system.delete_job(k_job.name, k_job.kind)
6363

6464
def kill_job(self, job: BaseJob) -> None:
@@ -70,5 +70,5 @@ def kill_job(self, job: BaseJob) -> None:
7070
"""
7171
k8s_system: KubernetesSystem = cast(KubernetesSystem, self.system)
7272
k_job = cast(KubernetesJob, job)
73-
k8s_system.store_logs_for_job(k_job.name, k_job.output_path)
73+
k8s_system.store_logs_for_job(k_job.name, k_job.test_run.output_path)
7474
k8s_system.delete_job(k_job.name, k_job.kind)

src/cloudai/runner/slurm/slurm_job.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,13 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
from typing import Union
17+
from dataclasses import dataclass
1818

19-
from cloudai import BaseJob, System, TestRun
19+
from cloudai import BaseJob
2020

2121

22+
@dataclass
2223
class SlurmJob(BaseJob):
23-
"""
24-
A job class for execution on a Slurm system.
24+
"""A job class for execution on a Slurm system."""
2525

26-
Attributes
27-
id (Union[str, int]): The unique identifier of the job.
28-
"""
29-
30-
def __init__(self, mode: str, system: System, test_run: TestRun, job_id: Union[str, int]):
31-
BaseJob.__init__(self, mode, system, test_run)
32-
self.id = job_id
33-
34-
def __repr__(self) -> str:
35-
"""
36-
Return a string representation of the SlurmJob instance.
37-
38-
Returns
39-
str: String representation of the job.
40-
"""
41-
return f"SlurmJob(id={self.id}, test={self.test_run.test.name})"
26+
pass

src/cloudai/runner/slurm/slurm_runner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,5 +68,5 @@ def _submit_test(self, tr: TestRun) -> SlurmJob:
6868
stderr=stderr,
6969
message="Failed to retrieve job ID from command output.",
7070
)
71-
logging.info(f"Submitted slurm job: {job_id}")
72-
return SlurmJob(self.mode, self.system, tr, job_id)
71+
logging.info(f"Submitted slurm job: {job_id}")
72+
return SlurmJob(tr, id=job_id)

src/cloudai/runner/standalone/standalone_job.py

Lines changed: 5 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,28 +14,13 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
from typing import Union
17+
from dataclasses import dataclass
1818

19-
from cloudai import BaseJob, System, TestRun
19+
from cloudai import BaseJob
2020

2121

22+
@dataclass
2223
class StandaloneJob(BaseJob):
23-
"""
24-
A job class for standalone execution.
24+
"""A job class for standalone execution."""
2525

26-
Attributes
27-
id (Union[str, int]): The unique identifier of the job.
28-
"""
29-
30-
def __init__(self, mode: str, system: System, test_run: TestRun, job_id: Union[str, int]):
31-
BaseJob.__init__(self, mode, system, test_run)
32-
self.id = job_id
33-
34-
def __repr__(self) -> str:
35-
"""
36-
Return a string representation of the StandaloneJob instance.
37-
38-
Returns
39-
str: String representation of the job.
40-
"""
41-
return f"StandaloneJob(id={self.id}, test={self.test_run.test.name})"
26+
pass

src/cloudai/runner/standalone/standalone_runner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,4 +68,4 @@ def _submit_test(self, tr: TestRun) -> StandaloneJob:
6868
stderr="",
6969
message="Failed to retrieve job ID from command output.",
7070
)
71-
return StandaloneJob(self.mode, self.system, tr, job_id)
71+
return StandaloneJob(tr, id=job_id)

src/cloudai/schema/test_template/nemo_launcher/slurm_command_gen_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def gen_exec_command(self, tr: TestRun) -> str:
6969
self.final_cmd_args.update(
7070
{
7171
"base_results_dir": str(tr.output_path.absolute()),
72-
"launcher_scripts_path": str((repo_path / tdef.cmd_args.launcher_script).parent),
72+
"launcher_scripts_path": str((repo_path / tdef.cmd_args.launcher_script).parent.absolute()),
7373
}
7474
)
7575

tests/slurm_command_gen_strategy/test_nemo_launcher_slurm_command_gen_strategy.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,10 @@ def test_log_command_to_file(
212212
test_run.output_path = tmp_path / "output_dir"
213213
test_run.output_path.mkdir()
214214

215+
repo_path = (tmp_path / "repo").relative_to(tmp_path)
216+
tdef: NeMoLauncherTestDefinition = cast(NeMoLauncherTestDefinition, test_run.test.test_definition)
217+
tdef.python_executable.git_repo.installed_path = repo_path
218+
tdef.python_executable.venv_path = repo_path.parent / f"{repo_path.name}-venv"
215219
cmd_gen_strategy.gen_exec_command(test_run)
216220

217221
written_content = mock_file().write.call_args[0][0]
@@ -221,6 +225,11 @@ def test_log_command_to_file(
221225
assert "TEST_VAR_1=value1" in written_content, "Logged command should contain environment variables"
222226
assert "training.trainer.num_nodes=2" in written_content, "Command should contain the number of nodes"
223227

228+
assert str((tdef.python_executable.venv_path / "bin" / "python").absolute()) in written_content
229+
assert (
230+
f"launcher_scripts_path={(repo_path / tdef.cmd_args.launcher_script).parent.absolute()} " in written_content
231+
)
232+
224233
def test_no_line_breaks_in_executed_command(
225234
self, cmd_gen_strategy: NeMoLauncherSlurmCommandGenStrategy, test_run: TestRun, tmp_path: Path
226235
) -> None:

0 commit comments

Comments
 (0)