Skip to content

Latest commit

 

History

History
162 lines (142 loc) · 3.83 KB

TFRecordDatasetSampler.md

File metadata and controls

162 lines (142 loc) · 3.83 KB

TFSimilarity.samplers.TFRecordDatasetSampler

Create a TFRecordDataset based sampler.

TFSimilarity.samplers.TFRecordDatasetSampler(
    shard_path: str,
    deserialization_fn: Callable,
    example_per_class: int = 2,
    batch_size: int = 32,
    shards_per_cycle: int = None,
    compression: Optional[str] = None,
    parallelism: int = tf.data.AUTOTUNE,
    async_cycle: bool = False,
    prefetch_size: Optional[int] = None,
    shard_suffix: str = *.tfrec,
    num_repeat: int = -1
) -> tf.data.Dataset

This sampler should be used when using a TFDataset or have a large dataset that needs to be stored on file.

WARNING: This samplers assume that classes examples are contigious, at least enough that you can get example_per_class numbers of them consecutively. This requirements is needed to make the sampling efficient and makes dataset constuctionn oftentime easier as there is no need to worry about shuffling. Somewhat contigious means its fine to have the same class in multiples shards as long as the examples for the same classes are contigious in that shard.

Overall the way we make the sampling process is by using the

  • tf.dataset.interleaves in a non orthodox way: we use its block_length to control the number of example per class and rely on the parallelize & non_deterministic version of the API to do the sampling efficiently for us. Relying on pure tf.data ops also ensure good compatibility with distribution strategy.

Args

shard_path Directory where the shards are stored.
deserialization_fn Function used to deserialize the tfRecord and construct a valid example.
example_per_class Number of example per class in each batch. Defaults to 2.
batch_size How many examples in each batch. The number of class in the batch will be batch_size // example_per_class. Defaults to 32.
shards_per_cycle How many shards to use concurrently per cycle. Default is None which is all of them. Can cause segv if too many shards.
compression Which compression was used when creating the dataset. None, "ZLIB", or "GZIP" as specified in - [TFRecordDataset documentation](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset) Defaults to None.
parallelism How many parallel calls to do. If not set, will let TensorFlow decide by using tf.data.AUTOTUNE (-1).
async_cycle If True, create a threadpool of size `batch_size // example_per_class` and fetch inputs from the cycle shards asynchronously; however, in practice, the default single thread setting is faster. We only recommend setting this to True if it is absolutely necessary.
prefetch_size How many batch to precache. Defaults to 10.
shard_suffix Glog pattern used to collect the shard files list. Defaults to "*.tfrec".
num_repeat How many times to repeat the dataset. Defaults to -1 (infinite).

Returns

A TF.data.dataset ready to be consumed by the model.