Skip to content

Minor optimization of initial reference frame selector function #1160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions suite2p/registration/register.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,17 +72,19 @@ def compute_crop(xoff: int, yoff: int, corrXY, th_badframes, badframes, maxregsh
return badframes, yrange, xrange


def pick_initial_reference(frames: np.ndarray):
def pick_initial_reference(frames: np.ndarray, k=20) -> np.ndarray:
""" computes the initial reference image

the seed frame is the frame with the largest correlations with other frames;
the average of the seed frame with its top 20 correlated pairs is the
inital reference frame returned
the average of the seed frame with its top k correlated pairs is the
initial reference frame returned

Parameters
----------
frames : 3D array, int16
size [frames x Ly x Lx], frames from binary
k : int, optional
number of top correlations to average, by default 20

Returns
-------
Expand All @@ -91,16 +93,16 @@ def pick_initial_reference(frames: np.ndarray):

"""
nimg, Ly, Lx = frames.shape
frames = np.reshape(frames, (nimg, -1)).astype("float32")
frames = frames - np.reshape(frames.mean(axis=1), (nimg, 1))
cc = np.matmul(frames, frames.T)
ndiag = np.sqrt(np.diag(cc))
cc = cc / np.outer(ndiag, ndiag)
CCsort = -np.sort(-cc, axis=1)
bestCC = np.mean(CCsort[:, 1:20], axis=1)
frames = np.reshape(frames, (nimg, -1)).astype("float32") # flatten frames
frames = frames - np.reshape(frames.mean(axis=1), (nimg, 1)) # subtract mean
cc = np.matmul(frames, frames.T) # correlation matrix (nimg x nimg)
ndiag = np.sqrt(np.diag(cc)) # norm of each frame
cc = cc / np.outer(ndiag, ndiag) # normalize by norm of each frame
CCpartsort = np.partition(cc, -(k+1), axis=1)[:, -k:-1] # skip the self-correlation
bestCC = np.mean(CCpartsort, axis=1) # mean of top k-1 correlations for each frame
imax = np.argmax(bestCC)
indsort = np.argsort(-cc[imax, :])
refImg = np.mean(frames[indsort[0:20], :], axis=0)
indpartsort = np.argpartition(cc[imax, :], -k)[-k:] # top k correlations for seed frame
refImg = np.mean(frames[indpartsort, :], axis=0)
refImg = np.reshape(refImg, (Ly, Lx))
return refImg

Expand Down