Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 25 additions & 13 deletions tensorflow_datasets/testing/mocking.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,12 +44,13 @@
import tree


def _get_fake_data_components(decoders, features):
def _get_fake_data_components(decoders, features, deserialize_method=None):
"""Gets all the components to generate fake data in the tests.

Args:
decoders: The decoders to override, or `None` if no decoding is used.
features: The original features.
deserialize_method: The deserialize method to use.

Returns:
A tuple with the data generator class, the features, the feature specs and
Expand All @@ -64,18 +65,27 @@ def _get_fake_data_components(decoders, features):
else:
decoders = decoders # pylint: disable=self-assigning-variable

has_nested_dataset = any(
isinstance(f, features_lib.Dataset) for f in features._flatten(features) # pylint: disable=protected-access
)
if decoders is not None or has_nested_dataset:
# If a decoder is passed, encode/decode the examples.
generator_cls = EncodedRandomFakeGenerator
if deserialize_method == decode.DeserializeMethod.RAW_BYTES:
generator_cls = SerializedRandomFakeGenerator
specs = features.get_serialized_info()
decode_fn = functools.partial(features.decode_example, decoders=decoders)
else:
generator_cls = RandomFakeGenerator
specs = features.get_tensor_info()
decode_fn = lambda ex: ex # identity
else:
has_nested_dataset = any(
isinstance(f, features_lib.Dataset) for f in features._flatten(features) # pylint: disable=protected-access
)
if (
decoders is not None
or has_nested_dataset
or deserialize_method == decode.DeserializeMethod.DESERIALIZE_NO_DECODE
):
# If a decoder is passed, encode/decode the examples.
generator_cls = EncodedRandomFakeGenerator
specs = features.get_serialized_info()
decode_fn = functools.partial(features.decode_example, decoders=decoders)
else:
generator_cls = RandomFakeGenerator
specs = features.get_tensor_info()
decode_fn = lambda ex: ex # identity
return generator_cls, features, specs, decode_fn


Expand Down Expand Up @@ -352,7 +362,9 @@ def mock_as_dataset(self, split, decoders=None, read_config=None, **kwargs):

return ds

def mock_as_data_source(self, split, decoders=None, **kwargs):
def mock_as_data_source(
self, split, decoders=None, deserialize_method=None, **kwargs
):
"""Mocks `builder.as_data_source`."""
del kwargs
nonlocal mock_array_record_data_source
Expand All @@ -362,7 +374,7 @@ def mock_as_data_source(self, split, decoders=None, **kwargs):
split = {s: s for s in self.info.splits}

generator_cls, features, _, _ = _get_fake_data_components(
decoders, self.info.features
decoders, self.info.features, deserialize_method=deserialize_method
)
generator = generator_cls(features, num_examples)

Expand Down
9 changes: 9 additions & 0 deletions tensorflow_datasets/testing/mocking_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,15 @@ def test_mock_data_source():
)
assert isinstance(data_source[0]['image'], bytes)

# Without deserializing the examples
data_source = tfds.data_source(
'imagenet2012',
split='train',
deserialize_method=tfds.decode.DeserializeMethod.RAW_BYTES,
)
assert len(data_source) == 10
assert isinstance(data_source[0], bytes)


def test_mock_multiple_data_source():
with tfds.testing.mock_data(num_examples=10):
Expand Down
Loading