Skip to content

Commit 2c75644

Browse files
ymoisanmpelchat04
authored andcommitted
PyTorch 1.0.1 and rasterio, fiona (#72)
* Edits for PyTorch 1.0.x * interpolate replaces upsample
1 parent 1a1476e commit 2c75644

File tree

5 files changed

+24
-14
lines changed

5 files changed

+24
-14
lines changed

Diff for: inference.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def sem_seg_inference(bucket, model, image, overlay):
186186
col_to = col + chunk_size - overlay
187187

188188
useful_sem_seg = segmentation[overlay:chunk_size - overlay, overlay:chunk_size - overlay]
189-
output_np[row_from:row_to, col_from:col_to, 0] = useful_sem_seg
189+
output_np[row_from:row_to, col_from:col_to, 0] = useful_sem_seg.cpu()
190190

191191
# Resize the output array to the size of the input image and write it
192192
output_np = output_np[overlay:h + overlay, overlay:w + overlay]

Diff for: metrics.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def report_classification(pred, label, batch_size, metrics_dict):
5353
"""Computes precision, recall and f-score for each class and average of all classes.
5454
http://scikit-learn.org/stable/modules/generated/sklearn.metrics.classification_report.html
5555
"""
56-
class_report = classification_report(label, pred, output_dict=True)
56+
class_report = classification_report(label.cpu(), pred.cpu(), output_dict=True)
5757

5858
class_score = {}
5959
for key, value in class_report.items():
@@ -75,6 +75,6 @@ def iou(pred, target, batch_size, metrics_dict):
7575
"""Calculate the intersection over union (or Jaccard index) between two datasets.
7676
The Jaccard distance (or dissimilarity) would be 1-iou.
7777
"""
78-
iou = jaccard_similarity_score(target, pred, normalize=True)
78+
iou = jaccard_similarity_score(target.cpu(), pred.cpu(), normalize=True)
7979
metrics_dict['iou'].update(iou, batch_size)
8080
return metrics_dict

Diff for: models/checkpointed_unet.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import torch
1+
import torch, utils
22
# import torch should be first. Unclear issue, mentioned here: https://github.com/pytorch/pytorch/issues/2083
33
from torch import nn
44
from torch.utils.checkpoint import checkpoint_sequential
@@ -44,15 +44,14 @@ def forward(self, input_data):
4444
modules = get_modules(self.encoding_block)
4545
return checkpoint_sequential(modules, segments, input_data)
4646

47-
4847
class DecodingBlock(nn.Module):
4948
"""Module in the decoding section of the UNet"""
5049

5150
def __init__(self, in_size, out_size, batch_norm=False, upsampling=True):
5251
super().__init__()
5352
up_modules = []
5453
if upsampling:
55-
self.up = nn.Sequential(nn.Upsample(mode='bilinear', scale_factor=2),
54+
self.up = nn.Sequential(utils.Interpolate(mode='bilinear', scale_factor=2),
5655
nn.Conv2d(in_size, out_size, kernel_size=1))
5756
self.upsampling = True
5857
else:
@@ -71,7 +70,7 @@ def forward(self, input1, input2):
7170
output2 = checkpoint_sequential(self.up_modules, segments, input2)
7271
else:
7372
output2 = self.up(input2)
74-
output1 = nn.functional.upsample(input1, output2.size()[2:], mode='bilinear', align_corners=True)
73+
output1 = nn.functional.interpolate(input1, output2.size()[2:], mode='bilinear', align_corners=True)
7574
return checkpoint_sequential(self.conv_modules, segments, torch.cat([output1, output2], 1))
7675

7776

@@ -118,7 +117,7 @@ def forward(self, input_data):
118117
decode2 = self.decode2(conv2, decode3)
119118
decode1 = self.decode1(conv1, decode2)
120119

121-
final = nn.functional.upsample(self.final(decode1), input_data.size()[2:], mode='bilinear')
120+
final = nn.functional.interpolate(self.final(decode1), input_data.size()[2:], mode='bilinear')
122121
return final
123122

124123

@@ -157,7 +156,7 @@ def forward(self, input_data):
157156
decode2 = self.decode2(conv2, decode3)
158157
decode1 = self.decode1(conv1, decode2)
159158

160-
final = nn.functional.upsample(self.final(decode1), input_data.size()[2:], mode='bilinear', align_corners=True)
159+
final = nn.functional.interpolate(self.final(decode1), input_data.size()[2:], mode='bilinear', align_corners=True)
161160
return final
162161

163162

Diff for: models/unet.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import torch
1+
import torch, utils
22
from torch import nn
33

44

@@ -48,7 +48,7 @@ class DecodingBlock(nn.Module):
4848
def __init__(self, in_size, out_size, batch_norm=False, upsampling=True):
4949
super().__init__()
5050
if upsampling:
51-
self.up = nn.Sequential(nn.Upsample(mode='bilinear', scale_factor=2),
51+
self.up = nn.Sequential(utils.Interpolate(mode='bilinear', scale_factor=2),
5252
nn.Conv2d(in_size, out_size, kernel_size=1))
5353
else:
5454
self.up = nn.ConvTranspose2d(in_size, out_size, kernel_size=2, stride=2)
@@ -57,7 +57,7 @@ def __init__(self, in_size, out_size, batch_norm=False, upsampling=True):
5757

5858
def forward(self, input1, input2):
5959
output2 = self.up(input2)
60-
output1 = nn.functional.upsample(input1, output2.size()[2:], mode='bilinear', align_corners=True)
60+
output1 = nn.functional.interpolate(input1, output2.size()[2:], mode='bilinear', align_corners=True)
6161
return self.conv(torch.cat([output1, output2], 1))
6262

6363

@@ -102,7 +102,7 @@ def forward(self, input_data):
102102
decode2 = self.decode2(conv2, decode3)
103103
decode1 = self.decode1(conv1, decode2)
104104

105-
final = nn.functional.upsample(self.final(decode1), input_data.size()[2:], mode='bilinear')
105+
final = nn.functional.interpolate(self.final(decode1), input_data.size()[2:], mode='bilinear')
106106

107107
return final
108108

@@ -142,6 +142,6 @@ def forward(self, input_data):
142142
decode2 = self.decode2(conv2, decode3)
143143
decode1 = self.decode1(conv1, decode2)
144144

145-
final = nn.functional.upsample(self.final(decode1), input_data.size()[2:], mode='bilinear', align_corners=True)
145+
final = nn.functional.interpolate(self.final(decode1), input_data.size()[2:], mode='bilinear', align_corners=True)
146146

147147
return final

Diff for: utils.py

+11
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,17 @@
1313
warnings.warn('The boto3 library counldn\'t be imported. Ignore if not using AWS s3 buckets', ImportWarning)
1414
pass
1515

16+
class Interpolate(torch.nn.Module):
17+
def __init__(self, mode, scale_factor):
18+
super(Interpolate, self).__init__()
19+
self.interp = torch.nn.functional.interpolate
20+
self.scale_factor = scale_factor
21+
self.mode = mode
22+
23+
def forward(self, x):
24+
x = self.interp(x, scale_factor=self.scale_factor, mode=self.mode, align_corners=False)
25+
return x
26+
1627

1728
def create_or_empty_folder(folder):
1829
"""Empty an existing folder or create it if it doesn't exist.

0 commit comments

Comments
 (0)