@@ -84,6 +84,8 @@ class BatchDecorator(StepDecorator):
84
84
Path to tmpfs mount for this step. Defaults to /metaflow_temp.
85
85
inferentia : int, default 0
86
86
Number of Inferentia chips required for this step.
87
+ trainium : int, default None
88
+ Alias for inferentia. Use only one of the two.
87
89
efa : int, default 0
88
90
Number of elastic fabric adapter network devices to attach to container
89
91
ephemeral_storage: int, default None
@@ -104,6 +106,7 @@ class BatchDecorator(StepDecorator):
104
106
"max_swap" : None ,
105
107
"swappiness" : None ,
106
108
"inferentia" : None ,
109
+ "trainium" : None , # alias for inferentia
107
110
"efa" : None ,
108
111
"host_volumes" : None ,
109
112
"efs_volumes" : None ,
@@ -151,6 +154,21 @@ def __init__(self, attributes=None, statically_defined=False):
151
154
self .attributes ["image" ],
152
155
)
153
156
157
+ # Alias trainium to inferentia and check that both are not in use.
158
+ if (
159
+ self .attributes ["inferentia" ] is not None
160
+ and self .attributes ["trainium" ] is not None
161
+ ):
162
+ raise BatchException (
163
+ "only specify a value for 'inferentia' or 'trainium', not both."
164
+ )
165
+
166
+ if self .attributes ["trainium" ] is not None :
167
+ self .attributes ["inferentia" ] = self .attributes ["trainium" ]
168
+
169
+ # clean up the alias attribute so it is not passed on.
170
+ self .attributes .pop ("trainium" , None )
171
+
154
172
# Refer https://github.com/Netflix/metaflow/blob/master/docs/lifecycle.png
155
173
# to understand where these functions are invoked in the lifecycle of a
156
174
# Metaflow flow.
0 commit comments