Skip to content

Commit c4a3ecb

Browse files
authored
Merge pull request #1184 from rishant3441/cellpose4-support
Adds Cellpose-SAM Support. Verified that it fixes both #1180 and #1178.
2 parents ddb8efd + 39f4ec0 commit c4a3ecb

File tree

5 files changed

+44
-16
lines changed

5 files changed

+44
-16
lines changed

README.md

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ Pachitariu, M., Stringer, C., Schröder, S., Dipoppa, M., Rossi, L. F., Carandin
5050

5151
## Local installation
5252

53+
If you are using a GPU, make sure its drivers and the cuda libraries are correctly installed.
54+
5355
1. Install an [Anaconda](https://www.anaconda.com/download/) distribution of Python -- Choose **Python 3.8** and your operating system. Note you might need to use an anaconda prompt if you did not add anaconda to the path.
5456
2. Open an anaconda prompt / command prompt with `conda` for **python 3** in the path
5557
3. Create a new environment with `conda create --name suite2p python=3.9`.
@@ -76,8 +78,28 @@ This package relies on the awesomeness of [pyqtgraph](http://pyqtgraph.org/), [P
7678

7779
The software has been heavily tested on Windows 10 and Ubuntu 18.04, and less well tested on Mac OS. Please post an [issue](https://github.com/MouseLand/suite2p/issues) if you have installation problems.
7880

81+
### GPU version (CUDA) on Windows or Linux
82+
83+
If you plan on running Cellpose-SAM (anatomical ROI detection), you may want to install a GPU version of *torch*. To use your NVIDIA GPU with python, you will need to make sure the NVIDIA driver for your GPU is installed, check out this [website](https://www.nvidia.com/Download/index.aspx?lang=en-us) to download it. You can also install the CUDA toolkit, or use the pytorch cudatoolkit (installed below with conda). If you have trouble with the below install, we recommend installing the CUDA toolkit yourself, choosing one of the 11.x releases [here](https://developer.nvidia.com/cuda-toolkit-archive).
84+
85+
With the latest versions of pytorch on Linux, as long as the NVIDIA drivers are installed, the GPU version is installed by default with pip. You can check if the GPU support is working by opening the GUI. If the GPU is working then the `GPU` box will be checked and the `CUDA` version will be displayed in the command line.
86+
87+
If it's not working, we will need to remove the CPU version of torch:
88+
~~~
89+
pip uninstall torch
90+
~~~
91+
92+
To install the GPU version of torch, follow the instructions [here](https://pytorch.org/get-started/locally/). The pip or conda installs should work across platforms, you will need torch and torchvision, e.g. for windows + cuda 12.6 the command is
93+
~~~
94+
pip3 install torch torchvision --index-url https://download.pytorch.org/whl/cu126
95+
~~~
96+
97+
Info on how to install several older versions is available [here](https://pytorch.org/get-started/previous-versions/). After install you can check `conda list` for `pytorch`, and its version info should have `cuXX.X`, not `cpu`.
98+
7999
### Installing the latest github version of the code
80100

101+
If you are using a GPU, make sure its drivers and the cuda libraries are correctly installed.
102+
81103
The simplest way is
82104
~~~
83105
pip install git+https://github.com/MouseLand/suite2p.git
@@ -90,6 +112,8 @@ If you want to download and edit the code, and use that version,
90112

91113
### Installation for developers
92114

115+
If you are using a GPU, make sure its drivers and the cuda libraries are correctly installed.
116+
93117
1. Clone the repository and `cd suite2p` in an anaconda prompt / command prompt with `conda` for **python 3** in the path
94118
2. Run `conda create --name suite2p python=3.9`
95119
3. To activate this new environment, run `conda activate suite2p` (you will have to activate every time you want to run suite2p)

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
"matplotlib",
1111
"scipy>=1.9.0",
1212
"scikit-learn",
13-
"cellpose<=3.1.1.2",
13+
"cellpose>=4.0.1",
1414
"scanimage-tiff-reader>=1.4.1"
1515
]
1616

suite2p/default_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -130,10 +130,10 @@ def default_ops():
130130
0, # run cellpose to get masks on 1: max_proj / mean_img; 2: mean_img; 3: mean_img enhanced, 4: max_proj
131131
"diameter": 0, # use diameter for cellpose, if 0 estimate diameter
132132
"cellprob_threshold": 0.0, # cellprob_threshold for cellpose
133-
"flow_threshold": 1.5, # flow_threshold for cellpose
133+
"flow_threshold": 0.4, # flow_threshold for cellpose
134134
"spatial_hp_cp": 0, # high-pass image spatially by a multiple of the diameter
135135
"pretrained_model":
136-
"cyto", # path to pretrained model or model type string in Cellpose (can be user model)
136+
"cpsam", # path to pretrained model or model type string in Cellpose (can be user model)
137137

138138
# classification parameters
139139
"soma_crop":

suite2p/detection/anatomical.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import numpy as np
55
from typing import Any, Dict
66
from scipy.ndimage import find_objects, gaussian_filter
7-
from cellpose.models import CellposeModel, Cellpose
8-
from cellpose import transforms, dynamics
7+
from cellpose.models import CellposeModel
8+
from cellpose import transforms, dynamics, core
99
from cellpose.utils import fill_holes_and_remove_small_masks
1010
from cellpose.transforms import normalize99
1111
import time
@@ -34,7 +34,7 @@ def patch_detect(patches, diam):
3434
print("refining masks using cellpose")
3535
npatches = len(patches)
3636
ly = patches[0].shape[0]
37-
model = Cellpose()
37+
model = CellposeModel(gpu=True if core.use_gpu() else False)
3838
imgs = np.zeros((npatches, ly, ly, 2), np.float32)
3939
for i, m in enumerate(patches):
4040
imgs[i, :, :, 0] = transforms.normalize99(m)
@@ -46,7 +46,7 @@ def patch_detect(patches, diam):
4646
batch_size = 8 * 224 // ly
4747
tic = time.time()
4848
for j in np.arange(0, npatches, batch_size):
49-
y = model.cp.network(imgs[j:j + batch_size])[0]
49+
y = model.net(imgs[j:j + batch_size])[0]
5050
y = y[:, :, ysub[0]:ysub[-1] + 1, xsub[0]:xsub[-1] + 1]
5151
y = y.asnumpy()
5252
for i, yi in enumerate(y):
@@ -99,14 +99,13 @@ def refine_masks(stats, patches, seeds, diam, Lyc, Lxc):
9999
return stats
100100

101101

102-
def roi_detect(mproj, diameter=None, cellprob_threshold=0.0, flow_threshold=1.5,
102+
def roi_detect(mproj, diameter=None, cellprob_threshold=0.0, flow_threshold=0.4,
103103
pretrained_model=None):
104-
pretrained_model = "cyto3" if pretrained_model is None else pretrained_model
105-
if not os.path.exists(pretrained_model):
106-
model = Cellpose(model_type=pretrained_model)
107-
else:
108-
model = CellposeModel(pretrained_model=pretrained_model)
109-
masks = model.eval(mproj, channels=[0, 0], diameter=diameter,
104+
if diameter == 0:
105+
diameter = None
106+
pretrained_model = "cpsam" if pretrained_model is None else pretrained_model
107+
model = CellposeModel(pretrained_model=pretrained_model, gpu=True if core.use_gpu() else False)
108+
masks = model.eval(mproj, diameter=diameter,
110109
cellprob_threshold=cellprob_threshold,
111110
flow_threshold=flow_threshold)[0]
112111
shape = masks.shape

suite2p/detection/chan2detect.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from scipy.ndimage import gaussian_filter
66
from ..extraction import masks
77
from . import utils
8+
import traceback
89
"""
910
identify cells with channel 2 brightness (aka red cells)
1011
@@ -90,7 +91,10 @@ def cellpose_overlap(stats, mimg2):
9091
ypix0, xpix0 = stats[i]["ypix"], stats[i]["xpix"]
9192
smask[ypix0, xpix0] = 1
9293
ious = utils.mask_ious(masks, smask)[0]
93-
iou = ious.max()
94+
if ious.size > 0:
95+
iou = ious.max()
96+
else:
97+
iou = 0.0
9498
redstats[
9599
i,
96100
] = np.array([iou > 0.25, iou]) #this had the wrong dimension
@@ -112,10 +116,11 @@ def detect(ops, stats):
112116
try:
113117
print(">>>> CELLPOSE estimating masks in anatomical channel")
114118
redstats, masks = cellpose_overlap(stats, mimg2)
115-
except:
119+
except Exception as e:
116120
print(
117121
"ERROR importing or running cellpose, continuing without anatomical estimates"
118122
)
123+
traceback.print_exc()
119124

120125
if redstats is None:
121126
redstats = intensity_ratio(ops, stats)

0 commit comments

Comments
 (0)