-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfineTuneSam.py
105 lines (93 loc) · 3.84 KB
/
fineTuneSam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import warnings
warnings.filterwarnings("ignore")
from glob import glob
from IPython.display import FileLink
import numpy as np
import imageio.v3 as imageio
from matplotlib import pyplot as plt
from skimage.measure import label as connected_components
import torch
from torch_em.util.debug import check_loader
from torch_em.data import MinInstanceSampler
from torch_em.util.util import get_random_colors
import micro_sam.training as sam_training
from micro_sam.sample_data import fetch_tracking_example_data, fetch_tracking_segmentation_data
from micro_sam.automatic_segmentation import get_predictor_and_segmenter, automatic_instance_segmentation
from PIL import Image
import glob
import os
# Update the paths for your image and segmentation mask directories
image_dir = "test/imgs"
segmentation_dir = "test/masks"
# Create a list of all supported image formats (png, tif, jpg)
image_paths = glob.glob(os.path.join(image_dir, "*.tif")) + \
glob.glob(os.path.join(image_dir, "*.png")) + \
glob.glob(os.path.join(image_dir, "*.jpg"))
mask_paths = glob.glob(os.path.join(segmentation_dir, "*_masks.tif")) + \
glob.glob(os.path.join(segmentation_dir, "*_masks.png")) + \
glob.glob(os.path.join(segmentation_dir, "*_masks.jpg"))
# Convert all RGB images to grayscale to ensure correct input shape
for image_path in image_paths:
img = Image.open(image_path)
if img.mode != 'L': # Convert to grayscale if not already
img = img.convert('L')
img.save(image_path)
for mask_path in mask_paths:
mask = Image.open(mask_path)
if mask.mode != 'L':
mask = mask.convert('L')
mask.save(mask_path)
# Training parameters
batch_size = 1 # Adjust the batch size if needed
patch_shape = ( 512, 512) # Define the patch shape, assuming 2D images
# Enable training for instance segmentation
train_instance_segmentation = True
# Define the sampler to ensure at least one foreground instance per input
sampler = MinInstanceSampler(min_size=25)
# Creating the data loader for training
train_loader = sam_training.default_sam_loader(
raw_paths=image_paths, # Explicitly pass the list of paths
raw_key=None, # Set to None because paths are explicitly provided
label_paths=mask_paths, # Explicitly pass the list of mask paths
label_key=None, # Set to None because paths are explicitly provided
with_segmentation_decoder=train_instance_segmentation,
patch_shape=patch_shape,
batch_size=batch_size,
is_seg_dataset=True,
shuffle=True,
raw_transform=sam_training.identity,
sampler=sampler,
)
# Creating the data loader for validation (using same approach as training)
val_loader = sam_training.default_sam_loader(
raw_paths=image_paths, # Explicitly pass the list of paths
raw_key=None, # Set to None because paths are explicitly provided
label_paths=mask_paths, # Explicitly pass the list of mask paths
label_key=None, # Set to None because paths are explicitly provided
with_segmentation_decoder=train_instance_segmentation,
patch_shape=patch_shape,
batch_size=batch_size,
is_seg_dataset=True,
shuffle=True,
raw_transform=sam_training.identity,
sampler=sampler,
)
#check_loader(train_loader, 4, plt=True) # it is working properly !
n_objects_per_batch = 5 # the number of objects per batch that will be sampled
device = "cuda" if torch.cuda.is_available() else "cpu" # the device/GPU used for training
print(device)
n_epochs = 5 # how long we train (in epochs)
model_type = "vit_b"
checkpoint_name = "sam_hela"
root_dir = ''
sam_training.train_sam(
name=checkpoint_name,
save_root=os.path.join(root_dir, "models"),
model_type=model_type,
train_loader=train_loader,
val_loader=val_loader,
n_epochs=n_epochs,
n_objects_per_batch=n_objects_per_batch,
with_segmentation_decoder=train_instance_segmentation,
device=device,
)