4444import tree
4545
4646
47- def _get_fake_data_components (decoders , features ):
47+ def _get_fake_data_components (decoders , features , deserialize_method = None ):
4848 """Gets all the components to generate fake data in the tests.
4949
5050 Args:
5151 decoders: The decoders to override, or `None` if no decoding is used.
5252 features: The original features.
53+ deserialize_method: The deserialize method to use.
5354
5455 Returns:
5556 A tuple with the data generator class, the features, the feature specs and
@@ -64,18 +65,27 @@ def _get_fake_data_components(decoders, features):
6465 else :
6566 decoders = decoders # pylint: disable=self-assigning-variable
6667
67- has_nested_dataset = any (
68- isinstance (f , features_lib .Dataset ) for f in features ._flatten (features ) # pylint: disable=protected-access
69- )
70- if decoders is not None or has_nested_dataset :
71- # If a decoder is passed, encode/decode the examples.
72- generator_cls = EncodedRandomFakeGenerator
68+ if deserialize_method == decode .DeserializeMethod .RAW_BYTES :
69+ generator_cls = SerializedRandomFakeGenerator
7370 specs = features .get_serialized_info ()
74- decode_fn = functools .partial (features .decode_example , decoders = decoders )
75- else :
76- generator_cls = RandomFakeGenerator
77- specs = features .get_tensor_info ()
7871 decode_fn = lambda ex : ex # identity
72+ else :
73+ has_nested_dataset = any (
74+ isinstance (f , features_lib .Dataset ) for f in features ._flatten (features ) # pylint: disable=protected-access
75+ )
76+ if (
77+ decoders is not None
78+ or has_nested_dataset
79+ or deserialize_method == decode .DeserializeMethod .DESERIALIZE_NO_DECODE
80+ ):
81+ # If a decoder is passed, encode/decode the examples.
82+ generator_cls = EncodedRandomFakeGenerator
83+ specs = features .get_serialized_info ()
84+ decode_fn = functools .partial (features .decode_example , decoders = decoders )
85+ else :
86+ generator_cls = RandomFakeGenerator
87+ specs = features .get_tensor_info ()
88+ decode_fn = lambda ex : ex # identity
7989 return generator_cls , features , specs , decode_fn
8090
8191
@@ -352,7 +362,9 @@ def mock_as_dataset(self, split, decoders=None, read_config=None, **kwargs):
352362
353363 return ds
354364
355- def mock_as_data_source (self , split , decoders = None , ** kwargs ):
365+ def mock_as_data_source (
366+ self , split , decoders = None , deserialize_method = None , ** kwargs
367+ ):
356368 """Mocks `builder.as_data_source`."""
357369 del kwargs
358370 nonlocal mock_array_record_data_source
@@ -362,7 +374,7 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
362374 split = {s : s for s in self .info .splits }
363375
364376 generator_cls , features , _ , _ = _get_fake_data_components (
365- decoders , self .info .features
377+ decoders , self .info .features , deserialize_method = deserialize_method
366378 )
367379 generator = generator_cls (features , num_examples )
368380
0 commit comments