Skip to content

Commit 91de0b6

Browse files
authored
add 'trainium' alias for inferentia to batch decorator (#1762)
1 parent 7e265ac commit 91de0b6

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

metaflow/plugins/aws/batch/batch_decorator.py

+18
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,8 @@ class BatchDecorator(StepDecorator):
8484
Path to tmpfs mount for this step. Defaults to /metaflow_temp.
8585
inferentia : int, default 0
8686
Number of Inferentia chips required for this step.
87+
trainium : int, default None
88+
Alias for inferentia. Use only one of the two.
8789
efa : int, default 0
8890
Number of elastic fabric adapter network devices to attach to container
8991
ephemeral_storage: int, default None
@@ -104,6 +106,7 @@ class BatchDecorator(StepDecorator):
104106
"max_swap": None,
105107
"swappiness": None,
106108
"inferentia": None,
109+
"trainium": None, # alias for inferentia
107110
"efa": None,
108111
"host_volumes": None,
109112
"efs_volumes": None,
@@ -151,6 +154,21 @@ def __init__(self, attributes=None, statically_defined=False):
151154
self.attributes["image"],
152155
)
153156

157+
# Alias trainium to inferentia and check that both are not in use.
158+
if (
159+
self.attributes["inferentia"] is not None
160+
and self.attributes["trainium"] is not None
161+
):
162+
raise BatchException(
163+
"only specify a value for 'inferentia' or 'trainium', not both."
164+
)
165+
166+
if self.attributes["trainium"] is not None:
167+
self.attributes["inferentia"] = self.attributes["trainium"]
168+
169+
# clean up the alias attribute so it is not passed on.
170+
self.attributes.pop("trainium", None)
171+
154172
# Refer https://github.com/Netflix/metaflow/blob/master/docs/lifecycle.png
155173
# to understand where these functions are invoked in the lifecycle of a
156174
# Metaflow flow.

0 commit comments

Comments
 (0)