17
17
18
18
import collections
19
19
import functools
20
+ import random
20
21
import uuid
21
22
from typing import Dict
22
23
from typing import List
@@ -40,8 +41,10 @@ def run_test(pipeline_spec, test_spec, options=None, fix_failures=False):
40
41
if isinstance (pipeline_spec , str ):
41
42
pipeline_spec = yaml .load (pipeline_spec , Loader = yaml_utils .SafeLineLoader )
42
43
44
+ pipeline_spec = _preprocess_for_testing (pipeline_spec )
45
+
43
46
transform_spec , recording_ids = inject_test_tranforms (
44
- yaml_transform . pipeline_as_composite ( pipeline_spec [ 'pipeline' ]) ,
47
+ pipeline_spec ,
45
48
test_spec ,
46
49
fix_failures )
47
50
@@ -71,13 +74,28 @@ def run_test(pipeline_spec, test_spec, options=None, fix_failures=False):
71
74
for recording_id in recording_ids :
72
75
if AssertEqualAndRecord .has_recorded_result (recording_id ):
73
76
fixes [recording_id [1 :]] = [
74
- row . _asdict () if isinstance ( row , beam . Row ) else row
77
+ _try_row_as_dict ( row )
75
78
for row in AssertEqualAndRecord .get_recorded_result (recording_id )
76
79
]
77
80
AssertEqualAndRecord .remove_recorded_result (recording_id )
78
81
return fixes
79
82
80
83
84
+ def _preprocess_for_testing (pipeline_spec ):
85
+ spec = yaml_transform .pipeline_as_composite (pipeline_spec ['pipeline' ])
86
+ # These are idempotent, so it's OK to do them preemptively.
87
+ for phase in [
88
+ yaml_transform .ensure_transforms_have_types ,
89
+ yaml_transform .preprocess_source_sink ,
90
+ yaml_transform .preprocess_chain ,
91
+ yaml_transform .tag_explicit_inputs ,
92
+ yaml_transform .normalize_inputs_outputs ,
93
+ ]:
94
+ spec = yaml_transform .apply_phase (phase , spec )
95
+
96
+ return spec
97
+
98
+
81
99
def validate_test_spec (test_spec ):
82
100
if not isinstance (test_spec , dict ):
83
101
raise TypeError (
@@ -136,16 +154,6 @@ def validate_test_spec(test_spec):
136
154
137
155
def inject_test_tranforms (spec , test_spec , fix_failures ):
138
156
validate_test_spec (test_spec )
139
- # These are idempotent, so it's OK to do them preemptively.
140
- for phase in [
141
- yaml_transform .ensure_transforms_have_types ,
142
- yaml_transform .preprocess_source_sink ,
143
- yaml_transform .preprocess_chain ,
144
- yaml_transform .tag_explicit_inputs ,
145
- yaml_transform .normalize_inputs_outputs ,
146
- ]:
147
- spec = yaml_transform .apply_phase (phase , spec )
148
-
149
157
scope = yaml_transform .LightweightScope (spec ['transforms' ])
150
158
151
159
mocked_inputs_by_id = {
@@ -324,6 +332,131 @@ def matcher(actual):
324
332
pcoll | beam .Map (lambda row : beam .Row (** row ._asdict ())), matcher )
325
333
326
334
335
+ def create_test (
336
+ pipeline_spec , options = None , max_num_inputs = 40 , min_num_outputs = 3 ):
337
+ if isinstance (pipeline_spec , str ):
338
+ pipeline_spec = yaml .load (pipeline_spec , Loader = yaml_utils .SafeLineLoader )
339
+
340
+ transform_spec = _preprocess_for_testing (pipeline_spec )
341
+
342
+ if options is None :
343
+ options = beam .options .pipeline_options .PipelineOptions (
344
+ pickle_library = 'cloudpickle' ,
345
+ ** yaml_transform .SafeLineLoader .strip_metadata (
346
+ pipeline_spec .get ('options' , {})))
347
+
348
+ def get_name (transform ):
349
+ if 'name' in transform :
350
+ return transform ['name' ]
351
+ else :
352
+ if sum (1 for t in transform_spec ['transforms' ]
353
+ if t ['type' ] == transform ['type' ]) > 1 :
354
+ raise ValueError ('Ambiguous unnamed transform {transform["type"]}' )
355
+ return transform ['type' ]
356
+
357
+ input_transforms = [
358
+ t for t in transform_spec ['transforms' ] if t ['type' ] != 'Create' and
359
+ not yaml_transform .empty_if_explicitly_empty (t .get ('input' , []))
360
+ ]
361
+
362
+ mock_outputs = [{
363
+ 'name' : get_name (t ),
364
+ 'elements' : [
365
+ _try_row_as_dict (row ) for row in _first_n (t , options , max_num_inputs )
366
+ ],
367
+ } for t in input_transforms ]
368
+
369
+ output_transforms = [
370
+ t for t in transform_spec ['transforms' ] if t ['type' ] == 'LogForTesting' or
371
+ yaml_transform .empty_if_explicitly_empty (t .get ('output' , [])) or
372
+ t ['type' ].startswith ('Write' )
373
+ ]
374
+
375
+ expected_inputs = [{
376
+ 'name' : get_name (t ),
377
+ 'elements' : [],
378
+ } for t in output_transforms ]
379
+
380
+ if not expected_inputs :
381
+ # TODO: Optionally take this as a parameter.
382
+ raise ValueError ('No output transforms detected.' )
383
+
384
+ num_inputs = min_num_outputs
385
+ while True :
386
+ test_spec = {
387
+ 'mock_outputs' : [{
388
+ 'name' : t ['name' ],
389
+ 'elements' : random .sample (
390
+ t ['elements' ], min (len (t ['elements' ]), num_inputs )),
391
+ } for t in mock_outputs ],
392
+ 'expected_inputs' : expected_inputs ,
393
+ }
394
+ fixes = run_test (pipeline_spec , test_spec , options , fix_failures = True )
395
+ if len (fixes ) < len (output_transforms ):
396
+ actual_output_size = 0
397
+ else :
398
+ actual_output_size = min (len (e ) for e in fixes .values ())
399
+ if actual_output_size >= min_num_outputs :
400
+ break
401
+ elif num_inputs == max_num_inputs :
402
+ break
403
+ else :
404
+ num_inputs = min (2 * num_inputs , max_num_inputs )
405
+
406
+ for expected_input in test_spec ['expected_inputs' ]:
407
+ if ('expected_inputs' , expected_input ['name' ]) in fixes :
408
+ expected_input ['elements' ] = fixes ['expected_inputs' ,
409
+ expected_input ['name' ]]
410
+
411
+ return test_spec
412
+
413
+
414
+ class _DoneException (Exception ):
415
+ pass
416
+
417
+
418
+ class RecordElements (beam .PTransform ):
419
+ _recorded_results = collections .defaultdict (list )
420
+
421
+ def __init__ (self , n ):
422
+ self ._n = n
423
+ self ._id = str (uuid .uuid4 ())
424
+
425
+ def get_and_remove (self ):
426
+ listing = RecordElements ._recorded_results [self ._id ]
427
+ del RecordElements ._recorded_results [self ._id ]
428
+ return listing
429
+
430
+ def expand (self , pcoll ):
431
+ def record (element ):
432
+ listing = RecordElements ._recorded_results [self ._id ]
433
+ if len (listing ) < self ._n :
434
+ listing .append (element )
435
+ else :
436
+ raise _DoneException ()
437
+
438
+ return pcoll | beam .Map (record )
439
+
440
+
441
+ def _first_n (transform_spec , options , n ):
442
+ recorder = RecordElements (n )
443
+ try :
444
+ with beam .Pipeline (options = options ) as p :
445
+ _ = (
446
+ p
447
+ | yaml_transform .YamlTransform (
448
+ transform_spec ,
449
+ providers = {'AssertEqualAndRecord' : AssertEqualAndRecord })
450
+ | recorder )
451
+ except _DoneException :
452
+ pass
453
+ except Exception as exn :
454
+ # Runners don't always raise a faithful exception type.
455
+ if not '_DoneException' in str (exn ):
456
+ raise
457
+ return recorder .get_and_remove ()
458
+
459
+
327
460
K1 = TypeVar ('K1' )
328
461
K2 = TypeVar ('K2' )
329
462
V = TypeVar ('V' )
@@ -337,3 +470,10 @@ def _composite_key_to_nested(
337
470
for (k1 , k2 ), v in d .items ():
338
471
nested [k1 ][k2 ] = v
339
472
return nested
473
+
474
+
475
+ def _try_row_as_dict (row ):
476
+ try :
477
+ return row ._asdict ()
478
+ except AttributeError :
479
+ return row
0 commit comments