Skip to content

Commit 902f75b

Browse files
committed
Add --fix-tests option to more easily generate YAML tests.
1 parent 5cffe5b commit 902f75b

File tree

4 files changed

+191
-28
lines changed

4 files changed

+191
-28
lines changed

sdks/python/apache_beam/yaml/main.py

+52-12
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
from apache_beam.typehints.schemas import MillisInstant
3131
from apache_beam.yaml import yaml_testing
3232
from apache_beam.yaml import yaml_transform
33+
from apache_beam.yaml import yaml_utils
3334

3435

3536
def _preparse_jinja_flags(argv):
@@ -98,6 +99,10 @@ def _parse_arguments(argv):
9899
action=argparse.BooleanOptionalAction,
99100
help='Run the tests associated with the given pipeline, rather than the '
100101
'pipeline itself.')
102+
parser.add_argument(
103+
'--fix_tests',
104+
action=argparse.BooleanOptionalAction,
105+
help='Update failing test expectations to match the actual ouput.')
101106
parser.add_argument(
102107
'--test_suite',
103108
help='Run the given tests against the given pipeline, rather than the '
@@ -149,24 +154,54 @@ def run_tests(argv=None, exit=True):
149154
pipeline_spec = yaml.load(pipeline_yaml, Loader=yaml_transform.SafeLineLoader)
150155
options = _build_pipeline_options(pipeline_spec, pipeline_args)
151156

152-
test_specs = pipeline_spec.get('tests', [])
153-
if not isinstance(test_specs, list):
154-
raise TypeError('tests attribute must be a list of test specifications')
155157
if known_args.test_suite:
156158
with open(known_args.test_suite) as fin:
157-
more_test_specs = yaml.load(fin, Loader=yaml_transform.SafeLineLoader)
158-
if 'tests' not in more_test_specs or not isinstance(
159-
more_test_specs['tests'], list):
159+
test_suite = yaml.load(fin, Loader=yaml_transform.SafeLineLoader)
160+
if 'tests' not in test_suite or not isinstance(test_suite['tests'], list):
161+
raise TypeError('tests attribute must be a list of test specifications')
162+
test_specs = test_suite['tests']
163+
else:
164+
test_specs = pipeline_spec.get('tests', [])
165+
if not isinstance(test_specs, list):
160166
raise TypeError('tests attribute must be a list of test specifications')
161-
test_specs += more_test_specs['tests']
162167
if not test_specs:
163168
raise RuntimeError('No tests found.')
164169

165170
with _fix_xlang_instant_coding():
166-
suite = unittest.TestSuite()
167-
for test_spec in test_specs:
168-
suite.addTest(_YamlTestCase(pipeline_spec, test_spec, options))
171+
tests = [
172+
_YamlTestCase(pipeline_spec, test_spec, options, known_args.fix_tests)
173+
for test_spec in test_specs
174+
]
175+
suite = unittest.TestSuite(tests)
169176
result = unittest.TextTestRunner().run(suite)
177+
178+
if known_args.fix_tests:
179+
if known_args.test_suite:
180+
path = known_args.test_suite
181+
elif known_args.yaml_pipeline_file:
182+
path = known_args.yaml_pipeline_file
183+
else:
184+
raise RuntimeError('Test fixing only supported for file-backed tests.')
185+
with open(path) as fin:
186+
original_yaml = fin.read()
187+
if path == known_args.yaml_pipeline_file and pipeline_yaml == content:
188+
raise RuntimeError('In-file test fixing not yet supported for templated pipelines.')
189+
updated_spec = yaml.load(original_yaml, Loader=yaml.SafeLoader)
190+
191+
for ix, test in enumerate(tests):
192+
if test.fixes:
193+
test_spec = yaml_transform.SafeLineLoader.strip_metadata(test.spec())
194+
assert test_spec == updated_spec['tests'][ix]
195+
for (loc, name), values in test.fixes.items():
196+
for expectation in updated_spec['tests'][ix][loc]:
197+
if expectation['name'] == name:
198+
expectation['elements'] = sorted(values, key=json.dumps)
199+
break
200+
201+
updated_yaml = yaml_utils.patch_yaml(original_yaml, updated_spec)
202+
with open(path, 'w') as fout:
203+
fout.write(updated_yaml)
204+
170205
if exit:
171206
# emulates unittest.main()
172207
sys.exit(0 if result.wasSuccessful() else 1)
@@ -233,14 +268,16 @@ def constructor(root):
233268

234269

235270
class _YamlTestCase(unittest.TestCase):
236-
def __init__(self, pipeline_spec, test_spec, options):
271+
def __init__(self, pipeline_spec, test_spec, options, fix_tests):
237272
super().__init__()
238273
self._pipeline_spec = pipeline_spec
239274
self._test_spec = test_spec
240275
self._options = options
276+
self._fix_tests = fix_tests
241277

242278
def runTest(self):
243-
yaml_testing.run_test(self._pipeline_spec, self._test_spec, self._options)
279+
self.fixes = yaml_testing.run_test(
280+
self._pipeline_spec, self._test_spec, self._options, self._fix_tests)
244281

245282
def id(self):
246283
return (
@@ -250,6 +287,9 @@ def id(self):
250287
def __str__(self):
251288
return self.id()
252289

290+
def spec(self):
291+
return self._test_spec
292+
253293

254294
if __name__ == '__main__':
255295
import logging

sdks/python/apache_beam/yaml/main_test.py

+23-6
Original file line numberDiff line numberDiff line change
@@ -54,28 +54,28 @@
5454

5555
PASSING_TEST_SUITE = '''
5656
tests:
57-
- name: ExternalTest
57+
- name: ExternalTest # comment
5858
mock_outputs:
5959
- name: Create
6060
elements: ['a', 'b', 'c']
6161
expected_inputs:
6262
- name: WriteToText
6363
elements:
64-
- {element: a}
65-
- {element: b}
66-
- {element: c}
64+
- element: a
65+
- element: b
66+
- element: c
6767
'''
6868

6969
FAILING_TEST_SUITE = '''
7070
tests:
71-
- name: ExternalTest
71+
- name: ExternalTest # comment
7272
mock_outputs:
7373
- name: Create
7474
elements: ['a', 'b', 'c']
7575
expected_inputs:
7676
- name: WriteToText
7777
elements:
78-
- {element: x}
78+
- element: x
7979
'''
8080

8181

@@ -182,6 +182,23 @@ def test_external_test_specs(self):
182182
],
183183
exit=False)
184184

185+
def test_fix_suite(self):
186+
with tempfile.TemporaryDirectory() as tmpdir:
187+
test_suite = os.path.join(tmpdir, 'tests.yaml')
188+
with open(test_suite, 'w') as fout:
189+
fout.write(FAILING_TEST_SUITE)
190+
191+
main.run_tests([
192+
'--yaml_pipeline',
193+
TEST_PIPELINE,
194+
'--test_suite',
195+
test_suite,
196+
'--fix_tests'
197+
],
198+
exit=False)
199+
200+
with open(test_suite) as fin:
201+
self.assertEqual(fin.read(), PASSING_TEST_SUITE)
185202

186203
if __name__ == '__main__':
187204
logging.getLogger().setLevel(logging.INFO)

sdks/python/apache_beam/yaml/yaml_testing.py

+87-10
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import collections
1919
import functools
20+
import uuid
2021
from typing import Dict
2122
from typing import List
2223
from typing import Mapping
@@ -28,17 +29,21 @@
2829
import yaml
2930

3031
import apache_beam as beam
32+
from apache_beam.testing.util import assert_that
33+
from apache_beam.testing.util import equal_to
34+
from apache_beam.yaml import yaml_provider
3135
from apache_beam.yaml import yaml_transform
3236
from apache_beam.yaml import yaml_utils
3337

3438

35-
def run_test(pipeline_spec, test_spec, options=None):
39+
def run_test(pipeline_spec, test_spec, options=None, fix_failures=False):
3640
if isinstance(pipeline_spec, str):
3741
pipeline_spec = yaml.load(pipeline_spec, Loader=yaml_utils.SafeLineLoader)
3842

39-
transform_spec = inject_test_tranforms(
43+
transform_spec, recording_ids = inject_test_tranforms(
4044
yaml_transform.pipeline_as_composite(pipeline_spec['pipeline']),
41-
test_spec)
45+
test_spec,
46+
fix_failures)
4247

4348
allowed_sources = set(test_spec.get('allowed_sources', []) + ['Create'])
4449
for transform in transform_spec['transforms']:
@@ -57,7 +62,20 @@ def run_test(pipeline_spec, test_spec, options=None):
5762
pipeline_spec.get('options', {})))
5863

5964
with beam.Pipeline(options=options) as p:
60-
_ = p | yaml_transform.YamlTransform(transform_spec)
65+
_ = p | yaml_transform.YamlTransform(
66+
transform_spec,
67+
providers={'AssertEqualAndRecord': AssertEqualAndRecord})
68+
69+
if fix_failures:
70+
fixes = {}
71+
for recording_id in recording_ids:
72+
if AssertEqualAndRecord.has_recorded_result(recording_id):
73+
fixes[recording_id[1:]] = [
74+
row._asdict() if isinstance(row, beam.Row) else row
75+
for row in AssertEqualAndRecord.get_recorded_result(recording_id)
76+
]
77+
AssertEqualAndRecord.remove_recorded_result(recording_id)
78+
return fixes
6179

6280

6381
def validate_test_spec(test_spec):
@@ -116,7 +134,7 @@ def validate_test_spec(test_spec):
116134
f'must be a list, got {type(attr_item["elements"])}')
117135

118136

119-
def inject_test_tranforms(spec, test_spec):
137+
def inject_test_tranforms(spec, test_spec, fix_failures):
120138
validate_test_spec(test_spec)
121139
# These are idempotent, so it's OK to do them preemptively.
122140
for phase in [
@@ -140,6 +158,9 @@ def inject_test_tranforms(spec, test_spec):
140158
for mock_output in test_spec.get('mock_outputs', [])
141159
})
142160

161+
recording_id_prefix = str(uuid.uuid4())
162+
recording_ids = []
163+
143164
transforms = []
144165

145166
@functools.cache
@@ -213,38 +234,94 @@ def create_create(name, elements):
213234
},
214235
}
215236

216-
def create_assertion(name, inputs, elements):
237+
def create_assertion(name, inputs, elements, recording_id=None):
217238
return {
218239
'__uuid__': yaml_utils.SafeLineLoader.create_uuid(),
219240
'name': name,
220241
'input': inputs,
221-
'type': 'AssertEqual',
242+
'type': 'AssertEqualAndRecord',
222243
'config': {
223244
'elements': elements,
245+
'recording_id': recording_id,
224246
},
225247
}
226248

227249
for expected_output in test_spec.get('expected_outputs', []):
250+
if fix_failures:
251+
recording_id = (
252+
recording_id_prefix, 'expected_outputs', expected_output['name'])
253+
recording_ids.append(recording_id)
254+
else:
255+
recording_id = None
228256
require_output(expected_output['name'])
229257
transforms.append(
230258
create_assertion(
231259
f'CheckExpectedOutput[{expected_output["name"]}]',
232260
expected_output['name'],
233-
expected_output['elements']))
261+
expected_output['elements'],
262+
recording_id))
234263

235264
for expected_input in test_spec.get('expected_inputs', []):
265+
if fix_failures:
266+
recording_id = (
267+
recording_id_prefix, 'expected_inputs', expected_input['name'])
268+
recording_ids.append(recording_id)
269+
else:
270+
recording_id = None
236271
transform_id = scope.get_transform_id(expected_input['name'])
237272
transforms.append(
238273
create_assertion(
239274
f'CheckExpectedInput[{expected_input["name"]}]',
240275
create_inputs(transform_id),
241-
expected_input['elements']))
276+
expected_input['elements'],
277+
recording_id))
242278

243279
return {
244280
'__uuid__': yaml_utils.SafeLineLoader.create_uuid(),
245281
'type': 'composite',
246282
'transforms': transforms,
247-
}
283+
}, recording_ids
284+
285+
286+
class AssertEqualAndRecord(beam.PTransform):
287+
_recorded_results = {}
288+
289+
@classmethod
290+
def store_recorded_result(cls, recording_id, value):
291+
assert recording_id not in cls._recorded_results
292+
cls._recorded_results[recording_id] = value
293+
294+
@classmethod
295+
def has_recorded_result(cls, recording_id):
296+
return recording_id in cls._recorded_results
297+
298+
@classmethod
299+
def get_recorded_result(cls, recording_id):
300+
return cls._recorded_results[recording_id]
301+
302+
@classmethod
303+
def remove_recorded_result(cls, recording_id):
304+
del cls._recorded_results[recording_id]
305+
306+
def __init__(self, elements, recording_id):
307+
self._elements = elements
308+
self._recording_id = recording_id
309+
310+
def expand(self, pcoll):
311+
equal_to_matcher = equal_to(yaml_provider.dicts_to_rows(self._elements))
312+
313+
def matcher(actual):
314+
try:
315+
equal_to_matcher(actual)
316+
except Exception:
317+
if self._recording_id:
318+
AssertEqualAndRecord.store_recorded_result(
319+
tuple(self._recording_id), actual)
320+
else:
321+
raise
322+
323+
return assert_that(
324+
pcoll | beam.Map(lambda row: beam.Row(**row._asdict())), matcher)
248325

249326

250327
K1 = TypeVar('K1')

sdks/python/apache_beam/yaml/yaml_testing_test.py

+29
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,35 @@ def test_unmocked_inputs(self):
163163
}]
164164
})
165165

166+
def test_fixes(self):
167+
fixes = yaml_testing.run_test(
168+
SIMPLE_PIPELINE,
169+
{
170+
'mock_outputs': [{
171+
'name': 'MyRead',
172+
'elements': [1, 2, 3],
173+
}],
174+
'expected_inputs': [{
175+
'name': 'ToBeExcluded',
176+
'elements': [
177+
{
178+
'element': 1, 'square': 1
179+
},
180+
{
181+
'element': 2, 'square': 4
182+
},
183+
]
184+
}]
185+
},
186+
fix_failures=True)
187+
self.assertEqual(
188+
fixes,
189+
{('expected_inputs', 'ToBeExcluded'): [
190+
dict(element=1, square=1),
191+
dict(element=2, square=4),
192+
dict(element=3, square=9),
193+
]})
194+
166195

167196
if __name__ == '__main__':
168197
logging.getLogger().setLevel(logging.INFO)

0 commit comments

Comments
 (0)