15
15
SFN_DYNAMO_DB_TABLE ,
16
16
SFN_EXECUTION_LOG_GROUP_ARN ,
17
17
SFN_IAM_ROLE ,
18
+ SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH ,
18
19
)
19
20
from metaflow .parameters import deploy_time_eval
20
21
from metaflow .util import dict_to_cli_options , to_pascalcase
@@ -369,7 +370,11 @@ def _visit(node, workflow, exit_node=None):
369
370
.parameter ("SplitParentTaskId.$" , "$.JobId" )
370
371
.parameter ("Parameters.$" , "$.Parameters" )
371
372
.parameter ("Index.$" , "$$.Map.Item.Value" )
372
- .next (node .matching_join )
373
+ .next (
374
+ "%s_*GetManifest" % iterator_name
375
+ if self .use_distributed_map
376
+ else node .matching_join
377
+ )
373
378
.iterator (
374
379
_visit (
375
380
self .graph [node .out_funcs [0 ]],
@@ -382,8 +387,54 @@ def _visit(node, workflow, exit_node=None):
382
387
)
383
388
)
384
389
.max_concurrency (self .max_workers )
385
- .output_path ("$.[0]" )
390
+ # AWS Step Functions has a short coming for DistributedMap at the
391
+ # moment that does not allow us to subset the output of for-each
392
+ # to just a single element. We have to rely on a rather terrible
393
+ # hack and resort to using ResultWriter to write the state to
394
+ # Amazon S3 and process it in another task. But, well what can we
395
+ # do...
396
+ .result_writer (
397
+ * (
398
+ SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH .rsplit ("/" , 1 )
399
+ if self .use_distributed_map
400
+ else ()
401
+ )
402
+ )
403
+ .output_path ("$" if self .use_distributed_map else "$.[0]" )
386
404
)
405
+ if self .use_distributed_map :
406
+ workflow .add_state (
407
+ State ("%s_*GetManifest" % iterator_name )
408
+ .resource ("arn:aws:states:::aws-sdk:s3:getObject" )
409
+ .parameter ("Bucket.$" , "$.ResultWriterDetails.Bucket" )
410
+ .parameter ("Key.$" , "$.ResultWriterDetails.Key" )
411
+ .next ("%s_*Map" % iterator_name )
412
+ .result_selector ("Body.$" , "States.StringToJson($.Body)" )
413
+ )
414
+ workflow .add_state (
415
+ Map ("%s_*Map" % iterator_name )
416
+ .iterator (
417
+ Workflow ("%s_*PassWorkflow" % iterator_name )
418
+ .mode ("DISTRIBUTED" )
419
+ .start_at ("%s_*Pass" % iterator_name )
420
+ .add_state (
421
+ Pass ("%s_*Pass" % iterator_name )
422
+ .end ()
423
+ .parameter ("Output.$" , "States.StringToJson($.Output)" )
424
+ .output_path ("$.Output" )
425
+ )
426
+ )
427
+ .next (node .matching_join )
428
+ .max_concurrency (1000 )
429
+ .item_reader (
430
+ JSONItemReader ()
431
+ .resource ("arn:aws:states:::s3:getObject" )
432
+ .parameter ("Bucket.$" , "$.Body.DestinationBucket" )
433
+ .parameter ("Key.$" , "$.Body.ResultFiles.SUCCEEDED.[0].Key" )
434
+ )
435
+ .output_path ("$.[0]" )
436
+ )
437
+
387
438
# Continue the traversal from the matching_join.
388
439
_visit (self .graph [node .matching_join ], workflow , exit_node )
389
440
# We shouldn't ideally ever get here.
@@ -508,7 +559,6 @@ def _batch(self, node):
508
559
# Distributed Map. To work around this issue, we pass the run id from the
509
560
# start step to all subsequent tasks.
510
561
attrs ["metaflow.run_id.$" ] = "$$.Execution.Name"
511
- attrs ["run_id.$" ] = "$$.Execution.Name"
512
562
513
563
# Initialize parameters for the flow in the `start` step.
514
564
parameters = self ._process_parameters ()
@@ -569,8 +619,7 @@ def _batch(self, node):
569
619
"$.Parameters.split_parent_task_id_%s" % node .split_parents [- 1 ]
570
620
)
571
621
# Inherit the run id from the parent and pass it along to children.
572
- attrs ["metaflow.run_id.$" ] = "$.Parameters.run_id"
573
- attrs ["run_id.$" ] = "$.Parameters.run_id"
622
+ attrs ["metaflow.run_id.$" ] = "$.Parameters.['metaflow.run_id']"
574
623
else :
575
624
# Set appropriate environment variables for runtime replacement.
576
625
if len (node .in_funcs ) == 1 :
@@ -580,8 +629,7 @@ def _batch(self, node):
580
629
)
581
630
env ["METAFLOW_PARENT_TASK_ID" ] = "$.JobId"
582
631
# Inherit the run id from the parent and pass it along to children.
583
- attrs ["metaflow.run_id.$" ] = "$.Parameters.run_id"
584
- attrs ["run_id.$" ] = "$.Parameters.run_id"
632
+ attrs ["metaflow.run_id.$" ] = "$.Parameters.['metaflow.run_id']"
585
633
else :
586
634
# Generate the input paths in a quasi-compressed format.
587
635
# See util.decompress_list for why this is written the way
@@ -592,8 +640,7 @@ def _batch(self, node):
592
640
for idx , _ in enumerate (node .in_funcs )
593
641
)
594
642
# Inherit the run id from the parent and pass it along to children.
595
- attrs ["metaflow.run_id.$" ] = "$.[0].Parameters.run_id"
596
- attrs ["run_id.$" ] = "$.[0].Parameters.run_id"
643
+ attrs ["metaflow.run_id.$" ] = "$.[0].Parameters.['metaflow.run_id']"
597
644
for idx , _ in enumerate (node .in_funcs ):
598
645
env ["METAFLOW_PARENT_%s_TASK_ID" % idx ] = "$.[%s].JobId" % idx
599
646
env ["METAFLOW_PARENT_%s_STEP" % idx ] = (
@@ -973,6 +1020,10 @@ def result_path(self, result_path):
973
1020
self .payload ["ResultPath" ] = result_path
974
1021
return self
975
1022
1023
+ def result_selector (self , name , value ):
1024
+ self .payload ["ResultSelector" ][name ] = value
1025
+ return self
1026
+
976
1027
def _partition (self ):
977
1028
# This is needed to support AWS Gov Cloud and AWS CN regions
978
1029
return SFN_IAM_ROLE .split (":" )[1 ]
@@ -1026,6 +1077,26 @@ def dynamo_db(self, table_name, primary_key, values):
1026
1077
return self
1027
1078
1028
1079
1080
+ class Pass (object ):
1081
+ def __init__ (self , name ):
1082
+ self .name = name
1083
+ tree = lambda : defaultdict (tree )
1084
+ self .payload = tree ()
1085
+ self .payload ["Type" ] = "Pass"
1086
+
1087
+ def end (self ):
1088
+ self .payload ["End" ] = True
1089
+ return self
1090
+
1091
+ def parameter (self , name , value ):
1092
+ self .payload ["Parameters" ][name ] = value
1093
+ return self
1094
+
1095
+ def output_path (self , output_path ):
1096
+ self .payload ["OutputPath" ] = output_path
1097
+ return self
1098
+
1099
+
1029
1100
class Parallel (object ):
1030
1101
def __init__ (self , name ):
1031
1102
self .name = name
@@ -1087,3 +1158,37 @@ def output_path(self, output_path):
1087
1158
def result_path (self , result_path ):
1088
1159
self .payload ["ResultPath" ] = result_path
1089
1160
return self
1161
+
1162
+ def item_reader (self , item_reader ):
1163
+ self .payload ["ItemReader" ] = item_reader .payload
1164
+ return self
1165
+
1166
+ def result_writer (self , bucket , prefix ):
1167
+ if bucket is not None and prefix is not None :
1168
+ self .payload ["ResultWriter" ] = {
1169
+ "Resource" : "arn:aws:states:::s3:putObject" ,
1170
+ "Parameters" : {
1171
+ "Bucket" : bucket ,
1172
+ "Prefix" : prefix ,
1173
+ },
1174
+ }
1175
+ return self
1176
+
1177
+
1178
+ class JSONItemReader (object ):
1179
+ def __init__ (self ):
1180
+ tree = lambda : defaultdict (tree )
1181
+ self .payload = tree ()
1182
+ self .payload ["ReaderConfig" ] = {"InputType" : "JSON" , "MaxItems" : 1 }
1183
+
1184
+ def resource (self , resource ):
1185
+ self .payload ["Resource" ] = resource
1186
+ return self
1187
+
1188
+ def parameter (self , name , value ):
1189
+ self .payload ["Parameters" ][name ] = value
1190
+ return self
1191
+
1192
+ def output_path (self , output_path ):
1193
+ self .payload ["OutputPath" ] = output_path
1194
+ return self
0 commit comments