Skip to content

Commit 48192bc

Browse files
allowing for different signs for motMask and movMask PCs
1 parent cd689f4 commit 48192bc

File tree

1 file changed

+23
-24
lines changed

1 file changed

+23
-24
lines changed

tests/test_svd_output.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
from facemap import process
1010

11-
r_tol, a_tol = 1, 1 # 1e-2, 1
11+
r_tol, a_tol = 1e-2, 1e-1 # 1e-2, 1
1212

1313

1414
def test_output_single_video(data_dir, video_names, expected_output_dir):
@@ -115,32 +115,31 @@ def check_frames(test_output, expected_output):
115115

116116

117117
def check_U(test_output, expected_output):
118-
motionMask = np.allclose(
119-
test_output["motMask"][0],
120-
expected_output["motMask"][0],
121-
rtol=r_tol + 5,
122-
atol=a_tol + 5,
123-
)
124-
movieMask = np.allclose(
125-
test_output["movMask"][0], expected_output["movMask"][0], rtol=r_tol, atol=a_tol
126-
)
127-
motionMask_reshape = np.allclose(
128-
test_output["motMask_reshape"][0],
129-
expected_output["motMask_reshape"][0],
130-
rtol=r_tol + 5,
131-
atol=a_tol + 5,
132-
)
133-
movMask_reshape = np.allclose(
134-
test_output["movMask_reshape"][0],
135-
expected_output["movMask_reshape"][0],
136-
rtol=r_tol,
137-
atol=a_tol,
138-
)
118+
nPCs = test_output["motSVD"][0].shape[1]
119+
motionMask_pos = [np.allclose(test_output["motMask"][0][:,i],
120+
expected_output["motMask"][0][:,i],
121+
rtol=r_tol, atol=a_tol) for i in range(nPCs)]
122+
motionMask_neg = [np.allclose(test_output["motMask"][0][:,i],
123+
-1 * expected_output["motMask"][0][:,i],
124+
rtol=r_tol, atol=a_tol) for i in range(nPCs)]
125+
motionMask = np.array(motionMask_pos) | np.array(motionMask_neg)
126+
motionMask = np.all(motionMask)
127+
128+
movieMask_pos = [np.allclose(test_output["movMask"][0][:,i],
129+
expected_output["movMask"][0][:,i],
130+
rtol=r_tol, atol=a_tol) for i in range(nPCs)]
131+
movieMask_neg = [np.allclose(test_output["movMask"][0][:,i],
132+
-1 * expected_output["movMask"][0][:,i],
133+
rtol=r_tol, atol=a_tol) for i in range(nPCs)]
134+
movieMask = np.array(movieMask_pos) | np.array(movieMask_neg)
135+
movieMask = np.all(movieMask)
136+
motionMask_reshape = test_output["motMask_reshape"][0].shape == expected_output["motMask_reshape"][0].shape
137+
movieMask_reshape = test_output["movMask_reshape"][0].shape == expected_output["movMask_reshape"][0].shape
139138
print("motionMask", motionMask)
140139
print("movieMask", movieMask)
141140
print("motionMask_reshape", motionMask_reshape)
142-
print("movMask_reshape", movMask_reshape)
143-
return motionMask and movieMask and motionMask_reshape and movMask_reshape
141+
print("movMask_reshape", movieMask_reshape)
142+
return motionMask and movieMask and motionMask_reshape and movieMask_reshape
144143

145144

146145
def check_V(test_output, expected_output):

0 commit comments

Comments
 (0)