Skip to content

Commit 6b4e64c

Browse files
committed
Use GPU If available for Cellpose
1 parent ee83979 commit 6b4e64c

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

suite2p/detection/anatomical.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Any, Dict
66
from scipy.ndimage import find_objects, gaussian_filter
77
from cellpose.models import CellposeModel
8-
from cellpose import transforms, dynamics
8+
from cellpose import transforms, dynamics, core
99
from cellpose.utils import fill_holes_and_remove_small_masks
1010
from cellpose.transforms import normalize99
1111
import time
@@ -34,7 +34,7 @@ def patch_detect(patches, diam):
3434
print("refining masks using cellpose")
3535
npatches = len(patches)
3636
ly = patches[0].shape[0]
37-
model = CellposeModel()
37+
model = CellposeModel(gpu=True if core.use_gpu() else False)
3838
imgs = np.zeros((npatches, ly, ly, 2), np.float32)
3939
for i, m in enumerate(patches):
4040
imgs[i, :, :, 0] = transforms.normalize99(m)
@@ -104,7 +104,7 @@ def roi_detect(mproj, diameter=None, cellprob_threshold=0.0, flow_threshold=0.4,
104104
if diameter == 0:
105105
diameter = None
106106
pretrained_model = "cpsam" if pretrained_model is None else pretrained_model
107-
model = CellposeModel(pretrained_model=pretrained_model)
107+
model = CellposeModel(pretrained_model=pretrained_model, gpu=True if core.use_gpu() else False)
108108
masks = model.eval(mproj, diameter=diameter,
109109
cellprob_threshold=cellprob_threshold,
110110
flow_threshold=flow_threshold)[0]

0 commit comments

Comments
 (0)