Skip to content

Commit cdd3ba3

Browse files
Fix tests
1 parent 1060842 commit cdd3ba3

File tree

4 files changed

+11
-5
lines changed

4 files changed

+11
-5
lines changed

development/apg_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ def example_script_wsi():
108108
# WITH MASK: 34 seconds
109109
generator = TiledAutomaticPromptGenerator(predictor, decoder)
110110
generator.initialize(
111-
data, image_embeddings=image_embeddings, tile_shape=tile_shape, halo=halo, verbose=True, batch_size=24
111+
data, image_embeddings=image_embeddings, tile_shape=tile_shape, halo=halo, verbose=True, batch_size=12
112112
)
113113

114114
# Processing time: 21:12 min
@@ -177,7 +177,7 @@ def debug_wsi():
177177
def main():
178178
# example_script()
179179
# example_script_tiled()
180-
example_script_wsi()
180+
# example_script_wsi()
181181
example_script_3d()
182182
# debug_wsi()
183183

micro_sam/multi_dimensional_segmentation.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,8 @@ def merge_instance_segmentation_3d(
349349

350350
# Extract the overlap between slices.
351351
edges = track_utils.compute_edges_from_overlap(slice_segmentation, verbose=False)
352+
if len(edges) == 0: # Nothing to merge.
353+
return slice_segmentation
352354

353355
uv_ids = np.array([[edge["source"], edge["target"]] for edge in edges])
354356
overlaps = np.array([edge["score"] for edge in edges])

micro_sam/util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -841,6 +841,7 @@ def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init
841841
msg = "Compute Image Embeddings 3D tiled"
842842
if mask is None:
843843
n_tiles_total = n_slices * n_tiles_per_plane
844+
tiles_in_mask_per_slice = None
844845
else:
845846
tiles_in_mask_per_slice = {}
846847
for z in range(n_slices):
@@ -867,7 +868,7 @@ def _compute_tiled_features_3d(predictor, input_, tile_shape, halo, f, pbar_init
867868
pbar_update(len(tile_ids))
868869

869870
if mask is not None:
870-
features.attrs["tiles_in_mask"] = tiles_in_mask_per_slice
871+
features.attrs["tiles_in_mask"] = {str(z): per_slice for z, per_slice in tiles_in_mask_per_slice.items()}
871872

872873
_write_embedding_signature(f, input_, predictor, tile_shape, halo, input_size=None, original_size=None)
873874
return features

test/test_automatic_segmentation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
import unittest
22

33
import numpy as np
4+
import torch
45
from skimage.draw import disk
56
from skimage.measure import label as connected_components
67

78
import micro_sam.util as util
89

10+
HAVE_CUDA = torch.cuda.is_available()
11+
912

1013
class TestAutomaticSegmentation(unittest.TestCase):
1114
model_type = "vit_t" if util.VIT_T_SUPPORT else "vit_b"
@@ -116,7 +119,7 @@ def test_tiled_instance_segmentation_with_decoder_2d(self):
116119
)
117120
self.assertEqual(mask.shape, instances.shape)
118121

119-
@unittest.skip("Skipping long running tests by default.")
122+
@unittest.skipUnless(HAVE_CUDA, "Skipping long running tests unless we have a GPU.")
120123
def test_automatic_mask_generator_3d(self):
121124
from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter
122125

@@ -129,7 +132,7 @@ def test_automatic_mask_generator_3d(self):
129132
)
130133
self.assertEqual(labels.shape, instances.shape)
131134

132-
@unittest.skip("Skipping long running tests by default.")
135+
@unittest.skipUnless(HAVE_CUDA, "Skipping long running tests unless we have a GPU.")
133136
def test_tiled_automatic_mask_generator_3d(self):
134137
from micro_sam.automatic_segmentation import automatic_instance_segmentation, get_predictor_and_segmenter
135138

0 commit comments

Comments
 (0)