Skip to content

Commit 5b6f993

Browse files
committed
Add logic to automatically create a regression tests.
1 parent 902f75b commit 5b6f993

File tree

2 files changed

+183
-12
lines changed

2 files changed

+183
-12
lines changed

Diff for: sdks/python/apache_beam/yaml/yaml_testing.py

+152-12
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import collections
1919
import functools
20+
import random
2021
import uuid
2122
from typing import Dict
2223
from typing import List
@@ -40,8 +41,10 @@ def run_test(pipeline_spec, test_spec, options=None, fix_failures=False):
4041
if isinstance(pipeline_spec, str):
4142
pipeline_spec = yaml.load(pipeline_spec, Loader=yaml_utils.SafeLineLoader)
4243

44+
pipeline_spec = _preprocess_for_testing(pipeline_spec)
45+
4346
transform_spec, recording_ids = inject_test_tranforms(
44-
yaml_transform.pipeline_as_composite(pipeline_spec['pipeline']),
47+
pipeline_spec,
4548
test_spec,
4649
fix_failures)
4750

@@ -71,13 +74,28 @@ def run_test(pipeline_spec, test_spec, options=None, fix_failures=False):
7174
for recording_id in recording_ids:
7275
if AssertEqualAndRecord.has_recorded_result(recording_id):
7376
fixes[recording_id[1:]] = [
74-
row._asdict() if isinstance(row, beam.Row) else row
77+
_try_row_as_dict(row)
7578
for row in AssertEqualAndRecord.get_recorded_result(recording_id)
7679
]
7780
AssertEqualAndRecord.remove_recorded_result(recording_id)
7881
return fixes
7982

8083

84+
def _preprocess_for_testing(pipeline_spec):
85+
spec = yaml_transform.pipeline_as_composite(pipeline_spec['pipeline'])
86+
# These are idempotent, so it's OK to do them preemptively.
87+
for phase in [
88+
yaml_transform.ensure_transforms_have_types,
89+
yaml_transform.preprocess_source_sink,
90+
yaml_transform.preprocess_chain,
91+
yaml_transform.tag_explicit_inputs,
92+
yaml_transform.normalize_inputs_outputs,
93+
]:
94+
spec = yaml_transform.apply_phase(phase, spec)
95+
96+
return spec
97+
98+
8199
def validate_test_spec(test_spec):
82100
if not isinstance(test_spec, dict):
83101
raise TypeError(
@@ -136,16 +154,6 @@ def validate_test_spec(test_spec):
136154

137155
def inject_test_tranforms(spec, test_spec, fix_failures):
138156
validate_test_spec(test_spec)
139-
# These are idempotent, so it's OK to do them preemptively.
140-
for phase in [
141-
yaml_transform.ensure_transforms_have_types,
142-
yaml_transform.preprocess_source_sink,
143-
yaml_transform.preprocess_chain,
144-
yaml_transform.tag_explicit_inputs,
145-
yaml_transform.normalize_inputs_outputs,
146-
]:
147-
spec = yaml_transform.apply_phase(phase, spec)
148-
149157
scope = yaml_transform.LightweightScope(spec['transforms'])
150158

151159
mocked_inputs_by_id = {
@@ -324,6 +332,131 @@ def matcher(actual):
324332
pcoll | beam.Map(lambda row: beam.Row(**row._asdict())), matcher)
325333

326334

335+
def create_test(
336+
pipeline_spec, options=None, max_num_inputs=40, min_num_outputs=3):
337+
if isinstance(pipeline_spec, str):
338+
pipeline_spec = yaml.load(pipeline_spec, Loader=yaml_utils.SafeLineLoader)
339+
340+
transform_spec = _preprocess_for_testing(pipeline_spec)
341+
342+
if options is None:
343+
options = beam.options.pipeline_options.PipelineOptions(
344+
pickle_library='cloudpickle',
345+
**yaml_transform.SafeLineLoader.strip_metadata(
346+
pipeline_spec.get('options', {})))
347+
348+
def get_name(transform):
349+
if 'name' in transform:
350+
return transform['name']
351+
else:
352+
if sum(1 for t in transform_spec['transforms']
353+
if t['type'] == transform['type']) > 1:
354+
raise ValueError('Ambiguous unnamed transform {transform["type"]}')
355+
return transform['type']
356+
357+
input_transforms = [
358+
t for t in transform_spec['transforms'] if t['type'] != 'Create' and
359+
not yaml_transform.empty_if_explicitly_empty(t.get('input', []))
360+
]
361+
362+
mock_outputs = [{
363+
'name': get_name(t),
364+
'elements': [
365+
_try_row_as_dict(row) for row in _first_n(t, options, max_num_inputs)
366+
],
367+
} for t in input_transforms]
368+
369+
output_transforms = [
370+
t for t in transform_spec['transforms'] if t['type'] == 'LogForTesting' or
371+
yaml_transform.empty_if_explicitly_empty(t.get('output', [])) or
372+
t['type'].startswith('Write')
373+
]
374+
375+
expected_inputs = [{
376+
'name': get_name(t),
377+
'elements': [],
378+
} for t in output_transforms]
379+
380+
if not expected_inputs:
381+
# TODO: Optionally take this as a parameter.
382+
raise ValueError('No output transforms detected.')
383+
384+
num_inputs = min_num_outputs
385+
while True:
386+
test_spec = {
387+
'mock_outputs': [{
388+
'name': t['name'],
389+
'elements': random.sample(
390+
t['elements'], min(len(t['elements']), num_inputs)),
391+
} for t in mock_outputs],
392+
'expected_inputs': expected_inputs,
393+
}
394+
fixes = run_test(pipeline_spec, test_spec, options, fix_failures=True)
395+
if len(fixes) < len(output_transforms):
396+
actual_output_size = 0
397+
else:
398+
actual_output_size = min(len(e) for e in fixes.values())
399+
if actual_output_size >= min_num_outputs:
400+
break
401+
elif num_inputs == max_num_inputs:
402+
break
403+
else:
404+
num_inputs = min(2 * num_inputs, max_num_inputs)
405+
406+
for expected_input in test_spec['expected_inputs']:
407+
if ('expected_inputs', expected_input['name']) in fixes:
408+
expected_input['elements'] = fixes['expected_inputs',
409+
expected_input['name']]
410+
411+
return test_spec
412+
413+
414+
class _DoneException(Exception):
415+
pass
416+
417+
418+
class RecordElements(beam.PTransform):
419+
_recorded_results = collections.defaultdict(list)
420+
421+
def __init__(self, n):
422+
self._n = n
423+
self._id = str(uuid.uuid4())
424+
425+
def get_and_remove(self):
426+
listing = RecordElements._recorded_results[self._id]
427+
del RecordElements._recorded_results[self._id]
428+
return listing
429+
430+
def expand(self, pcoll):
431+
def record(element):
432+
listing = RecordElements._recorded_results[self._id]
433+
if len(listing) < self._n:
434+
listing.append(element)
435+
else:
436+
raise _DoneException()
437+
438+
return pcoll | beam.Map(record)
439+
440+
441+
def _first_n(transform_spec, options, n):
442+
recorder = RecordElements(n)
443+
try:
444+
with beam.Pipeline(options=options) as p:
445+
_ = (
446+
p
447+
| yaml_transform.YamlTransform(
448+
transform_spec,
449+
providers={'AssertEqualAndRecord': AssertEqualAndRecord})
450+
| recorder)
451+
except _DoneException:
452+
pass
453+
except Exception as exn:
454+
# Runners don't always raise a faithful exception type.
455+
if not '_DoneException' in str(exn):
456+
raise
457+
return recorder.get_and_remove()
458+
459+
327460
K1 = TypeVar('K1')
328461
K2 = TypeVar('K2')
329462
V = TypeVar('V')
@@ -337,3 +470,10 @@ def _composite_key_to_nested(
337470
for (k1, k2), v in d.items():
338471
nested[k1][k2] = v
339472
return nested
473+
474+
475+
def _try_row_as_dict(row):
476+
try:
477+
return row._asdict()
478+
except AttributeError:
479+
return row

Diff for: sdks/python/apache_beam/yaml/yaml_testing_test.py

+31
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616
#
1717

1818
import logging
19+
import os
20+
import tempfile
1921
import unittest
2022

2123
from apache_beam.yaml import yaml_testing
@@ -192,6 +194,35 @@ def test_fixes(self):
192194
dict(element=3, square=9),
193195
]})
194196

197+
def test_create(self):
198+
with tempfile.TemporaryDirectory() as tmpdir:
199+
# input_path = os.path.join(tmpdir, 'input.csv')
200+
input_path = os.path.join('.', 'input.csv')
201+
with open(input_path, 'w') as fout:
202+
fout.write('a,b,c\n')
203+
for ix in range(1000):
204+
fout.write(f'{ix % 5},{ix},Ccc\n')
205+
pipeline = f'''
206+
pipeline:
207+
type: chain
208+
transforms:
209+
- type: ReadFromCsv
210+
config:
211+
path: {input_path}
212+
- type: Filter
213+
config:
214+
language: python
215+
keep: a == 1
216+
- type: WriteToSink
217+
'''
218+
test_spec = yaml_testing.create_test(
219+
pipeline, max_num_inputs=100, min_num_outputs=5)
220+
221+
self.assertEqual(len(test_spec['mock_outputs']), 1)
222+
self.assertEqual(len(test_spec['expected_inputs']), 1)
223+
self.assertGreaterEqual(len(test_spec['expected_inputs'][0]['elements']), 5)
224+
yaml_testing.run_test(pipeline, test_spec)
225+
195226

196227
if __name__ == '__main__':
197228
logging.getLogger().setLevel(logging.INFO)

0 commit comments

Comments
 (0)