Skip to content

Commit e1392cb

Browse files
bug with tbin in rrr prediction
1 parent a3e9df1 commit e1392cb

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

facemap/neural_prediction/prediction_utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ def rrr_prediction(
238238
rank = min(min_dim, rank)
239239
corrf = np.zeros((rank, n_feats))
240240
varexpf = np.zeros((rank, n_feats))
241-
varexp = np.zeros((rank, 2)) if tbin != 0 else np.zeros((rank, 1))
241+
varexp = np.zeros((rank, 2)) if (tbin is not None and tbin > 1) else np.zeros((rank, 1))
242242
Y_pred_test = np.zeros((len(itest), n_feats))
243243
for r in range(rank):
244244
Y_pred_test = X[itest] @ B[:, : r + 1] @ A[:, : r + 1].T
@@ -268,8 +268,8 @@ def rrr_prediction(
268268
itest,
269269
A.cpu().numpy(),
270270
B.cpu().numpy(),
271-
varexpf,
272-
corrf,
271+
varexpf.squeeze(),
272+
corrf.squeeze(),
273273
)
274274

275275

0 commit comments

Comments
 (0)