Skip to content

Commit 5e4f680

Browse files
authored
Update otter.py (#359)
* Update otter.py
1 parent 97fff75 commit 5e4f680

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

netZooPy/otter/otter.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,9 @@ def otter(W, P, C, lam=0.035, gamma=0.335, Iter=60, eta=0.00001, bexp=1):
4848

4949
C = C / np.trace(C)
5050
W = W / np.sqrt(np.trace(W @ W.T))
51-
P = P / np.trace(P)
51+
diagP = np.trace(P)
52+
if diagP > 0:
53+
P = P / np.trace(P)
5254

5355
P = P * (-(1 - lam)) + gamma * np.identity(t)
5456
C = C * (-lam)
@@ -138,4 +140,4 @@ def otter_gpu(W, P, C, lam=0.035, gamma=0.335, Iter=60, eta=0.00001, bexp=1):
138140

139141
W1 = cp.asnumpy(W)
140142
del W,C,P
141-
return W1
143+
return W1

0 commit comments

Comments
 (0)