-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathutils.py
More file actions
91 lines (77 loc) · 2.96 KB
/
utils.py
File metadata and controls
91 lines (77 loc) · 2.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""This file contain common utility functions."""
import os
import random
from datetime import datetime
from pytz import timezone
from tqdm import tqdm
# tqdm.pandas()
from transformers import set_seed
import numpy as np
import torch
import torch.backends.cudnn as cudnn
from typing import Any, Union
def get_curr_time() -> str:
"""Get current date and time in PST as str."""
return datetime.now().astimezone(
timezone('US/Pacific')).strftime("%d/%m/%Y %H:%M:%S")
class Logger:
"""Class to write message to both output_dir/filename.txt and terminal."""
def __init__(self, output_dir: str = None, filename: str = None) -> None:
if filename is not None:
self.log = os.path.join(output_dir, filename)
def write(self, message: Any, show_time: bool = True) -> None:
"write the message"
message = str(message)
if show_time:
# if message starts with \n, print the \n first before printing time
if message.startswith('\n'):
message = '\n' + get_curr_time() + ' >> ' + message[1:]
else:
message = get_curr_time() + ' >> ' + message
print(message)
if hasattr(self, 'log'):
with open(self.log, 'a') as f:
f.write(message + '\n')
def set_all_seeds(seed: int) -> None:
"""Function to set seeds for all RNGs."""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.device_count() > 0:
torch.cuda.manual_seed_all(seed)
cudnn.benchmark = True
set_seed(seed)
class CycleIndex:
"""Class to generate batches of training ids,
shuffled after each epoch."""
def __init__(self, indices: Union[int, list], batch_size: int,
shuffle: bool = True) -> None:
if type(indices) == int:
indices = np.arange(indices)
self.indices = indices
self.num_samples = len(indices)
self.batch_size = batch_size
self.pointer = 0
if shuffle:
np.random.shuffle(self.indices)
self.shuffle = shuffle
def get_batch_ind(self):
"""Get indices for next batch."""
start, end = self.pointer, self.pointer + self.batch_size
# If we have a full batch within this epoch, then get it.
if end <= self.num_samples:
if end == self.num_samples:
self.pointer = 0
if self.shuffle:
np.random.shuffle(self.indices)
else:
self.pointer = end
return self.indices[start:end]
# Otherwise, fill the batch with samples from next epoch.
last_batch_indices_incomplete = self.indices[start:]
remaining = self.batch_size - (self.num_samples - start)
self.pointer = remaining
if self.shuffle:
np.random.shuffle(self.indices)
return np.concatenate((last_batch_indices_incomplete,
self.indices[:remaining]))