diff --git a/ecs_deploy/ecs.py b/ecs_deploy/ecs.py index cd8e669..566f825 100644 --- a/ecs_deploy/ecs.py +++ b/ecs_deploy/ecs.py @@ -3,6 +3,7 @@ import re import copy from collections import defaultdict +import itertools import logging import click_log @@ -40,6 +41,11 @@ def read_env_file(container_name, file): return tuple(env_vars) +def chunks(items, size=1): + for i in range(0, len(items), size): + yield items[i:i + size] + + class EcsClient(object): def __init__(self, access_key_id=None, secret_access_key=None, region=None, profile=None, session_token=None): @@ -71,13 +77,28 @@ def describe_task_definition(self, task_definition_arn): ) def list_tasks(self, cluster_name, service_name): - return self.boto.list_tasks( - cluster=cluster_name, - serviceName=service_name + tasks_paginator = self.boto.get_paginator(u"list_tasks") + return list( + itertools.chain.from_iterable( + res["taskArns"] + for res in tasks_paginator.paginate( + cluster=cluster_name, serviceName=service_name + ) + ) ) def describe_tasks(self, cluster_name, task_arns): - return self.boto.describe_tasks(cluster=cluster_name, tasks=task_arns) + return list( + itertools.chain.from_iterable( + res["tasks"] + for res in map( + lambda chunk: self.boto.describe_tasks( + cluster=cluster_name, tasks=chunk + ), + chunks(task_arns, 100), + ) + ) + ) def register_task_definition(self, family, containers, volumes, role_arn, execution_role_arn, runtime_platform, tags, @@ -1317,11 +1338,11 @@ def is_deployed(self, service): cluster_name=service.cluster, service_name=service.name ) - if not running_tasks[u'taskArns']: + if not running_tasks: return service.desired_count == 0 running_count = self.get_running_tasks_count( service=service, - task_arns=running_tasks[u'taskArns'] + task_arns=running_tasks ) return service.desired_count == running_count @@ -1331,7 +1352,7 @@ def get_running_tasks_count(self, service, task_arns): cluster_name=self._cluster_name, task_arns=task_arns ) - for task in tasks_details[u'tasks']: + for task in tasks_details: arn = task[u'taskDefinitionArn'] status = task[u'lastStatus'] if arn == service.task_definition and status == u'RUNNING': diff --git a/tests/test_ecs.py b/tests/test_ecs.py index 6ed7815..222d7b5 100644 --- a/tests/test_ecs.py +++ b/tests/test_ecs.py @@ -393,21 +393,13 @@ u'test-task': RESPONSE_TASK_DEFINITION_2, } -RESPONSE_LIST_TASKS_2 = { - u"taskArns": [TASK_ARN_1, TASK_ARN_2] -} +RESPONSE_LIST_TASKS_2 = [TASK_ARN_1, TASK_ARN_2] -RESPONSE_LIST_TASKS_1 = { - u"taskArns": [TASK_ARN_1] -} +RESPONSE_LIST_TASKS_1 = [TASK_ARN_1] -RESPONSE_LIST_TASKS_0 = { - u"taskArns": [] -} +RESPONSE_LIST_TASKS_0 = [] -RESPONSE_DESCRIBE_TASKS = { - u"tasks": [PAYLOAD_TASK_1, PAYLOAD_TASK_2] -} +RESPONSE_DESCRIBE_TASKS = [PAYLOAD_TASK_1, PAYLOAD_TASK_2] @pytest.fixture() @@ -1077,7 +1069,7 @@ def test_client_describe_unknown_task_definition(client): def test_client_list_tasks(client): client.list_tasks(u'test-cluster', u'test-service') - client.boto.list_tasks.assert_called_once_with(cluster=u'test-cluster', serviceName=u'test-service') + client.boto.get_paginator.assert_called_once_with(u'list_tasks') def test_client_describe_tasks(client):