Skip to content

Commit 8a418f5

Browse files
authored
Merge pull request #34290 Refactor main.py so that it can be called by runner.run_async().
2 parents d3b1312 + 0aca8f1 commit 8a418f5

File tree

2 files changed

+51
-25
lines changed

2 files changed

+51
-25
lines changed

sdks/python/apache_beam/yaml/main.py

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -122,37 +122,59 @@ def _fix_xlang_instant_coding():
122122

123123

124124
def run(argv=None):
125+
options, constructor, display_data = build_pipeline_components_from_argv(argv)
126+
with _fix_xlang_instant_coding():
127+
with beam.Pipeline(options=options, display_data=display_data) as p:
128+
print('Building pipeline...')
129+
constructor(p)
130+
print('Running pipeline...')
131+
132+
133+
def build_pipeline_components_from_argv(argv):
125134
argv = _preparse_jinja_flags(argv)
126135
known_args, pipeline_args = _parse_arguments(argv)
127136
pipeline_template = _pipeline_spec_from_args(known_args)
128137
pipeline_yaml = yaml_transform.expand_jinja(
129138
pipeline_template, known_args.jinja_variables or {})
139+
display_data = {
140+
'yaml': pipeline_yaml,
141+
'yaml_jinja_template': pipeline_template,
142+
'yaml_jinja_variables': json.dumps(known_args.jinja_variables),
143+
}
144+
options, constructor = build_pipeline_components_from_yaml(
145+
pipeline_yaml,
146+
pipeline_args,
147+
known_args.json_schema_validation,
148+
known_args.yaml_pipeline_file,
149+
)
150+
return options, constructor, display_data
151+
152+
153+
def build_pipeline_components_from_yaml(
154+
pipeline_yaml, pipeline_args, validate_schema='generic', pipeline_path=''):
130155
pipeline_spec = yaml.load(pipeline_yaml, Loader=yaml_transform.SafeLineLoader)
131156

132-
with _fix_xlang_instant_coding():
133-
with beam.Pipeline( # linebreak for better yapf formatting
134-
options=beam.options.pipeline_options.PipelineOptions(
135-
pipeline_args,
136-
pickle_library='cloudpickle',
137-
**yaml_transform.SafeLineLoader.strip_metadata(pipeline_spec.get(
138-
'options', {}))),
139-
display_data={'yaml': pipeline_yaml,
140-
'yaml_jinja_template': pipeline_template,
141-
'yaml_jinja_variables': json.dumps(
142-
known_args.jinja_variables)}) as p:
143-
print("Building pipeline...")
144-
if 'resource_hints' in pipeline_spec.get('pipeline', {}):
145-
# Add the declared resource hints to the "root" spec.
146-
p._current_transform().resource_hints.update(
147-
resources.parse_resource_hints(
148-
yaml_transform.SafeLineLoader.strip_metadata(
149-
pipeline_spec['pipeline']['resource_hints'])))
150-
yaml_transform.expand_pipeline(
151-
p,
152-
pipeline_spec,
153-
validate_schema=known_args.json_schema_validation,
154-
pipeline_path=known_args.yaml_pipeline_file)
155-
print("Running pipeline...")
157+
options = beam.options.pipeline_options.PipelineOptions(
158+
pipeline_args,
159+
pickle_library='cloudpickle',
160+
**yaml_transform.SafeLineLoader.strip_metadata(
161+
pipeline_spec.get('options', {})))
162+
163+
def constructor(root):
164+
if 'resource_hints' in pipeline_spec.get('pipeline', {}):
165+
# Add the declared resource hints to the "root" spec.
166+
root._current_transform().resource_hints.update(
167+
resources.parse_resource_hints(
168+
yaml_transform.SafeLineLoader.strip_metadata(
169+
pipeline_spec['pipeline']['resource_hints'])))
170+
yaml_transform.expand_pipeline(
171+
root,
172+
pipeline_spec,
173+
validate_schema=validate_schema,
174+
pipeline_path=pipeline_path,
175+
)
176+
177+
return options, constructor
156178

157179

158180
if __name__ == '__main__':

sdks/python/apache_beam/yaml/yaml_transform.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,10 @@ def expand_pipeline(
10751075
providers=None,
10761076
validate_schema='generic' if jsonschema is not None else None,
10771077
pipeline_path=''):
1078+
if isinstance(pipeline, beam.pvalue.PBegin):
1079+
root = pipeline
1080+
else:
1081+
root = beam.pvalue.PBegin(pipeline)
10781082
if isinstance(pipeline_spec, str):
10791083
pipeline_spec = yaml.load(pipeline_spec, Loader=SafeLineLoader)
10801084
# TODO(robertwb): It's unclear whether this gives as good of errors, but
@@ -1087,4 +1091,4 @@ def expand_pipeline(
10871091
yaml_provider.merge_providers(
10881092
yaml_provider.parse_providers(
10891093
pipeline_path, pipeline_spec.get('providers', [])),
1090-
providers or {})).expand(beam.pvalue.PBegin(pipeline))
1094+
providers or {})).expand(root)

0 commit comments

Comments
 (0)