Skip to content

Commit 78e38e6

Browse files
author
The TensorFlow Datasets Authors
committed
Handle the argument deserialize_method in mocked data sources.
PiperOrigin-RevId: 861755107
1 parent 7f40717 commit 78e38e6

File tree

2 files changed

+34
-13
lines changed

2 files changed

+34
-13
lines changed

tensorflow_datasets/testing/mocking.py

Lines changed: 25 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -44,12 +44,13 @@
4444
import tree
4545

4646

47-
def _get_fake_data_components(decoders, features):
47+
def _get_fake_data_components(decoders, features, deserialize_method=None):
4848
"""Gets all the components to generate fake data in the tests.
4949
5050
Args:
5151
decoders: The decoders to override, or `None` if no decoding is used.
5252
features: The original features.
53+
deserialize_method: The deserialize method to use.
5354
5455
Returns:
5556
A tuple with the data generator class, the features, the feature specs and
@@ -64,18 +65,27 @@ def _get_fake_data_components(decoders, features):
6465
else:
6566
decoders = decoders # pylint: disable=self-assigning-variable
6667

67-
has_nested_dataset = any(
68-
isinstance(f, features_lib.Dataset) for f in features._flatten(features) # pylint: disable=protected-access
69-
)
70-
if decoders is not None or has_nested_dataset:
71-
# If a decoder is passed, encode/decode the examples.
72-
generator_cls = EncodedRandomFakeGenerator
68+
if deserialize_method == decode.DeserializeMethod.RAW_BYTES:
69+
generator_cls = SerializedRandomFakeGenerator
7370
specs = features.get_serialized_info()
74-
decode_fn = functools.partial(features.decode_example, decoders=decoders)
75-
else:
76-
generator_cls = RandomFakeGenerator
77-
specs = features.get_tensor_info()
7871
decode_fn = lambda ex: ex # identity
72+
else:
73+
has_nested_dataset = any(
74+
isinstance(f, features_lib.Dataset) for f in features._flatten(features) # pylint: disable=protected-access
75+
)
76+
if (
77+
decoders is not None
78+
or has_nested_dataset
79+
or deserialize_method == decode.DeserializeMethod.DESERIALIZE_NO_DECODE
80+
):
81+
# If a decoder is passed, encode/decode the examples.
82+
generator_cls = EncodedRandomFakeGenerator
83+
specs = features.get_serialized_info()
84+
decode_fn = functools.partial(features.decode_example, decoders=decoders)
85+
else:
86+
generator_cls = RandomFakeGenerator
87+
specs = features.get_tensor_info()
88+
decode_fn = lambda ex: ex # identity
7989
return generator_cls, features, specs, decode_fn
8090

8191

@@ -352,7 +362,9 @@ def mock_as_dataset(self, split, decoders=None, read_config=None, **kwargs):
352362

353363
return ds
354364

355-
def mock_as_data_source(self, split, decoders=None, **kwargs):
365+
def mock_as_data_source(
366+
self, split, decoders=None, deserialize_method=None, **kwargs
367+
):
356368
"""Mocks `builder.as_data_source`."""
357369
del kwargs
358370
nonlocal mock_array_record_data_source
@@ -362,7 +374,7 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
362374
split = {s: s for s in self.info.splits}
363375

364376
generator_cls, features, _, _ = _get_fake_data_components(
365-
decoders, self.info.features
377+
decoders, self.info.features, deserialize_method=deserialize_method
366378
)
367379
generator = generator_cls(features, num_examples)
368380

tensorflow_datasets/testing/mocking_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,15 @@ def test_mock_data_source():
372372
)
373373
assert isinstance(data_source[0]['image'], bytes)
374374

375+
# Without deserializing the examples
376+
data_source = tfds.data_source(
377+
'imagenet2012',
378+
split='train',
379+
deserialize_method=tfds.decode.DeserializeMethod.RAW_BYTES,
380+
)
381+
assert len(data_source) == 10
382+
assert isinstance(data_source[0], bytes)
383+
375384

376385
def test_mock_multiple_data_source():
377386
with tfds.testing.mock_data(num_examples=10):

0 commit comments

Comments
 (0)