forked from ClownsharkBatwing/RES4LYF
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhelper.py
More file actions
38 lines (23 loc) · 712 Bytes
/
helper.py
File metadata and controls
38 lines (23 loc) · 712 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
import re
import torch
def get_extra_options_kv(key, default, extra_options):
match = re.search(rf"{key}\s*=\s*([a-zA-Z0-9_.+-]+)", extra_options)
if match:
value = match.group(1)
else:
value = default
return value
def extra_options_flag(flag, extra_options):
return bool(re.search(rf"{flag}", extra_options))
def safe_get_nested(d, keys, default=None):
for key in keys:
if isinstance(d, dict):
d = d.get(key, default)
else:
return default
return d
def initialize_or_scale(tensor, value, steps):
if tensor is None:
return torch.full((steps,), value)
else:
return value * tensor