|
| 1 | +import numpy as np |
| 2 | +from scipy.ndimage import find_objects |
| 3 | +from cellpose.models import Cellpose |
| 4 | +from cellpose import transforms, dynamics |
| 5 | +from cellpose.utils import fill_holes_and_remove_small_masks |
| 6 | +from mxnet import nd |
| 7 | +import time |
| 8 | +import cv2 |
| 9 | + |
| 10 | +from . import utils |
| 11 | +from .stats import roi_stats |
| 12 | + |
| 13 | +def mask_centers(masks): |
| 14 | + centers = np.zeros((masks.max(), 2), np.int32) |
| 15 | + diams = np.zeros(masks.max(), np.float32) |
| 16 | + slices = find_objects(masks) |
| 17 | + for i,si in enumerate(slices): |
| 18 | + if si is not None: |
| 19 | + sr,sc = si |
| 20 | + ymed, xmed, diam = utils.mask_stats(masks[sr, sc] == (i+1)) |
| 21 | + centers[i] = np.array([ymed, xmed]) |
| 22 | + diams[i] = diam |
| 23 | + return centers, diams |
| 24 | + |
| 25 | +def patch_detect(patches, diam): |
| 26 | + """ anatomical detection of masks from top active frames for putative cell """ |
| 27 | + print('refining masks using cellpose') |
| 28 | + npatches = len(patches) |
| 29 | + ly = patches[0].shape[0] |
| 30 | + model = Cellpose(net_avg=False) |
| 31 | + imgs = np.zeros((npatches, ly, ly, 2), np.float32) |
| 32 | + for i,m in enumerate(patches): |
| 33 | + imgs[i,:,:,0] = transforms.normalize99(m) |
| 34 | + rsz = 30. / diam |
| 35 | + imgs = transforms.resize_image(imgs, rsz=rsz).transpose(0,3,1,2) |
| 36 | + imgs, ysub, xsub = transforms.pad_image_ND(imgs) |
| 37 | + |
| 38 | + pmasks = np.zeros((npatches, ly, ly), np.uint16) |
| 39 | + batch_size = 8 * 224 // ly |
| 40 | + tic=time.time() |
| 41 | + for j in np.arange(0, npatches, batch_size): |
| 42 | + img = nd.array(imgs[j:j+batch_size]) |
| 43 | + y = model.cp.net(img)[0] |
| 44 | + y = y[:, :, ysub[0]:ysub[-1]+1, xsub[0]:xsub[-1]+1] |
| 45 | + y = y.asnumpy() |
| 46 | + for i,yi in enumerate(y): |
| 47 | + cellprob = yi[-1] |
| 48 | + dP = yi[:2] |
| 49 | + niter = 1 / rsz * 200 |
| 50 | + p = dynamics.follow_flows(-1 * dP * (cellprob>0) / 5., |
| 51 | + niter=niter) |
| 52 | + maski = dynamics.get_masks(p, iscell=(cellprob>0), |
| 53 | + flows=dP, threshold=1.0) |
| 54 | + maski = fill_holes_and_remove_small_masks(maski) |
| 55 | + maski = transforms.resize_image(maski, ly, ly, |
| 56 | + interpolation=cv2.INTER_NEAREST) |
| 57 | + pmasks[j+i] = maski |
| 58 | + if j%5==0: |
| 59 | + print('%d / %d masks created in %0.2fs'%(j+batch_size, npatches, time.time()-tic)) |
| 60 | + return pmasks |
| 61 | + |
| 62 | +def refine_masks(stats, patches, seeds, diam, Lyc, Lxc): |
| 63 | + nmasks = len(patches) |
| 64 | + patch_masks = patch_detect(patches, diam) |
| 65 | + ly = patches[0].shape[0] // 2 |
| 66 | + igood = np.zeros(nmasks, np.bool) |
| 67 | + for i, (patch_mask, stat, (yi,xi)) in enumerate(zip(patch_masks, stats, seeds)): |
| 68 | + mask = np.zeros((Lyc, Lxc), np.float32) |
| 69 | + ypix0, xpix0= stat['ypix'], stat['xpix'] |
| 70 | + mask[ypix0, xpix0] = stat['lam'] |
| 71 | + func_mask = utils.square_mask(mask, ly, yi, xi) |
| 72 | + ious = utils.mask_ious(patch_mask.astype(np.uint16), |
| 73 | + (func_mask>0).astype(np.uint16))[0] |
| 74 | + if len(ious)>0 and ious.max() > 0.45: |
| 75 | + mask_id = np.argmax(ious) + 1 |
| 76 | + patch_mask = patch_mask[max(0, ly-yi) : min(2*ly, Lyc+ly-yi), |
| 77 | + max(0, ly-xi) : min(2*ly, Lxc+ly-xi)] |
| 78 | + func_mask = func_mask[max(0, ly-yi) : min(2*ly, Lyc+ly-yi), |
| 79 | + max(0, ly-xi) : min(2*ly, Lxc+ly-xi)] |
| 80 | + ypix0, xpix0 = np.nonzero(patch_mask==mask_id) |
| 81 | + lam0 = func_mask[ypix0, xpix0] |
| 82 | + lam0[lam0<=0] = lam0.min() |
| 83 | + ypix0 = ypix0 + max(0, yi-ly) |
| 84 | + xpix0 = xpix0 + max(0, xi-ly) |
| 85 | + igood[i] = True |
| 86 | + stat['ypix'] = ypix0 |
| 87 | + stat['xpix'] = xpix0 |
| 88 | + stat['lam'] = lam0 |
| 89 | + stat['anatomical'] = True |
| 90 | + else: |
| 91 | + stat['anatomical'] = False |
| 92 | + return stats |
| 93 | + |
| 94 | +def roi_detect(mproj, diameter=None): |
| 95 | + model = Cellpose() |
| 96 | + masks = model.eval(mproj, net_avg=True, channels=[0,0], diameter=diameter, flow_threshold=1.5)[0] |
| 97 | + shape = masks.shape |
| 98 | + _, masks = np.unique(np.int32(masks), return_inverse=True) |
| 99 | + masks = masks.reshape(shape) |
| 100 | + centers, mask_diams = mask_centers(masks) |
| 101 | + median_diam = np.median(mask_diams) |
| 102 | + print('>>>> %d masks detected, median diameter = %0.2f ' % (masks.max(), median_diam)) |
| 103 | + return masks, centers, median_diam, mask_diams.astype(np.int32) |
| 104 | + |
| 105 | +def masks_to_stats(masks, weights): |
| 106 | + stats = [] |
| 107 | + slices = find_objects(masks) |
| 108 | + for i,si in enumerate(slices): |
| 109 | + sr,sc = si |
| 110 | + ypix0, xpix0 = np.nonzero(masks[sr, sc]==(i+1)) |
| 111 | + ypix0 = ypix0.astype(int) + sr.start |
| 112 | + xpix0 = xpix0.astype(int) + sc.start |
| 113 | + stats.append({ |
| 114 | + 'ypix': ypix0, |
| 115 | + 'xpix': xpix0, |
| 116 | + 'lam': weights[ypix0, xpix0], |
| 117 | + 'footprint': 1 |
| 118 | + }) |
| 119 | + return stats |
| 120 | + |
| 121 | +def select_rois(meanImg, weights, Ly, Lx, ymin, xmin): |
| 122 | + masks, centers, median_diam, mask_diams = roi_detect(meanImg) |
| 123 | + stats = masks_to_stats(masks, weights) |
| 124 | + for stat in stats: |
| 125 | + stat['ypix'] += int(ymin) |
| 126 | + stat['xpix'] += int(xmin) |
| 127 | + stats = roi_stats(stats, median_diam, median_diam, Ly, Lx) |
| 128 | + return stats |
| 129 | + |
| 130 | +# def run_assist(): |
| 131 | +# nmasks, diam = 0, None |
| 132 | +# if anatomical: |
| 133 | +# try: |
| 134 | +# print('>>>> CELLPOSE estimating spatial scale and masks as seeds for functional algorithm') |
| 135 | +# from . import anatomical |
| 136 | +# mproj = np.log(np.maximum(1e-3, max_proj / np.maximum(1e-3, mean_img))) |
| 137 | +# masks, centers, diam, mask_diams = anatomical.roi_detect(mproj) |
| 138 | +# nmasks = masks.max() |
| 139 | +# except: |
| 140 | +# print('ERROR importing or running cellpose, continuing without anatomical estimates') |
| 141 | +# if tj < nmasks: |
| 142 | +# yi, xi = centers[tj] |
| 143 | +# ls = mask_diams[tj] |
| 144 | +# imap = np.ravel_multi_index((yi, xi), (Lyc, Lxc)) |
| 145 | +# if nmasks > 0: |
| 146 | +# stats = anatomical.refine_masks(stats, patches, seeds, diam, Lyc, Lxc) |
| 147 | +# for stat in stats: |
| 148 | +# if stat['anatomical']: |
| 149 | +# stat['lam'] *= sdmov[stat['ypix'], stat['xpix']] |
| 150 | + |
| 151 | + |
| 152 | + |
| 153 | + |
0 commit comments