Skip to content

Commit c597159

Browse files
committed
Segmentation refactored, now for the point tracking
1 parent a111770 commit c597159

9 files changed

Lines changed: 283 additions & 128 deletions

source/GUI/ImageViewerWidget.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,10 @@ def convertImage(self, image):
5252

5353
h, w, ch = image.shape
5454
bytesPerLine = ch * w
55-
return QImage(image.data, w, h, bytesPerLine, QImage.Format_BGR888)
55+
return QImage(image.copy().data, w, h, bytesPerLine, QImage.Format_BGR888)
5656

5757
def updateImage(self, image, widget):
58-
widget.setPixmap(QPixmap.fromImage(self.convertImage(image).scaledToWidth(128)))
58+
widget.setPixmap(QPixmap.fromImage(self.convertImage(image)))
5959

6060
def getWidget(self, key):
6161
return self.imageDICT[key]

source/GUI/MainMenuWidget.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ def __init__(self, viewer_palette, parent=None):
2828
"Segmentation",
2929
[
3030
("Koc et al", "checkbox", False),
31-
("Neural Segmentation", "checkbox", False),
32-
("Silicone Segmentation", "checkbox", True),
31+
("Neural Segmentation", "checkbox", True),
32+
("Silicone Segmentation", "checkbox", False),
3333
],
3434
)
3535
self.addSubMenu(

source/KocSegmentation.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,9 +198,9 @@ def generateSegmentationData(self):
198198
self.generate()
199199

200200
for i, image in enumerate(self.images):
201-
self._segmentations.append(self.segmentImage(image))
202-
self._glottal_outlines.append(self.computeGlottalOutline(i))
203-
self._glottal_midlines.append(self.computeGlottalMidline(i))
201+
self.segmentations.append(self.segmentImage(image))
202+
self.glottal_outlines.append(self.computeGlottalOutline(i))
203+
self.glottal_midlines.append(self.computeGlottalMidline(i))
204204

205205
self.closedGlottisIndex = self.estimateClosedGlottis()
206206
self.openGlottisIndex = self.estimateOpenGlottis()

source/NeuralSegmentation.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,125 @@ def forward(self, x):
7070

7171

7272

73+
class DoubleConvB(nn.Module):
74+
def __init__(self, in_channels, out_channels):
75+
super(DoubleConvB, self).__init__()
76+
self.conv = nn.Sequential(
77+
nn.Conv2d(in_channels, out_channels, 3, 1, 1, bias=False),
78+
nn.BatchNorm2d(out_channels),
79+
nn.ReLU(inplace=True),
80+
nn.Conv2d(out_channels, out_channels, 3, 1, 1, bias=False),
81+
nn.BatchNorm2d(out_channels),
82+
nn.ReLU(inplace=True))
83+
84+
def forward(self, x):
85+
return self.conv(x)
86+
87+
88+
class Decoder(nn.Module):
89+
def __init__(self, encoder, out_channels, features):
90+
super(Decoder, self).__init__()
91+
self.ups = nn.ModuleList()
92+
self.encoder = encoder
93+
self.out_channels=out_channels
94+
95+
for feature in reversed(features):
96+
self.ups.append(nn.ConvTranspose2d(feature*2, feature, kernel_size=2, stride=2))
97+
self.ups.append(DoubleConvB(feature*2, feature))
98+
99+
100+
def forward(self, x):
101+
for idx in range(0, len(self.ups), 2):
102+
x = self.ups[idx](x)
103+
skip_connection = self.encoder.skip_connections[idx//2]
104+
105+
if x.shape != skip_connection.shape:
106+
x = TF.resize(x, size=skip_connection.shape[2:])
107+
108+
concat_skip = torch.cat((skip_connection, x), dim=1)
109+
x = self.ups[idx+1](concat_skip)
110+
111+
return x
112+
113+
114+
class Encoder(nn.Module):
115+
def __init__(self, in_channels, features):
116+
super(Encoder, self).__init__()
117+
self.downs = nn.ModuleList()
118+
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
119+
self.in_channels = in_channels
120+
121+
#Downsampling
122+
for feature in features:
123+
self.downs.append(DoubleConvB(in_channels, feature))
124+
in_channels = feature
125+
126+
def forward(self, x):
127+
self.skip_connections = []
128+
for down in self.downs:
129+
x = down(x)
130+
self.skip_connections.append(x)
131+
x = self.pool(x)
132+
133+
self.skip_connections = self.skip_connections[::-1]
134+
135+
return x
136+
137+
138+
139+
class UNETNew(nn.Module):
140+
def __init__(self, config={'in_channels': 3, 'out_channels': 4, 'features': [32, 64, 128, 256, 512]}, state_dict=None, pretrain=False, device="cuda"):
141+
super(UNETNew, self).__init__()
142+
try:
143+
in_channels = config['in_channels']
144+
except:
145+
in_channels = 3
146+
147+
try:
148+
out_channels = config['out_channels']
149+
except:
150+
out_channels = 4
151+
152+
features = config['features']
153+
154+
self.bottleneck_size = features[-1]*2
155+
156+
self.encoder = Encoder(in_channels, features)
157+
self.decoder = Decoder(self.encoder, out_channels, features)
158+
self.bottleneck = DoubleConv(features[-1], self.bottleneck_size)
159+
self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)
160+
161+
if state_dict:
162+
self.load_from_dict(state_dict)
163+
164+
if pretrain:
165+
self.encoder.requires_grad_ = False
166+
167+
def get_statedict(self):
168+
return {"Encoder": self.encoder.state_dict(),
169+
"Bottleneck": self.bottleneck.state_dict(),
170+
"Decoder": self.decoder.state_dict(),
171+
"LastConv": self.final_conv.state_dict()}
172+
173+
def load_from_dict(self, dict):
174+
self.encoder.load_state_dict(dict["Encoder"])
175+
self.bottleneck.load_state_dict(dict["Bottleneck"])
176+
self.decoder.load_state_dict(dict["Decoder"])
177+
178+
try:
179+
self.final_conv.load_state_dict(dict["LastConv"])
180+
except:
181+
print("Final conv not initialized.")
182+
183+
def forward(self, x):
184+
x = self.encoder(x)
185+
x = self.bottleneck(x)
186+
x = self.decoder(x)
187+
188+
return self.final_conv(x)
189+
190+
191+
73192
class NeuralSegmentator(BaseSegmentator):
74193
def __init__(self, images, path="assets/model.pth.tar"):
75194
super().__init__(images)

source/Viewer.py

Lines changed: 26 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import helper
1414
import igl
1515
import KocSegmentation
16+
import kornia
1617
import Laser
1718
import Mesh
1819
import NeuralSegmentation
@@ -24,6 +25,7 @@
2425
import SiliconeSegmentation
2526
import SiliconeSurfaceReconstruction
2627
import surface_reconstruction
28+
import torch
2729
import Triangulation
2830
import VoronoiRHC
2931
from GraphWidget import GraphWidget
@@ -158,11 +160,18 @@ def __init__(self):
158160
self.timer_thread.start()
159161
self.image_timer_thread.start()
160162

163+
path = "/media/nu94waro/Windows_C/save/datasets/HLEDataset/dataset"
161164
self.loadData(
162165
"assets/camera_calibration.json",
163166
"assets/laser_calibration.json",
164167
"assets/example_vid.avi",
165168
)
169+
'''
170+
self.loadData(
171+
os.path.join(path, "camera_calibration.json"),
172+
os.path.join(path, "laser_calibration.json"),
173+
os.path.join(path, "MK/MK.avi"),
174+
)'''
166175

167176
self._reconstruction_pipeline = reconstruction_pipeline.ReconstructionPipeline(
168177
self.camera,
@@ -442,70 +451,36 @@ def loadData(self, camera_path, laser_path, video_path):
442451

443452
self.images_set = True
444453

454+
455+
456+
457+
445458
def segmentImages(self):
446-
if self.menu_widget.widget().getSubmenuValue("Segmentation", "Koc et al"):
447-
self.segmentator = KocSegmentation.KocSegmentator(self.images)
448-
elif self.menu_widget.widget().getSubmenuValue("Segmentation", "Neural Segmentation"):
449-
self.segmentator = NeuralSegmentation.NeuralSegmentator(self.images)
459+
segmentator: feature_estimation.FeatureEstimator = None
460+
if self.menu_widget.widget().getSubmenuValue("Segmentation", "Neural Segmentation"):
461+
segmentator = feature_estimation.NeuralFeatureEstimator("bla")
450462
elif self.menu_widget.widget().getSubmenuValue("Segmentation", "Silicone Segmentation"):
451-
self.segmentator = SiliconeSegmentation.SiliconeSegmentator(self.images)
463+
segmentator = feature_estimation.SiliconeFeatureEstimator()
452464
else:
453465
print("Please choose a Segmentation Algorithm")
454466

455-
x, w, y, h = self.segmentator.getROI()
456-
self.roi = self.segmentator.getROIImage()
467+
self._reconstruction_pipeline.set_feature_estimator(segmentator)
468+
469+
images = torch.from_numpy(np.stack(self.images)).to("cuda")
470+
segmentator.compute_features(images)
457471

458472
segmentations = list()
459-
laserdots = list()
460-
461-
for index in range(len(self.segmentator)):
462-
base_image = self.segmentator.getImage(index).copy()
473+
feature_images = segmentator.create_feature_images()
463474

464-
segmentation_image = self.segmentator.getSegmentation(index).copy()
465-
gml_a, gml_b = self.segmentator.getGlottalMidline(index)
475+
for feature_image in feature_images:
476+
segmentations.append(feature_image.permute(1, 2, 0).detach().cpu().numpy().astype(np.uint8))
466477

467-
segmentation_image = cv2.cvtColor(segmentation_image, cv2.COLOR_GRAY2BGR)
468-
469-
cv2.rectangle(
470-
segmentation_image,
471-
(x, y),
472-
(x + w, y + h),
473-
color=(255, 0, 0),
474-
thickness=2,
475-
)
476-
try:
477-
cv2.line(
478-
segmentation_image,
479-
gml_a.astype(np.int32),
480-
gml_b.astype(np.int32),
481-
color=(125, 125, 0),
482-
thickness=2,
483-
)
484-
except:
485-
pass
486-
segmentations.append(
487-
cv2.cvtColor(base_image, cv2.COLOR_GRAY2BGR) | segmentation_image
488-
)
489-
490-
laserdot_image = self.segmentator.getLocalMaxima(index).copy()
491-
laserdot_image = cv2.dilate(laserdot_image, np.ones((3, 3)))
492-
laserdot_image = np.where(laserdot_image > 0, 255, 0).astype(np.uint8)
493-
laserdot_image = cv2.cvtColor(laserdot_image, cv2.COLOR_GRAY2BGR)
494-
laserdot_image[:, :, [0, 2]] = 0
495-
laserdots.append(
496-
cv2.cvtColor(base_image, cv2.COLOR_GRAY2BGR) | laserdot_image
497-
)
498-
499-
glottal_area_waveform = [
500-
len(self.segmentator.getSegmentation(index).nonzero()[0])
501-
for index in range(len(self.segmentator))
502-
]
503478
self.graph_widget.updateGraph(
504-
glottal_area_waveform, self.graph_widget.glottal_seg_graph
479+
segmentator.glottalAreaWaveform().tolist(), self.graph_widget.glottal_seg_graph
505480
)
506481

507482
self.segmentations = segmentations
508-
self.laserdots = laserdots
483+
self.laserdots = segmentations
509484

510485
def buildCorrespondences(self):
511486
min_search_space = float(

source/cv.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,10 +105,10 @@ def compute_segmentation_outline(segmentation: torch.tensor, kernel_size=3, bord
105105
Returns:
106106
border: (B, 1, H, W) tensor of borders
107107
"""
108-
kernel = torch.ones((1, 1, kernel_size, kernel_size), device=segmentation.device)
108+
kernel = torch.ones((kernel_size, kernel_size), device=segmentation.device)
109109

110-
dilated = kornia.morphology.dilation(segmentation, kernel)
111-
eroded = kornia.morphology.erosion(segmentation, kernel)
110+
dilated = kornia.morphology.dilation(segmentation.unsqueeze(0).unsqueeze(0).float(), kernel).squeeze()
111+
eroded = kornia.morphology.erosion(segmentation.unsqueeze(0).unsqueeze(0).float(), kernel).squeeze()
112112

113113
if border_type == "both":
114114
border = dilated - eroded
@@ -145,9 +145,9 @@ def windows_out_of_bounds(indices, image_size, pad):
145145
def extractWindow(batch, indices, window_size=7, device="cuda"):
146146
# Clean Windows, such that no image boundaries are hit
147147

148-
batch_index = indices[:, 0].int()
149-
y = indices[:, 2].floor().int()
150-
x = indices[:, 1].floor().int()
148+
batch_index = indices[:, 0].long()
149+
y = indices[:, 2].long()
150+
x = indices[:, 1].long()
151151

152152
y = windows_out_of_bounds(y, batch.shape[1], window_size // 2)
153153
x = windows_out_of_bounds(x, batch.shape[2], window_size // 2)

0 commit comments

Comments
 (0)