1
+ import numpy as np
2
+
3
+
4
+ def pnorm (p ):
5
+ if not isinstance (p , (list , tuple )):
6
+ raise ValueError (f'probability map { p } must be of type (list, tuple), not { type (p )} ' )
7
+ ptot = np .sum (p )
8
+ if not np .allclose (ptot , 1 ):
9
+ p = [i / ptot for i in p ]
10
+ return p
11
+
12
+
13
+ def multinomial (num_samples , p ):
14
+ valid_p = pnorm (p )
15
+ res = np .random .multinomial (num_samples , valid_p )
16
+ return res
17
+
18
+
19
+ class Sampler (object ):
20
+ r"""Base class for all Samplers.
21
+ Every Sampler subclass has to provide an :meth:`__iter__` method, providing a
22
+ way to iterate over indices of dataset elements, and a :meth:`__len__` method
23
+ that returns the length of the returned iterators.
24
+ .. note:: The :meth:`__len__` method isn't strictly required by
25
+ :class:`~torch.utils.data.DataLoader`, but is expected in any
26
+ calculation involving the length of a :class:`~torch.utils.data.DataLoader`.
27
+ """
28
+
29
+ def __init__ (self , data_source ):
30
+ pass
31
+
32
+ def __iter__ (self ):
33
+ raise NotImplementedError
34
+
35
+ # NOTE [ Lack of Default `__len__` in Python Abstract Base Classes ]
36
+ #
37
+ # Many times we have an abstract class representing a collection/iterable of
38
+ # data, e.g., `torch.utils.data.Sampler`, with its subclasses optionally
39
+ # implementing a `__len__` method. In such cases, we must make sure to not
40
+ # provide a default implementation, because both straightforward default
41
+ # implementations have their issues:
42
+ #
43
+ # + `return NotImplemented`:
44
+ # Calling `len(subclass_instance)` raises:
45
+ # TypeError: 'NotImplementedType' object cannot be interpreted as an integer
46
+ #
47
+ # + `raise NotImplementedError()`:
48
+ # This prevents triggering some fallback behavior. E.g., the built-in
49
+ # `list(X)` tries to call `len(X)` first, and executes a different code
50
+ # path if the method is not found or `NotImplemented` is returned, while
51
+ # raising an `NotImplementedError` will propagate and and make the call
52
+ # fail where it could have use `__iter__` to complete the call.
53
+ #
54
+ # Thus, the only two sensible things to do are
55
+ #
56
+ # + **not** provide a default `__len__`.
57
+ #
58
+ # + raise a `TypeError` instead, which is what Python uses when users call
59
+ # a method that is not defined on an object.
60
+ # (@ssnl verifies that this works on at least Python 3.7.)
61
+
62
+
63
+ class SequentialSampler (Sampler ):
64
+ r"""Samples elements sequentially, always in the same order.
65
+ Arguments:
66
+ data_source (Dataset): dataset to sample from
67
+ """
68
+
69
+ def __init__ (self , data_source ):
70
+ self .data_source = data_source
71
+
72
+ def __iter__ (self ):
73
+ return iter (self .data_source .keys ())
74
+
75
+ def __len__ (self ):
76
+ return len (self .data_source )
77
+
78
+
79
+ class RandomSampler (Sampler ):
80
+ r"""Samples elements randomly. If without replacement, then sample from a shuffled dataset.
81
+ If with replacement, then user can specify :attr:`num_samples` to draw.
82
+ Arguments:
83
+ data_source (Dataset): dataset to sample from
84
+ replacement (bool): samples are drawn with replacement if ``True``, default=``False``
85
+ num_samples (int): number of samples to draw, default=`len(dataset)`. This argument
86
+ is supposed to be specified only when `replacement` is ``True``.
87
+ """
88
+
89
+ def __init__ (self , data_source , replacement = False , num_samples = None ):
90
+ self .data_source = data_source
91
+ self .replacement = replacement
92
+ self ._num_samples = num_samples
93
+
94
+ if not isinstance (self .replacement , bool ):
95
+ raise ValueError ("replacement should be a boolean value, but got "
96
+ "replacement={}" .format (self .replacement ))
97
+
98
+ if self ._num_samples is not None and not replacement :
99
+ raise ValueError ("With replacement=False, num_samples should not be specified, "
100
+ "since a random permute will be performed." )
101
+
102
+ if not isinstance (self .num_samples , int ) or self .num_samples <= 0 :
103
+ raise ValueError ("num_samples should be a positive integer "
104
+ "value, but got num_samples={}" .format (self .num_samples ))
105
+
106
+ @property
107
+ def num_samples (self ):
108
+ # dataset size might change at runtime
109
+ if self ._num_samples is None :
110
+ return len (self .data_source )
111
+ return self ._num_samples
112
+
113
+ def __iter__ (self ):
114
+ n = len (self .data_source )
115
+ keys = list (self .data_source .keys ())
116
+ if self .replacement :
117
+ choose = np .random .randint (low = 0 , high = n , size = (self .num_samples ,), dtype = np .int64 ).tolist ()
118
+ return (keys [x ] for x in choose )
119
+ choose = np .random .permutation (self .num_samples )
120
+ return (keys [x ] for x in choose )
121
+
122
+ def __len__ (self ):
123
+ return self .num_samples
124
+
125
+
126
+ class SubsetRandomSampler (Sampler ):
127
+ r"""Samples elements randomly from a given list of indices, without replacement.
128
+ Arguments:
129
+ indices (sequence): a sequence of indices
130
+ """
131
+
132
+ def __init__ (self , indices ):
133
+ self .indices = indices
134
+
135
+ def __iter__ (self ):
136
+ choose = np .random .permutation (len (self .indices ))
137
+ return (self .indices [x ] for x in choose )
138
+
139
+ def __len__ (self ):
140
+ return len (self .indices )
141
+
142
+
143
+ class WeightedRandomSampler (Sampler ):
144
+ r"""Samples elements from ``[0,..,len(weights)-1]`` with given probabilities (weights).
145
+ Args:
146
+ weights (sequence) : a sequence of weights, not necessary summing up to one
147
+ num_samples (int): number of samples to draw
148
+ replacement (bool): if ``True``, samples are drawn with replacement.
149
+ If not, they are drawn without replacement, which means that when a
150
+ sample index is drawn for a row, it cannot be drawn again for that row.
151
+ Example:
152
+ >>> list(WeightedRandomSampler([0.1, 0.9, 0.4, 0.7, 3.0, 0.6], 5, replacement=True))
153
+ [0, 0, 0, 1, 0]
154
+ >>> list(WeightedRandomSampler([0.9, 0.4, 0.05, 0.2, 0.3, 0.1], 5, replacement=False))
155
+ [0, 1, 4, 3, 2]
156
+ """
157
+
158
+ def __init__ (self , weights , num_samples ):
159
+ if not isinstance (num_samples , int ) or isinstance (num_samples , bool ) or \
160
+ num_samples <= 0 :
161
+ raise ValueError ("num_samples should be a positive integer "
162
+ "value, but got num_samples={}" .format (num_samples ))
163
+ self .weights = tuple (weights )
164
+ self .num_samples = num_samples
165
+
166
+ def __iter__ (self ):
167
+ return iter (multinomial (self .num_samples , self .weights ))
168
+
169
+ def __len__ (self ):
170
+ return self .num_samples
171
+
172
+
173
+ class BatchSampler (Sampler ):
174
+ r"""Wraps another sampler to yield a mini-batch of indices.
175
+ Args:
176
+ sampler (Sampler): Base sampler.
177
+ batch_size (int): Size of mini-batch.
178
+ drop_last (bool): If ``True``, the sampler will drop the last batch if
179
+ its size would be less than ``batch_size``
180
+ Example:
181
+ >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=False))
182
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8], [9]]
183
+ >>> list(BatchSampler(SequentialSampler(range(10)), batch_size=3, drop_last=True))
184
+ [[0, 1, 2], [3, 4, 5], [6, 7, 8]]
185
+ """
186
+
187
+ def __init__ (self , sampler , batch_size , drop_last ):
188
+ if not isinstance (sampler , Sampler ):
189
+ raise ValueError ("sampler should be an instance of "
190
+ "torch.utils.data.Sampler, but got sampler={}"
191
+ .format (sampler ))
192
+ if not isinstance (batch_size , int ) or isinstance (batch_size , bool ) or \
193
+ batch_size <= 0 :
194
+ raise ValueError ("batch_size should be a positive integer value, "
195
+ "but got batch_size={}" .format (batch_size ))
196
+ if not isinstance (drop_last , bool ):
197
+ raise ValueError ("drop_last should be a boolean value, but got "
198
+ "drop_last={}" .format (drop_last ))
199
+ self .sampler = sampler
200
+ self .batch_size = batch_size
201
+ self .drop_last = drop_last
202
+
203
+ def __iter__ (self ):
204
+ batch = []
205
+ for idx in self .sampler :
206
+ batch .append (idx )
207
+ if len (batch ) == self .batch_size :
208
+ yield batch
209
+ batch = []
210
+ if len (batch ) > 0 and not self .drop_last :
211
+ yield batch
212
+
213
+ def __len__ (self ):
214
+ if self .drop_last :
215
+ return len (self .sampler ) // self .batch_size
216
+ else :
217
+ return (len (self .sampler ) + self .batch_size - 1 ) // self .batch_size
0 commit comments