Skip to content

Commit 61daa98

Browse files
authored
Merge pull request #6 from aelzeiny/add-executor-queues-to-fargate
Add Executor Queues to Fargate, and improve APIs
2 parents f530c46 + bae3ba2 commit 61daa98

7 files changed

+111
-81
lines changed

.travis.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ python:
44
- "3.7"
55
- "3.8"
66
install:
7-
- pip install apache-airflow boto3 pylint isort
7+
- pip install apache-airflow boto3 pylint isort marshmallow
88
env:
99
- AIRFLOW__BATCH__REGION=us-west-1 AIRFLOW__BATCH__JOB_NAME=some-job-name AIRFLOW__BATCH__JOB_QUEUE=some-job-queue AIRFLOW__BATCH__JOB_DEFINITION=some-job-def AIRFLOW__ECS_FARGATE__REGION=us-west-1 AIRFLOW__ECS_FARGATE__CLUSTER=some-cluster AIRFLOW__ECS_FARGATE__CONTAINER_NAME=some-container-name AIRFLOW__ECS_FARGATE__TASK_DEFINITION=some-task-def AIRFLOW__ECS_FARGATE__LAUNCH_TYPE=FARGATE AIRFLOW__ECS_FARGATE__PLATFORM_VERSION=LATEST AIRFLOW__ECS_FARGATE__ASSIGN_PUBLIC_IP=DISABLED AIRFLOW__ECS_FARGATE__SECURITY_GROUPS=SG1,SG2 AIRFLOW__ECS_FARGATE__SUBNETS=SUB1,SUB2
1010
script:

airflow_aws_executors/batch_executor.py

+25-14
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from airflow.executors.base_executor import BaseExecutor
1010
from airflow.utils.module_loading import import_string
1111
from airflow.utils.state import State
12-
from marshmallow import Schema, fields, post_load
12+
from marshmallow import EXCLUDE, Schema, ValidationError, fields, post_load
1313

1414
CommandType = List[str]
1515
TaskInstanceKeyType = Tuple[Any]
@@ -105,16 +105,17 @@ def _describe_tasks(self, job_ids) -> List[BatchJob]:
105105
for i in range((len(job_ids) // max_batch_size) + 1):
106106
batched_job_ids = job_ids[i * max_batch_size: (i + 1) * max_batch_size]
107107
boto_describe_tasks = self.batch.describe_jobs(jobs=batched_job_ids)
108-
describe_tasks_response = BatchDescribeJobsResponseSchema().load(boto_describe_tasks)
109-
if describe_tasks_response.errors:
108+
try:
109+
describe_tasks_response = BatchDescribeJobsResponseSchema().load(boto_describe_tasks)
110+
except ValidationError as err:
110111
self.log.error('Batch DescribeJobs API Response: %s', boto_describe_tasks)
111112
raise BatchError(
112113
'DescribeJobs API call does not match expected JSON shape. '
113114
'Are you sure that the correct version of Boto3 is installed? {}'.format(
114-
describe_tasks_response.errors
115+
err
115116
)
116117
)
117-
all_jobs.extend(describe_tasks_response.data['jobs'])
118+
all_jobs.extend(describe_tasks_response['jobs'])
118119
return all_jobs
119120

120121
def execute_async(self, key: TaskInstanceKeyType, command: CommandType, queue=None, executor_config=None):
@@ -135,16 +136,17 @@ def _submit_job(self, cmd: CommandType, exec_config: ExecutorConfigType) -> str:
135136
submit_job_api['containerOverrides'].update(exec_config)
136137
submit_job_api['containerOverrides']['command'] = cmd
137138
boto_run_task = self.batch.submit_job(**submit_job_api)
138-
submit_job_response = BatchSubmitJobResponseSchema().load(boto_run_task)
139-
if submit_job_response.errors:
140-
self.log.error('Batch SubmitJob Response: %s', submit_job_response)
139+
try:
140+
submit_job_response = BatchSubmitJobResponseSchema().load(boto_run_task)
141+
except ValidationError as err:
142+
self.log.error('Batch SubmitJob Response: %s', err)
141143
raise BatchError(
142144
'RunTask API call does not match expected JSON shape. '
143145
'Are you sure that the correct version of Boto3 is installed? {}'.format(
144-
submit_job_response.errors
146+
err
145147
)
146148
)
147-
return submit_job_response.data['job_id']
149+
return submit_job_response['job_id']
148150

149151
def end(self, heartbeat_interval=10):
150152
"""
@@ -213,29 +215,38 @@ def __len__(self):
213215
class BatchSubmitJobResponseSchema(Schema):
214216
"""API Response for SubmitJob"""
215217
# The unique identifier for the job.
216-
job_id = fields.String(load_from='jobId', required=True)
218+
job_id = fields.String(data_key='jobId', required=True)
219+
220+
class Meta:
221+
unknown = EXCLUDE
217222

218223

219224
class BatchJobDetailSchema(Schema):
220225
"""API Response for Describe Jobs"""
221226
# The unique identifier for the job.
222-
job_id = fields.String(load_from='jobId', required=True)
227+
job_id = fields.String(data_key='jobId', required=True)
223228
# The current status for the job: 'SUBMITTED', 'PENDING', 'RUNNABLE', 'STARTING', 'RUNNING', 'SUCCEEDED', 'FAILED'
224229
status = fields.String(required=True)
225230
# A short, human-readable string to provide additional details about the current status of the job.
226-
status_reason = fields.String(load_from='statusReason')
231+
status_reason = fields.String(data_key='statusReason')
227232

228233
@post_load
229234
def make_job(self, data, **kwargs):
230-
"""Overwrites marshmallow data property to return an instance of BatchJob instead of a dictionary"""
235+
"""Overwrites marshmallow load() to return an instance of BatchJob instead of a dictionary"""
231236
return BatchJob(**data)
232237

238+
class Meta:
239+
unknown = EXCLUDE
240+
233241

234242
class BatchDescribeJobsResponseSchema(Schema):
235243
"""API Response for Describe Jobs"""
236244
# The list of jobs
237245
jobs = fields.List(fields.Nested(BatchJobDetailSchema), required=True)
238246

247+
class Meta:
248+
unknown = EXCLUDE
249+
239250

240251
class BatchError(Exception):
241252
"""Thrown when something unexpected has occurred within the AWS Batch ecosystem"""

airflow_aws_executors/conf.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,8 @@
2525
from airflow.configuration import conf
2626

2727

28-
def has_option(section, config_name):
28+
def has_option(section, config_name) -> bool:
29+
"""Returns True if configuration has a section and an option"""
2930
if conf.has_option(section, config_name):
3031
config_val = conf.get(section, config_name)
3132
return config_val is not None and config_val != ''

airflow_aws_executors/ecs_fargate_executor.py

+66-36
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,14 @@
1010
from airflow.executors.base_executor import BaseExecutor
1111
from airflow.utils.module_loading import import_string
1212
from airflow.utils.state import State
13-
from marshmallow import Schema, fields, post_load
13+
from marshmallow import EXCLUDE, Schema, ValidationError, fields, post_load
1414

1515
CommandType = List[str]
1616
TaskInstanceKeyType = Tuple[Any]
1717
ExecutorConfigFunctionType = Callable[[CommandType], dict]
18-
EcsFargateQueuedTask = namedtuple('EcsFargateQueuedTask', ('key', 'command', 'executor_config'))
18+
EcsFargateQueuedTask = namedtuple('EcsFargateQueuedTask', ('key', 'command', 'queue', 'executor_config'))
1919
ExecutorConfigType = Dict[str, Any]
20-
EcsFargateTaskInfo = namedtuple('EcsFargateTaskInfo', ('cmd', 'config'))
20+
EcsFargateTaskInfo = namedtuple('EcsFargateTaskInfo', ('cmd', 'queue', 'config'))
2121

2222

2323
class EcsFargateTask:
@@ -147,17 +147,18 @@ def __describe_tasks(self, task_arns):
147147
for i in range((len(task_arns) // self.DESCRIBE_TASKS_BATCH_SIZE) + 1):
148148
batched_task_arns = task_arns[i * self.DESCRIBE_TASKS_BATCH_SIZE: (i + 1) * self.DESCRIBE_TASKS_BATCH_SIZE]
149149
boto_describe_tasks = self.ecs.describe_tasks(tasks=batched_task_arns, cluster=self.cluster)
150-
describe_tasks_response = BotoDescribeTasksSchema().load(boto_describe_tasks)
151-
if describe_tasks_response.errors:
150+
try:
151+
describe_tasks_response = BotoDescribeTasksSchema().load(boto_describe_tasks)
152+
except ValidationError as err:
152153
self.log.error('ECS DescribeTask Response: %s', boto_describe_tasks)
153154
raise EcsFargateError(
154155
'DescribeTasks API call does not match expected JSON shape. '
155156
'Are you sure that the correct version of Boto3 is installed? {}'.format(
156-
describe_tasks_response.errors
157+
err
157158
)
158159
)
159-
all_task_descriptions['tasks'].extend(describe_tasks_response.data['tasks'])
160-
all_task_descriptions['failures'].extend(describe_tasks_response.data['failures'])
160+
all_task_descriptions['tasks'].extend(describe_tasks_response['tasks'])
161+
all_task_descriptions['failures'].extend(describe_tasks_response['failures'])
161162
return all_task_descriptions
162163

163164
def __handle_failed_task(self, task_arn: str, reason: str):
@@ -166,14 +167,14 @@ def __handle_failed_task(self, task_arn: str, reason: str):
166167
ECS/Fargate Cloud. If an API failure occurs the task is simply rescheduled.
167168
"""
168169
task_key = self.active_workers.arn_to_key[task_arn]
169-
task_cmd, exec_info = self.active_workers.info_by_key(task_key)
170+
task_cmd, queue, exec_info = self.active_workers.info_by_key(task_key)
170171
failure_count = self.active_workers.failure_count_by_key(task_key)
171172
if failure_count < self.__class__.MAX_FAILURE_CHECKS:
172173
self.log.warning('Task %s has failed due to %s. '
173174
'Failure %s out of %s occurred on %s. Rescheduling.',
174175
task_key, reason, failure_count, self.__class__.MAX_FAILURE_CHECKS, task_arn)
175176
self.active_workers.increment_failure_count(task_key)
176-
self.pending_tasks.appendleft(EcsFargateQueuedTask(task_key, task_cmd, exec_info))
177+
self.pending_tasks.appendleft(EcsFargateQueuedTask(task_key, task_cmd, queue, exec_info))
177178
else:
178179
self.log.error('Task %s has failed a maximum of %s times. Marking as failed', task_key,
179180
failure_count)
@@ -192,8 +193,8 @@ def attempt_task_runs(self):
192193
failure_reasons = defaultdict(int)
193194
for _ in range(queue_len):
194195
ecs_task = self.pending_tasks.popleft()
195-
task_key, cmd, exec_config = ecs_task
196-
run_task_response = self.__run_task(cmd, exec_config)
196+
task_key, cmd, queue, exec_config = ecs_task
197+
run_task_response = self._run_task(task_key, cmd, queue, exec_config)
197198
if run_task_response['failures']:
198199
for f in run_task_response['failures']:
199200
failure_reasons[f['reason']] += 1
@@ -203,39 +204,53 @@ def attempt_task_runs(self):
203204
raise EcsFargateError('No failures and no tasks provided in response. This should never happen.')
204205
else:
205206
task = run_task_response['tasks'][0]
206-
self.active_workers.add_task(task, task_key, cmd, exec_config)
207+
self.active_workers.add_task(task, task_key, queue, cmd, exec_config)
207208
if failure_reasons:
208209
self.log.debug('Pending tasks failed to launch for the following reasons: %s. Will retry later.',
209210
dict(failure_reasons))
210211

211-
def __run_task(self, cmd: CommandType, exec_config: ExecutorConfigType):
212+
def _run_task(self, task_id: TaskInstanceKeyType, cmd: CommandType, queue: str, exec_config: ExecutorConfigType):
212213
"""
214+
This function is the actual attempt to run a queued-up airflow task. Not to be confused with
215+
execute_async() which inserts tasks into the queue.
213216
The command and executor config will be placed in the container-override section of the JSON request, before
214217
calling Boto3's "run_task" function.
215218
"""
216-
run_task_api = deepcopy(self.run_task_kwargs)
217-
container_override = self.get_container(run_task_api['overrides']['containerOverrides'])
218-
container_override['command'] = cmd
219-
container_override.update(exec_config)
219+
run_task_api = self._run_task_kwargs(task_id, cmd, queue, exec_config)
220220
boto_run_task = self.ecs.run_task(**run_task_api)
221-
run_task_response = BotoRunTaskSchema().load(boto_run_task)
222-
if run_task_response.errors:
223-
self.log.error('ECS RunTask Response: %s', run_task_response)
221+
try:
222+
run_task_response = BotoRunTaskSchema().load(boto_run_task)
223+
except ValidationError as err:
224+
self.log.error('ECS RunTask Response: %s', err)
224225
raise EcsFargateError(
225226
'RunTask API call does not match expected JSON shape. '
226227
'Are you sure that the correct version of Boto3 is installed? {}'.format(
227-
run_task_response.errors
228+
err
228229
)
229230
)
230-
return run_task_response.data
231+
return run_task_response
232+
233+
def _run_task_kwargs(self, task_id: TaskInstanceKeyType, cmd: CommandType,
234+
queue: str, exec_config: ExecutorConfigType) -> dict:
235+
"""
236+
This modifies the standard kwargs to be specific to this task by overriding the airflow command and updating
237+
the container overrides.
238+
239+
One last chance to modify Boto3's "run_task" kwarg params before it gets passed into the Boto3 client.
240+
"""
241+
run_task_api = deepcopy(self.run_task_kwargs)
242+
container_override = self.get_container(run_task_api['overrides']['containerOverrides'])
243+
container_override['command'] = cmd
244+
container_override.update(exec_config)
245+
return run_task_api
231246

232247
def execute_async(self, key: TaskInstanceKeyType, command: CommandType, queue=None, executor_config=None):
233248
"""
234-
Save the task to be executed in the next sync using Boto3's RunTask API
249+
Save the task to be executed in the next sync by inserting the commands into a queue.
235250
"""
236251
if executor_config and ('name' in executor_config or 'command' in executor_config):
237252
raise ValueError('Executor Config should never override "name" or "command"')
238-
self.pending_tasks.append(EcsFargateQueuedTask(key, command, executor_config or {}))
253+
self.pending_tasks.append(EcsFargateQueuedTask(key, command, queue, executor_config or {}))
239254

240255
def end(self, heartbeat_interval=10):
241256
"""
@@ -298,14 +313,14 @@ def __init__(self):
298313
self.key_to_failure_counts: Dict[TaskInstanceKeyType, int] = defaultdict(int)
299314
self.key_to_task_info: Dict[TaskInstanceKeyType, EcsFargateTaskInfo] = {}
300315

301-
def add_task(self, task: EcsFargateTask, airflow_task_key: TaskInstanceKeyType, airflow_cmd: CommandType,
302-
exec_config: ExecutorConfigType):
316+
def add_task(self, task: EcsFargateTask, airflow_task_key: TaskInstanceKeyType, queue: str,
317+
airflow_cmd: CommandType, exec_config: ExecutorConfigType):
303318
"""Adds a task to the collection"""
304319
arn = task.task_arn
305320
self.tasks[arn] = task
306321
self.key_to_arn[airflow_task_key] = arn
307322
self.arn_to_key[arn] = airflow_task_key
308-
self.key_to_task_info[airflow_task_key] = EcsFargateTaskInfo(airflow_cmd, exec_config)
323+
self.key_to_task_info[airflow_task_key] = EcsFargateTaskInfo(airflow_cmd, queue, exec_config)
309324

310325
def update_task(self, task: EcsFargateTask):
311326
"""Updates the state of the given task based on task ARN"""
@@ -366,28 +381,34 @@ class BotoContainerSchema(Schema):
366381
Botocore Serialization Object for ECS 'Container' shape.
367382
Note that there are many more parameters, but the executor only needs the members listed below.
368383
"""
369-
exit_code = fields.Integer(load_from='exitCode')
370-
last_status = fields.String(load_from='lastStatus')
384+
exit_code = fields.Integer(data_key='exitCode')
385+
last_status = fields.String(data_key='lastStatus')
371386
name = fields.String(required=True)
372387

388+
class Meta:
389+
unknown = EXCLUDE
390+
373391

374392
class BotoTaskSchema(Schema):
375393
"""
376394
Botocore Serialization Object for ECS 'Task' shape.
377395
Note that there are many more parameters, but the executor only needs the members listed below.
378396
"""
379-
task_arn = fields.String(load_from='taskArn', required=True)
380-
last_status = fields.String(load_from='lastStatus', required=True)
381-
desired_status = fields.String(load_from='desiredStatus', required=True)
397+
task_arn = fields.String(data_key='taskArn', required=True)
398+
last_status = fields.String(data_key='lastStatus', required=True)
399+
desired_status = fields.String(data_key='desiredStatus', required=True)
382400
containers = fields.List(fields.Nested(BotoContainerSchema), required=True)
383-
started_at = fields.Field(load_from='startedAt')
384-
stopped_reason = fields.String(load_from='stoppedReason')
401+
started_at = fields.Field(data_key='startedAt')
402+
stopped_reason = fields.String(data_key='stoppedReason')
385403

386404
@post_load
387405
def make_task(self, data, **kwargs):
388-
"""Overwrites marshmallow .data property to return an instance of EcsFargateTask instead of a dictionary"""
406+
"""Overwrites marshmallow load() to return an instance of EcsFargateTask instead of a dictionary"""
389407
return EcsFargateTask(**data)
390408

409+
class Meta:
410+
unknown = EXCLUDE
411+
391412

392413
class BotoFailureSchema(Schema):
393414
"""
@@ -396,6 +417,9 @@ class BotoFailureSchema(Schema):
396417
arn = fields.String()
397418
reason = fields.String()
398419

420+
class Meta:
421+
unknown = EXCLUDE
422+
399423

400424
class BotoRunTaskSchema(Schema):
401425
"""
@@ -404,6 +428,9 @@ class BotoRunTaskSchema(Schema):
404428
tasks = fields.List(fields.Nested(BotoTaskSchema), required=True)
405429
failures = fields.List(fields.Nested(BotoFailureSchema), required=True)
406430

431+
class Meta:
432+
unknown = EXCLUDE
433+
407434

408435
class BotoDescribeTasksSchema(Schema):
409436
"""
@@ -412,6 +439,9 @@ class BotoDescribeTasksSchema(Schema):
412439
tasks = fields.List(fields.Nested(BotoTaskSchema), required=True)
413440
failures = fields.List(fields.Nested(BotoFailureSchema), required=True)
414441

442+
class Meta:
443+
unknown = EXCLUDE
444+
415445

416446
class EcsFargateError(Exception):
417447
"""Thrown when something unexpected has occurred within the AWS ECS/Fargate ecosystem"""

setup.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212

1313
setup(
1414
name="airflow-aws-executors",
15-
version="1.0.0",
15+
version="1.1.0",
1616
description=description,
1717
long_description=long_description,
1818
long_description_content_type="text/markdown",
@@ -29,5 +29,5 @@
2929
],
3030
packages=["airflow_aws_executors"],
3131
include_package_data=True,
32-
install_requires=["boto3", "apache-airflow>=1.10.5"]
32+
install_requires=["boto3", "apache-airflow>=1.10.5", "marshmallow>=3"]
3333
)

0 commit comments

Comments
 (0)