-
Notifications
You must be signed in to change notification settings - Fork 188
Expand file tree
/
Copy pathdataset.py
More file actions
1412 lines (1161 loc) · 61.4 KB
/
dataset.py
File metadata and controls
1412 lines (1161 loc) · 61.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2023 MosaicML Streaming authors
# SPDX-License-Identifier: Apache-2.0
"""A mid-epoch-resumable streaming/caching pytorch IterableDataset."""
import json
import logging
import os
import sys
import warnings
from concurrent.futures import ThreadPoolExecutor, wait
from concurrent.futures._base import Future
from enum import IntEnum
from math import ceil
from threading import Event, Lock
from time import sleep, time_ns
from typing import Any, Dict, Iterator, Optional, Sequence, Tuple, Union
import numpy as np
from filelock import FileLock
from numpy.typing import NDArray
from torch import distributed as dist
from torch.utils.data import IterableDataset
from streaming.array import Array
from streaming.batching import generate_work
from streaming.constant import (BARRIER, BARRIER_FILELOCK, CACHE_FILELOCK, CACHE_USAGE, EPOCH_DATA,
EPOCH_SHAPE, NEXT_EPOCH, RESUME, SHARD_ACCESS_TIMES, SHARD_STATES,
TICK)
from streaming.distributed import maybe_init_dist
from streaming.format import get_index_basename
from streaming.sampling import get_sampling
from streaming.shared import (SharedArray, SharedBarrier, SharedMemory, SharedScalar, _get_path,
get_shm_prefix)
from streaming.spanner import Spanner
from streaming.stream import Stream
from streaming.util import normalize_bytes, normalize_count
from streaming.util.migration import get_keep_packed
from streaming.world import World
# An arbitrary time in the future, used for cold shard eviction.
NEVER = np.iinfo(np.uint64).max
logger = logging.getLogger(__name__)
class _ShardState(IntEnum):
"""The download status of a shard.
Restrictions:
- The initial state of INVALID must be zero.
- State transitions: REMOTE -> PREPARING -> LOCAL -> REMOTE.
"""
INVALID = 0 # The state is allocated (e.g., in an array), but not initialized yet.
REMOTE = 1 # The shard exists only at the remote source.
PREPARING = 2 # The shard is currently being worked on: (a) downloading from remote to local,
# (b) decompressing zip-only, etc.
LOCAL = 3 # Some form of the shard (raw or zip) exists locally (as well as remotely).
class _IterState(IntEnum):
"""The iter status of an _Iterator.
Restrictions:
- State transitions: ITERATING -> EXITING -> EXITED.
"""
ITERATING = 0 # We are currently iterating through an epoch.
EXITING = 1 # We have been signalled to end the epoch (either we hit end of __iter__, or
# someone else started a new epoch, of which only one can be valid at a time).
EXITED = 2 # All threads have noticed the exit signal and exited.
class _Iterator:
"""State of StreamingDataset __iter__, used to track and coordinate its threads.
Has methods to implement early exit when a new epoch is started before the last one is done.
Order of threads: 0 <= yield loop <= ready thread <= download thread <= total.
Three indices:
* Download index: points to the sample we are presently downloading, skipping other workers'
downloads in progress.
* Ready index: points to the farthest contiguously downloaded sample by any worker on this
node.
* Yield index: points to the (downloaded) sample that we are currently yielding.
Args:
sample_ids (NDArray[np.int64]): This worker's samples to download and yield.
"""
# The number of threads (`download`, `ready`, `yield``) to wait on the exits of before
# returning. The `yield` main thread exits at the end of epoch(s).
_num_threads_to_exit = 2
def __init__(self, sample_ids: NDArray[np.int64]) -> None:
self.sample_ids = sample_ids
self.total = len(sample_ids)
self.prepare_index = 0
self.ready_index = 0
self.yield_index = 0
self.eviction_index = 0
self._lock = Lock()
self._state = 0
self._num_exited = 0
# python will attempt to join all threads on shutdown.
# Here, we register a call to self.non_blocking_exit to run
# at shutdown to prevent a deadlock.
# In python version >=3.9 this can be accomplished via
# threading._register_atexit but not with the atexit module.
# In older python versions, the atexit module can be used, and
# threading._register_atexit does not exist.
if sys.version_info[1] <= 8: # check if python version <=3.8
import atexit
atexit.register(self.non_blocking_exit)
else:
from threading import _register_atexit # pyright: ignore
_register_atexit(self.non_blocking_exit)
def non_blocking_exit(self) -> None:
"""Signal threads to exit without blocking.
This will be called at process exit.
"""
with self._lock:
if self._state == _IterState.ITERATING:
self._state = _IterState.EXITING
def exit(self) -> None:
"""Signal threads to exit, wait until they have all exited, then return.
This is called when the user starts a new epoch without the threads from the previous epoch
having exited yet.
"""
# Signal threads to exit.
with self._lock:
if self._state == _IterState.ITERATING:
self._state = _IterState.EXITING
elif self._state == _IterState.EXITING:
pass
elif self._state == _IterState.EXITED:
return
else:
raise RuntimeError(f'Invalid _IterState: {self._state}')
# Block until they have all exited, updating _state to done.
while True:
with self._lock:
if self._num_exited >= self._num_threads_to_exit:
self._state = _IterState.EXITED
break
sleep(TICK)
def should_exit(self) -> bool:
"""Check if the calling thread should exit.
Returns:
bool: Whether to exit.
"""
with self._lock:
return self._state in {_IterState.EXITING, _IterState.EXITED}
def on_exit(self) -> None:
"""Note that a thread has exited."""
with self._lock:
self._num_exited += 1
class StreamingDataset(Array, IterableDataset):
"""A mid-epoch-resumable streaming/caching pytorch IterableDataset.
Features elastically deterministic shuffling, which enables fast mid-epoch resumption.
Checkpoints are represented in JSON as follows:
.. code-block:: json
{
"epoch" :"int",
"sample_in_epoch": "int",
"shuffle_seed": "int",
"num_canonical_nodes": "int"
}
StreamingDataset init takes two kinds of arguments:
* What to iterate:
* One or more streams (you must provide either ``streams`` or ``remote``/``local``):
* ``streams``
* ``remote``
* ``local``
* Knobs to control streaming behavior, which, if multiple streams are provided,
become defaults applied to each of them:
* ``split``
* ``download_retry``
* ``download_timeout``
* ``validate_hash``
* ``keep_packed``
* Absolute dataset size, if streams were weighted relatively:
* ``epoch_size``
* How to iterate:
* Shard lifecycle:
* ``predownload``
* ``cache_limit``
* Sampling:
* ``sampling_method``
* ``sampling_granularity``
* Determinism:
* ``partition_algo``
* ``num_canonical_nodes``
* ``batch_size``
* Shuffling:
* ``shuffle``
* ``shuffle_algo``
* ``shuffle_seed``
* ``shuffle_block_size``
* Batching:
* ``batching_method``
Args:
streams (Sequence[Stream], optional): One or more streams to stream/cache samples from,
which may be upsampled or downsampled. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
remote (str, optional): Remote path or directory to download the dataset from. If ``None``,
its data must exist locally. StreamingDataset uses either ``streams`` or
``remote``/``local``. Defaults to ``None``.
local (str, optional): Local working directory to download shards to. This is where shards
are cached while they are being used. Uses a temp directory if not set.
StreamingDataset uses either ``streams`` or ``remote``/``local``. Defaults to ``None``.
split (str, optional): Which dataset split to use, if any. If provided, we stream from/to
the ``split`` subdirs of ``remote`` and ``local``. Defaults to ``None``.
download_retry (int): Number of download re-attempts before giving up. Defaults to ``2``.
download_timeout (float): Number of seconds to wait for a shard to download before raising
an exception. Defaults to ``60``.
validate_hash (str, optional): Optional hash or checksum algorithm to use to validate
shards. Defaults to ``None``.
keep_packed (bool, optional): Whether to keep or drop the packed form of shards after
unpacking, e.g. compressed shards after decompression or Parquet shards after
conversion to MDS. If ``False``, keep iff remote is local or no remote. Defaults to
``None``, which is normalized to ``False``, in order to distinguish setting it on
purpose from receiving the default.
epoch_size (Union[int, str], optional): Number of samples to draw per epoch balanced
across all streams. If ``None``, takes its value from the total number of underlying
samples. Provide this field if you are weighting streams relatively to target a larger
or smaller epoch size. Defaults to ``None``. Can also take in human-readable number
abbreviations (e.g., ``"100k"``, ``"64M"``, ``"77b"``, etc). Defaults to ``None``.
predownload (int, optional): Target number of samples to download per worker in advance
of current sample. Workers will attempt to download ahead by this many samples during,
but not before, training. Recommendation is to provide a value greater than per device
batch size to ensure at-least per device batch size number of samples cached locally.
If ``None``, its value is set to ``8 * batch_size``. Defaults to ``None``.
cache_limit (Union[int, str], optional): Maximum size in bytes of this StreamingDataset's
shard cache. Before downloading a shard, the least recently used resident shard(s)
may be evicted (deleted from the local cache) in order to stay under the limit.
Set to ``None`` to disable shard eviction. Supports integer bytes as well as string
human-readable bytes (e.g., ``100b``, ``64kb``, ``77mb``, and so on). Defaults to
``None``.
sampling_method (str): Which sampling method to use, either ``balanced`` or ``fixed``.
Defaults to ``balanced``.
sampling_granularity (int): When picking samples for a stream's final partial repeat,
how many samples to pick from the same shard at a time (``1`` for evenly balanced
across shards, ``1000`` to pick 1000 samples from the same shard at a time, etc).
Defaults to ``1``.
partition_algo (str): Which partitioning algorithm to use. Defaults to ``relaxed``.
num_canonical_nodes (int, optional): Canonical number of nodes for shuffling with
resumption. The sample space is divided evenly according to the number of canonical
nodes. The higher the value, the more independent non-overlapping paths the
StreamingDataset replicas take through the shards per model replica (increasing data
source diversity). If ``None``, this is interpreted as 64 times the number of physical
nodes of the initial run if ``shuffle_algo`` is ``py1s`` or ``py2s``, and simply the
number of physical nodes of the initial run otherwise. Defaults to ``None``.
.. note::
For sequential sample ordering, set ``shuffle`` to ``False`` and
``num_canonical_nodes`` to the number of physical nodes of the initial run.
batch_size (int, optional): Batch size of its DataLoader, which affects how the dataset is
partitioned over the workers. Defaults to ``None``.
shuffle (bool): Whether to iterate over the samples in randomized order. Defaults to
``False``.
shuffle_algo (str): Which shuffling algorithm to use. Defaults to ``py1e``.
shuffle_seed (int): Seed for deterministic data shuffling. Defaults to ``9176``.
shuffle_block_size (int, optional): Unit of shuffle. A canonical node's samples are split
into blocks of this size, and samples within each block are shuffled. If ``None``, its
value is calculated as ``max(4_000_000 // num_canonical_nodes), 1 << 18)``. Defaults to
``None``.
batching_method (str): Which batching method to use, either ``random``, ``stratified``, or
``per_stream``. Defaults to ``random``.
keep_zip (bool, optional): This argument is deprecated. It has been replaced by
``keep_packed``, which is more general, for which it serves as a fallback. Defaults to
``None``.
"""
def __init__(self,
*,
streams: Optional[Sequence[Stream]] = None,
remote: Optional[str] = None,
local: Optional[str] = None,
split: Optional[str] = None,
download_retry: int = 2,
download_timeout: float = 60,
validate_hash: Optional[str] = None,
keep_packed: Optional[bool] = None,
epoch_size: Optional[Union[int, str]] = None,
predownload: Optional[int] = None,
cache_limit: Optional[Union[int, str]] = None,
sampling_method: str = 'balanced',
sampling_granularity: int = 1,
partition_algo: str = 'relaxed',
num_canonical_nodes: Optional[int] = None,
batch_size: Optional[int] = None,
shuffle: bool = False,
shuffle_algo: str = 'py1e',
shuffle_seed: int = 9176,
shuffle_block_size: Optional[int] = None,
batching_method: str = 'random',
keep_zip: Optional[bool] = None) -> None:
# Global arguments (which do not live in Streams).
self.predownload = predownload
self.cache_limit = cache_limit
self.sampling_method = sampling_method
self.sampling_granularity = sampling_granularity
self.partition_algo = partition_algo
self.num_canonical_nodes = num_canonical_nodes
self.batch_size = batch_size
self.shuffle = shuffle
self.shuffle_algo = shuffle_algo
self.shuffle_seed = shuffle_seed
self.shuffle_block_size = shuffle_block_size
self.batching_method = batching_method
keep_packed = get_keep_packed(keep_packed, keep_zip)
# Initialize initial_physical_nodes to None. If we are resuming, then we will set it to the
# number of physical nodes of the initial run in the _resume function.
self.initial_physical_nodes = None
# Check streams vs remote/local.
if bool(streams) == (bool(remote) or bool(local)):
raise ValueError(
'You must provide either `streams` or `remote`/`local`, but not both.')
# Check sampling method is one of "balanced" or "fixed".
if self.sampling_method not in ['balanced', 'fixed']:
raise ValueError(
f'Invalid sampling method: {sampling_method}. ' + \
f'Must be one of `balanced` or `fixed`.'
)
# Check sampling granularity.
if self.sampling_granularity <= 0:
raise ValueError(f'`sampling_granularity` must be a positive integer, but got: ' +
f'{self.sampling_granularity}.')
# Check batching method is one of "random", "stratified", or "per_stream".
if self.batching_method not in ['random', 'stratified', 'per_stream']:
raise ValueError(
f'Invalid batching method: {batching_method}. ' + \
f'Must be one of `random`, `stratified`, or `per_stream.'
)
# issue deprecation warning for py1b shuffle algorithm.
if self.shuffle_algo == 'py1b':
warnings.warn('The \'py1b\' shuffle algorithm will soon be deprecated. \
Please use the more performant \'py1br\' algorithm instead.',
DeprecationWarning,
stacklevel=2)
# Check shuffle seed.
if self.shuffle_seed < 0:
raise ValueError(f'`shuffle_seed` must be a non-negative integer, but got: ' +
f'{self.shuffle_seed}.')
# Check that predownload is at least per device batch size, and set it if currently `None`.
if self.predownload is not None and self.batch_size is not None and \
self.predownload < self.batch_size:
warnings.warn(f'predownload < batch_size ({self.predownload} < {self.batch_size}).' +
f'This may result in slower batch time. Recommendation is to set ' +
f'predownload to at-least batch_size.')
elif self.predownload is None:
self.predownload = 8 * self.batch_size if self.batch_size is not None else 64
# Convert epoch size from string to int, if needed. Cannot be negative.
epoch_size_value = None
if epoch_size:
epoch_size_value = normalize_count(epoch_size)
if epoch_size_value < 0:
raise ValueError(f'Epoch size cannot be negative. Received {epoch_size_value}.')
# Initialize torch dist ourselves, if necessary.
destroy_dist = maybe_init_dist()
# Initialize the Stream defaults and normalize to a list of Streams.
if streams:
default = {
'remote': remote,
'local': local,
'split': split,
'download_retry': download_retry,
'download_timeout': download_timeout,
'validate_hash': validate_hash,
'keep_packed': keep_packed,
}
for stream in streams:
stream.apply_default(default)
else:
default = Stream(remote=remote,
local=local,
split=split,
download_retry=download_retry,
download_timeout=download_timeout,
validate_hash=validate_hash,
keep_packed=keep_packed)
streams = [default]
# Validate the stream weighting scheme (relative or absolute) to catch errors before we go
# to the trouble of loading them.
Stream.validate_weights(streams)
# Set streams.
self.streams = streams
self.num_streams = len(streams)
# Initialize the World context.
#
# Beware: This information is for the per-rank process. DataLoader worker processes may see
# different values for these fields. We are saving the rank World here because we cannot
# instantiate a World inside the StreamingDataset destructor.
self._rank_world = world = World()
# Download each stream's index, load their shards, and map streams <-> shards.
self.num_samples = 0
self.shards = []
stream_per_shard = []
self.shard_offset_per_stream = np.zeros(self.num_streams, np.int64)
self.shards_per_stream = np.zeros(self.num_streams, np.int64)
self.sample_offset_per_stream = np.zeros(self.num_streams, np.int64)
self.samples_per_stream = np.zeros(self.num_streams, np.int64)
for stream_id, stream in enumerate(self.streams):
stream_shards = stream.get_shards(world)
num_stream_samples = sum(map(len, stream_shards))
if not num_stream_samples:
index_filename = os.path.join(stream.local, stream.split, get_index_basename())
raise RuntimeError(f'Stream contains no samples: {index_filename}.')
stream_per_shard += [stream_id] * len(stream_shards)
self.shard_offset_per_stream[stream_id] = len(self.shards)
self.shards_per_stream[stream_id] = len(stream_shards)
self.sample_offset_per_stream[stream_id] = self.num_samples
self.samples_per_stream[stream_id] = num_stream_samples
self.shards += stream_shards
self.num_samples += num_stream_samples
self.stream_per_shard = np.array(stream_per_shard, np.int64)
self.num_shards = len(self.shards)
# Check that cache limit is possible.
if self.cache_limit:
self.cache_limit = normalize_bytes(self.cache_limit)
min_cache_usage = sum((stream.get_index_size() for stream in streams))
if self.cache_limit <= min_cache_usage:
raise ValueError(f'Minimum cache usage ({min_cache_usage} bytes) is larger than ' +
f'the cache limit ({self.cache_limit} bytes). Please raise ' +
f'`cache_limit`. Recommendation is to provide a `cache_limit` ' +
f'as high as possible to avoid thrashing.')
self.max_shard_size_across_all_streams = max(
np.array([shard.get_max_size() for shard in self.shards]))
if self.cache_limit < 4 * self.max_shard_size_across_all_streams:
raise ValueError(f'Cache limit ({self.cache_limit} bytes) is too low. ' +
f'Increase the `cache_limit` to at-least four times the ' +
f'largest shard size ({self.max_shard_size_across_all_streams} ' +
f'bytes) which includes raw (decompressed) and zip ' +
f'(compressed) file size. Recommendation is to provide a ' +
f'`cache_limit` as high as possible to avoid thrashing.')
# Build the shard index (for partitioning and mapping samples to shards).
self.samples_per_shard = np.array([shard.samples for shard in self.shards], np.int64)
self.sample_offset_per_shard = self.samples_per_shard.cumsum() - self.samples_per_shard
self.spanner = Spanner(self.samples_per_shard)
# Now that we know the number of underlying samples of each stream, derive each stream's
# true proportion/repeat/choose, as well as the total epoch size.
self.epoch_size = Stream.apply_weights(self.streams, self.samples_per_stream,
epoch_size_value, self.shuffle_seed)
# Length (__len__) is the resampled epoch size divided over the number of devices.
self.length = ceil(self.epoch_size / world.num_ranks)
# Register/lookup our shared memory prefix and filelock root directory.
streams_local = [os.path.abspath(os.path.join(x.local, x.split)) for x in streams]
streams_remote = [
os.path.join(x.remote, x.split) if x.remote is not None else None for x in streams
]
self._shm_prefix_int, self._locals_shm = get_shm_prefix(streams_local, streams_remote,
world)
self._filelock_root = os.path.join(os.path.sep, 'tmp', 'streaming')
os.makedirs(self._filelock_root, exist_ok=True)
# Create the shared memory-backed barrier, without its lock, which is unpickleable.
self._shared_barrier = SharedBarrier(
os.path.join(self._filelock_root, _get_path(self._shm_prefix_int, BARRIER_FILELOCK)),
_get_path(self._shm_prefix_int, BARRIER))
# Epoch counter.
#
# Note: we do not assume that the end of __iter__() will ever be reached, so we need to
# increment the epoch counter at the start of __iter__() instead of at the end, so we need
# to track what the next epoch is, not the current epoch.
self._next_epoch = SharedScalar(np.int64, _get_path(self._shm_prefix_int, NEXT_EPOCH))
# Cache filelock. Protects downloading and evicting shards.
self._cache_filelock_path = os.path.join(self._filelock_root,
_get_path(self._shm_prefix_int, CACHE_FILELOCK))
self._cache_filelock: FileLock
# Cache usage in bytes.
self._cache_usage = SharedScalar(np.int64, _get_path(self._shm_prefix_int, CACHE_USAGE))
# Shard states array. Tells if a shard is missing, downloading, or present (eviction
# happens under the lock).
self._shard_states = SharedArray(self.num_shards, np.uint8,
_get_path(self._shm_prefix_int, SHARD_STATES))
# Time of last access per shard. This is used to decide which shard(s) to evict when we run
# out of space.
self._shard_access_times = SharedArray(self.num_shards, np.uint64,
_get_path(self._shm_prefix_int, SHARD_ACCESS_TIMES))
# Initialize shared memory objects.
if world.is_local_leader:
# Set initial epoch (before any resumption).
self.next_epoch = 0
# Get cache usage due to streams.
self.cache_usage = 0
for stream in self.streams:
self.cache_usage += stream.get_index_size()
# Get cache usage due to shards.
cache_usage_per_shard = np.zeros(self.num_shards, np.int64)
for stream_id, stream in enumerate(self.streams):
begin = self.shard_offset_per_stream[stream_id]
end = begin + self.shards_per_stream[stream_id]
stream.set_up_local(self.shards[begin:end], cache_usage_per_shard[begin:end])
self.cache_usage += cache_usage_per_shard.sum()
# If either raw or zip are present after local dir setup, the shard is considered
# present for download/eviction logic purposes (may need to decompress upon use).
for shard_id, size in enumerate(cache_usage_per_shard):
self._shard_states[shard_id] = _ShardState.LOCAL if size else _ShardState.REMOTE
self._shard_access_times[shard_id] = time_ns()
if dist.is_available() and dist.is_initialized():
dist.barrier()
if destroy_dist:
dist.destroy_process_group()
# Placeholder for a shared memory object where load_state_dict() saves its data to be
# picked up by __iter__().
self._resume_shm: SharedMemory
# Placeholder for an _Iterator which tracks state during __iter__().
self._iterator: _Iterator
# For exception handling in __iter__ threads.
self._executor: ThreadPoolExecutor
self._event: Event
del self._shared_barrier.lock # Remote the lock that makes it unpickleable.
def __del__(self) -> None:
"""Destructor, which releases its local working directories."""
if hasattr(self, '_locals_shm'):
try:
self._locals_shm.buf[:4] = np.int32(0).tobytes()
except:
pass
@property
def size(self) -> int:
"""Get the size of the dataset in samples.
Returns:
int: Number of samples.
"""
return self.num_samples
@property
def next_epoch(self) -> int:
"""Get the next epoch.
Returns:
int: Next epoch.
"""
return int(self._next_epoch.get())
@next_epoch.setter
def next_epoch(self, next_epoch: int) -> None:
"""Set the next epoch.
Args:
next_epoch (int): Next epoch.
"""
self._next_epoch.set(next_epoch)
@property
def cache_usage(self) -> int:
"""Get the cache usage.
Returns:
int: Cache usage in bytes.
"""
return int(self._cache_usage.get())
@cache_usage.setter
def cache_usage(self, cache_usage: int) -> None:
"""Set the cache usage.
Args:
cache_usage (int): Cache usage in bytes.
"""
self._cache_usage.set(cache_usage)
def __len__(self) -> int:
"""Get the length as a PyTorch IterableDataset.
Returns:
int: Dataset length.
"""
return self.length
def _set_shuffle_block_size(self):
"""Set the shuffle block size value."""
if self.shuffle_block_size is None:
self.shuffle_block_size = max(4_000_000 // self.num_canonical_nodes, 1 << 18) \
if self.num_canonical_nodes is not None else 1 << 18
def _resume(self, world: World, epoch: int) -> Tuple[int, int]:
"""Either resume from checkpoint or start at the beginning.
Args:
world (World): World state.
epoch (int): What epoch we think it is (pre-checkpoint).
Returns:
Tuple[int, int]: What epoch this is, and sample offset in that epoch.
"""
# Get the resume state, if it exists.
name = _get_path(self._shm_prefix_int, RESUME)
try:
shm = SharedMemory(name=name, create=False)
except FileNotFoundError:
# There is nothing to resume.
if not self.num_canonical_nodes:
if self.shuffle_algo in ['py1s', 'py2s']:
self.num_canonical_nodes = 64 * world.num_nodes
else:
self.num_canonical_nodes = world.num_nodes
self._set_shuffle_block_size()
return epoch, 0
# SharedMemory buffers may contain additional null bytes at the end.
buf = bytes(shm.buf)
index = buf.find(b'\0')
buf = buf[:index] if index != -1 else buf
obj = json.loads(buf.decode('utf-8'))
# Check if the resume state is stale.
if obj['epoch'] < epoch:
if not self.num_canonical_nodes:
if self.shuffle_algo in ['py1s', 'py2s']:
self.num_canonical_nodes = 64 * world.num_nodes
else:
self.num_canonical_nodes = world.num_nodes
self._set_shuffle_block_size()
return epoch, 0
# Load the correct resumption meta data.
epoch = obj['epoch']
sample_in_epoch = obj['sample_in_epoch']
self.num_canonical_nodes = obj['num_canonical_nodes']
self.shuffle_seed = obj['shuffle_seed']
# Ensure that we are backwards compatible with old checkpoint dataset state, since the
# 'initial_physical_nodes' key may not be present.
self.initial_physical_nodes = obj.get('initial_physical_nodes', None)
self._set_shuffle_block_size()
return epoch, sample_in_epoch
def _resume_incr_epoch(self, world: World) -> Tuple[int, int]:
"""Start or resume training, pre-incrementing the next epoch.
Args:
world (World): World state.
Returns:
Tuple[int, int]: What epoch this is, and sample offset in that epoch.
"""
# Lazily create the shared barrier's FileLock, which contains a threading Lock, which is
# unpickleable.
if not hasattr(self._shared_barrier, 'lock'):
self._shared_barrier.lock = FileLock(self._shared_barrier.filelock_path)
# Either resume from checkpoint, or start from scratch.
presumed_epoch = self.next_epoch
epoch, sample_in_epoch = self._resume(world, presumed_epoch)
# Wait for everyone to get the epoch above.
self._shared_barrier(world.workers_per_node)
# Set the new next epoch.
if world.is_local_leader:
self.next_epoch = epoch + 1
return epoch, sample_in_epoch
def state_dict(self, num_samples: int, from_beginning: bool) -> Dict[str, Any]:
"""Get a dict containing training state (called from non-worker process).
This is called on rank zero.
Our stock StreamingDataLoader counts samples from start of training (from_beginning=false).
However, if you are always counting from the start of the epoch, set from_beginning=true.
Args:
num_samples (int): The number of samples processed so far in the current epoch.
from_beginning (int): Whether we are counting samples from the start of this epoch, or
the start of just this potentially resumed training run this epoch.
Returns:
Dict[str, Any]: The state.
"""
world = World()
epoch = self.next_epoch - 1
epoch, offset = self._resume(world, epoch)
if from_beginning:
sample_in_epoch = num_samples
else:
sample_in_epoch = offset + num_samples
# If `self.initial_physical_nodes` is None, we are running for the first time, so we set
# initial_physical_nodes to the current number of physical nodes. Otherwise, we persist
# initial_physical_nodes as the value loaded and set from the resumption state.
initial_physical_nodes = world.num_nodes if self.initial_physical_nodes is None \
else self.initial_physical_nodes
return {
'epoch': epoch,
'sample_in_epoch': sample_in_epoch,
'num_canonical_nodes': self.num_canonical_nodes,
'shuffle_seed': self.shuffle_seed,
'initial_physical_nodes': initial_physical_nodes,
}
def load_state_dict(self, obj: Dict[str, Any]) -> None:
"""Load a dict containing training state (called from non-worker process).
This is called on each copy of the dataset when resuming.
We just save the state to shared memory for workers to pick up when __iter__ is next
called. We use shm because changes to this copy of the dataset wouldn't be picked up by
persistent workers.
Args:
obj (Dict[str, Any]): The state.
"""
name = _get_path(self._shm_prefix_int, RESUME)
data = json.dumps(obj, sort_keys=True).encode('utf-8')
# Some platforms choose to allocate chunks of memory based upon that platform's memory page
# size, hence the exact size of the shared memory block that was returned may be larger
# than what was requested.
self._resume_shm = SharedMemory(name=name, size=len(data))
self._resume_shm.buf[:len(data)] = data
def resample_streams(
self,
epoch: int,
stream_id: Optional[int] = None) -> Tuple[NDArray[np.int64], NDArray[np.int64]]:
"""Perform the up/down-sampling needed to generate the weighted epoch.
Args:
epoch (int): What epoch this is for. Used in seeding the sampling RNG.
stream_id (Optional[int]): Which stream to resample. If ``None``, resample all streams.
Defaults to ``None``.
Returns:
Tuple[NDArray[np.int64], NDArray[np.int64]]: Sampled shard sizes and sample mapping.
"""
# Initialize random number generator and arrays. If sampling_method is "fixed", the rng
# seed does not change, resulting in the same samples from each stream each epoch.
rng = np.random.default_rng(self.shuffle_seed + epoch) \
if self.sampling_method == 'balanced' \
else np.random.default_rng(self.shuffle_seed)
shuffle_units = []
sample_ids = []
resampling_streams = range(self.num_streams) if stream_id is None else [stream_id]
# Iterate over each stream.
for stream_id in resampling_streams:
# stream's shard offset in list of all shards from all streams
stream_shard_offset = self.shard_offset_per_stream[stream_id]
num_stream_shards = self.shards_per_stream[stream_id]
stream_shard_ids = stream_shard_offset + np.arange(num_stream_shards)
# Calculate choose per stream shard.
samples_per_stream_shard = self.samples_per_shard[stream_shard_ids]
# the number of items to choose from each stream, obtained during initialization
stream_choose = self.streams[stream_id].choose
use_epoch = self.sampling_method == 'balanced'
choose_per_stream_shard = get_sampling(samples_per_stream_shard, stream_choose,
self.sampling_granularity, self.shuffle_seed,
epoch, use_epoch)
# Iterate over each shard of this stream.
for shard_id, shard_samples, shard_choose in zip(stream_shard_ids,
samples_per_stream_shard,
choose_per_stream_shard):
# Calculate shuffle units for this shard.
# shuffle units are lists where each entry is a number of samples to take
# from the shard. If upsampling a shard with 4 samples by 2.5x,
# shard_choose will be 10, and shard_shuffle_units will be [4, 4, 2]. If
# downsampling that same shard by 0.5x, shard_choose will be 2 and
# shard_shuffle_units will be just [2].
shard_shuffle_units = [shard_samples] * (shard_choose // shard_samples)
remainder = shard_choose % shard_samples
if remainder:
shard_shuffle_units.append(remainder)
shuffle_units.append(shard_shuffle_units)
# Calculate sample IDs of any full repeats.
shard_sample_offset = self.sample_offset_per_shard[shard_id]
num_full_repeats = shard_choose // shard_samples
if num_full_repeats:
full_repeat = shard_sample_offset + np.arange(shard_samples)
sample_ids += [full_repeat] * num_full_repeats
# Calculate sample IDs of a possible partial repeat.
# for fixed sampling this partial repeat chooses the same
# samples since we have fixed the rng seed.
shortfall = shard_choose % shard_samples
if shortfall:
partial_repeat = shard_sample_offset + rng.choice(
shard_samples, shortfall, False)
partial_repeat.sort()
sample_ids.append(partial_repeat)
shuffle_units = np.concatenate(shuffle_units).astype(np.int64)
sample_ids = np.concatenate(sample_ids).astype(np.int64)
return shuffle_units, sample_ids
def _share_work(self, sample_ids: NDArray[np.int64]) -> Tuple[SharedMemory, SharedMemory]:
"""Put an epoch's sample ordering into shared memory.
Args:
sample_ids (NDArray[np.int64]): Sample IDs.
Returns:
Tuple[SharedMemory, SharedMemory]: Shared memory arrays containing shape and data.
"""
ndim = 5
# Validate shape.
if sample_ids.ndim != ndim:
raise ValueError(f'Sample IDs must be of {ndim}D shape (num physical nodes, ' +
f'ranks per node, workers per rank, batches per worker, ' +
f'batch size). Instead, found as {sample_ids.ndim}D shape.')
# Save the generated epoch shape to shared memory.
name = _get_path(self._shm_prefix_int, EPOCH_SHAPE)
size = ndim * np.int64().nbytes
shape_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False)
shape_shm.buf[:size] = np.array(sample_ids.shape, np.int64).tobytes()
# Save the generated epoch data to shared memory.
name = _get_path(self._shm_prefix_int, EPOCH_DATA)
size = sample_ids.size * np.int64().nbytes
data_shm = SharedMemory(name=name, create=True, size=size, auto_cleanup=False)
data_shm.buf[:size] = sample_ids.tobytes()
return shape_shm, data_shm
def _attach_work(self) -> Tuple[NDArray[np.int64], SharedMemory, SharedMemory]:
"""Get an epoch's sample ordering from shared memory.
Returns:
NDArray[np.int64]: Sample IDs.
"""
ndim = 5
# Load the generated epoch shape from shared memory.
name = _get_path(self._shm_prefix_int, EPOCH_SHAPE)
size = ndim * np.int64().nbytes
shape_shm = SharedMemory(name=name, create=False, size=size, auto_cleanup=False)
shape = tuple(np.ndarray(5, buffer=shape_shm.buf, dtype=np.int64))
# Attach to the generated epoch data in shared memory.
name = _get_path(self._shm_prefix_int, EPOCH_DATA)
size = int(np.prod(shape)) * np.int64().nbytes
data_shm = SharedMemory(name=name, create=False, size=size, auto_cleanup=False)
sample_ids = np.ndarray(shape, buffer=data_shm.buf, dtype=np.int64)
return sample_ids, shape_shm, data_shm
def _get_work(self, world: World, epoch: int, sample_in_epoch: int) -> NDArray[np.int64]:
"""Get this worker's partition of this epoch's sample space.
Args:
world (World): World state.
epoch (int): Which epoch it is.
sample_in_epoch (int): Where we are in the epoch.
Returns:
Optional[NDArray[np.int64]]: Our partition of the epoch.
"""
# Lazily create the shared barrier's FileLock, which contains a threading Lock, which is
# unpickleable.
if not hasattr(self._shared_barrier, 'lock'):
self._shared_barrier.lock = FileLock(self._shared_barrier.filelock_path)
# Do expensive work that may use a lot of cores/memory just once, in the local leader.
if world.is_local_leader:
epoch_sample_ids = generate_work(self.batching_method, self, world, epoch,
sample_in_epoch)
shape_shm, data_shm = self._share_work(epoch_sample_ids)
self._shared_barrier(world.workers_per_node)
else:
self._shared_barrier(world.workers_per_node)
epoch_sample_ids, shape_shm, data_shm = self._attach_work()
# Each worker gets their portion of the work.
worker_sample_ids = epoch_sample_ids[world.node, world.rank_of_node,
world.worker_of_rank].flatten()
self._shared_barrier(world.workers_per_node)
# Now clean up after ourselves.
shape_shm.cleanup()
data_shm.cleanup()
return worker_sample_ids
def _evict_shard(self, shard_id: int) -> None:
"""Evict the given shard.
Assumes you hold ``_cache_filelock``, preventing anyone else from modifying the cache. We
expect that shard deletions are very fast.
This method is called internally by ``prepare_shard`` to clear space for more downloads.
Args:
shard_id (int): Shard to evict.
"""
# Delete the shard's last access time, so that it is not searchable when finding the
# coldest shard to evict. This is done by setting the time far into the future.
self._shard_access_times[shard_id] = NEVER
# Set the shard state to missing.
self._shard_states[shard_id] = _ShardState.REMOTE
# Perform the eviction, updating cache usage to account for the removal.
shard = self.shards[shard_id]
self.cache_usage -= shard.evict()
if self.cache_usage < 0:
raise RuntimeError(f'Negative cache usage: {self.cache_usage}.')
def _evict_coldest_shard(self) -> None:
"""Evict the coldeset (i.e., least recently accessed) shard.
Assumes you hold ``__cache_filelock``, preventing anyone else from modifying the cache. We
expect that shard deletions are very fast.
This method is called internally by ``prepare_shard`` to clear space for more downloads.
"""
while True:
# Find the shard with the oldest last access time.
shard_id = int(self._shard_access_times.numpy().argmin())
# Check the shard's last access time. If it is NEVER, there are no downloaded shards to
# evict. If any shards are currently being downloaded, wait, else raise an error.
if self._shard_access_times[shard_id] == NEVER: