17
17
18
18
import collections
19
19
import functools
20
+ import uuid
20
21
from typing import Dict
21
22
from typing import List
22
23
from typing import Mapping
28
29
import yaml
29
30
30
31
import apache_beam as beam
32
+ from apache_beam .testing .util import assert_that
33
+ from apache_beam .testing .util import equal_to
34
+ from apache_beam .yaml import yaml_provider
31
35
from apache_beam .yaml import yaml_transform
32
36
from apache_beam .yaml import yaml_utils
33
37
34
38
35
- def run_test (pipeline_spec , test_spec , options = None ):
39
+ def run_test (pipeline_spec , test_spec , options = None , fix_failures = False ):
36
40
if isinstance (pipeline_spec , str ):
37
41
pipeline_spec = yaml .load (pipeline_spec , Loader = yaml_utils .SafeLineLoader )
38
42
39
- transform_spec = inject_test_tranforms (
43
+ transform_spec , recording_ids = inject_test_tranforms (
40
44
yaml_transform .pipeline_as_composite (pipeline_spec ['pipeline' ]),
41
- test_spec )
45
+ test_spec ,
46
+ fix_failures )
42
47
43
48
allowed_sources = set (test_spec .get ('allowed_sources' , []) + ['Create' ])
44
49
for transform in transform_spec ['transforms' ]:
@@ -57,7 +62,20 @@ def run_test(pipeline_spec, test_spec, options=None):
57
62
pipeline_spec .get ('options' , {})))
58
63
59
64
with beam .Pipeline (options = options ) as p :
60
- _ = p | yaml_transform .YamlTransform (transform_spec )
65
+ _ = p | yaml_transform .YamlTransform (
66
+ transform_spec ,
67
+ providers = {'AssertEqualAndRecord' : AssertEqualAndRecord })
68
+
69
+ if fix_failures :
70
+ fixes = {}
71
+ for recording_id in recording_ids :
72
+ if AssertEqualAndRecord .has_recorded_result (recording_id ):
73
+ fixes [recording_id [1 :]] = [
74
+ row ._asdict () if isinstance (row , beam .Row ) else row
75
+ for row in AssertEqualAndRecord .get_recorded_result (recording_id )
76
+ ]
77
+ AssertEqualAndRecord .remove_recorded_result (recording_id )
78
+ return fixes
61
79
62
80
63
81
def validate_test_spec (test_spec ):
@@ -116,7 +134,7 @@ def validate_test_spec(test_spec):
116
134
f'must be a list, got { type (attr_item ["elements" ])} ' )
117
135
118
136
119
- def inject_test_tranforms (spec , test_spec ):
137
+ def inject_test_tranforms (spec , test_spec , fix_failures ):
120
138
validate_test_spec (test_spec )
121
139
# These are idempotent, so it's OK to do them preemptively.
122
140
for phase in [
@@ -140,6 +158,9 @@ def inject_test_tranforms(spec, test_spec):
140
158
for mock_output in test_spec .get ('mock_outputs' , [])
141
159
})
142
160
161
+ recording_id_prefix = str (uuid .uuid4 ())
162
+ recording_ids = []
163
+
143
164
transforms = []
144
165
145
166
@functools .cache
@@ -213,38 +234,94 @@ def create_create(name, elements):
213
234
},
214
235
}
215
236
216
- def create_assertion (name , inputs , elements ):
237
+ def create_assertion (name , inputs , elements , recording_id = None ):
217
238
return {
218
239
'__uuid__' : yaml_utils .SafeLineLoader .create_uuid (),
219
240
'name' : name ,
220
241
'input' : inputs ,
221
- 'type' : 'AssertEqual ' ,
242
+ 'type' : 'AssertEqualAndRecord ' ,
222
243
'config' : {
223
244
'elements' : elements ,
245
+ 'recording_id' : recording_id ,
224
246
},
225
247
}
226
248
227
249
for expected_output in test_spec .get ('expected_outputs' , []):
250
+ if fix_failures :
251
+ recording_id = (
252
+ recording_id_prefix , 'expected_outputs' , expected_output ['name' ])
253
+ recording_ids .append (recording_id )
254
+ else :
255
+ recording_id = None
228
256
require_output (expected_output ['name' ])
229
257
transforms .append (
230
258
create_assertion (
231
259
f'CheckExpectedOutput[{ expected_output ["name" ]} ]' ,
232
260
expected_output ['name' ],
233
- expected_output ['elements' ]))
261
+ expected_output ['elements' ],
262
+ recording_id ))
234
263
235
264
for expected_input in test_spec .get ('expected_inputs' , []):
265
+ if fix_failures :
266
+ recording_id = (
267
+ recording_id_prefix , 'expected_inputs' , expected_input ['name' ])
268
+ recording_ids .append (recording_id )
269
+ else :
270
+ recording_id = None
236
271
transform_id = scope .get_transform_id (expected_input ['name' ])
237
272
transforms .append (
238
273
create_assertion (
239
274
f'CheckExpectedInput[{ expected_input ["name" ]} ]' ,
240
275
create_inputs (transform_id ),
241
- expected_input ['elements' ]))
276
+ expected_input ['elements' ],
277
+ recording_id ))
242
278
243
279
return {
244
280
'__uuid__' : yaml_utils .SafeLineLoader .create_uuid (),
245
281
'type' : 'composite' ,
246
282
'transforms' : transforms ,
247
- }
283
+ }, recording_ids
284
+
285
+
286
+ class AssertEqualAndRecord (beam .PTransform ):
287
+ _recorded_results = {}
288
+
289
+ @classmethod
290
+ def store_recorded_result (cls , recording_id , value ):
291
+ assert recording_id not in cls ._recorded_results
292
+ cls ._recorded_results [recording_id ] = value
293
+
294
+ @classmethod
295
+ def has_recorded_result (cls , recording_id ):
296
+ return recording_id in cls ._recorded_results
297
+
298
+ @classmethod
299
+ def get_recorded_result (cls , recording_id ):
300
+ return cls ._recorded_results [recording_id ]
301
+
302
+ @classmethod
303
+ def remove_recorded_result (cls , recording_id ):
304
+ del cls ._recorded_results [recording_id ]
305
+
306
+ def __init__ (self , elements , recording_id ):
307
+ self ._elements = elements
308
+ self ._recording_id = recording_id
309
+
310
+ def expand (self , pcoll ):
311
+ equal_to_matcher = equal_to (yaml_provider .dicts_to_rows (self ._elements ))
312
+
313
+ def matcher (actual ):
314
+ try :
315
+ equal_to_matcher (actual )
316
+ except Exception :
317
+ if self ._recording_id :
318
+ AssertEqualAndRecord .store_recorded_result (
319
+ tuple (self ._recording_id ), actual )
320
+ else :
321
+ raise
322
+
323
+ return assert_that (
324
+ pcoll | beam .Map (lambda row : beam .Row (** row ._asdict ())), matcher )
248
325
249
326
250
327
K1 = TypeVar ('K1' )
0 commit comments