|
25 | 25 |
|
26 | 26 | from .slurm_command_gen_strategy import SlurmCommandGenStrategy |
27 | 27 | from .slurm_job import SlurmJob |
28 | | -from .slurm_system import SlurmJobMetadata, SlurmSystem |
| 28 | +from .slurm_metadata import SlurmJobMetadata, SlurmStepMetadata |
| 29 | +from .slurm_system import SlurmSystem |
29 | 30 |
|
30 | 31 |
|
31 | 32 | class SlurmRunner(BaseRunner): |
@@ -59,25 +60,46 @@ def _submit_test(self, tr: TestRun) -> SlurmJob: |
59 | 60 | logging.info(f"Submitted slurm job: {job_id}") |
60 | 61 | return SlurmJob(tr, id=job_id) |
61 | 62 |
|
62 | | - async def job_completion_callback(self, job: BaseJob) -> None: |
63 | | - self.store_job_metadata(job) |
| 63 | + def on_job_completion(self, job: BaseJob) -> None: |
| 64 | + logging.debug(f"Job completion callback for job {job.id}") |
| 65 | + self.store_job_metadata(cast(SlurmJob, job)) |
64 | 66 |
|
65 | | - def store_job_metadata(self, job): |
66 | | - jb = cast(SlurmJob, job) |
67 | | - system = cast(SlurmSystem, self.system) |
68 | | - cmd_gen = cast(SlurmCommandGenStrategy, jb.test_run.test.test_template.command_gen_strategy) |
69 | | - res = None if self.mode == "dry-run" else system.get_job_status(jb) |
70 | | - job_name, job_state, time_sec = "unknown", "UNKNOWN", 0 |
71 | | - if res: |
72 | | - job_name, job_state, time_sec = res[0], res[1], int(res[2]) |
73 | | - job_meta = SlurmJobMetadata( |
74 | | - job_id=int(jb.id), |
75 | | - job_name=job_name, |
76 | | - job_state=job_state, |
77 | | - elapsed_time_sec=time_sec, |
78 | | - srun_cmd=cmd_gen.gen_srun_command(jb.test_run), |
79 | | - test_cmd=" ".join(cmd_gen.generate_test_command({}, {}, jb.test_run)), |
| 67 | + def _mock_job_metadata(self) -> SlurmStepMetadata: |
| 68 | + return SlurmStepMetadata( |
| 69 | + job_id=0, |
| 70 | + step_id="", |
| 71 | + name="unknown", |
| 72 | + state="UNKNOWN", |
| 73 | + exit_code="0", |
| 74 | + start_time="", |
| 75 | + end_time="", |
| 76 | + elapsed_time_sec=0, |
| 77 | + submit_line="dry-run test", |
| 78 | + ) |
| 79 | + |
| 80 | + def _get_job_metadata( |
| 81 | + self, job: SlurmJob, steps_metadata: list[SlurmStepMetadata] |
| 82 | + ) -> tuple[Path, SlurmJobMetadata]: |
| 83 | + cmd_gen = cast(SlurmCommandGenStrategy, job.test_run.test.test_template.command_gen_strategy) |
| 84 | + return job.test_run.output_path / "slurm-job.toml", SlurmJobMetadata( |
| 85 | + job_id=int(job.id), |
| 86 | + name=steps_metadata[0].name, |
| 87 | + state=steps_metadata[0].state, |
| 88 | + exit_code=steps_metadata[0].exit_code, |
| 89 | + start_time=steps_metadata[0].start_time, |
| 90 | + end_time=steps_metadata[0].end_time, |
| 91 | + elapsed_time_sec=steps_metadata[0].elapsed_time_sec, |
| 92 | + job_steps=steps_metadata[1:], |
| 93 | + srun_cmd=cmd_gen.gen_srun_command(job.test_run), |
| 94 | + test_cmd=" ".join(cmd_gen.generate_test_command({}, {}, job.test_run)), |
| 95 | + job_root=job.test_run.output_path.absolute(), |
80 | 96 | ) |
81 | 97 |
|
82 | | - with open(jb.test_run.output_path / "slurm-job.toml", "w") as job_file: |
| 98 | + def store_job_metadata(self, job: SlurmJob): |
| 99 | + system = cast(SlurmSystem, self.system) |
| 100 | + steps_metadata = [self._mock_job_metadata()] if self.mode == "dry-run" else system.get_job_status(job) |
| 101 | + slurm_job_file, job_meta = self._get_job_metadata(job, steps_metadata) |
| 102 | + |
| 103 | + logging.debug(f"Storing job metadata for job {job.id} to {slurm_job_file}") |
| 104 | + with slurm_job_file.open("w") as job_file: |
83 | 105 | toml.dump(job_meta.model_dump(), job_file) |
0 commit comments