-
Notifications
You must be signed in to change notification settings - Fork 384
/
Copy pathutils.py
257 lines (208 loc) · 10.1 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
from contextlib import contextmanager
import hashlib
import math
from pathlib import Path
import shutil
import urllib
import warnings
import torch
from torch import optim, Tensor
from torchvision.transforms import functional as TF
from typing import Union
def from_pil_image(x):
"""Converts from a PIL image to a tensor."""
x = TF.to_tensor(x)
if x.ndim == 2:
x = x[..., None]
return x * 2 - 1
def to_pil_image(x):
"""Converts from a tensor to a PIL image."""
if x.ndim == 4:
assert x.shape[0] == 1
x = x[0]
if x.shape[0] == 1:
x = x[0]
return TF.to_pil_image((x.clamp(-1, 1) + 1) / 2)
def hf_datasets_augs_helper(examples, transform, image_key, mode='RGB'):
"""Apply passed in transforms for HuggingFace Datasets."""
images = [transform(image.convert(mode)) for image in examples[image_key]]
return {image_key: images}
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
return x[(...,) + (None,) * dims_to_append]
def n_params(module):
"""Returns the number of trainable parameters in a module."""
return sum(p.numel() for p in module.parameters())
def download_file(path, url, digest=None):
"""Downloads a file if it does not exist, optionally checking its SHA-256 hash."""
path = Path(path)
path.parent.mkdir(parents=True, exist_ok=True)
if not path.exists():
with urllib.request.urlopen(url) as response, open(path, 'wb') as f:
shutil.copyfileobj(response, f)
if digest is not None:
file_digest = hashlib.sha256(open(path, 'rb').read()).hexdigest()
if digest != file_digest:
raise OSError(f'hash of {path} (url: {url}) failed to validate')
return path
@contextmanager
def train_mode(model, mode=True):
"""A context manager that places a model into training mode and restores
the previous mode on exit."""
modes = [module.training for module in model.modules()]
try:
yield model.train(mode)
finally:
for i, module in enumerate(model.modules()):
module.training = modes[i]
def eval_mode(model):
"""A context manager that places a model into evaluation mode and restores
the previous mode on exit."""
return train_mode(model, False)
@torch.no_grad()
def ema_update(model, averaged_model, decay):
"""Incorporates updated model parameters into an exponential moving averaged
version of a model. It should be called after each optimizer step."""
model_params = dict(model.named_parameters())
averaged_params = dict(averaged_model.named_parameters())
assert model_params.keys() == averaged_params.keys()
for name, param in model_params.items():
averaged_params[name].mul_(decay).add_(param, alpha=1 - decay)
model_buffers = dict(model.named_buffers())
averaged_buffers = dict(averaged_model.named_buffers())
assert model_buffers.keys() == averaged_buffers.keys()
for name, buf in model_buffers.items():
averaged_buffers[name].copy_(buf)
class EMAWarmup:
"""Implements an EMA warmup using an inverse decay schedule.
If inv_gamma=1 and power=1, implements a simple average. inv_gamma=1, power=2/3 are
good values for models you plan to train for a million or more steps (reaches decay
factor 0.999 at 31.6K steps, 0.9999 at 1M steps), inv_gamma=1, power=3/4 for models
you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999 at
215.4k steps).
Args:
inv_gamma (float): Inverse multiplicative factor of EMA warmup. Default: 1.
power (float): Exponential factor of EMA warmup. Default: 1.
min_value (float): The minimum EMA decay rate. Default: 0.
max_value (float): The maximum EMA decay rate. Default: 1.
start_at (int): The epoch to start averaging at. Default: 0.
last_epoch (int): The index of last epoch. Default: 0.
"""
def __init__(self, inv_gamma=1., power=1., min_value=0., max_value=1., start_at=0,
last_epoch=0):
self.inv_gamma = inv_gamma
self.power = power
self.min_value = min_value
self.max_value = max_value
self.start_at = start_at
self.last_epoch = last_epoch
def state_dict(self):
"""Returns the state of the class as a :class:`dict`."""
return dict(self.__dict__.items())
def load_state_dict(self, state_dict):
"""Loads the class's state.
Args:
state_dict (dict): scaler state. Should be an object returned
from a call to :meth:`state_dict`.
"""
self.__dict__.update(state_dict)
def get_value(self):
"""Gets the current EMA decay rate."""
epoch = max(0, self.last_epoch - self.start_at)
value = 1 - (1 + epoch / self.inv_gamma) ** -self.power
return 0. if epoch < 0 else min(self.max_value, max(self.min_value, value))
def step(self):
"""Updates the step count."""
self.last_epoch += 1
class InverseLR(optim.lr_scheduler._LRScheduler):
"""Implements an inverse decay learning rate schedule with an optional exponential
warmup. When last_epoch=-1, sets initial lr as lr.
inv_gamma is the number of steps/epochs required for the learning rate to decay to
(1 / 2)**power of its original value.
Args:
optimizer (Optimizer): Wrapped optimizer.
inv_gamma (float): Inverse multiplicative factor of learning rate decay. Default: 1.
power (float): Exponential factor of learning rate decay. Default: 1.
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
Default: 0.
min_lr (float): The minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
"""
def __init__(self, optimizer, inv_gamma=1., power=1., warmup=0., min_lr=0.,
last_epoch=-1, verbose=False):
self.inv_gamma = inv_gamma
self.power = power
if not 0. <= warmup < 1:
raise ValueError('Invalid value for warmup')
self.warmup = warmup
self.min_lr = min_lr
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.")
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
warmup = 1 - self.warmup ** (self.last_epoch + 1)
lr_mult = (1 + self.last_epoch / self.inv_gamma) ** -self.power
return [warmup * max(self.min_lr, base_lr * lr_mult)
for base_lr in self.base_lrs]
class ExponentialLR(optim.lr_scheduler._LRScheduler):
"""Implements an exponential learning rate schedule with an optional exponential
warmup. When last_epoch=-1, sets initial lr as lr. Decays the learning rate
continuously by decay (default 0.5) every num_steps steps.
Args:
optimizer (Optimizer): Wrapped optimizer.
num_steps (float): The number of steps to decay the learning rate by decay in.
decay (float): The factor by which to decay the learning rate every num_steps
steps. Default: 0.5.
warmup (float): Exponential warmup factor (0 <= warmup < 1, 0 to disable)
Default: 0.
min_lr (float): The minimum learning rate. Default: 0.
last_epoch (int): The index of last epoch. Default: -1.
verbose (bool): If ``True``, prints a message to stdout for
each update. Default: ``False``.
"""
def __init__(self, optimizer, num_steps, decay=0.5, warmup=0., min_lr=0.,
last_epoch=-1, verbose=False):
self.num_steps = num_steps
self.decay = decay
if not 0. <= warmup < 1:
raise ValueError('Invalid value for warmup')
self.warmup = warmup
self.min_lr = min_lr
super().__init__(optimizer, last_epoch, verbose)
def get_lr(self):
if not self._get_lr_called_within_step:
warnings.warn("To get the last learning rate computed by the scheduler, "
"please use `get_last_lr()`.")
return self._get_closed_form_lr()
def _get_closed_form_lr(self):
warmup = 1 - self.warmup ** (self.last_epoch + 1)
lr_mult = (self.decay ** (1 / self.num_steps)) ** self.last_epoch
return [warmup * max(self.min_lr, base_lr * lr_mult)
for base_lr in self.base_lrs]
def rand_log_normal(shape, loc=0., scale=1., device='cpu', dtype=torch.float32):
"""Draws samples from an lognormal distribution."""
return (torch.randn(shape, device=device, dtype=dtype) * scale + loc).exp()
def rand_log_logistic(shape, loc=0., scale=1., min_value=0., max_value=float('inf'), device='cpu', dtype=torch.float32):
"""Draws samples from an optionally truncated log-logistic distribution."""
min_value = torch.as_tensor(min_value, device=device, dtype=torch.float64)
max_value = torch.as_tensor(max_value, device=device, dtype=torch.float64)
min_cdf = min_value.log().sub(loc).div(scale).sigmoid()
max_cdf = max_value.log().sub(loc).div(scale).sigmoid()
u = torch.rand(shape, device=device, dtype=torch.float64) * (max_cdf - min_cdf) + min_cdf
return u.logit().mul(scale).add(loc).exp().to(dtype)
def rand_log_uniform(shape, min_value, max_value, device='cpu', dtype=torch.float32):
"""Draws samples from an log-uniform distribution."""
min_value = math.log(min_value)
max_value = math.log(max_value)
return (torch.rand(shape, device=device, dtype=dtype) * (max_value - min_value) + min_value).exp()
def quantize(quanta: Tensor, candidate: Union[int, float, Tensor]) -> Tensor:
"""Rounds `candidate` to the nearest element in `quanta`"""
return quanta[torch.argmin((quanta-candidate).abs(), dim=0)]