Skip to content

Commit 429dc5c

Browse files
adding anatomical segmentation code
1 parent b571bd2 commit 429dc5c

File tree

1 file changed

+153
-0
lines changed

1 file changed

+153
-0
lines changed

suite2p/detection/anatomical.py

Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
import numpy as np
2+
from scipy.ndimage import find_objects
3+
from cellpose.models import Cellpose
4+
from cellpose import transforms, dynamics
5+
from cellpose.utils import fill_holes_and_remove_small_masks
6+
from mxnet import nd
7+
import time
8+
import cv2
9+
10+
from . import utils
11+
from .stats import roi_stats
12+
13+
def mask_centers(masks):
14+
centers = np.zeros((masks.max(), 2), np.int32)
15+
diams = np.zeros(masks.max(), np.float32)
16+
slices = find_objects(masks)
17+
for i,si in enumerate(slices):
18+
if si is not None:
19+
sr,sc = si
20+
ymed, xmed, diam = utils.mask_stats(masks[sr, sc] == (i+1))
21+
centers[i] = np.array([ymed, xmed])
22+
diams[i] = diam
23+
return centers, diams
24+
25+
def patch_detect(patches, diam):
26+
""" anatomical detection of masks from top active frames for putative cell """
27+
print('refining masks using cellpose')
28+
npatches = len(patches)
29+
ly = patches[0].shape[0]
30+
model = Cellpose(net_avg=False)
31+
imgs = np.zeros((npatches, ly, ly, 2), np.float32)
32+
for i,m in enumerate(patches):
33+
imgs[i,:,:,0] = transforms.normalize99(m)
34+
rsz = 30. / diam
35+
imgs = transforms.resize_image(imgs, rsz=rsz).transpose(0,3,1,2)
36+
imgs, ysub, xsub = transforms.pad_image_ND(imgs)
37+
38+
pmasks = np.zeros((npatches, ly, ly), np.uint16)
39+
batch_size = 8 * 224 // ly
40+
tic=time.time()
41+
for j in np.arange(0, npatches, batch_size):
42+
img = nd.array(imgs[j:j+batch_size])
43+
y = model.cp.net(img)[0]
44+
y = y[:, :, ysub[0]:ysub[-1]+1, xsub[0]:xsub[-1]+1]
45+
y = y.asnumpy()
46+
for i,yi in enumerate(y):
47+
cellprob = yi[-1]
48+
dP = yi[:2]
49+
niter = 1 / rsz * 200
50+
p = dynamics.follow_flows(-1 * dP * (cellprob>0) / 5.,
51+
niter=niter)
52+
maski = dynamics.get_masks(p, iscell=(cellprob>0),
53+
flows=dP, threshold=1.0)
54+
maski = fill_holes_and_remove_small_masks(maski)
55+
maski = transforms.resize_image(maski, ly, ly,
56+
interpolation=cv2.INTER_NEAREST)
57+
pmasks[j+i] = maski
58+
if j%5==0:
59+
print('%d / %d masks created in %0.2fs'%(j+batch_size, npatches, time.time()-tic))
60+
return pmasks
61+
62+
def refine_masks(stats, patches, seeds, diam, Lyc, Lxc):
63+
nmasks = len(patches)
64+
patch_masks = patch_detect(patches, diam)
65+
ly = patches[0].shape[0] // 2
66+
igood = np.zeros(nmasks, np.bool)
67+
for i, (patch_mask, stat, (yi,xi)) in enumerate(zip(patch_masks, stats, seeds)):
68+
mask = np.zeros((Lyc, Lxc), np.float32)
69+
ypix0, xpix0= stat['ypix'], stat['xpix']
70+
mask[ypix0, xpix0] = stat['lam']
71+
func_mask = utils.square_mask(mask, ly, yi, xi)
72+
ious = utils.mask_ious(patch_mask.astype(np.uint16),
73+
(func_mask>0).astype(np.uint16))[0]
74+
if len(ious)>0 and ious.max() > 0.45:
75+
mask_id = np.argmax(ious) + 1
76+
patch_mask = patch_mask[max(0, ly-yi) : min(2*ly, Lyc+ly-yi),
77+
max(0, ly-xi) : min(2*ly, Lxc+ly-xi)]
78+
func_mask = func_mask[max(0, ly-yi) : min(2*ly, Lyc+ly-yi),
79+
max(0, ly-xi) : min(2*ly, Lxc+ly-xi)]
80+
ypix0, xpix0 = np.nonzero(patch_mask==mask_id)
81+
lam0 = func_mask[ypix0, xpix0]
82+
lam0[lam0<=0] = lam0.min()
83+
ypix0 = ypix0 + max(0, yi-ly)
84+
xpix0 = xpix0 + max(0, xi-ly)
85+
igood[i] = True
86+
stat['ypix'] = ypix0
87+
stat['xpix'] = xpix0
88+
stat['lam'] = lam0
89+
stat['anatomical'] = True
90+
else:
91+
stat['anatomical'] = False
92+
return stats
93+
94+
def roi_detect(mproj, diameter=None):
95+
model = Cellpose()
96+
masks = model.eval(mproj, net_avg=True, channels=[0,0], diameter=diameter, flow_threshold=1.5)[0]
97+
shape = masks.shape
98+
_, masks = np.unique(np.int32(masks), return_inverse=True)
99+
masks = masks.reshape(shape)
100+
centers, mask_diams = mask_centers(masks)
101+
median_diam = np.median(mask_diams)
102+
print('>>>> %d masks detected, median diameter = %0.2f ' % (masks.max(), median_diam))
103+
return masks, centers, median_diam, mask_diams.astype(np.int32)
104+
105+
def masks_to_stats(masks, weights):
106+
stats = []
107+
slices = find_objects(masks)
108+
for i,si in enumerate(slices):
109+
sr,sc = si
110+
ypix0, xpix0 = np.nonzero(masks[sr, sc]==(i+1))
111+
ypix0 = ypix0.astype(int) + sr.start
112+
xpix0 = xpix0.astype(int) + sc.start
113+
stats.append({
114+
'ypix': ypix0,
115+
'xpix': xpix0,
116+
'lam': weights[ypix0, xpix0],
117+
'footprint': 1
118+
})
119+
return stats
120+
121+
def select_rois(meanImg, weights, Ly, Lx, ymin, xmin):
122+
masks, centers, median_diam, mask_diams = roi_detect(meanImg)
123+
stats = masks_to_stats(masks, weights)
124+
for stat in stats:
125+
stat['ypix'] += int(ymin)
126+
stat['xpix'] += int(xmin)
127+
stats = roi_stats(stats, median_diam, median_diam, Ly, Lx)
128+
return stats
129+
130+
# def run_assist():
131+
# nmasks, diam = 0, None
132+
# if anatomical:
133+
# try:
134+
# print('>>>> CELLPOSE estimating spatial scale and masks as seeds for functional algorithm')
135+
# from . import anatomical
136+
# mproj = np.log(np.maximum(1e-3, max_proj / np.maximum(1e-3, mean_img)))
137+
# masks, centers, diam, mask_diams = anatomical.roi_detect(mproj)
138+
# nmasks = masks.max()
139+
# except:
140+
# print('ERROR importing or running cellpose, continuing without anatomical estimates')
141+
# if tj < nmasks:
142+
# yi, xi = centers[tj]
143+
# ls = mask_diams[tj]
144+
# imap = np.ravel_multi_index((yi, xi), (Lyc, Lxc))
145+
# if nmasks > 0:
146+
# stats = anatomical.refine_masks(stats, patches, seeds, diam, Lyc, Lxc)
147+
# for stat in stats:
148+
# if stat['anatomical']:
149+
# stat['lam'] *= sdmov[stat['ypix'], stat['xpix']]
150+
151+
152+
153+

0 commit comments

Comments
 (0)