10
10
from airflow .executors .base_executor import BaseExecutor
11
11
from airflow .utils .module_loading import import_string
12
12
from airflow .utils .state import State
13
- from marshmallow import Schema , fields , post_load
13
+ from marshmallow import EXCLUDE , Schema , ValidationError , fields , post_load
14
14
15
15
CommandType = List [str ]
16
16
TaskInstanceKeyType = Tuple [Any ]
17
17
ExecutorConfigFunctionType = Callable [[CommandType ], dict ]
18
- EcsFargateQueuedTask = namedtuple ('EcsFargateQueuedTask' , ('key' , 'command' , 'executor_config' ))
18
+ EcsFargateQueuedTask = namedtuple ('EcsFargateQueuedTask' , ('key' , 'command' , 'queue' , ' executor_config' ))
19
19
ExecutorConfigType = Dict [str , Any ]
20
- EcsFargateTaskInfo = namedtuple ('EcsFargateTaskInfo' , ('cmd' , 'config' ))
20
+ EcsFargateTaskInfo = namedtuple ('EcsFargateTaskInfo' , ('cmd' , 'queue' , ' config' ))
21
21
22
22
23
23
class EcsFargateTask :
@@ -147,17 +147,18 @@ def __describe_tasks(self, task_arns):
147
147
for i in range ((len (task_arns ) // self .DESCRIBE_TASKS_BATCH_SIZE ) + 1 ):
148
148
batched_task_arns = task_arns [i * self .DESCRIBE_TASKS_BATCH_SIZE : (i + 1 ) * self .DESCRIBE_TASKS_BATCH_SIZE ]
149
149
boto_describe_tasks = self .ecs .describe_tasks (tasks = batched_task_arns , cluster = self .cluster )
150
- describe_tasks_response = BotoDescribeTasksSchema ().load (boto_describe_tasks )
151
- if describe_tasks_response .errors :
150
+ try :
151
+ describe_tasks_response = BotoDescribeTasksSchema ().load (boto_describe_tasks )
152
+ except ValidationError as err :
152
153
self .log .error ('ECS DescribeTask Response: %s' , boto_describe_tasks )
153
154
raise EcsFargateError (
154
155
'DescribeTasks API call does not match expected JSON shape. '
155
156
'Are you sure that the correct version of Boto3 is installed? {}' .format (
156
- describe_tasks_response . errors
157
+ err
157
158
)
158
159
)
159
- all_task_descriptions ['tasks' ].extend (describe_tasks_response . data ['tasks' ])
160
- all_task_descriptions ['failures' ].extend (describe_tasks_response . data ['failures' ])
160
+ all_task_descriptions ['tasks' ].extend (describe_tasks_response ['tasks' ])
161
+ all_task_descriptions ['failures' ].extend (describe_tasks_response ['failures' ])
161
162
return all_task_descriptions
162
163
163
164
def __handle_failed_task (self , task_arn : str , reason : str ):
@@ -166,14 +167,14 @@ def __handle_failed_task(self, task_arn: str, reason: str):
166
167
ECS/Fargate Cloud. If an API failure occurs the task is simply rescheduled.
167
168
"""
168
169
task_key = self .active_workers .arn_to_key [task_arn ]
169
- task_cmd , exec_info = self .active_workers .info_by_key (task_key )
170
+ task_cmd , queue , exec_info = self .active_workers .info_by_key (task_key )
170
171
failure_count = self .active_workers .failure_count_by_key (task_key )
171
172
if failure_count < self .__class__ .MAX_FAILURE_CHECKS :
172
173
self .log .warning ('Task %s has failed due to %s. '
173
174
'Failure %s out of %s occurred on %s. Rescheduling.' ,
174
175
task_key , reason , failure_count , self .__class__ .MAX_FAILURE_CHECKS , task_arn )
175
176
self .active_workers .increment_failure_count (task_key )
176
- self .pending_tasks .appendleft (EcsFargateQueuedTask (task_key , task_cmd , exec_info ))
177
+ self .pending_tasks .appendleft (EcsFargateQueuedTask (task_key , task_cmd , queue , exec_info ))
177
178
else :
178
179
self .log .error ('Task %s has failed a maximum of %s times. Marking as failed' , task_key ,
179
180
failure_count )
@@ -192,8 +193,8 @@ def attempt_task_runs(self):
192
193
failure_reasons = defaultdict (int )
193
194
for _ in range (queue_len ):
194
195
ecs_task = self .pending_tasks .popleft ()
195
- task_key , cmd , exec_config = ecs_task
196
- run_task_response = self .__run_task ( cmd , exec_config )
196
+ task_key , cmd , queue , exec_config = ecs_task
197
+ run_task_response = self ._run_task ( task_key , cmd , queue , exec_config )
197
198
if run_task_response ['failures' ]:
198
199
for f in run_task_response ['failures' ]:
199
200
failure_reasons [f ['reason' ]] += 1
@@ -203,39 +204,53 @@ def attempt_task_runs(self):
203
204
raise EcsFargateError ('No failures and no tasks provided in response. This should never happen.' )
204
205
else :
205
206
task = run_task_response ['tasks' ][0 ]
206
- self .active_workers .add_task (task , task_key , cmd , exec_config )
207
+ self .active_workers .add_task (task , task_key , queue , cmd , exec_config )
207
208
if failure_reasons :
208
209
self .log .debug ('Pending tasks failed to launch for the following reasons: %s. Will retry later.' ,
209
210
dict (failure_reasons ))
210
211
211
- def __run_task (self , cmd : CommandType , exec_config : ExecutorConfigType ):
212
+ def _run_task (self , task_id : TaskInstanceKeyType , cmd : CommandType , queue : str , exec_config : ExecutorConfigType ):
212
213
"""
214
+ This function is the actual attempt to run a queued-up airflow task. Not to be confused with
215
+ execute_async() which inserts tasks into the queue.
213
216
The command and executor config will be placed in the container-override section of the JSON request, before
214
217
calling Boto3's "run_task" function.
215
218
"""
216
- run_task_api = deepcopy (self .run_task_kwargs )
217
- container_override = self .get_container (run_task_api ['overrides' ]['containerOverrides' ])
218
- container_override ['command' ] = cmd
219
- container_override .update (exec_config )
219
+ run_task_api = self ._run_task_kwargs (task_id , cmd , queue , exec_config )
220
220
boto_run_task = self .ecs .run_task (** run_task_api )
221
- run_task_response = BotoRunTaskSchema ().load (boto_run_task )
222
- if run_task_response .errors :
223
- self .log .error ('ECS RunTask Response: %s' , run_task_response )
221
+ try :
222
+ run_task_response = BotoRunTaskSchema ().load (boto_run_task )
223
+ except ValidationError as err :
224
+ self .log .error ('ECS RunTask Response: %s' , err )
224
225
raise EcsFargateError (
225
226
'RunTask API call does not match expected JSON shape. '
226
227
'Are you sure that the correct version of Boto3 is installed? {}' .format (
227
- run_task_response . errors
228
+ err
228
229
)
229
230
)
230
- return run_task_response .data
231
+ return run_task_response
232
+
233
+ def _run_task_kwargs (self , task_id : TaskInstanceKeyType , cmd : CommandType ,
234
+ queue : str , exec_config : ExecutorConfigType ) -> dict :
235
+ """
236
+ This modifies the standard kwargs to be specific to this task by overriding the airflow command and updating
237
+ the container overrides.
238
+
239
+ One last chance to modify Boto3's "run_task" kwarg params before it gets passed into the Boto3 client.
240
+ """
241
+ run_task_api = deepcopy (self .run_task_kwargs )
242
+ container_override = self .get_container (run_task_api ['overrides' ]['containerOverrides' ])
243
+ container_override ['command' ] = cmd
244
+ container_override .update (exec_config )
245
+ return run_task_api
231
246
232
247
def execute_async (self , key : TaskInstanceKeyType , command : CommandType , queue = None , executor_config = None ):
233
248
"""
234
- Save the task to be executed in the next sync using Boto3's RunTask API
249
+ Save the task to be executed in the next sync by inserting the commands into a queue.
235
250
"""
236
251
if executor_config and ('name' in executor_config or 'command' in executor_config ):
237
252
raise ValueError ('Executor Config should never override "name" or "command"' )
238
- self .pending_tasks .append (EcsFargateQueuedTask (key , command , executor_config or {}))
253
+ self .pending_tasks .append (EcsFargateQueuedTask (key , command , queue , executor_config or {}))
239
254
240
255
def end (self , heartbeat_interval = 10 ):
241
256
"""
@@ -298,14 +313,14 @@ def __init__(self):
298
313
self .key_to_failure_counts : Dict [TaskInstanceKeyType , int ] = defaultdict (int )
299
314
self .key_to_task_info : Dict [TaskInstanceKeyType , EcsFargateTaskInfo ] = {}
300
315
301
- def add_task (self , task : EcsFargateTask , airflow_task_key : TaskInstanceKeyType , airflow_cmd : CommandType ,
302
- exec_config : ExecutorConfigType ):
316
+ def add_task (self , task : EcsFargateTask , airflow_task_key : TaskInstanceKeyType , queue : str ,
317
+ airflow_cmd : CommandType , exec_config : ExecutorConfigType ):
303
318
"""Adds a task to the collection"""
304
319
arn = task .task_arn
305
320
self .tasks [arn ] = task
306
321
self .key_to_arn [airflow_task_key ] = arn
307
322
self .arn_to_key [arn ] = airflow_task_key
308
- self .key_to_task_info [airflow_task_key ] = EcsFargateTaskInfo (airflow_cmd , exec_config )
323
+ self .key_to_task_info [airflow_task_key ] = EcsFargateTaskInfo (airflow_cmd , queue , exec_config )
309
324
310
325
def update_task (self , task : EcsFargateTask ):
311
326
"""Updates the state of the given task based on task ARN"""
@@ -366,28 +381,34 @@ class BotoContainerSchema(Schema):
366
381
Botocore Serialization Object for ECS 'Container' shape.
367
382
Note that there are many more parameters, but the executor only needs the members listed below.
368
383
"""
369
- exit_code = fields .Integer (load_from = 'exitCode' )
370
- last_status = fields .String (load_from = 'lastStatus' )
384
+ exit_code = fields .Integer (data_key = 'exitCode' )
385
+ last_status = fields .String (data_key = 'lastStatus' )
371
386
name = fields .String (required = True )
372
387
388
+ class Meta :
389
+ unknown = EXCLUDE
390
+
373
391
374
392
class BotoTaskSchema (Schema ):
375
393
"""
376
394
Botocore Serialization Object for ECS 'Task' shape.
377
395
Note that there are many more parameters, but the executor only needs the members listed below.
378
396
"""
379
- task_arn = fields .String (load_from = 'taskArn' , required = True )
380
- last_status = fields .String (load_from = 'lastStatus' , required = True )
381
- desired_status = fields .String (load_from = 'desiredStatus' , required = True )
397
+ task_arn = fields .String (data_key = 'taskArn' , required = True )
398
+ last_status = fields .String (data_key = 'lastStatus' , required = True )
399
+ desired_status = fields .String (data_key = 'desiredStatus' , required = True )
382
400
containers = fields .List (fields .Nested (BotoContainerSchema ), required = True )
383
- started_at = fields .Field (load_from = 'startedAt' )
384
- stopped_reason = fields .String (load_from = 'stoppedReason' )
401
+ started_at = fields .Field (data_key = 'startedAt' )
402
+ stopped_reason = fields .String (data_key = 'stoppedReason' )
385
403
386
404
@post_load
387
405
def make_task (self , data , ** kwargs ):
388
- """Overwrites marshmallow .data property to return an instance of EcsFargateTask instead of a dictionary"""
406
+ """Overwrites marshmallow load() to return an instance of EcsFargateTask instead of a dictionary"""
389
407
return EcsFargateTask (** data )
390
408
409
+ class Meta :
410
+ unknown = EXCLUDE
411
+
391
412
392
413
class BotoFailureSchema (Schema ):
393
414
"""
@@ -396,6 +417,9 @@ class BotoFailureSchema(Schema):
396
417
arn = fields .String ()
397
418
reason = fields .String ()
398
419
420
+ class Meta :
421
+ unknown = EXCLUDE
422
+
399
423
400
424
class BotoRunTaskSchema (Schema ):
401
425
"""
@@ -404,6 +428,9 @@ class BotoRunTaskSchema(Schema):
404
428
tasks = fields .List (fields .Nested (BotoTaskSchema ), required = True )
405
429
failures = fields .List (fields .Nested (BotoFailureSchema ), required = True )
406
430
431
+ class Meta :
432
+ unknown = EXCLUDE
433
+
407
434
408
435
class BotoDescribeTasksSchema (Schema ):
409
436
"""
@@ -412,6 +439,9 @@ class BotoDescribeTasksSchema(Schema):
412
439
tasks = fields .List (fields .Nested (BotoTaskSchema ), required = True )
413
440
failures = fields .List (fields .Nested (BotoFailureSchema ), required = True )
414
441
442
+ class Meta :
443
+ unknown = EXCLUDE
444
+
415
445
416
446
class EcsFargateError (Exception ):
417
447
"""Thrown when something unexpected has occurred within the AWS ECS/Fargate ecosystem"""
0 commit comments