Skip to content

Commit 8b064a9

Browse files
committed
add s3 path
1 parent b7df18a commit 8b064a9

File tree

3 files changed

+122
-10
lines changed

3 files changed

+122
-10
lines changed

metaflow/metaflow_config.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -268,7 +268,13 @@
268268
# machine execution logs. This needs to be available when using the
269269
# `step-functions create --log-execution-history` command.
270270
SFN_EXECUTION_LOG_GROUP_ARN = from_conf("SFN_EXECUTION_LOG_GROUP_ARN")
271-
271+
# Amazon S3 path for storing the results of AWS Step Functions Distributed Map
272+
SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH = from_conf(
273+
"SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH",
274+
os.path.join(DATASTORE_SYSROOT_S3, "sfn_distributed_map_output")
275+
if DATASTORE_SYSROOT_S3
276+
else None,
277+
)
272278
###
273279
# Kubernetes configuration
274280
###

metaflow/plugins/aws/step_functions/dynamo_db_client.py

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import time
3+
34
import requests
45

56
from metaflow.metaflow_config import SFN_DYNAMO_DB_TABLE

metaflow/plugins/aws/step_functions/step_functions.py

+114-9
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
SFN_DYNAMO_DB_TABLE,
1616
SFN_EXECUTION_LOG_GROUP_ARN,
1717
SFN_IAM_ROLE,
18+
SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH,
1819
)
1920
from metaflow.parameters import deploy_time_eval
2021
from metaflow.util import dict_to_cli_options, to_pascalcase
@@ -369,7 +370,11 @@ def _visit(node, workflow, exit_node=None):
369370
.parameter("SplitParentTaskId.$", "$.JobId")
370371
.parameter("Parameters.$", "$.Parameters")
371372
.parameter("Index.$", "$$.Map.Item.Value")
372-
.next(node.matching_join)
373+
.next(
374+
"%s_*GetManifest" % iterator_name
375+
if self.use_distributed_map
376+
else node.matching_join
377+
)
373378
.iterator(
374379
_visit(
375380
self.graph[node.out_funcs[0]],
@@ -382,8 +387,54 @@ def _visit(node, workflow, exit_node=None):
382387
)
383388
)
384389
.max_concurrency(self.max_workers)
385-
.output_path("$.[0]")
390+
# AWS Step Functions has a short coming for DistributedMap at the
391+
# moment that does not allow us to subset the output of for-each
392+
# to just a single element. We have to rely on a rather terrible
393+
# hack and resort to using ResultWriter to write the state to
394+
# Amazon S3 and process it in another task. But, well what can we
395+
# do...
396+
.result_writer(
397+
*(
398+
SFN_S3_DISTRIBUTED_MAP_OUTPUT_PATH.rsplit("/", 1)
399+
if self.use_distributed_map
400+
else ()
401+
)
402+
)
403+
.output_path("$" if self.use_distributed_map else "$.[0]")
386404
)
405+
if self.use_distributed_map:
406+
workflow.add_state(
407+
State("%s_*GetManifest" % iterator_name)
408+
.resource("arn:aws:states:::aws-sdk:s3:getObject")
409+
.parameter("Bucket.$", "$.ResultWriterDetails.Bucket")
410+
.parameter("Key.$", "$.ResultWriterDetails.Key")
411+
.next("%s_*Map" % iterator_name)
412+
.result_selector("Body.$", "States.StringToJson($.Body)")
413+
)
414+
workflow.add_state(
415+
Map("%s_*Map" % iterator_name)
416+
.iterator(
417+
Workflow("%s_*PassWorkflow" % iterator_name)
418+
.mode("DISTRIBUTED")
419+
.start_at("%s_*Pass" % iterator_name)
420+
.add_state(
421+
Pass("%s_*Pass" % iterator_name)
422+
.end()
423+
.parameter("Output.$", "States.StringToJson($.Output)")
424+
.output_path("$.Output")
425+
)
426+
)
427+
.next(node.matching_join)
428+
.max_concurrency(1000)
429+
.item_reader(
430+
JSONItemReader()
431+
.resource("arn:aws:states:::s3:getObject")
432+
.parameter("Bucket.$", "$.Body.DestinationBucket")
433+
.parameter("Key.$", "$.Body.ResultFiles.SUCCEEDED.[0].Key")
434+
)
435+
.output_path("$.[0]")
436+
)
437+
387438
# Continue the traversal from the matching_join.
388439
_visit(self.graph[node.matching_join], workflow, exit_node)
389440
# We shouldn't ideally ever get here.
@@ -508,7 +559,6 @@ def _batch(self, node):
508559
# Distributed Map. To work around this issue, we pass the run id from the
509560
# start step to all subsequent tasks.
510561
attrs["metaflow.run_id.$"] = "$$.Execution.Name"
511-
attrs["run_id.$"] = "$$.Execution.Name"
512562

513563
# Initialize parameters for the flow in the `start` step.
514564
parameters = self._process_parameters()
@@ -569,8 +619,7 @@ def _batch(self, node):
569619
"$.Parameters.split_parent_task_id_%s" % node.split_parents[-1]
570620
)
571621
# Inherit the run id from the parent and pass it along to children.
572-
attrs["metaflow.run_id.$"] = "$.Parameters.run_id"
573-
attrs["run_id.$"] = "$.Parameters.run_id"
622+
attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']"
574623
else:
575624
# Set appropriate environment variables for runtime replacement.
576625
if len(node.in_funcs) == 1:
@@ -580,8 +629,7 @@ def _batch(self, node):
580629
)
581630
env["METAFLOW_PARENT_TASK_ID"] = "$.JobId"
582631
# Inherit the run id from the parent and pass it along to children.
583-
attrs["metaflow.run_id.$"] = "$.Parameters.run_id"
584-
attrs["run_id.$"] = "$.Parameters.run_id"
632+
attrs["metaflow.run_id.$"] = "$.Parameters.['metaflow.run_id']"
585633
else:
586634
# Generate the input paths in a quasi-compressed format.
587635
# See util.decompress_list for why this is written the way
@@ -592,8 +640,7 @@ def _batch(self, node):
592640
for idx, _ in enumerate(node.in_funcs)
593641
)
594642
# Inherit the run id from the parent and pass it along to children.
595-
attrs["metaflow.run_id.$"] = "$.[0].Parameters.run_id"
596-
attrs["run_id.$"] = "$.[0].Parameters.run_id"
643+
attrs["metaflow.run_id.$"] = "$.[0].Parameters.['metaflow.run_id']"
597644
for idx, _ in enumerate(node.in_funcs):
598645
env["METAFLOW_PARENT_%s_TASK_ID" % idx] = "$.[%s].JobId" % idx
599646
env["METAFLOW_PARENT_%s_STEP" % idx] = (
@@ -973,6 +1020,10 @@ def result_path(self, result_path):
9731020
self.payload["ResultPath"] = result_path
9741021
return self
9751022

1023+
def result_selector(self, name, value):
1024+
self.payload["ResultSelector"][name] = value
1025+
return self
1026+
9761027
def _partition(self):
9771028
# This is needed to support AWS Gov Cloud and AWS CN regions
9781029
return SFN_IAM_ROLE.split(":")[1]
@@ -1026,6 +1077,26 @@ def dynamo_db(self, table_name, primary_key, values):
10261077
return self
10271078

10281079

1080+
class Pass(object):
1081+
def __init__(self, name):
1082+
self.name = name
1083+
tree = lambda: defaultdict(tree)
1084+
self.payload = tree()
1085+
self.payload["Type"] = "Pass"
1086+
1087+
def end(self):
1088+
self.payload["End"] = True
1089+
return self
1090+
1091+
def parameter(self, name, value):
1092+
self.payload["Parameters"][name] = value
1093+
return self
1094+
1095+
def output_path(self, output_path):
1096+
self.payload["OutputPath"] = output_path
1097+
return self
1098+
1099+
10291100
class Parallel(object):
10301101
def __init__(self, name):
10311102
self.name = name
@@ -1087,3 +1158,37 @@ def output_path(self, output_path):
10871158
def result_path(self, result_path):
10881159
self.payload["ResultPath"] = result_path
10891160
return self
1161+
1162+
def item_reader(self, item_reader):
1163+
self.payload["ItemReader"] = item_reader.payload
1164+
return self
1165+
1166+
def result_writer(self, bucket, prefix):
1167+
if bucket is not None and prefix is not None:
1168+
self.payload["ResultWriter"] = {
1169+
"Resource": "arn:aws:states:::s3:putObject",
1170+
"Parameters": {
1171+
"Bucket": bucket,
1172+
"Prefix": prefix,
1173+
},
1174+
}
1175+
return self
1176+
1177+
1178+
class JSONItemReader(object):
1179+
def __init__(self):
1180+
tree = lambda: defaultdict(tree)
1181+
self.payload = tree()
1182+
self.payload["ReaderConfig"] = {"InputType": "JSON", "MaxItems": 1}
1183+
1184+
def resource(self, resource):
1185+
self.payload["Resource"] = resource
1186+
return self
1187+
1188+
def parameter(self, name, value):
1189+
self.payload["Parameters"][name] = value
1190+
return self
1191+
1192+
def output_path(self, output_path):
1193+
self.payload["OutputPath"] = output_path
1194+
return self

0 commit comments

Comments
 (0)