Skip to content

Commit b99ecec

Browse files
authored
Merge pull request #44 from AllenInstitute/feature/add-gpu-support-for-sm-batch-fragments
Add GPU support to BatchInvokedLambdaFunction and BatchOperation classes
2 parents 91fe227 + e4eb620 commit b99ecec

File tree

8 files changed

+570
-89
lines changed

8 files changed

+570
-89
lines changed

Makefile

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -150,9 +150,6 @@ format: format-isort format-black ## Run all formatters (black, isort)
150150
pytest: $(INSTALL_STAMP) ## Run test (pytest)
151151
$(VENV_BIN)/pytest -vv --durations=10
152152

153-
tox: $(INSTALL_STAMP) ## Run Test in tox environment
154-
$(VENV_BIN)/tox
155-
156153
test: pytest ## Run Standard Tests
157154

158155
.PHONY: pytest tox test

src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/batch.py

Lines changed: 7 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,9 @@ def from_defaults(
137137
job_queue: str,
138138
image: str,
139139
command: str = "",
140-
memory: str = "1024",
141-
vcpus: str = "1",
140+
memory: Union[str, int] = 1024,
141+
vcpus: Union[str, int] = 1,
142+
gpu: Union[str, int] = 0,
142143
environment: Optional[Mapping[str, str]] = None,
143144
mount_point_configs: Optional[List[MountPointConfiguration]] = None,
144145
job_role_arn: Optional[str] = None,
@@ -147,10 +148,10 @@ def from_defaults(
147148
defaults["command"] = command
148149
defaults["job_queue"] = job_queue
149150
defaults["environment"] = environment or {}
150-
defaults["memory"] = memory
151151
defaults["image"] = image
152-
defaults["vcpus"] = vcpus
153-
defaults["gpu"] = "0"
152+
defaults["memory"] = str(memory)
153+
defaults["vcpus"] = str(vcpus)
154+
defaults["gpu"] = str(gpu)
154155
defaults["platform_capabilities"] = ["EC2"]
155156
defaults["job_role_arn"] = job_role_arn or JsonNull.INSTANCE
156157

@@ -170,10 +171,7 @@ def from_defaults(
170171
environment=sfn.JsonPath.string_at("$.request.environment"),
171172
memory=sfn.JsonPath.string_at("$.request.memory"),
172173
vcpus=sfn.JsonPath.string_at("$.request.vcpus"),
173-
# TODO: Handle GPU parameter better - right now, we cannot handle cases where it is
174-
# not specified. Setting to zero causes issues with the Batch API.
175-
# If it is set to zero, then the json list of resources are not properly set.
176-
# gpu=sfn.JsonPath.string_at("$.request.gpu"),
174+
gpu=sfn.JsonPath.string_at("$.request.gpu"),
177175
mount_points=sfn.JsonPath.string_at("$.request.mount_points"),
178176
volumes=sfn.JsonPath.string_at("$.request.volumes"),
179177
platform_capabilities=sfn.JsonPath.string_at("$.request.platform_capabilities"),
@@ -201,70 +199,3 @@ def from_defaults(
201199

202200
submit_job.definition = start.next(merge).next(submit_job.definition)
203201
return submit_job
204-
205-
206-
class SubmitJobWithDefaultsFragment(EnvBaseStateMachineFragment, AWSBatchMixins):
207-
def __init__(
208-
self,
209-
scope: constructs.Construct,
210-
id: str,
211-
env_base: EnvBase,
212-
job_queue: str,
213-
command: str = "",
214-
memory: str = "1024",
215-
vcpus: str = "1",
216-
environment: Optional[Mapping[str, str]] = None,
217-
mount_point_configs: Optional[List[MountPointConfiguration]] = None,
218-
platform_capabilities: Optional[Union[List[Literal["EC2", "FARGATE"]], str]] = None,
219-
job_role_arn: Optional[str] = None,
220-
):
221-
super().__init__(scope, id, env_base)
222-
defaults: dict[str, Any] = {}
223-
defaults["command"] = command
224-
defaults["job_queue"] = job_queue
225-
defaults["environment"] = environment or {}
226-
defaults["memory"] = memory
227-
defaults["vcpus"] = vcpus
228-
defaults["gpu"] = "0"
229-
defaults["platform_capabilities"] = platform_capabilities or ["EC2"]
230-
defaults["job_role_arn"] = job_role_arn or JsonNull.INSTANCE
231-
232-
if mount_point_configs:
233-
mount_points, volumes = self.convert_to_mount_point_and_volumes(mount_point_configs)
234-
defaults["mount_points"] = mount_points
235-
defaults["volumes"] = volumes
236-
237-
start = sfn.Pass(
238-
self,
239-
"Start",
240-
parameters={
241-
"input": sfn.JsonPath.object_at("$"),
242-
"default": defaults,
243-
},
244-
)
245-
merge_chain = CommonOperation.merge_defaults(
246-
self, f"{id}", defaults=defaults, input_path="$.input", result_path="$.request"
247-
)
248-
249-
submit_job = SubmitJobFragment(
250-
self,
251-
"SubmitJobCore",
252-
env_base=self.env_base,
253-
name="SubmitJobCore",
254-
image=sfn.JsonPath.string_at("$.request.image"),
255-
command=sfn.JsonPath.string_at("$.request.command"),
256-
job_queue=sfn.JsonPath.string_at("$.request.job_queue"),
257-
environment=sfn.JsonPath.string_at("$.request.environment"),
258-
memory=sfn.JsonPath.string_at("$.request.memory"),
259-
vcpus=sfn.JsonPath.string_at("$.request.vcpus"),
260-
# TODO: Handle GPU parameter better - right now, we cannot handle cases where it is
261-
# not specified. Setting to zero causes issues with the Batch API.
262-
# If it is set to zero, then the json list of resources are not properly set.
263-
# gpu=sfn.JsonPath.string_at("$.request.gpu"),
264-
mount_points=sfn.JsonPath.string_at("$.request.mount_points"),
265-
volumes=sfn.JsonPath.string_at("$.request.volumes"),
266-
platform_capabilities=sfn.JsonPath.string_at("$.request.platform_capabilities"),
267-
job_role_arn=sfn.JsonPath.string_at("$.request.job_role_arn"),
268-
).to_single_state()
269-
270-
self.definition = start.next(merge_chain).next(submit_job)

src/aibs_informatics_cdk_lib/constructs_/sfn/fragments/informatics/batch.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ def __init__(
7272
environment: Optional[Mapping[str, str]] = None,
7373
memory: Optional[Union[int, str]] = None,
7474
vcpus: Optional[Union[int, str]] = None,
75+
gpu: Optional[Union[int, str]] = None,
7576
mount_points: Optional[Union[List[MountPointTypeDef], str]] = None,
7677
volumes: Optional[Union[List[VolumeTypeDef], str]] = None,
7778
mount_point_configs: Optional[List[MountPointConfiguration]] = None,
@@ -113,6 +114,7 @@ def __init__(
113114
environment (Mapping[str, str] | None): Additional environment variables to specify. These are added to default environment variables.
114115
memory (int | str | None): Memory in MiB (either int or reference path str). Defaults to None.
115116
vcpus (int | str | None): Number of vCPUs (either int or reference path str). Defaults to None.
117+
gpu (int | str | None): Number of GPUs (either int or reference path str). Defaults to None.
116118
mount_points (List[MountPointTypeDef] | None): List of mount points to add to state machine. Defaults to None.
117119
volumes (List[VolumeTypeDef] | None): List of volumes to add to state machine. Defaults to None.
118120
platform_capabilities (List[Literal["EC2", "FARGATE"]] | str | None): platform capabilities to use. This can be a reference path (e.g. "$.platform_capabilities")
@@ -187,6 +189,7 @@ def __init__(
187189
},
188190
memory=memory,
189191
vcpus=vcpus,
192+
gpu=gpu,
190193
mount_points=mount_points or [],
191194
volumes=volumes or [],
192195
platform_capabilities=platform_capabilities,
@@ -228,8 +231,9 @@ def with_defaults(
228231
payload_path: Optional[str] = None,
229232
overrides_path: Optional[str] = None,
230233
command: Optional[List[str]] = None,
231-
memory: str = "1024",
232-
vcpus: str = "1",
234+
memory: Union[int, str] = "1024",
235+
vcpus: Union[int, str] = "1",
236+
gpu: Union[int, str] = "0",
233237
environment: Optional[Mapping[str, str]] = None,
234238
mount_point_configs: Optional[List[MountPointConfiguration]] = None,
235239
platform_capabilities: Optional[List[Literal["EC2", "FARGATE"]]] = None,
@@ -238,8 +242,9 @@ def with_defaults(
238242
defaults: dict[str, Any] = {}
239243

240244
defaults["job_queue"] = job_queue
241-
defaults["memory"] = memory
242-
defaults["vcpus"] = vcpus
245+
defaults["memory"] = str(memory)
246+
defaults["vcpus"] = str(vcpus)
247+
defaults["gpu"] = str(gpu)
243248
defaults["environment"] = environment or {}
244249
defaults["platform_capabilities"] = platform_capabilities or ["EC2"]
245250
defaults["bucket_name"] = bucket_name
@@ -270,7 +275,7 @@ def with_defaults(
270275
# TODO: Handle GPU parameter better - right now, we cannot handle cases where it is
271276
# not specified. Setting to zero causes issues with the Batch API.
272277
# If it is set to zero, then the json list of resources are not properly set.
273-
# gpu=sfn.JsonPath.string_at("$.merged.gpu"),
278+
gpu=sfn.JsonPath.string_at("$.merged.gpu"),
274279
mount_points=sfn.JsonPath.string_at("$.merged.mount_points"),
275280
volumes=sfn.JsonPath.string_at("$.merged.volumes"),
276281
platform_capabilities=sfn.JsonPath.string_at("$.merged.platform_capabilities"),

src/aibs_informatics_cdk_lib/constructs_/sfn/states/batch.py

Lines changed: 22 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,8 @@ def register_job_definition(
3939
memory: Optional[Union[int, str]] = None,
4040
vcpus: Optional[Union[int, str]] = None,
4141
gpu: Optional[Union[int, str]] = None,
42-
# mount_points: Optional[List[MountPointTypeDef]] = None,
4342
mount_points: Optional[Union[List[MountPointTypeDef], str]] = None,
44-
# volumes: Optional[List[VolumeTypeDef]] = None,
4543
volumes: Optional[Union[List[VolumeTypeDef], str]] = None,
46-
# platform_capabilities: Optional[List[Literal["EC2", "FARGATE"]]] = None,
4744
platform_capabilities: Optional[Union[List[Literal["EC2", "FARGATE"]], str]] = None,
4845
result_path: Optional[str] = "$",
4946
output_path: Optional[str] = "$",
@@ -147,7 +144,17 @@ def register_job_definition(
147144
],
148145
},
149146
)
150-
return start.next(register)
147+
chain = start
148+
if gpu is not None:
149+
chain = chain.next(
150+
sfn.Pass(
151+
scope,
152+
id + " Register Definition Filter Resource Requirements",
153+
input_path=f"{result_path or '$'}.ContainerProperties.ResourceRequirements[?(@.Value != 0 && @.Value != '0')]",
154+
result_path=f"{result_path or '$'}.ContainerProperties.ResourceRequirements",
155+
)
156+
)
157+
return chain.next(register)
151158

152159
@classmethod
153160
def submit_job(
@@ -223,7 +230,17 @@ def submit_job(
223230
],
224231
},
225232
)
226-
return start.next(submit)
233+
chain = start
234+
if gpu is not None:
235+
chain = chain.next(
236+
sfn.Pass(
237+
scope,
238+
id + " SubmitJob Filter Resource Requirements",
239+
input_path=f"{result_path or '$'}.ContainerOverrides.ResourceRequirements[?(@.Value != 0 && @.Value != '0')]",
240+
result_path=f"{result_path or '$'}.ContainerOverrides.ResourceRequirements",
241+
)
242+
)
243+
return chain.next(submit)
227244

228245
@classmethod
229246
def deregister_job_definition(

test/aibs_informatics_cdk_lib/constructs_/sfn/__init__.py

Whitespace-only changes.

test/aibs_informatics_cdk_lib/constructs_/sfn/states/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)