Skip to content

Commit 2588f1d

Browse files
authored
Distributed Map Support in AWS Step Functions (#1720)
* support git repos * remove spurious function * fix formatting * handle race condition with local packages metadata * remove print * Support distributed map in AWS Step Functions * add comments * add plenty of jitters and retries * add retries * add s3 path * fix bucket * fix * fix
1 parent d95e7b6 commit 2588f1d

6 files changed

+228
-19
lines changed

metaflow/metaflow_config.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,13 @@
268268
# machine execution logs. This needs to be available when using the
269269
# `step-functions create --log-execution-history` command.
270270
SFN_EXECUTION_LOG_GROUP_ARN = from_conf("SFN_EXECUTION_LOG_GROUP_ARN")
271-
271+
# Amazon S3 path for storing the results of AWS Step Functions Distributed Map
272+
SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH = from_conf(
273+
"SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH",
274+
os.path.join(DATASTORE_SYSROOT_S3, "sfn_distributed_map_output")
275+
if DATASTORE_SYSROOT_S3
276+
else None,
277+
)
272278
###
273279
# Kubernetes configuration
274280
###

metaflow/plugins/aws/step_functions/dynamo_db_client.py

+28-6
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
import os
2+
import time
3+
24
import requests
5+
36
from metaflow.metaflow_config import SFN_DYNAMO_DB_TABLE
47

58

@@ -25,12 +28,31 @@ def save_foreach_cardinality(self, foreach_split_task_id, foreach_cardinality, t
2528
def save_parent_task_id_for_foreach_join(
2629
self, foreach_split_task_id, foreach_join_parent_task_id
2730
):
28-
return self._client.update_item(
29-
TableName=self.name,
30-
Key={"pathspec": {"S": foreach_split_task_id}},
31-
UpdateExpression="ADD parent_task_ids_for_foreach_join :val",
32-
ExpressionAttributeValues={":val": {"SS": [foreach_join_parent_task_id]}},
33-
)
31+
ex = None
32+
for attempt in range(10):
33+
try:
34+
return self._client.update_item(
35+
TableName=self.name,
36+
Key={"pathspec": {"S": foreach_split_task_id}},
37+
UpdateExpression="ADD parent_task_ids_for_foreach_join :val",
38+
ExpressionAttributeValues={
39+
":val": {"SS": [foreach_join_parent_task_id]}
40+
},
41+
)
42+
except self._client.exceptions.ClientError as error:
43+
ex = error
44+
if (
45+
error.response["Error"]["Code"]
46+
== "ProvisionedThroughputExceededException"
47+
):
48+
# hopefully, enough time for AWS to scale up! otherwise
49+
# ensure sufficient on-demand throughput for dynamo db
50+
# is provisioned ahead of time
51+
sleep_time = min((2**attempt) * 10, 60)
52+
time.sleep(sleep_time)
53+
else:
54+
raise
55+
raise ex
3456

3557
def get_parent_task_ids_for_foreach_join(self, foreach_split_task_id):
3658
response = self._client.get_item(

metaflow/plugins/aws/step_functions/production_token.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
import os
21
import json
2+
import os
33
import random
44
import string
55
import zlib

metaflow/plugins/aws/step_functions/step_functions.py

+168-4
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
SFN_DYNAMO_DB_TABLE,
1616
SFN_EXECUTION_LOG_GROUP_ARN,
1717
SFN_IAM_ROLE,
18+
SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH,
1819
)
1920
from metaflow.parameters import deploy_time_eval
2021
from metaflow.util import dict_to_cli_options, to_pascalcase
@@ -52,6 +53,7 @@ def __init__(
5253
max_workers=None,
5354
workflow_timeout=None,
5455
is_project=False,
56+
use_distributed_map=False,
5557
):
5658
self.name = name
5759
self.graph = graph
@@ -70,6 +72,9 @@ def __init__(
7072
self.max_workers = max_workers
7173
self.workflow_timeout = workflow_timeout
7274

75+
# https://aws.amazon.com/blogs/aws/step-functions-distributed-map-a-serverless-solution-for-large-scale-parallel-data-processing/
76+
self.use_distributed_map = use_distributed_map
77+
7378
self._client = StepFunctionsClient()
7479
self._workflow = self._compile()
7580
self._cron = self._cron()
@@ -365,17 +370,80 @@ def _visit(node, workflow, exit_node=None):
365370
.parameter("SplitParentTaskId.$", "$.JobId")
366371
.parameter("Parameters.$", "$.Parameters")
367372
.parameter("Index.$", "$$.Map.Item.Value")
368-
.next(node.matching_join)
373+
.next(
374+
"%s_*GetManifest" % iterator_name
375+
if self.use_distributed_map
376+
else node.matching_join
377+
)
369378
.iterator(
370379
_visit(
371380
self.graph[node.out_funcs[0]],
372-
Workflow(node.out_funcs[0]).start_at(node.out_funcs[0]),
381+
Workflow(node.out_funcs[0])
382+
.start_at(node.out_funcs[0])
383+
.mode(
384+
"DISTRIBUTED" if self.use_distributed_map else "INLINE"
385+
),
373386
node.matching_join,
374387
)
375388
)
376389
.max_concurrency(self.max_workers)
377-
.output_path("$.[0]")
390+
# AWS Step Functions has a short coming for DistributedMap at the
391+
# moment that does not allow us to subset the output of for-each
392+
# to just a single element. We have to rely on a rather terrible
393+
# hack and resort to using ResultWriter to write the state to
394+
# Amazon S3 and process it in another task. But, well what can we
395+
# do...
396+
.result_writer(
397+
*(
398+
(
399+
(
400+
SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH[len("s3://") :]
401+
if SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH.startswith(
402+
"s3://"
403+
)
404+
else SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH
405+
).split("/", 1)
406+
+ [""]
407+
)[:2]
408+
if self.use_distributed_map
409+
else (None, None)
410+
)
411+
)
412+
.output_path("$" if self.use_distributed_map else "$.[0]")
378413
)
414+
if self.use_distributed_map:
415+
workflow.add_state(
416+
State("%s_*GetManifest" % iterator_name)
417+
.resource("arn:aws:states:::aws-sdk:s3:getObject")
418+
.parameter("Bucket.$", "$.ResultWriterDetails.Bucket")
419+
.parameter("Key.$", "$.ResultWriterDetails.Key")
420+
.next("%s_*Map" % iterator_name)
421+
.result_selector("Body.$", "States.StringToJson($.Body)")
422+
)
423+
workflow.add_state(
424+
Map("%s_*Map" % iterator_name)
425+
.iterator(
426+
Workflow("%s_*PassWorkflow" % iterator_name)
427+
.mode("DISTRIBUTED")
428+
.start_at("%s_*Pass" % iterator_name)
429+
.add_state(
430+
Pass("%s_*Pass" % iterator_name)
431+
.end()
432+
.parameter("Output.$", "States.StringToJson($.Output)")
433+
.output_path("$.Output")
434+
)
435+
)
436+
.next(node.matching_join)
437+
.max_concurrency(1000)
438+
.item_reader(
439+
JSONItemReader()
440+
.resource("arn:aws:states:::s3:getObject")
441+
.parameter("Bucket.$", "$.Body.DestinationBucket")
442+
.parameter("Key.$", "$.Body.ResultFiles.SUCCEEDED.[0].Key")
443+
)
444+
.output_path("$.[0]")
445+
)
446+
379447
# Continue the traversal from the matching_join.
380448
_visit(self.graph[node.matching_join], workflow, exit_node)
381449
# We shouldn't ideally ever get here.
@@ -444,7 +512,6 @@ def _batch(self, node):
444512
"metaflow.owner": self.username,
445513
"metaflow.flow_name": self.flow.name,
446514
"metaflow.step_name": node.name,
447-
"metaflow.run_id.$": "$$.Execution.Name",
448515
# Unfortunately we can't set the task id here since AWS Step
449516
# Functions lacks any notion of run-scoped task identifiers. We
450517
# instead co-opt the AWS Batch job id as the task id. This also
@@ -456,6 +523,10 @@ def _batch(self, node):
456523
# `$$.State.RetryCount` resolves to an int dynamically and
457524
# AWS Batch job specification only accepts strings. We handle
458525
# retries/catch within AWS Batch to get around this limitation.
526+
# And, we also cannot set the run id here since the run id maps to
527+
# the execution name of the AWS Step Functions State Machine, which
528+
# is different when executing inside a distributed map. We set it once
529+
# in the start step and move it along to be consumed by all the children.
459530
"metaflow.version": self.environment.get_environment_info()[
460531
"metaflow_version"
461532
],
@@ -492,6 +563,12 @@ def _batch(self, node):
492563
env["METAFLOW_S3_ENDPOINT_URL"] = S3_ENDPOINT_URL
493564

494565
if node.name == "start":
566+
# metaflow.run_id maps to AWS Step Functions State Machine Execution in all
567+
# cases except for when within a for-each construct that relies on
568+
# Distributed Map. To work around this issue, we pass the run id from the
569+
# start step to all subsequent tasks.
570+
attrs["metaflow.run_id.$"] = "$$.Execution.Name"
571+
495572
# Initialize parameters for the flow in the `start` step.
496573
parameters = self._process_parameters()
497574
if parameters:
@@ -550,6 +627,8 @@ def _batch(self, node):
550627
env["METAFLOW_SPLIT_PARENT_TASK_ID"] = (
551628
"$.Parameters.split_parent_task_id_%s" % node.split_parents[-1]
552629
)
630+
# Inherit the run id from the parent and pass it along to children.
631+
attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']"
553632
else:
554633
# Set appropriate environment variables for runtime replacement.
555634
if len(node.in_funcs) == 1:
@@ -558,6 +637,8 @@ def _batch(self, node):
558637
% node.in_funcs[0]
559638
)
560639
env["METAFLOW_PARENT_TASK_ID"] = "$.JobId"
640+
# Inherit the run id from the parent and pass it along to children.
641+
attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']"
561642
else:
562643
# Generate the input paths in a quasi-compressed format.
563644
# See util.decompress_list for why this is written the way
@@ -567,6 +648,8 @@ def _batch(self, node):
567648
"${METAFLOW_PARENT_%s_TASK_ID}" % (idx, idx)
568649
for idx, _ in enumerate(node.in_funcs)
569650
)
651+
# Inherit the run id from the parent and pass it along to children.
652+
attrs["metaflow.run_id.$"] = "$.[0].Parameters.['metaflow.run_id']"
570653
for idx, _ in enumerate(node.in_funcs):
571654
env["METAFLOW_PARENT_%s_TASK_ID" % idx] = "$.[%s].JobId" % idx
572655
env["METAFLOW_PARENT_%s_STEP" % idx] = (
@@ -893,6 +976,12 @@ def __init__(self, name):
893976
tree = lambda: defaultdict(tree)
894977
self.payload = tree()
895978

979+
def mode(self, mode):
980+
self.payload["ProcessorConfig"] = {"Mode": mode}
981+
if mode == "DISTRIBUTED":
982+
self.payload["ProcessorConfig"]["ExecutionType"] = "STANDARD"
983+
return self
984+
896985
def start_at(self, start_at):
897986
self.payload["StartAt"] = start_at
898987
return self
@@ -940,10 +1029,18 @@ def result_path(self, result_path):
9401029
self.payload["ResultPath"] = result_path
9411030
return self
9421031

1032+
def result_selector(self, name, value):
1033+
self.payload["ResultSelector"][name] = value
1034+
return self
1035+
9431036
def _partition(self):
9441037
# This is needed to support AWS Gov Cloud and AWS CN regions
9451038
return SFN_IAM_ROLE.split(":")[1]
9461039

1040+
def retry_strategy(self, retry_strategy):
1041+
self.payload["Retry"] = [retry_strategy]
1042+
return self
1043+
9471044
def batch(self, job):
9481045
self.resource(
9491046
"arn:%s:states:::batch:submitJob.sync" % self._partition()
@@ -963,6 +1060,19 @@ def batch(self, job):
9631060
# tags may not be present in all scenarios
9641061
if "tags" in job.payload:
9651062
self.parameter("Tags", job.payload["tags"])
1063+
# set retry strategy for AWS Batch job submission to account for the
1064+
# measily 50 jobs / second queue admission limit which people can
1065+
# run into very quickly.
1066+
self.retry_strategy(
1067+
{
1068+
"ErrorEquals": ["Batch.AWSBatchException"],
1069+
"BackoffRate": 2,
1070+
"IntervalSeconds": 2,
1071+
"MaxDelaySeconds": 60,
1072+
"MaxAttempts": 10,
1073+
"JitterStrategy": "FULL",
1074+
}
1075+
)
9661076
return self
9671077

9681078
def dynamo_db(self, table_name, primary_key, values):
@@ -976,6 +1086,26 @@ def dynamo_db(self, table_name, primary_key, values):
9761086
return self
9771087

9781088

1089+
class Pass(object):
1090+
def __init__(self, name):
1091+
self.name = name
1092+
tree = lambda: defaultdict(tree)
1093+
self.payload = tree()
1094+
self.payload["Type"] = "Pass"
1095+
1096+
def end(self):
1097+
self.payload["End"] = True
1098+
return self
1099+
1100+
def parameter(self, name, value):
1101+
self.payload["Parameters"][name] = value
1102+
return self
1103+
1104+
def output_path(self, output_path):
1105+
self.payload["OutputPath"] = output_path
1106+
return self
1107+
1108+
9791109
class Parallel(object):
9801110
def __init__(self, name):
9811111
self.name = name
@@ -1037,3 +1167,37 @@ def output_path(self, output_path):
10371167
def result_path(self, result_path):
10381168
self.payload["ResultPath"] = result_path
10391169
return self
1170+
1171+
def item_reader(self, item_reader):
1172+
self.payload["ItemReader"] = item_reader.payload
1173+
return self
1174+
1175+
def result_writer(self, bucket, prefix):
1176+
if bucket is not None and prefix is not None:
1177+
self.payload["ResultWriter"] = {
1178+
"Resource": "arn:aws:states:::s3:putObject",
1179+
"Parameters": {
1180+
"Bucket": bucket,
1181+
"Prefix": prefix,
1182+
},
1183+
}
1184+
return self
1185+
1186+
1187+
class JSONItemReader(object):
1188+
def __init__(self):
1189+
tree = lambda: defaultdict(tree)
1190+
self.payload = tree()
1191+
self.payload["ReaderConfig"] = {"InputType": "JSON", "MaxItems": 1}
1192+
1193+
def resource(self, resource):
1194+
self.payload["Resource"] = resource
1195+
return self
1196+
1197+
def parameter(self, name, value):
1198+
self.payload["Parameters"][name] = value
1199+
return self
1200+
1201+
def output_path(self, output_path):
1202+
self.payload["OutputPath"] = output_path
1203+
return self

0 commit comments

Comments
 (0)