forked from AI-secure/DecodingTrust
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutils.py
83 lines (66 loc) · 2.19 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import os
import errno
import numpy as np
import logging
import torch
import random
from torch.backends import cudnn
import functools
import signal
def ensure_dir(file_path):
directory = os.path.dirname(file_path)
if not os.path.exists(directory):
os.makedirs(directory)
print(f"Directory {directory} created.")
def timeout(sec):
"""
timeout decorator
:param sec: function raise TimeoutError after ? seconds
"""
def decorator(func):
@functools.wraps(func)
def wrapped_func(*args, **kwargs):
def _handle_timeout(signum, frame):
err_msg = f'Function {func.__name__} timed out after {sec} seconds'
raise TimeoutError(err_msg)
signal.signal(signal.SIGALRM, _handle_timeout)
signal.alarm(sec)
try:
result = func(*args, **kwargs)
finally:
signal.alarm(0)
return result
return wrapped_func
return decorator
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
def trim_to_shorter_length(texta, textb):
# truncate to shorter of o and s
shorter_length = min(len(texta.split(' ')), len(textb.split(' ')))
texta = ' '.join(texta.split(' ')[:shorter_length])
textb = ' '.join(textb.split(' ')[:shorter_length])
return texta, textb
def make_sure_path_exists(path):
try:
os.makedirs(path)
except OSError as exception:
if exception.errno != errno.EEXIST:
raise
def init_logger(root_dir, name="info"):
make_sure_path_exists(root_dir)
log_formatter = logging.Formatter("%(message)s")
logger = logging.getLogger()
file_handler = logging.FileHandler("{0}/{1}.log".format(root_dir, name), mode='w')
file_handler.setFormatter(log_formatter)
logger.addHandler(file_handler)
console_handler = logging.StreamHandler()
console_handler.setFormatter(log_formatter)
logger.addHandler(console_handler)
logger.setLevel(logging.INFO)
return logger