15
15
SFN_DYNAMO_DB_TABLE ,
16
16
SFN_EXECUTION_LOG_GROUP_ARN ,
17
17
SFN_IAM_ROLE ,
18
+ SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH ,
18
19
)
19
20
from metaflow .parameters import deploy_time_eval
20
21
from metaflow .util import dict_to_cli_options , to_pascalcase
@@ -52,6 +53,7 @@ def __init__(
52
53
max_workers = None ,
53
54
workflow_timeout = None ,
54
55
is_project = False ,
56
+ use_distributed_map = False ,
55
57
):
56
58
self .name = name
57
59
self .graph = graph
@@ -70,6 +72,9 @@ def __init__(
70
72
self .max_workers = max_workers
71
73
self .workflow_timeout = workflow_timeout
72
74
75
+ # https://aws.amazon.com/blogs/aws/step-functions-distributed-map-a-serverless-solution-for-large-scale-parallel-data-processing/
76
+ self .use_distributed_map = use_distributed_map
77
+
73
78
self ._client = StepFunctionsClient ()
74
79
self ._workflow = self ._compile ()
75
80
self ._cron = self ._cron ()
@@ -365,17 +370,80 @@ def _visit(node, workflow, exit_node=None):
365
370
.parameter ("SplitParentTaskId.$" , "$.JobId" )
366
371
.parameter ("Parameters.$" , "$.Parameters" )
367
372
.parameter ("Index.$" , "$$.Map.Item.Value" )
368
- .next (node .matching_join )
373
+ .next (
374
+ "%s_*GetManifest" % iterator_name
375
+ if self .use_distributed_map
376
+ else node .matching_join
377
+ )
369
378
.iterator (
370
379
_visit (
371
380
self .graph [node .out_funcs [0 ]],
372
- Workflow (node .out_funcs [0 ]).start_at (node .out_funcs [0 ]),
381
+ Workflow (node .out_funcs [0 ])
382
+ .start_at (node .out_funcs [0 ])
383
+ .mode (
384
+ "DISTRIBUTED" if self .use_distributed_map else "INLINE"
385
+ ),
373
386
node .matching_join ,
374
387
)
375
388
)
376
389
.max_concurrency (self .max_workers )
377
- .output_path ("$.[0]" )
390
+ # AWS Step Functions has a short coming for DistributedMap at the
391
+ # moment that does not allow us to subset the output of for-each
392
+ # to just a single element. We have to rely on a rather terrible
393
+ # hack and resort to using ResultWriter to write the state to
394
+ # Amazon S3 and process it in another task. But, well what can we
395
+ # do...
396
+ .result_writer (
397
+ * (
398
+ (
399
+ (
400
+ SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH [len ("s3://" ) :]
401
+ if SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH .startswith (
402
+ "s3://"
403
+ )
404
+ else SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH
405
+ ).split ("/" , 1 )
406
+ + ["" ]
407
+ )[:2 ]
408
+ if self .use_distributed_map
409
+ else (None , None )
410
+ )
411
+ )
412
+ .output_path ("$" if self .use_distributed_map else "$.[0]" )
378
413
)
414
+ if self .use_distributed_map :
415
+ workflow .add_state (
416
+ State ("%s_*GetManifest" % iterator_name )
417
+ .resource ("arn:aws:states:::aws-sdk:s3:getObject" )
418
+ .parameter ("Bucket.$" , "$.ResultWriterDetails.Bucket" )
419
+ .parameter ("Key.$" , "$.ResultWriterDetails.Key" )
420
+ .next ("%s_*Map" % iterator_name )
421
+ .result_selector ("Body.$" , "States.StringToJson($.Body)" )
422
+ )
423
+ workflow .add_state (
424
+ Map ("%s_*Map" % iterator_name )
425
+ .iterator (
426
+ Workflow ("%s_*PassWorkflow" % iterator_name )
427
+ .mode ("DISTRIBUTED" )
428
+ .start_at ("%s_*Pass" % iterator_name )
429
+ .add_state (
430
+ Pass ("%s_*Pass" % iterator_name )
431
+ .end ()
432
+ .parameter ("Output.$" , "States.StringToJson($.Output)" )
433
+ .output_path ("$.Output" )
434
+ )
435
+ )
436
+ .next (node .matching_join )
437
+ .max_concurrency (1000 )
438
+ .item_reader (
439
+ JSONItemReader ()
440
+ .resource ("arn:aws:states:::s3:getObject" )
441
+ .parameter ("Bucket.$" , "$.Body.DestinationBucket" )
442
+ .parameter ("Key.$" , "$.Body.ResultFiles.SUCCEEDED.[0].Key" )
443
+ )
444
+ .output_path ("$.[0]" )
445
+ )
446
+
379
447
# Continue the traversal from the matching_join.
380
448
_visit (self .graph [node .matching_join ], workflow , exit_node )
381
449
# We shouldn't ideally ever get here.
@@ -444,7 +512,6 @@ def _batch(self, node):
444
512
"metaflow.owner" : self .username ,
445
513
"metaflow.flow_name" : self .flow .name ,
446
514
"metaflow.step_name" : node .name ,
447
- "metaflow.run_id.$" : "$$.Execution.Name" ,
448
515
# Unfortunately we can't set the task id here since AWS Step
449
516
# Functions lacks any notion of run-scoped task identifiers. We
450
517
# instead co-opt the AWS Batch job id as the task id. This also
@@ -456,6 +523,10 @@ def _batch(self, node):
456
523
# `$$.State.RetryCount` resolves to an int dynamically and
457
524
# AWS Batch job specification only accepts strings. We handle
458
525
# retries/catch within AWS Batch to get around this limitation.
526
+ # And, we also cannot set the run id here since the run id maps to
527
+ # the execution name of the AWS Step Functions State Machine, which
528
+ # is different when executing inside a distributed map. We set it once
529
+ # in the start step and move it along to be consumed by all the children.
459
530
"metaflow.version" : self .environment .get_environment_info ()[
460
531
"metaflow_version"
461
532
],
@@ -492,6 +563,12 @@ def _batch(self, node):
492
563
env ["METAFLOW_S3_ENDPOINT_URL" ] = S3_ENDPOINT_URL
493
564
494
565
if node .name == "start" :
566
+ # metaflow.run_id maps to AWS Step Functions State Machine Execution in all
567
+ # cases except for when within a for-each construct that relies on
568
+ # Distributed Map. To work around this issue, we pass the run id from the
569
+ # start step to all subsequent tasks.
570
+ attrs ["metaflow.run_id.$" ] = "$$.Execution.Name"
571
+
495
572
# Initialize parameters for the flow in the `start` step.
496
573
parameters = self ._process_parameters ()
497
574
if parameters :
@@ -550,6 +627,8 @@ def _batch(self, node):
550
627
env ["METAFLOW_SPLIT_PARENT_TASK_ID" ] = (
551
628
"$.Parameters.split_parent_task_id_%s" % node .split_parents [- 1 ]
552
629
)
630
+ # Inherit the run id from the parent and pass it along to children.
631
+ attrs ["metaflow.run_id.$" ] = "$.Parameters.['metaflow.run_id']"
553
632
else :
554
633
# Set appropriate environment variables for runtime replacement.
555
634
if len (node .in_funcs ) == 1 :
@@ -558,6 +637,8 @@ def _batch(self, node):
558
637
% node .in_funcs [0 ]
559
638
)
560
639
env ["METAFLOW_PARENT_TASK_ID" ] = "$.JobId"
640
+ # Inherit the run id from the parent and pass it along to children.
641
+ attrs ["metaflow.run_id.$" ] = "$.Parameters.['metaflow.run_id']"
561
642
else :
562
643
# Generate the input paths in a quasi-compressed format.
563
644
# See util.decompress_list for why this is written the way
@@ -567,6 +648,8 @@ def _batch(self, node):
567
648
"${METAFLOW_PARENT_%s_TASK_ID}" % (idx , idx )
568
649
for idx , _ in enumerate (node .in_funcs )
569
650
)
651
+ # Inherit the run id from the parent and pass it along to children.
652
+ attrs ["metaflow.run_id.$" ] = "$.[0].Parameters.['metaflow.run_id']"
570
653
for idx , _ in enumerate (node .in_funcs ):
571
654
env ["METAFLOW_PARENT_%s_TASK_ID" % idx ] = "$.[%s].JobId" % idx
572
655
env ["METAFLOW_PARENT_%s_STEP" % idx ] = (
@@ -893,6 +976,12 @@ def __init__(self, name):
893
976
tree = lambda : defaultdict (tree )
894
977
self .payload = tree ()
895
978
979
+ def mode (self , mode ):
980
+ self .payload ["ProcessorConfig" ] = {"Mode" : mode }
981
+ if mode == "DISTRIBUTED" :
982
+ self .payload ["ProcessorConfig" ]["ExecutionType" ] = "STANDARD"
983
+ return self
984
+
896
985
def start_at (self , start_at ):
897
986
self .payload ["StartAt" ] = start_at
898
987
return self
@@ -940,10 +1029,18 @@ def result_path(self, result_path):
940
1029
self .payload ["ResultPath" ] = result_path
941
1030
return self
942
1031
1032
+ def result_selector (self , name , value ):
1033
+ self .payload ["ResultSelector" ][name ] = value
1034
+ return self
1035
+
943
1036
def _partition (self ):
944
1037
# This is needed to support AWS Gov Cloud and AWS CN regions
945
1038
return SFN_IAM_ROLE .split (":" )[1 ]
946
1039
1040
+ def retry_strategy (self , retry_strategy ):
1041
+ self .payload ["Retry" ] = [retry_strategy ]
1042
+ return self
1043
+
947
1044
def batch (self , job ):
948
1045
self .resource (
949
1046
"arn:%s:states:::batch:submitJob.sync" % self ._partition ()
@@ -963,6 +1060,19 @@ def batch(self, job):
963
1060
# tags may not be present in all scenarios
964
1061
if "tags" in job .payload :
965
1062
self .parameter ("Tags" , job .payload ["tags" ])
1063
+ # set retry strategy for AWS Batch job submission to account for the
1064
+ # measily 50 jobs / second queue admission limit which people can
1065
+ # run into very quickly.
1066
+ self .retry_strategy (
1067
+ {
1068
+ "ErrorEquals" : ["Batch.AWSBatchException" ],
1069
+ "BackoffRate" : 2 ,
1070
+ "IntervalSeconds" : 2 ,
1071
+ "MaxDelaySeconds" : 60 ,
1072
+ "MaxAttempts" : 10 ,
1073
+ "JitterStrategy" : "FULL" ,
1074
+ }
1075
+ )
966
1076
return self
967
1077
968
1078
def dynamo_db (self , table_name , primary_key , values ):
@@ -976,6 +1086,26 @@ def dynamo_db(self, table_name, primary_key, values):
976
1086
return self
977
1087
978
1088
1089
+ class Pass (object ):
1090
+ def __init__ (self , name ):
1091
+ self .name = name
1092
+ tree = lambda : defaultdict (tree )
1093
+ self .payload = tree ()
1094
+ self .payload ["Type" ] = "Pass"
1095
+
1096
+ def end (self ):
1097
+ self .payload ["End" ] = True
1098
+ return self
1099
+
1100
+ def parameter (self , name , value ):
1101
+ self .payload ["Parameters" ][name ] = value
1102
+ return self
1103
+
1104
+ def output_path (self , output_path ):
1105
+ self .payload ["OutputPath" ] = output_path
1106
+ return self
1107
+
1108
+
979
1109
class Parallel (object ):
980
1110
def __init__ (self , name ):
981
1111
self .name = name
@@ -1037,3 +1167,37 @@ def output_path(self, output_path):
1037
1167
def result_path (self , result_path ):
1038
1168
self .payload ["ResultPath" ] = result_path
1039
1169
return self
1170
+
1171
+ def item_reader (self , item_reader ):
1172
+ self .payload ["ItemReader" ] = item_reader .payload
1173
+ return self
1174
+
1175
+ def result_writer (self , bucket , prefix ):
1176
+ if bucket is not None and prefix is not None :
1177
+ self .payload ["ResultWriter" ] = {
1178
+ "Resource" : "arn:aws:states:::s3:putObject" ,
1179
+ "Parameters" : {
1180
+ "Bucket" : bucket ,
1181
+ "Prefix" : prefix ,
1182
+ },
1183
+ }
1184
+ return self
1185
+
1186
+
1187
+ class JSONItemReader (object ):
1188
+ def __init__ (self ):
1189
+ tree = lambda : defaultdict (tree )
1190
+ self .payload = tree ()
1191
+ self .payload ["ReaderConfig" ] = {"InputType" : "JSON" , "MaxItems" : 1 }
1192
+
1193
+ def resource (self , resource ):
1194
+ self .payload ["Resource" ] = resource
1195
+ return self
1196
+
1197
+ def parameter (self , name , value ):
1198
+ self .payload ["Parameters" ][name ] = value
1199
+ return self
1200
+
1201
+ def output_path (self , output_path ):
1202
+ self .payload ["OutputPath" ] = output_path
1203
+ return self
0 commit comments