Skip to content

Commit fbe997c

Browse files
committed
Add --create_test option to yaml main.py.
1 parent 5b6f993 commit fbe997c

File tree

3 files changed

+86
-35
lines changed

3 files changed

+86
-35
lines changed

sdks/python/apache_beam/yaml/main.py

+58-33
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import argparse
1919
import contextlib
2020
import json
21+
import os
2122
import sys
2223
import unittest
2324

@@ -103,6 +104,12 @@ def _parse_arguments(argv):
103104
'--fix_tests',
104105
action=argparse.BooleanOptionalAction,
105106
help='Update failing test expectations to match the actual ouput.')
107+
parser.add_argument(
108+
'--create_test',
109+
action=argparse.BooleanOptionalAction,
110+
help='Automatically creates a regression test for the given pipeline, '
111+
'adding it to the pipeline spec or test suite dependon on whether '
112+
'--test_suite is given.')
106113
parser.add_argument(
107114
'--test_suite',
108115
help='Run the given tests against the given pipeline, rather than the '
@@ -154,51 +161,69 @@ def run_tests(argv=None, exit=True):
154161
pipeline_spec = yaml.load(pipeline_yaml, Loader=yaml_transform.SafeLineLoader)
155162
options = _build_pipeline_options(pipeline_spec, pipeline_args)
156163

157-
if known_args.test_suite:
158-
with open(known_args.test_suite) as fin:
159-
test_suite = yaml.load(fin, Loader=yaml_transform.SafeLineLoader)
160-
if 'tests' not in test_suite or not isinstance(test_suite['tests'], list):
161-
raise TypeError('tests attribute must be a list of test specifications')
162-
test_specs = test_suite['tests']
164+
if known_args.create_test and not known_args.fix_tests:
165+
result = unittest.TestResult()
163166
else:
164-
test_specs = pipeline_spec.get('tests', [])
165-
if not isinstance(test_specs, list):
166-
raise TypeError('tests attribute must be a list of test specifications')
167-
if not test_specs:
168-
raise RuntimeError('No tests found.')
169-
170-
with _fix_xlang_instant_coding():
171-
tests = [
172-
_YamlTestCase(pipeline_spec, test_spec, options, known_args.fix_tests)
173-
for test_spec in test_specs
174-
]
175-
suite = unittest.TestSuite(tests)
176-
result = unittest.TextTestRunner().run(suite)
177-
178-
if known_args.fix_tests:
167+
if known_args.test_suite:
168+
with open(known_args.test_suite) as fin:
169+
test_suite = yaml.load(fin, Loader=yaml_transform.SafeLineLoader) or {}
170+
if 'tests' not in test_suite or not isinstance(test_suite['tests'], list):
171+
raise TypeError('tests attribute must be a list of test specifications')
172+
test_specs = test_suite['tests']
173+
else:
174+
test_specs = pipeline_spec.get('tests', [])
175+
if not isinstance(test_specs, list):
176+
raise TypeError('tests attribute must be a list of test specifications')
177+
if not test_specs:
178+
raise RuntimeError('No tests found.')
179+
180+
with _fix_xlang_instant_coding():
181+
tests = [
182+
_YamlTestCase(
183+
pipeline_spec, test_spec, options, known_args.fix_tests)
184+
for test_spec in test_specs
185+
]
186+
suite = unittest.TestSuite(tests)
187+
result = unittest.TextTestRunner().run(suite)
188+
189+
if known_args.fix_tests or known_args.create_test:
179190
if known_args.test_suite:
180191
path = known_args.test_suite
192+
if not os.path.exists(path) and known_args.create_test:
193+
with open(path, 'w') as fout:
194+
fout.write('tests: []')
181195
elif known_args.yaml_pipeline_file:
182196
path = known_args.yaml_pipeline_file
183197
else:
184198
raise RuntimeError('Test fixing only supported for file-backed tests.')
185199
with open(path) as fin:
186200
original_yaml = fin.read()
187201
if path == known_args.yaml_pipeline_file and pipeline_yaml == content:
188-
raise RuntimeError('In-file test fixing not yet supported for templated pipelines.')
189-
updated_spec = yaml.load(original_yaml, Loader=yaml.SafeLoader)
190-
191-
for ix, test in enumerate(tests):
192-
if test.fixes:
193-
test_spec = yaml_transform.SafeLineLoader.strip_metadata(test.spec())
194-
assert test_spec == updated_spec['tests'][ix]
195-
for (loc, name), values in test.fixes.items():
196-
for expectation in updated_spec['tests'][ix][loc]:
197-
if expectation['name'] == name:
198-
expectation['elements'] = sorted(values, key=json.dumps)
199-
break
202+
raise RuntimeError(
203+
'In-file test fixing not yet supported for templated pipelines.')
204+
updated_spec = yaml.load(original_yaml, Loader=yaml.SafeLoader) or {}
205+
206+
if known_args.fix_tests:
207+
for ix, test in enumerate(tests):
208+
if test.fixes:
209+
test_spec = yaml_transform.SafeLineLoader.strip_metadata(test.spec())
210+
assert test_spec == updated_spec['tests'][ix]
211+
for (loc, name), values in test.fixes.items():
212+
for expectation in updated_spec['tests'][ix][loc]:
213+
if expectation['name'] == name:
214+
expectation['elements'] = sorted(values, key=json.dumps)
215+
break
216+
217+
if known_args.create_test:
218+
if 'tests' not in updated_spec:
219+
updated_spec['tests'] = []
220+
updated_spec['tests'].append(
221+
yaml_testing.create_test(pipeline_spec, options))
200222

201223
updated_yaml = yaml_utils.patch_yaml(original_yaml, updated_spec)
224+
import pprint
225+
pprint.pprint(updated_spec, sort_dicts=False)
226+
print(updated_yaml)
202227
with open(path, 'w') as fout:
203228
fout.write(updated_yaml)
204229

sdks/python/apache_beam/yaml/main_test.py

+26
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,32 @@ def test_fix_suite(self):
200200
with open(test_suite) as fin:
201201
self.assertEqual(fin.read(), PASSING_TEST_SUITE)
202202

203+
def test_create_test(self):
204+
with tempfile.TemporaryDirectory() as tmpdir:
205+
test_suite = os.path.join(tmpdir, 'tests.yaml')
206+
with open(test_suite, 'w') as fout:
207+
fout.write('')
208+
209+
main.run_tests([
210+
'--yaml_pipeline',
211+
TEST_PIPELINE.replace('ELEMENT', 'x'),
212+
'--test_suite',
213+
test_suite,
214+
'--create_test'
215+
],
216+
exit=False)
217+
218+
with open(test_suite) as fin:
219+
self.assertEqual(fin.read(), '''
220+
tests:
221+
- mock_outputs: []
222+
expected_inputs:
223+
- name: WriteToText
224+
elements:
225+
- element: x
226+
'''.lstrip())
227+
228+
203229
if __name__ == '__main__':
204230
logging.getLogger().setLevel(logging.INFO)
205231
unittest.main()

sdks/python/apache_beam/yaml/yaml_testing.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -347,12 +347,12 @@ def create_test(
347347

348348
def get_name(transform):
349349
if 'name' in transform:
350-
return transform['name']
350+
return str(transform['name'])
351351
else:
352352
if sum(1 for t in transform_spec['transforms']
353353
if t['type'] == transform['type']) > 1:
354354
raise ValueError('Ambiguous unnamed transform {transform["type"]}')
355-
return transform['type']
355+
return str(transform['type'])
356356

357357
input_transforms = [
358358
t for t in transform_spec['transforms'] if t['type'] != 'Create' and

0 commit comments

Comments
 (0)