Skip to content

Latest commit

 

History

History
152 lines (132 loc) · 3.51 KB

File metadata and controls

152 lines (132 loc) · 3.51 KB

TFSimilarity.samplers.TFRecordDatasetSampler

Create a TFRecordDataset

TFSimilarity.samplers.TFRecordDatasetSampler(
    shard_path: str,
    deserialization_fn: Callable,
    example_per_class: int = 2,
    batch_size: int = 32,
    shards_per_cycle: int = None,
    compression: Optional[str] = None,
    parallelism: int = tf.data.AUTOTUNE,
    file_parallelism: int = 1,
    prefetch_size: Optional[int] = None,
    shard_suffix: str = *.tfrec
) -> tf.data.Dataset

based sampler

This sampler should be used when using a TFDataset or have a large dataset that needs to be stored on file.

WARNING: This samplers assume that classes examples are contigious, at least enough that you can get example_per_class numbers of them consecutively. This requirements is needed to make the sampling efficient and makes dataset constuctionn oftentime easier as there is no need to worry about shuffling. Somewhat contigious means its fine to have the same class in multiples shards as long as the examples for the same classes are contigious in that shard.

Overall the way we make the sampling process is by using the

  • tf.dataset.interleaves in a non orthodox way: we use its block_length to control the number of example per class and rely on the parallelize & non_deterministic version of the API to do the sampling efficiently for us. Relying on pure tf.data ops also ensure good compatibility with distribution strategy.

Args

shard_path Directory where the shards are stored.
deserialization_fn Function used to deserialize the tfRecord and construct a valid example.
example_per_class Number of example per class in each batch. Defaults to 2.
batch_size How many examples in each batch. The number of class in the batch will be batch_size // example_per_class. Defaults to 32.
shards_per_cycle How many shards to use concurrently per cycle. Default is None which is all of them. Can cause segv if too many shards.
compression Which compression was used when creating the dataset. None, "ZLIB", or "GZIP" as specified in [TFRecordDataset documentation](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset) Defaults to None.
parallelism How many parallel calls to do. If not set, will let TensorFlow decide by using tf.data.AUTOTUNE (-1).
file_parallelism How many parallel shards to read increase number if IO bound. Defaults to 1.
prefetch_size How many batch to precache. Defaults to 10.
shard_suffix Glog pattern used to collect the shard files list. Defaults to "*.tfrec".

Returns

A TF.data.dataset ready to be consumed by the model.