|
5 | 5 | from typing import Any, Dict |
6 | 6 | from scipy.ndimage import find_objects, gaussian_filter |
7 | 7 | from cellpose.models import CellposeModel |
8 | | -from cellpose import transforms, dynamics |
| 8 | +from cellpose import transforms, dynamics, core |
9 | 9 | from cellpose.utils import fill_holes_and_remove_small_masks |
10 | 10 | from cellpose.transforms import normalize99 |
11 | 11 | import time |
@@ -34,7 +34,7 @@ def patch_detect(patches, diam): |
34 | 34 | print("refining masks using cellpose") |
35 | 35 | npatches = len(patches) |
36 | 36 | ly = patches[0].shape[0] |
37 | | - model = CellposeModel() |
| 37 | + model = CellposeModel(gpu=True if core.use_gpu() else False) |
38 | 38 | imgs = np.zeros((npatches, ly, ly, 2), np.float32) |
39 | 39 | for i, m in enumerate(patches): |
40 | 40 | imgs[i, :, :, 0] = transforms.normalize99(m) |
@@ -104,7 +104,7 @@ def roi_detect(mproj, diameter=None, cellprob_threshold=0.0, flow_threshold=0.4, |
104 | 104 | if diameter == 0: |
105 | 105 | diameter = None |
106 | 106 | pretrained_model = "cpsam" if pretrained_model is None else pretrained_model |
107 | | - model = CellposeModel(pretrained_model=pretrained_model) |
| 107 | + model = CellposeModel(pretrained_model=pretrained_model, gpu=True if core.use_gpu() else False) |
108 | 108 | masks = model.eval(mproj, diameter=diameter, |
109 | 109 | cellprob_threshold=cellprob_threshold, |
110 | 110 | flow_threshold=flow_threshold)[0] |
|
0 commit comments