Skip to content

Commit 61a229f

Browse files
Merge pull request #1134 from computational-cell-analytics/dev
Release 1.7
2 parents 32f07b0 + baedeb8 commit 61a229f

31 files changed

+4978
-578
lines changed

development/apg_example.py

Lines changed: 186 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,186 @@
1+
import os
2+
import time
3+
4+
import h5py
5+
import napari
6+
7+
from micro_sam.sample_data import sample_data_hela_2d
8+
from micro_sam.instance_segmentation import (
9+
TiledAutomaticPromptGenerator, AutomaticPromptGenerator, get_predictor_and_decoder
10+
)
11+
from micro_sam.util import precompute_image_embeddings
12+
from elf.wrapper.resized_volume import ResizedVolume
13+
14+
15+
# TODO example with a custom prompt function
16+
def example_script():
17+
im = sample_data_hela_2d()[0][0]
18+
19+
predictor, decoder = get_predictor_and_decoder(model_type="vit_b_lm")
20+
image_embeddings = precompute_image_embeddings(predictor, im, save_path="x.zarr")
21+
generator = AutomaticPromptGenerator(predictor, decoder)
22+
generator.initialize(im, image_embeddings=image_embeddings)
23+
segmentation = generator.generate(intersection_over_min=True)
24+
25+
v = napari.Viewer()
26+
v.add_image(im)
27+
v.add_labels(segmentation)
28+
napari.run()
29+
30+
31+
def example_script_tiled():
32+
im = sample_data_hela_2d()[0][0]
33+
34+
tile_shape, halo = (256, 256), (64, 64)
35+
predictor, decoder = get_predictor_and_decoder(model_type="vit_b_lm")
36+
image_embeddings = precompute_image_embeddings(predictor, im, tile_shape=tile_shape, halo=halo, save_path="y.zarr")
37+
generator = TiledAutomaticPromptGenerator(predictor, decoder)
38+
generator.initialize(im, image_embeddings=image_embeddings, tile_shape=tile_shape, halo=halo, verbose=True)
39+
segmentation = generator.generate(intersection_over_min=False)
40+
41+
v = napari.Viewer()
42+
v.add_image(im)
43+
v.add_labels(segmentation)
44+
napari.run()
45+
46+
47+
def _require_wsi_data():
48+
out_path = "./data/wsi.h5"
49+
if os.path.exists(out_path):
50+
return out_path
51+
52+
from micro_sam.sample_data import fetch_wholeslide_histopathology_example_data
53+
from patho_sam.io.util import read_wsi
54+
55+
example_data = fetch_wholeslide_histopathology_example_data("./data")
56+
data = read_wsi(example_data)
57+
shape = data.shape[:2]
58+
59+
with h5py.File(out_path, "w") as f:
60+
f.create_dataset("data/s0", data=data, compression="gzip")
61+
for level in range(1, 5):
62+
ds_shape = tuple(sh // (2 ** level) for sh in shape)
63+
print(level, ds_shape)
64+
data = read_wsi(example_data, scale=ds_shape)
65+
f.create_dataset(f"data/s{level}", data=data, compression="gzip")
66+
67+
os.remove(example_data)
68+
return out_path
69+
70+
71+
def _require_mask(path, level=4, bg_threshold=240, window=15, majority_threshold=0.3):
72+
mask_key = f"mask/s{level}"
73+
with h5py.File(path, "a") as f:
74+
full_shape = f["data/s0"].shape[:2]
75+
if mask_key in f:
76+
mask = f[mask_key][:]
77+
else:
78+
from scipy.ndimage import uniform_filter
79+
image = f[f"data/s{level}"][:]
80+
mask = (image > bg_threshold).all(axis=-1)
81+
mask = uniform_filter(mask.astype("float"), size=window)
82+
mask = ~(mask >= majority_threshold)
83+
f.create_dataset(mask_key, data=mask, compression="gzip")
84+
85+
resized_mask = ResizedVolume(mask, shape=full_shape, order=0)
86+
return resized_mask
87+
88+
89+
def example_script_wsi():
90+
data_path = _require_wsi_data()
91+
mask = _require_mask(data_path)
92+
93+
tile_shape, halo = (768, 768), (64, 64)
94+
predictor, decoder = get_predictor_and_decoder(model_type="vit_b_histopathology")
95+
96+
with h5py.File(data_path, "r") as f:
97+
data = f["data/s0"][:]
98+
print("Run prediction for WSI of shape:", data.shape)
99+
100+
# Processing time: 10:34 min (batch size 24 on an A100 with 80 GB)
101+
# WITH MASK: 3:33 min (+ some further optimizartions)
102+
embed_path = "./data/embeds.zarr"
103+
image_embeddings = precompute_image_embeddings(
104+
predictor, data, tile_shape=tile_shape, halo=halo, save_path=embed_path, batch_size=24, ndim=2, mask=mask,
105+
)
106+
107+
# Processing time: 03:14 min (batch size 24 on an A100 with 80 GB)
108+
# WITH MASK: 34 seconds
109+
generator = TiledAutomaticPromptGenerator(predictor, decoder)
110+
generator.initialize(
111+
data, image_embeddings=image_embeddings, tile_shape=tile_shape, halo=halo, verbose=True, batch_size=12
112+
)
113+
114+
# Processing time: 21:12 min
115+
# Out of this 18:09 for the batched prediction, the rest for pre/post-processing.
116+
# WITH MASK: 19:59 min (total time).
117+
print("Start generate ...")
118+
t0 = time.time()
119+
seg = generator.generate(batch_size=32, optimize_memory=True)
120+
print("Generate took:", time.time() - t0, "s")
121+
print(seg.shape)
122+
123+
# Save the segmentation to check the result
124+
with h5py.File("./data/seg.h5", "w") as f:
125+
f.create_dataset("seg", data=seg, compression="gzip")
126+
127+
128+
def example_script_3d():
129+
data_path = "./data/N_522_final_crop_ds2.h5"
130+
with h5py.File(data_path, "r") as f:
131+
data = f["raw"][:]
132+
mask = f["label"][:] > 0
133+
134+
tile_shape, halo = (512, 512), (64, 64)
135+
predictor, decoder = get_predictor_and_decoder(model_type="vit_b_lm")
136+
137+
embed_path = "./data/embeds_3d.zarr"
138+
image_embeddings = precompute_image_embeddings(
139+
predictor, data, tile_shape=tile_shape, halo=halo, save_path=embed_path, batch_size=12, ndim=3, mask=mask,
140+
)
141+
142+
z = 50
143+
generator = TiledAutomaticPromptGenerator(predictor, decoder)
144+
generator.initialize(
145+
data[z], image_embeddings=image_embeddings, tile_shape=tile_shape,
146+
halo=halo, verbose=True, batch_size=12, mask=mask, i=z,
147+
)
148+
seg = generator.generate(batch_size=12, optimize_memory=True)
149+
150+
with h5py.File(f"./data/seg_z{z}.h5", "w") as f:
151+
f.create_dataset("seg", data=seg, compression="gzip")
152+
153+
154+
def debug_wsi():
155+
from micro_sam.inference import _stitch_segmentation
156+
from nifty.tools import blocking
157+
from tqdm import tqdm
158+
159+
print("Load data for debugging ....")
160+
masks = []
161+
with h5py.File("./debug.h5", mode="r") as f:
162+
tile_ids = f["tile_ids"][:]
163+
g = f["masks"]
164+
for tile_id in tqdm(tile_ids, desc="Load masks"):
165+
masks.append(g[str(tile_id)][:])
166+
167+
halo = f.attrs["halo"]
168+
shape = f.attrs["shape"]
169+
tile_shape = f.attrs["tile_shape"]
170+
171+
tiling = blocking([0, 0], shape, tile_shape)
172+
print("Start stitching ...")
173+
seg = _stitch_segmentation(masks, tile_ids, tiling, halo, output_shape=shape)
174+
print(seg.shape)
175+
176+
177+
def main():
178+
# example_script()
179+
# example_script_tiled()
180+
# example_script_wsi()
181+
example_script_3d()
182+
# debug_wsi()
183+
184+
185+
if __name__ == "__main__":
186+
main()

environment.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ dependencies:
2323
- pytorch >=2.5
2424
- segment-anything
2525
- torchvision
26-
- torch_em >=0.7.10
26+
- torch_em >=0.8
2727
- tqdm
2828
- timm
2929
- trackastra

micro_sam/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = "1.6.2"
1+
__version__ = "1.7.0"

0 commit comments

Comments
 (0)