forked from junyanz/interactive-deep-colorization
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.py
82 lines (67 loc) · 2.31 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import numpy as np
import time
import datetime
def check_value(inds, val):
# Check to see if an array is a single element equaling a particular value
# Good for pre-processing inputs in a function
if(np.array(inds).size == 1):
if(inds == val):
return True
return False
def flatten_nd_array(pts_nd, axis=1):
# Flatten an nd array into a 2d array with a certain axis
# INPUTS
# pts_nd N0xN1x...xNd array
# axis integer
# OUTPUTS
# pts_flt prod(N \ N_axis) x N_axis array
NDIM = pts_nd.ndim
SHP = np.array(pts_nd.shape)
nax = np.setdiff1d(np.arange(0, NDIM), np.array((axis))) # non axis indices
NPTS = np.prod(SHP[nax])
axorder = np.concatenate((nax, np.array(axis).flatten()), axis=0)
pts_flt = pts_nd.transpose((axorder))
pts_flt = pts_flt.reshape(NPTS, SHP[axis])
return pts_flt
def unflatten_2d_array(pts_flt, pts_nd, axis=1, squeeze=False):
# Unflatten a 2d array with a certain axis
# INPUTS
# pts_flt prod(N \ N_axis) x M array
# pts_nd N0xN1x...xNd array
# axis integer
# squeeze bool if true, M=1, squeeze it out
# OUTPUTS
# pts_out N0xN1x...xNd array
NDIM = pts_nd.ndim
SHP = np.array(pts_nd.shape)
nax = np.setdiff1d(np.arange(0, NDIM), np.array((axis))) # non axis indices
if(squeeze):
axorder = nax
axorder_rev = np.argsort(axorder)
M = pts_flt.shape[1]
NEW_SHP = SHP[nax].tolist()
pts_out = pts_flt.reshape(NEW_SHP)
pts_out = pts_out.transpose(axorder_rev)
else:
axorder = np.concatenate((nax, np.array(axis).flatten()), axis=0)
axorder_rev = np.argsort(axorder)
M = pts_flt.shape[1]
NEW_SHP = SHP[nax].tolist()
NEW_SHP.append(M)
pts_out = pts_flt.reshape(NEW_SHP)
pts_out = pts_out.transpose(axorder_rev)
return pts_out
def na():
return np.newaxis
class Timer():
def __init__(self):
self.cur_t = time.time()
def tic(self):
self.cur_t = time.time()
def toc(self):
return time.time() - self.cur_t
def tocStr(self, t=-1):
if(t == -1):
return str(datetime.timedelta(seconds=np.round(time.time() - self.cur_t, 3)))[:-4]
else:
return str(datetime.timedelta(seconds=np.round(t, 3)))[:-4]