Create a TFRecordDataset based sampler.
TFSimilarity.samplers.TFRecordDatasetSampler(
shard_path: str,
deserialization_fn: Callable,
example_per_class: int = 2,
batch_size: int = 32,
shards_per_cycle: int = None,
compression: Optional[str] = None,
parallelism: int = tf.data.AUTOTUNE,
async_cycle: bool = False,
prefetch_size: Optional[int] = None,
shard_suffix: str = *.tfrec,
num_repeat: int = -1
) -> tf.data.Dataset
This sampler should be used when using a TFDataset or have a large dataset that needs to be stored on file.
WARNING: This samplers assume that classes examples are contigious, at least enough that you can get example_per_class numbers of them consecutively. This requirements is needed to make the sampling efficient and makes dataset constuctionn oftentime easier as there is no need to worry about shuffling. Somewhat contigious means its fine to have the same class in multiples shards as long as the examples for the same classes are contigious in that shard.
Overall the way we make the sampling process is by using the
- tf.dataset.interleaves in a non orthodox way: we use its block_length to control the number of example per class and rely on the parallelize & non_deterministic version of the API to do the sampling efficiently for us. Relying on pure tf.data ops also ensure good compatibility with distribution strategy.
shard_path | Directory where the shards are stored. |
deserialization_fn | Function used to deserialize the tfRecord and construct a valid example. |
example_per_class | Number of example per class in each batch. Defaults to 2. |
batch_size | How many examples in each batch. The number of class in the batch will be batch_size // example_per_class. Defaults to 32. |
shards_per_cycle | How many shards to use concurrently per cycle. Default is None which is all of them. Can cause segv if too many shards. |
compression | Which compression was used when creating the dataset. None, "ZLIB", or "GZIP" as specified in - [TFRecordDataset documentation](https://www.tensorflow.org/api_docs/python/tf/data/TFRecordDataset) Defaults to None. |
parallelism | How many parallel calls to do. If not set, will let TensorFlow decide by using tf.data.AUTOTUNE (-1). |
async_cycle | If True, create a threadpool of size `batch_size // example_per_class` and fetch inputs from the cycle shards asynchronously; however, in practice, the default single thread setting is faster. We only recommend setting this to True if it is absolutely necessary. |
prefetch_size | How many batch to precache. Defaults to 10. |
shard_suffix | Glog pattern used to collect the shard files list. Defaults to "*.tfrec". |
num_repeat | How many times to repeat the dataset. Defaults to -1 (infinite). |
A TF.data.dataset ready to be consumed by the model. |