|
1 | 1 | from typing import TYPE_CHECKING, Dict, List |
| 2 | +from unittest import mock |
| 3 | + |
| 4 | +from aibs_informatics_core.env import ENV_BASE_KEY_ALIAS, EnvBase, EnvType |
| 5 | +from aibs_informatics_core.models.aws.batch import ResourceRequirements |
2 | 6 |
|
3 | 7 | from aibs_informatics_aws_utils.batch import ( |
| 8 | + BatchJobBuilder, |
4 | 9 | ContainerPropertiesTypeDef, |
5 | 10 | JobDefinitionTypeDef, |
6 | 11 | RetryStrategyTypeDef, |
| 12 | + batch_log_stream_name_to_url, |
7 | 13 | build_retry_strategy, |
| 14 | + describe_jobs, |
8 | 15 | get_batch_client, |
9 | 16 | register_job_definition, |
| 17 | + submit_job, |
10 | 18 | to_key_value_pairs, |
11 | 19 | to_mount_point, |
12 | 20 | to_resource_requirements, |
@@ -187,6 +195,83 @@ def test__build_retry_strategy__builds_without_default_and_custom_retry_configs( |
187 | 195 | }, |
188 | 196 | ) |
189 | 197 |
|
| 198 | + @mock.patch("aibs_informatics_aws_utils.batch.sha256_hexdigest", return_value="hashvalue") |
| 199 | + def test__submit_job__submits_with_minimal_args(self, mock_sha: mock.MagicMock): |
| 200 | + with self.stub(self.batch_client) as batch_stubber: |
| 201 | + mock_sha.return_value = ( |
| 202 | + "1ee55eb8c7f4cee6a644c1346db610ba2306547a695a7a76ff28b9a47b829fac" # noqa |
| 203 | + ) |
| 204 | + job_def_name = "test-job-def-name" |
| 205 | + job_queue = "test-queue" |
| 206 | + expected_job_name = ( |
| 207 | + "dev-marmotdev-1ee55eb8c7f4cee6a644c1346db610ba2306547a695a7a76ff28b9a47b829fac" # noqa |
| 208 | + ) |
| 209 | + batch_stubber.add_response( |
| 210 | + "submit_job", |
| 211 | + { |
| 212 | + "jobName": expected_job_name, |
| 213 | + "jobId": "01234567-89ab-cdef-0123-456789abcdef", |
| 214 | + }, |
| 215 | + { |
| 216 | + "jobName": expected_job_name, |
| 217 | + "jobQueue": job_queue, |
| 218 | + "jobDefinition": job_def_name, |
| 219 | + }, |
| 220 | + ) |
| 221 | + submit_response = submit_job( |
| 222 | + job_definition=job_def_name, |
| 223 | + job_queue=job_queue, |
| 224 | + env_base=self.env_base, |
| 225 | + region=self.DEFAULT_REGION, |
| 226 | + ) |
| 227 | + self.assertEqual( |
| 228 | + submit_response, |
| 229 | + { |
| 230 | + "jobName": expected_job_name, |
| 231 | + "jobId": "01234567-89ab-cdef-0123-456789abcdef", |
| 232 | + }, |
| 233 | + ) |
| 234 | + |
| 235 | + batch_stubber.assert_no_pending_responses() |
| 236 | + |
| 237 | + @mock.patch("aibs_informatics_aws_utils.batch.sha256_hexdigest", return_value="hashvalue") |
| 238 | + def test__submit_job__submits_with_all_args_specified(self, mock_sha: mock.MagicMock): |
| 239 | + with self.stub(self.batch_client) as batch_stubber: |
| 240 | + mock_sha.return_value = ( |
| 241 | + "1ee55eb8c7f4cee6a644c1346db610ba2306547a695a7a76ff28b9a47b829fac" # noqa |
| 242 | + ) |
| 243 | + job_def_name = "test-job-def-name" |
| 244 | + job_queue = "test-queue" |
| 245 | + expected_job_name = "test-job-name" |
| 246 | + batch_stubber.add_response( |
| 247 | + "submit_job", |
| 248 | + { |
| 249 | + "jobName": expected_job_name, |
| 250 | + "jobId": "01234567-89ab-cdef-0123-456789abcdef", |
| 251 | + }, |
| 252 | + { |
| 253 | + "jobName": expected_job_name, |
| 254 | + "jobQueue": job_queue, |
| 255 | + "jobDefinition": job_def_name, |
| 256 | + }, |
| 257 | + ) |
| 258 | + submit_response = submit_job( |
| 259 | + job_definition=job_def_name, |
| 260 | + job_queue=job_queue, |
| 261 | + job_name=expected_job_name, |
| 262 | + env_base=self.env_base, |
| 263 | + region=self.DEFAULT_REGION, |
| 264 | + ) |
| 265 | + self.assertEqual( |
| 266 | + submit_response, |
| 267 | + { |
| 268 | + "jobName": expected_job_name, |
| 269 | + "jobId": "01234567-89ab-cdef-0123-456789abcdef", |
| 270 | + }, |
| 271 | + ) |
| 272 | + batch_stubber.assert_no_pending_responses() |
| 273 | + mock_sha.assert_not_called() |
| 274 | + |
190 | 275 | def get_container_props( |
191 | 276 | self, |
192 | 277 | command: List[str] = [], |
@@ -246,6 +331,101 @@ def get_job_def_arn(self, job_def_name: str, revision: int) -> str: |
246 | 331 | return f"arn:aws:batch:us-west-2:051791135335:job-definition/{job_def_name}:{revision}" |
247 | 332 |
|
248 | 333 |
|
| 334 | +@mock.patch("aibs_informatics_aws_utils.batch.get_region", return_value="us-east-1") |
| 335 | +def test__batch_job_builder__container_properties_include_optional_fields(_mock_get_region): |
| 336 | + env_base = EnvBase.from_type_and_label(EnvType.DEV, "builder") |
| 337 | + resource_requirements = [ |
| 338 | + {"type": "MEMORY", "value": "8192"}, |
| 339 | + {"type": "GPU", "value": "1"}, |
| 340 | + {"type": "VCPU", "value": "2"}, |
| 341 | + ] |
| 342 | + builder = BatchJobBuilder( |
| 343 | + image="example:latest", |
| 344 | + job_definition_name="definition", |
| 345 | + job_name="job", |
| 346 | + command=["python", "script.py"], |
| 347 | + environment={"EXTRA": "value"}, |
| 348 | + resource_requirements=resource_requirements, |
| 349 | + mount_points=[{"containerPath": "/data", "readOnly": False, "sourceVolume": "data"}], |
| 350 | + volumes=[{"name": "data", "host": {"sourcePath": "/mnt/data"}}], |
| 351 | + job_role_arn="arn:aws:iam::123456789012:role/BatchRole", |
| 352 | + privileged=True, |
| 353 | + linux_parameters={"initProcessEnabled": True}, |
| 354 | + env_base=env_base, |
| 355 | + ) |
| 356 | + |
| 357 | + assert builder.environment[ENV_BASE_KEY_ALIAS] == env_base |
| 358 | + assert builder.environment["AWS_REGION"] == "us-east-1" |
| 359 | + assert builder.environment["EXTRA"] == "value" |
| 360 | + |
| 361 | + container_props = builder.container_properties |
| 362 | + expected_environment = [ |
| 363 | + {"name": "AWS_REGION", "value": "us-east-1"}, |
| 364 | + {"name": ENV_BASE_KEY_ALIAS, "value": env_base}, |
| 365 | + {"name": "EXTRA", "value": "value"}, |
| 366 | + ] |
| 367 | + expected_resource_requirements = [ |
| 368 | + {"type": "GPU", "value": "1"}, |
| 369 | + {"type": "MEMORY", "value": "8192"}, |
| 370 | + {"type": "VCPU", "value": "2"}, |
| 371 | + ] |
| 372 | + |
| 373 | + assert container_props["image"] == "example:latest" |
| 374 | + assert container_props["command"] == ["python", "script.py"] |
| 375 | + assert container_props["privileged"] is True |
| 376 | + assert container_props["mountPoints"] == [ |
| 377 | + {"containerPath": "/data", "readOnly": False, "sourceVolume": "data"} |
| 378 | + ] |
| 379 | + assert container_props["volumes"] == [{"name": "data", "host": {"sourcePath": "/mnt/data"}}] |
| 380 | + assert container_props["environment"] == expected_environment |
| 381 | + assert container_props["resourceRequirements"] == expected_resource_requirements |
| 382 | + assert container_props["linuxParameters"] == {"initProcessEnabled": True} |
| 383 | + assert container_props["jobRoleArn"] == "arn:aws:iam::123456789012:role/BatchRole" |
| 384 | + assert builder._normalized_resource_requirements() == expected_resource_requirements |
| 385 | + |
| 386 | + |
| 387 | +@mock.patch("aibs_informatics_aws_utils.batch.get_region", return_value="us-west-2") |
| 388 | +def test__batch_job_builder__container_overrides_and_pascal_case(_mock_get_region): |
| 389 | + env_base = EnvBase.from_type_and_label(EnvType.TEST, "builder") |
| 390 | + builder = BatchJobBuilder( |
| 391 | + image="example:latest", |
| 392 | + job_definition_name="definition", |
| 393 | + job_name="job", |
| 394 | + environment={"EXTRA": "value", "NULL": None}, |
| 395 | + resource_requirements=ResourceRequirements(GPU=2, MEMORY=4096, VCPU=16), |
| 396 | + env_base=env_base, |
| 397 | + ) |
| 398 | + |
| 399 | + expected_resource_requirements = [ |
| 400 | + {"type": "GPU", "value": "2"}, |
| 401 | + {"type": "MEMORY", "value": "4096"}, |
| 402 | + {"type": "VCPU", "value": "16"}, |
| 403 | + ] |
| 404 | + expected_environment = [ |
| 405 | + {"name": "AWS_REGION", "value": "us-west-2"}, |
| 406 | + {"name": ENV_BASE_KEY_ALIAS, "value": env_base}, |
| 407 | + {"name": "EXTRA", "value": "value"}, |
| 408 | + ] |
| 409 | + |
| 410 | + container_overrides = builder.container_overrides |
| 411 | + assert builder.environment["NULL"] is None |
| 412 | + assert container_overrides["resourceRequirements"] == expected_resource_requirements |
| 413 | + assert container_overrides["environment"] == expected_environment |
| 414 | + assert builder.container_overrides__sfn == { |
| 415 | + "Environment": [ |
| 416 | + {"Name": "AWS_REGION", "Value": "us-west-2"}, |
| 417 | + {"Name": ENV_BASE_KEY_ALIAS, "Value": env_base}, |
| 418 | + {"Name": "EXTRA", "Value": "value"}, |
| 419 | + ], |
| 420 | + "ResourceRequirements": [ |
| 421 | + {"Type": "GPU", "Value": "2"}, |
| 422 | + {"Type": "MEMORY", "Value": "4096"}, |
| 423 | + {"Type": "VCPU", "Value": "16"}, |
| 424 | + ], |
| 425 | + } |
| 426 | + assert builder._normalized_resource_requirements() == expected_resource_requirements |
| 427 | + |
| 428 | + |
249 | 429 | def test__to_volume__works(): |
250 | 430 | volume = to_volume("source", "name", None) |
251 | 431 | expected = { |
@@ -283,3 +463,24 @@ def test__to_key_value_pairs__works(): |
283 | 463 |
|
284 | 464 | expected = [{"name": "a", "value": "a"}, {"name": "b", "value": None}] |
285 | 465 | assert key_value_pairs == expected |
| 466 | + |
| 467 | + |
| 468 | +@mock.patch("aibs_informatics_aws_utils.batch.get_batch_client") |
| 469 | +def test__describe_jobs__works(mock_get_batch_client): |
| 470 | + mock_client = mock.MagicMock() |
| 471 | + mock_get_batch_client.return_value = mock_client |
| 472 | + mock_client.describe_jobs.return_value = {"jobs": []} |
| 473 | + |
| 474 | + describe_jobs(job_ids=["job1", "job2"]) |
| 475 | + mock_client.describe_jobs.assert_called_once_with(jobs=["job1", "job2"]) |
| 476 | + |
| 477 | + |
| 478 | +@mock.patch("aibs_informatics_aws_utils.batch.build_log_stream_url") |
| 479 | +def test__batch_log_stream_name_to_url__works(mock_build_log_stream_url): |
| 480 | + mock_build_log_stream_url.return_value = "http://example.com" |
| 481 | + batch_log_stream_name_to_url(log_stream_name="stream", region="us-west-2") |
| 482 | + mock_build_log_stream_url.assert_called_once_with( |
| 483 | + log_group_name="/aws/batch/job", |
| 484 | + log_stream_name="stream", |
| 485 | + region="us-west-2", |
| 486 | + ) |
0 commit comments