Skip to content

Commit 66b8e0b

Browse files
authored
Merge pull request #34 from AllenInstitute/feature/DT-8749-add-support-for-job-def-role
feature/DT 8749 add support for job def role
2 parents 3d94dd0 + 39cd77e commit 66b8e0b

File tree

6 files changed

+2179
-1343
lines changed

6 files changed

+2179
-1343
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ requires-python = ">=3.9"
2020

2121
dependencies = [
2222
"boto3~=1.35",
23-
"aibs-informatics-core~=0.1",
23+
"aibs-informatics-core>=0.2.6,<1",
2424
"typing-extensions~=4.15; python_version < '3.11'",
2525
]
2626

src/aibs_informatics_aws_utils/batch.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,17 @@
2626
KeyValuePairTypeDef,
2727
LinuxParametersTypeDef,
2828
MountPointTypeDef,
29+
RegisterJobDefinitionRequestTypeDef,
2930
RegisterJobDefinitionResponseTypeDef,
3031
ResourceRequirementTypeDef,
3132
RetryStrategyTypeDef,
33+
SubmitJobRequestTypeDef,
34+
SubmitJobResponseTypeDef,
3235
VolumeTypeDef,
3336
)
3437
else:
35-
JobDefinitionTypeType = object
36-
38+
JobDefinitionTypeType = str
39+
RegisterJobDefinitionRequestTypeDef = dict
3740
ContainerOverridesTypeDef = dict
3841
ContainerPropertiesTypeDef = dict
3942
EFSVolumeConfigurationTypeDef = dict
@@ -47,6 +50,8 @@
4750
RegisterJobDefinitionResponseTypeDef = dict
4851
ResourceRequirementTypeDef = dict
4952
RetryStrategyTypeDef = dict
53+
SubmitJobRequestTypeDef = dict
54+
SubmitJobResponseTypeDef = dict
5055
VolumeTypeDef = dict
5156

5257

@@ -183,7 +188,7 @@ def register_job_definition(
183188
tags: Optional[Mapping[str, str]] = None,
184189
propagate_tags: bool = False,
185190
region: Optional[str] = None,
186-
) -> JobDefinitionTypeDef | RegisterJobDefinitionResponseTypeDef:
191+
) -> Union[JobDefinitionTypeDef, RegisterJobDefinitionResponseTypeDef]:
187192
batch = get_batch_client(region=region)
188193

189194
# First we check to make sure that we aren't crearting unnecessary revisions
@@ -195,6 +200,8 @@ def register_job_definition(
195200
if (
196201
latest_container_properties.get("command") == container_properties.get("command")
197202
and latest_container_properties.get("image") == container_properties.get("image")
203+
and latest_container_properties.get("jobRoleArn")
204+
== container_properties.get("jobRoleArn")
198205
and latest.get("parameters") == parameters
199206
and latest.get("type") == job_definition_type
200207
and latest.get("tags") == tags
@@ -205,7 +212,7 @@ def register_job_definition(
205212
"Skipping register new job definition call"
206213
)
207214
return latest
208-
register_job_definition_kwargs = dict(
215+
register_job_definition_kwargs = RegisterJobDefinitionRequestTypeDef(
209216
jobDefinitionName=job_definition_name,
210217
type=job_definition_type,
211218
parameters=parameters or {},
@@ -240,20 +247,23 @@ def get_latest_job_definition(
240247
def submit_job(
241248
job_definition: str,
242249
job_queue: str,
243-
job_name: Optional[JobName] = None,
250+
job_name: Optional[Union[JobName, str]] = None,
244251
env_base: Optional[EnvBase] = None,
245252
region: Optional[str] = None,
246-
):
253+
) -> SubmitJobResponseTypeDef:
247254
batch_client = get_batch_client(region=region)
248255
env_base = env_base or get_env_base()
249256
if job_name is None:
250257
job_name = JobName(f"{env_base}-{sha256_hexdigest()}")
251-
252-
batch_client.submit_job(
258+
else:
259+
job_name = JobName(job_name)
260+
submit_job_kwargs = SubmitJobRequestTypeDef(
253261
jobName=job_name,
254262
jobQueue=job_queue,
255263
jobDefinition=job_definition,
256264
)
265+
response = batch_client.submit_job(**submit_job_kwargs)
266+
return response
257267

258268

259269
@dataclass
@@ -269,6 +279,7 @@ class BatchJobBuilder:
269279
)
270280
mount_points: List[MountPointTypeDef] = field(default_factory=list)
271281
volumes: List[VolumeTypeDef] = field(default_factory=list)
282+
job_role_arn: Optional[str] = field(default=None)
272283
privileged: bool = field(default=False)
273284
linux_parameters: Optional[LinuxParametersTypeDef] = field(default=None)
274285
env_base: EnvBase = field(default_factory=EnvBase.from_env)
@@ -291,6 +302,8 @@ def container_properties(self) -> ContainerPropertiesTypeDef:
291302
)
292303
if self.linux_parameters:
293304
container_props["linuxParameters"] = self.linux_parameters
305+
if self.job_role_arn:
306+
container_props["jobRoleArn"] = self.job_role_arn
294307
return container_props
295308

296309
@property

src/aibs_informatics_aws_utils/dynamodb/conditions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def deserialize_condition(
255255
def _deserialize_condition(ce: ConditionBaseExpression) -> ConditionBase:
256256
ce_key = (ce.format, ce.operator)
257257
condition_base_cls = cls._CONDITION_BASE_CLASS_LOOKUP[ce_key]
258-
ce_values: list[AttributeBase | ConditionBase] = []
258+
ce_values: list[Union[AttributeBase, ConditionBase]] = []
259259
for ce_value in ce.values:
260260
if isinstance(ce_value, ConditionBaseExpression):
261261
ce_values.append(_deserialize_condition(ce_value))

test/aibs_informatics_aws_utils/test_batch.py

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,20 @@
11
from typing import TYPE_CHECKING, Dict, List
2+
from unittest import mock
3+
4+
from aibs_informatics_core.env import ENV_BASE_KEY_ALIAS, EnvBase, EnvType
5+
from aibs_informatics_core.models.aws.batch import ResourceRequirements
26

37
from aibs_informatics_aws_utils.batch import (
8+
BatchJobBuilder,
49
ContainerPropertiesTypeDef,
510
JobDefinitionTypeDef,
611
RetryStrategyTypeDef,
12+
batch_log_stream_name_to_url,
713
build_retry_strategy,
14+
describe_jobs,
815
get_batch_client,
916
register_job_definition,
17+
submit_job,
1018
to_key_value_pairs,
1119
to_mount_point,
1220
to_resource_requirements,
@@ -187,6 +195,83 @@ def test__build_retry_strategy__builds_without_default_and_custom_retry_configs(
187195
},
188196
)
189197

198+
@mock.patch("aibs_informatics_aws_utils.batch.sha256_hexdigest", return_value="hashvalue")
199+
def test__submit_job__submits_with_minimal_args(self, mock_sha: mock.MagicMock):
200+
with self.stub(self.batch_client) as batch_stubber:
201+
mock_sha.return_value = (
202+
"1ee55eb8c7f4cee6a644c1346db610ba2306547a695a7a76ff28b9a47b829fac" # noqa
203+
)
204+
job_def_name = "test-job-def-name"
205+
job_queue = "test-queue"
206+
expected_job_name = (
207+
"dev-marmotdev-1ee55eb8c7f4cee6a644c1346db610ba2306547a695a7a76ff28b9a47b829fac" # noqa
208+
)
209+
batch_stubber.add_response(
210+
"submit_job",
211+
{
212+
"jobName": expected_job_name,
213+
"jobId": "01234567-89ab-cdef-0123-456789abcdef",
214+
},
215+
{
216+
"jobName": expected_job_name,
217+
"jobQueue": job_queue,
218+
"jobDefinition": job_def_name,
219+
},
220+
)
221+
submit_response = submit_job(
222+
job_definition=job_def_name,
223+
job_queue=job_queue,
224+
env_base=self.env_base,
225+
region=self.DEFAULT_REGION,
226+
)
227+
self.assertEqual(
228+
submit_response,
229+
{
230+
"jobName": expected_job_name,
231+
"jobId": "01234567-89ab-cdef-0123-456789abcdef",
232+
},
233+
)
234+
235+
batch_stubber.assert_no_pending_responses()
236+
237+
@mock.patch("aibs_informatics_aws_utils.batch.sha256_hexdigest", return_value="hashvalue")
238+
def test__submit_job__submits_with_all_args_specified(self, mock_sha: mock.MagicMock):
239+
with self.stub(self.batch_client) as batch_stubber:
240+
mock_sha.return_value = (
241+
"1ee55eb8c7f4cee6a644c1346db610ba2306547a695a7a76ff28b9a47b829fac" # noqa
242+
)
243+
job_def_name = "test-job-def-name"
244+
job_queue = "test-queue"
245+
expected_job_name = "test-job-name"
246+
batch_stubber.add_response(
247+
"submit_job",
248+
{
249+
"jobName": expected_job_name,
250+
"jobId": "01234567-89ab-cdef-0123-456789abcdef",
251+
},
252+
{
253+
"jobName": expected_job_name,
254+
"jobQueue": job_queue,
255+
"jobDefinition": job_def_name,
256+
},
257+
)
258+
submit_response = submit_job(
259+
job_definition=job_def_name,
260+
job_queue=job_queue,
261+
job_name=expected_job_name,
262+
env_base=self.env_base,
263+
region=self.DEFAULT_REGION,
264+
)
265+
self.assertEqual(
266+
submit_response,
267+
{
268+
"jobName": expected_job_name,
269+
"jobId": "01234567-89ab-cdef-0123-456789abcdef",
270+
},
271+
)
272+
batch_stubber.assert_no_pending_responses()
273+
mock_sha.assert_not_called()
274+
190275
def get_container_props(
191276
self,
192277
command: List[str] = [],
@@ -246,6 +331,101 @@ def get_job_def_arn(self, job_def_name: str, revision: int) -> str:
246331
return f"arn:aws:batch:us-west-2:051791135335:job-definition/{job_def_name}:{revision}"
247332

248333

334+
@mock.patch("aibs_informatics_aws_utils.batch.get_region", return_value="us-east-1")
335+
def test__batch_job_builder__container_properties_include_optional_fields(_mock_get_region):
336+
env_base = EnvBase.from_type_and_label(EnvType.DEV, "builder")
337+
resource_requirements = [
338+
{"type": "MEMORY", "value": "8192"},
339+
{"type": "GPU", "value": "1"},
340+
{"type": "VCPU", "value": "2"},
341+
]
342+
builder = BatchJobBuilder(
343+
image="example:latest",
344+
job_definition_name="definition",
345+
job_name="job",
346+
command=["python", "script.py"],
347+
environment={"EXTRA": "value"},
348+
resource_requirements=resource_requirements,
349+
mount_points=[{"containerPath": "/data", "readOnly": False, "sourceVolume": "data"}],
350+
volumes=[{"name": "data", "host": {"sourcePath": "/mnt/data"}}],
351+
job_role_arn="arn:aws:iam::123456789012:role/BatchRole",
352+
privileged=True,
353+
linux_parameters={"initProcessEnabled": True},
354+
env_base=env_base,
355+
)
356+
357+
assert builder.environment[ENV_BASE_KEY_ALIAS] == env_base
358+
assert builder.environment["AWS_REGION"] == "us-east-1"
359+
assert builder.environment["EXTRA"] == "value"
360+
361+
container_props = builder.container_properties
362+
expected_environment = [
363+
{"name": "AWS_REGION", "value": "us-east-1"},
364+
{"name": ENV_BASE_KEY_ALIAS, "value": env_base},
365+
{"name": "EXTRA", "value": "value"},
366+
]
367+
expected_resource_requirements = [
368+
{"type": "GPU", "value": "1"},
369+
{"type": "MEMORY", "value": "8192"},
370+
{"type": "VCPU", "value": "2"},
371+
]
372+
373+
assert container_props["image"] == "example:latest"
374+
assert container_props["command"] == ["python", "script.py"]
375+
assert container_props["privileged"] is True
376+
assert container_props["mountPoints"] == [
377+
{"containerPath": "/data", "readOnly": False, "sourceVolume": "data"}
378+
]
379+
assert container_props["volumes"] == [{"name": "data", "host": {"sourcePath": "/mnt/data"}}]
380+
assert container_props["environment"] == expected_environment
381+
assert container_props["resourceRequirements"] == expected_resource_requirements
382+
assert container_props["linuxParameters"] == {"initProcessEnabled": True}
383+
assert container_props["jobRoleArn"] == "arn:aws:iam::123456789012:role/BatchRole"
384+
assert builder._normalized_resource_requirements() == expected_resource_requirements
385+
386+
387+
@mock.patch("aibs_informatics_aws_utils.batch.get_region", return_value="us-west-2")
388+
def test__batch_job_builder__container_overrides_and_pascal_case(_mock_get_region):
389+
env_base = EnvBase.from_type_and_label(EnvType.TEST, "builder")
390+
builder = BatchJobBuilder(
391+
image="example:latest",
392+
job_definition_name="definition",
393+
job_name="job",
394+
environment={"EXTRA": "value", "NULL": None},
395+
resource_requirements=ResourceRequirements(GPU=2, MEMORY=4096, VCPU=16),
396+
env_base=env_base,
397+
)
398+
399+
expected_resource_requirements = [
400+
{"type": "GPU", "value": "2"},
401+
{"type": "MEMORY", "value": "4096"},
402+
{"type": "VCPU", "value": "16"},
403+
]
404+
expected_environment = [
405+
{"name": "AWS_REGION", "value": "us-west-2"},
406+
{"name": ENV_BASE_KEY_ALIAS, "value": env_base},
407+
{"name": "EXTRA", "value": "value"},
408+
]
409+
410+
container_overrides = builder.container_overrides
411+
assert builder.environment["NULL"] is None
412+
assert container_overrides["resourceRequirements"] == expected_resource_requirements
413+
assert container_overrides["environment"] == expected_environment
414+
assert builder.container_overrides__sfn == {
415+
"Environment": [
416+
{"Name": "AWS_REGION", "Value": "us-west-2"},
417+
{"Name": ENV_BASE_KEY_ALIAS, "Value": env_base},
418+
{"Name": "EXTRA", "Value": "value"},
419+
],
420+
"ResourceRequirements": [
421+
{"Type": "GPU", "Value": "2"},
422+
{"Type": "MEMORY", "Value": "4096"},
423+
{"Type": "VCPU", "Value": "16"},
424+
],
425+
}
426+
assert builder._normalized_resource_requirements() == expected_resource_requirements
427+
428+
249429
def test__to_volume__works():
250430
volume = to_volume("source", "name", None)
251431
expected = {
@@ -283,3 +463,24 @@ def test__to_key_value_pairs__works():
283463

284464
expected = [{"name": "a", "value": "a"}, {"name": "b", "value": None}]
285465
assert key_value_pairs == expected
466+
467+
468+
@mock.patch("aibs_informatics_aws_utils.batch.get_batch_client")
469+
def test__describe_jobs__works(mock_get_batch_client):
470+
mock_client = mock.MagicMock()
471+
mock_get_batch_client.return_value = mock_client
472+
mock_client.describe_jobs.return_value = {"jobs": []}
473+
474+
describe_jobs(job_ids=["job1", "job2"])
475+
mock_client.describe_jobs.assert_called_once_with(jobs=["job1", "job2"])
476+
477+
478+
@mock.patch("aibs_informatics_aws_utils.batch.build_log_stream_url")
479+
def test__batch_log_stream_name_to_url__works(mock_build_log_stream_url):
480+
mock_build_log_stream_url.return_value = "http://example.com"
481+
batch_log_stream_name_to_url(log_stream_name="stream", region="us-west-2")
482+
mock_build_log_stream_url.assert_called_once_with(
483+
log_group_name="/aws/batch/job",
484+
log_stream_name="stream",
485+
region="us-west-2",
486+
)

0 commit comments

Comments
 (0)