Skip to content

Commit c23c175

Browse files
authored
modify submission.json format; will contain machine information in submission.json (#152)
* modify submission.json format; will contain machine information in submission.json * fix bug in machine deserialize; add unittest * only try one line
1 parent ac5e8d6 commit c23c175

12 files changed

+102
-14
lines changed

dpdispatcher/dp_cloud_server_context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def __init__ (self,
2626
*args,
2727
**kwargs,
2828
):
29+
self.init_local_root = local_root
30+
self.init_remote_root = remote_root
2931
self.temp_local_root = os.path.abspath(local_root)
3032
self.remote_profile = remote_profile
3133
email = remote_profile.get("email", None)

dpdispatcher/hdfs_context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@ def __init__(self,
1414
):
1515

1616
assert(type(local_root) == str)
17+
self.init_local_root = local_root
18+
self.init_remote_root = remote_root
1719
self.temp_local_root = os.path.abspath(local_root)
1820
self.temp_remote_root = remote_root
1921
self.remote_profile = remote_profile

dpdispatcher/lazy_local_context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@ def __init__ (self,
3131
remote_profile:
3232
"""
3333
assert(type(local_root) == str)
34+
self.init_local_root = local_root
35+
self.init_remote_root = remote_root
3436
self.temp_local_root = os.path.abspath(local_root)
3537
self.temp_remote_root = os.path.abspath(local_root)
3638
self.remote_profile = remote_profile

dpdispatcher/local_context.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,8 @@ def __init__(self,
5555
remote_profile:
5656
"""
5757
assert(type(local_root) == str)
58+
self.init_local_root = local_root
59+
self.init_remote_root = remote_root
5860
self.temp_local_root = os.path.abspath(local_root)
5961
self.temp_remote_root = os.path.abspath(remote_root)
6062
self.remote_profile = remote_profile

dpdispatcher/machine.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,26 @@ def load_from_dict(cls, machine_dict):
124124
machine = machine_class(context=context)
125125
return machine
126126

127+
def serialize(self, if_empty_remote_profile=False):
128+
machine_dict = {}
129+
machine_dict['batch_type'] = self.__class__.__name__
130+
machine_dict['context_type'] = self.context.__class__.__name__
131+
machine_dict['local_root'] = self.context.init_local_root
132+
machine_dict['remote_root'] = self.context.init_remote_root
133+
if not if_empty_remote_profile:
134+
machine_dict['remote_profile'] = self.context.remote_profile
135+
else:
136+
machine_dict['remote_profile'] = {}
137+
return machine_dict
138+
139+
def __eq__(self, other):
140+
return self.serialize() == other.serialize()
141+
142+
@classmethod
143+
def deserialize(cls, machine_dict):
144+
machine = cls.load_from_dict(machine_dict=machine_dict)
145+
return machine
146+
127147
def check_status(self, job) :
128148
raise NotImplementedError('abstract method check_status should be implemented by derived class')
129149

dpdispatcher/ssh_context.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,12 @@ def __init__ (self,
186186
**kwargs,
187187
):
188188
assert(type(local_root) == str)
189+
self.init_local_root = local_root
190+
self.init_remote_root = remote_root
189191
self.temp_local_root = os.path.abspath(local_root)
190192
assert os.path.isabs(remote_root), f"remote_root must be a abspath"
191193
self.temp_remote_root = remote_root
194+
self.remote_profile = remote_profile
192195

193196
# self.job_uuid = None
194197
self.clean_asynchronously = clean_asynchronously
@@ -258,6 +261,12 @@ def bind_submission(self, submission):
258261
# self.remote_root = os.path.join(self.temp_remote_root, self.submission.submission_hash, self.submission.work_base )
259262
self.remote_root = pathlib.PurePath(os.path.join(self.temp_remote_root, self.submission.submission_hash)).as_posix()
260263

264+
sftp = self.ssh_session.ssh.open_sftp()
265+
try:
266+
sftp.mkdir(self.remote_root)
267+
except OSError:
268+
pass
269+
261270
# self.job_uuid = submission.submission_hash
262271
# dlog.debug("debug:SSHContext.bind_submission"
263272
# "{submission.submission_hash}; {self.local_root}; {self.remote_root")

dpdispatcher/submission.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,14 @@ def deserialize(cls, submission_dict, machine=None):
8888
backward_common_files=submission_dict['backward_common_files'])
8989
submission.belonging_jobs = [Job.deserialize(job_dict=job_dict) for job_dict in submission_dict['belonging_jobs']]
9090
submission.submission_hash = submission.get_hash()
91-
submission.bind_machine(machine=machine)
91+
if machine is not None:
92+
submission.bind_machine(machine=machine)
93+
else:
94+
machine = Machine.deserialize(machine_dict=submission_dict['machine'])
95+
submission.bind_machine(machine)
9296
return submission
9397

94-
def serialize(self, if_static=False, if_none_local_root=False):
98+
def serialize(self, if_static=False):
9599
"""convert the Submission class instance to a dictionary.
96100
97101
Parameters
@@ -105,11 +109,17 @@ def serialize(self, if_static=False, if_none_local_root=False):
105109
the dictionary converted from the Submission class instance
106110
"""
107111
submission_dict = {}
108-
if if_none_local_root:
109-
submission_dict['local_root'] = None
110-
else:
111-
submission_dict['local_root'] = self.local_root
112+
# if if_none_local_root:
113+
# submission_dict['local_root'] = None
114+
# else:
115+
# submission_dict['local_root'] = self.local_root
116+
112117
submission_dict['work_base'] = self.work_base
118+
machine = getattr(self, 'machine', None)
119+
if machine is None:
120+
submission_dict['machine'] = {}
121+
else:
122+
submission_dict['machine'] = machine.serialize()
113123
submission_dict['resources'] = self.resources.serialize()
114124
submission_dict['forward_common_files'] = self.forward_common_files
115125
submission_dict['backward_common_files'] = self.backward_common_files
@@ -333,7 +343,7 @@ def try_recover_from_json(self):
333343
if self == submission:
334344
self.belonging_jobs = submission.belonging_jobs
335345
self.bind_machine(machine=self.machine)
336-
dlog.info(f"Find old submission; recover from json; "
346+
dlog.info(f"Find old submission; recover submission from json file;"
337347
f"submission.submission_hash:{submission.submission_hash}; "
338348
f"machine.context.remote_root:{self.machine.context.remote_root}; "
339349
f"submission.work_base:{submission.work_base};")

tests/jsons/submission.json

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
{
2-
"local_root": null,
32
"work_base": "0_md/",
3+
"machine": {
4+
"batch_type": "PBS",
5+
"context_type": "LocalContext",
6+
"local_root": "test_pbs_dir/",
7+
"remote_root": "tmp_pbs_dir/",
8+
"remote_profile": {}
9+
},
410
"resources": {
511
"number_node": 1,
612
"cpu_per_node": 4,
@@ -128,4 +134,4 @@
128134
}
129135
}
130136
]
131-
}
137+
}

tests/script_gen_json.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,10 +21,12 @@
2121
with open('jsons/resources.json', 'w') as f:
2222
json.dump(resources_dict, f, indent=4)
2323

24-
submission_dict = SampleClass.get_sample_submission_dict()
24+
pbs = SampleClass.get_sample_pbs_local_context()
25+
submission = SampleClass.get_sample_submission()
26+
submission.bind_machine(machine=pbs)
2527
assert os.path.isfile('jsons/submission.json') is False
2628
with open('jsons/submission.json', 'w') as f:
27-
json.dump(submission_dict, f, indent=4)
29+
json.dump(submission.serialize(), f, indent=4)
2830

2931
job_dict = SampleClass.get_sample_job_dict()
3032
assert os.path.isfile('jsons/job.json') is False

tests/test_class_machine.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
import os,sys,json,glob,shutil,uuid,time
2+
from socket import gaierror
3+
import unittest
4+
from unittest.mock import MagicMock, patch, PropertyMock
5+
6+
7+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..')))
8+
__package__ = 'tests'
9+
from .context import LocalContext
10+
from .context import BaseContext
11+
from .context import PBS
12+
from .context import JobStatus
13+
from .context import LazyLocalContext, LocalContext, SSHContext
14+
from .context import LSF, Slurm, PBS, Shell
15+
from .context import Machine
16+
from .context import dargs
17+
from .context import DistributedShell, HDFSContext
18+
from .sample_class import SampleClass
19+
from dargs.dargs import ArgumentKeyError
20+
21+
class TestMachineInit(unittest.TestCase):
22+
def setUp(self):
23+
self.maxDiff = None
24+
25+
def test_machine_serialize_deserialize(self):
26+
pbs = SampleClass.get_sample_pbs_local_context()
27+
self.assertEqual(pbs, Machine.deserialize(pbs.serialize()))
28+
29+
def test_machine_load_from_dict(self):
30+
pbs = SampleClass.get_sample_pbs_local_context()
31+
self.assertEqual(pbs, PBS.load_from_dict(pbs.serialize()))

0 commit comments

Comments
 (0)