Skip to content

Commit 1a1476e

Browse files
Dan-Elimpelchat04
authored andcommitted
Use rasterio and fiona libraries instead of gdal and ogr libraries (#67)
* Start Rasterio migration * Remove GDAL for rasterio and fiona * Replace all the calls from GDAL and OGR to fiona and rasterio * Remove soms dead code and also code cleanup * Delete unused files * Adjust the travis tool * Adjust the write array * Adjust the Reame and modify unit8 instead of float32 in rasterize * Update images_to_samples.py * Update inference.py
1 parent 62b3a6e commit 1a1476e

File tree

5 files changed

+112
-103
lines changed

5 files changed

+112
-103
lines changed

Diff for: .travis.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ install:
1010
- conda update -q conda
1111
- conda info -a
1212

13-
- conda create -q -n ci_env python=3.6 pytorch-cpu=0.4.0 torchvision ruamel_yaml h5py gdal scikit-image scikit-learn -c pytorch
13+
- conda create -q -n ci_env python=3.6 pytorch-cpu=0.4.0 torchvision ruamel_yaml h5py scikit-image scikit-learn fiona rasterio -c pytorch
1414
- source activate ci_env
1515
before_script:
1616
- unzip ./data/massachusetts_buildings.zip -d ./data

Diff for: README.md

+4-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,8 @@ After installing the required computing environment (see next section), one need
1616
- pytorch 0.4.1
1717
- torchvision 0.2.1
1818
- numpy
19-
- gdal
19+
- rasterio
20+
- fiona
2021
- ruamel_yaml
2122
- scikit-image
2223
- scikit-learn
@@ -28,12 +29,12 @@ After installing the required computing environment (see next section), one need
2829
1. Using conda, you can set and activate your python environment with the following commands:
2930
With GPU:
3031
```shell
31-
conda create -p YOUR_PATH python=3.6 pytorch=0.4.0 torchvision cuda80 ruamel_yaml h5py gdal=2.2.2 scikit-image scikit-learn=0.20 -c pytorch
32+
conda create -p YOUR_PATH python=3.6 pytorch=0.4.0 torchvision cuda80 ruamel_yaml h5py fiona rasterio scikit-image scikit-learn=0.20 -c pytorch
3233
source activate YOUR_ENV
3334
```
3435
CPU only:
3536
```shell
36-
conda create -p YOUR_PATH python=3.6 pytorch-cpu=0.4.0 torchvision ruamel_yaml h5py gdal=2.2.2 scikit-image scikit-learn=0.20 -c pytorch
37+
conda create -p YOUR_PATH python=3.6 pytorch-cpu=0.4.0 torchvision ruamel_yaml h5py fiona rasterio scikit-image scikit-learn=0.20 -c pytorch
3738
source activate YOUR_ENV
3839
```
3940
1. Set your parameters in the `config.yaml` (see section bellow)

Diff for: images_to_samples.py

+58-46
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import argparse
2-
import csv
32
import os
43
import numpy as np
54
import h5py
65
import warnings
7-
from osgeo import gdal, osr, ogr
8-
from utils import read_parameters, create_new_raster_from_base, assert_band_number, image_reader_as_array, \
6+
import fiona
7+
import rasterio
8+
from rasterio import features
9+
10+
from utils import read_parameters, assert_band_number, image_reader_as_array, \
911
create_or_empty_folder, validate_num_classes, read_csv
1012

1113
try:
@@ -53,12 +55,12 @@ def resize_datasets(hdf5_file):
5355
hdf5_file['map_img'].resize(new_size, axis=0)
5456

5557

56-
def samples_preparation(sat_img, ref_img, sample_size, dist_samples, samples_count, num_classes, samples_file, dataset,
57-
background_switch):
58+
def samples_preparation(in_img_array, label_array, sample_size, dist_samples, samples_count, num_classes, samples_file,
59+
dataset, background_switch):
5860
"""Extract and write samples from input image and reference image
5961
Args:
60-
sat_img: Path and name to the input image
61-
ref_img: path and name to the reference image
62+
sat_img: num py array of to the input image
63+
ref_img: num py array the reference image
6264
sample_size: Size (in pixel) of the samples to create
6365
dist_samples: Distance (in pixel) between samples in both images
6466
samples_count: Current number of samples created (will be appended and return)
@@ -69,8 +71,6 @@ def samples_preparation(sat_img, ref_img, sample_size, dist_samples, samples_cou
6971
"""
7072

7173
# read input and reference images as array
72-
in_img_array = image_reader_as_array(sat_img)
73-
label_array = image_reader_as_array(ref_img)
7474

7575
h, w, num_bands = in_img_array.shape
7676

@@ -110,23 +110,38 @@ def samples_preparation(sat_img, ref_img, sample_size, dist_samples, samples_cou
110110
return samples_count, num_classes
111111

112112

113-
def vector_to_raster(vector_file, attribute_name, new_raster):
113+
def vector_to_raster(vector_file, input_image, attribute_name):
114114
"""Function to rasterize vector data.
115115
Args:
116116
vector_file: Path and name of reference GeoPackage
117+
input_image: Path and name of the input raster image
117118
attribute_name: Attribute containing the pixel value to write
118-
new_raster: Raster file where the info will be written
119+
120+
Return
121+
num py array of the burned image
119122
"""
120-
source_ds = ogr.Open(vector_file)
121-
source_layer = source_ds.GetLayer()
122-
name_lyr = source_layer.GetLayerDefn().GetName()
123-
rev_lyr = source_ds.ExecuteSQL("SELECT * FROM " + name_lyr + " ORDER BY " + attribute_name + " ASC")
124123

125-
gdal.RasterizeLayer(new_raster, [1], rev_lyr, options=["ATTRIBUTE=%s" % attribute_name])
124+
# Extract vector features to burn in the raster image
125+
with fiona.open(vector_file, 'r') as src:
126+
lst_vector = [vector for vector in src]
127+
128+
# Sort feature in order to priorize the burning in the raster image (ex: vegetation before roads...)
129+
lst_vector.sort(key=lambda vector : vector['properties'][attribute_name])
130+
lst_vector_tuple = [(vector['geometry'], int(vector['properties'][attribute_name])) for vector in lst_vector]
131+
132+
# Open input raster image to have access to number of rows, column, crs...
133+
with rasterio.open(input_image, 'r') as src:
134+
burned_raster = rasterio.features.rasterize( (vector_tuple for vector_tuple in lst_vector_tuple),
135+
fill = 0,
136+
out_shape=src.shape,
137+
transform=src.transform,
138+
dtype=np.uint8)
126139

140+
return burned_raster
127141

128-
def main(bucket_name, data_path, samples_size, num_classes, number_of_bands, csv_file, samples_dist,
129-
remove_background, mask_input_image, mask_reference):
142+
143+
def main( bucket_name, data_path, samples_size, num_classes, number_of_bands, csv_file, samples_dist,
144+
remove_background, mask_input_image, mask_reference):
130145
gpkg_file = []
131146
if bucket_name:
132147
s3 = boto3.resource('s3')
@@ -135,10 +150,8 @@ def main(bucket_name, data_path, samples_size, num_classes, number_of_bands, csv
135150
list_data_prep = read_csv('samples_prep.csv')
136151
if data_path:
137152
final_samples_folder = os.path.join(data_path, "samples")
138-
final_out_label_folder = os.path.join(data_path, "label")
139153
else:
140154
final_samples_folder = "samples"
141-
final_out_label_folder = "label"
142155
samples_folder = "samples"
143156
out_label_folder = "label"
144157

@@ -157,17 +170,14 @@ def main(bucket_name, data_path, samples_size, num_classes, number_of_bands, csv
157170
val_hdf5 = h5py.File(os.path.join(samples_folder, "val_samples.hdf5"), "w")
158171

159172
trn_hdf5.create_dataset("sat_img", (0, samples_size, samples_size, number_of_bands), np.float32,
160-
maxshape=(None, samples_size, samples_size, number_of_bands))
173+
maxshape=(None, samples_size, samples_size, number_of_bands))
161174
trn_hdf5.create_dataset("map_img", (0, samples_size, samples_size), np.uint8,
162-
maxshape=(None, samples_size, samples_size))
175+
maxshape=(None, samples_size, samples_size))
163176
val_hdf5.create_dataset("sat_img", (0, samples_size, samples_size, number_of_bands), np.float32,
164-
maxshape=(None, samples_size, samples_size, number_of_bands))
177+
maxshape=(None, samples_size, samples_size, number_of_bands))
165178
val_hdf5.create_dataset("map_img", (0, samples_size, samples_size), np.uint8,
166-
maxshape=(None, samples_size, samples_size))
179+
maxshape=(None, samples_size, samples_size))
167180
for info in list_data_prep:
168-
img_name = os.path.basename(info['tif']).split('.')[0]
169-
tmp_label_name = os.path.join(out_label_folder, img_name + "_label_tmp.tif")
170-
label_name = os.path.join(out_label_folder, img_name + "_label.tif")
171181

172182
if bucket_name:
173183
bucket.download_file(info['tif'], "Images/" + info['tif'].split('/')[-1])
@@ -178,38 +188,33 @@ def main(bucket_name, data_path, samples_size, num_classes, number_of_bands, csv
178188
info['gpkg'] = info['gpkg'].split('/')[-1]
179189
assert_band_number(info['tif'], number_of_bands)
180190

181-
value_field = info['attribute_name']
182-
validate_num_classes(info['gpkg'], num_classes, value_field)
183-
184-
# Mask zeros from input image into label raster.
185-
if mask_reference:
186-
tmp_label_raster = create_new_raster_from_base(info['tif'], tmp_label_name, 1)
187-
vector_to_raster(info['gpkg'], info['attribute_name'], tmp_label_raster)
188-
tmp_label_raster = None
191+
# Read the input raster image
192+
np_input_image = image_reader_as_array(info['tif'])
189193

190-
masked_array = mask_image(image_reader_as_array(info['tif']), image_reader_as_array(tmp_label_name))
191-
create_new_raster_from_base(info['tif'], label_name, 1, masked_array)
194+
# Validate the number of class in the vector file
195+
validate_num_classes(info['gpkg'], num_classes, info['attribute_name'])
192196

193-
os.remove(tmp_label_name)
197+
# Burn vector file in a raster file
198+
np_label_raster = vector_to_raster(info['gpkg'], info['tif'], info['attribute_name'])
194199

195-
else:
196-
label_raster = create_new_raster_from_base(info['tif'], label_name, 1)
197-
vector_to_raster(info['gpkg'], info['attribute_name'], label_raster)
198-
label_raster = None
200+
# Mask the zeros from input image into label raster.
201+
if mask_reference:
202+
np_label_raster = mask_image(np_input_image, np_label_raster)
199203

200-
# Mask zeros from label raster into input image.
204+
# Mask zeros from label raster into input image otherwise use original image
201205
if mask_input_image:
202-
masked_img = mask_image(image_reader_as_array(label_name), image_reader_as_array(info['tif']))
203-
create_new_raster_from_base(label_name, info['tif'], number_of_bands, masked_img)
206+
np_input_image = mask_image(np_label_raster, np_input_image)
204207

205208
if info['dataset'] == 'trn':
206209
out_file = trn_hdf5
207210
elif info['dataset'] == 'val':
208211
out_file = val_hdf5
209212

210-
number_samples, number_classes = samples_preparation(info['tif'], label_name, samples_size, samples_dist,
213+
np_label_raster = np.reshape(np_label_raster, (np_label_raster.shape[0], np_label_raster.shape[1], 1))
214+
number_samples, number_classes = samples_preparation(np_input_image, np_label_raster, samples_size, samples_dist,
211215
number_samples, number_classes, out_file, info['dataset'],
212216
remove_background)
217+
213218
print(info['tif'])
214219
print(number_samples)
215220
out_file.flush()
@@ -234,6 +239,10 @@ def main(bucket_name, data_path, samples_size, num_classes, number_of_bands, csv
234239
args = parser.parse_args()
235240
params = read_parameters(args.ParamFile)
236241

242+
import time
243+
start_time = time.time()
244+
245+
237246
main(params['global']['bucket_name'],
238247
params['global']['data_path'],
239248
params['global']['samples_size'],
@@ -244,3 +253,6 @@ def main(bucket_name, data_path, samples_size, num_classes, number_of_bands, csv
244253
params['sample']['remove_background'],
245254
params['sample']['mask_input_image'],
246255
params['sample']['mask_reference'])
256+
257+
print ("Elapsed time:{}".format(time.time() - start_time))
258+

Diff for: inference.py

+26-2
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
import time
77
import argparse
88
import heapq
9+
import rasterio
910
from PIL import Image
1011
import torchvision
1112
from models.model_choice import net, maxpool_level
12-
from utils import read_parameters, create_new_raster_from_base, assert_band_number, load_from_checkpoint, \
13+
from utils import read_parameters, assert_band_number, load_from_checkpoint, \
1314
image_reader_as_array, read_csv
1415

1516
try:
@@ -83,7 +84,7 @@ def main(bucket, work_folder, img_list, weights_file_name, model, number_of_band
8384
print()
8485
else:
8586
sem_seg_results = sem_seg_inference(bucket, model, img['tif'], overlay)
86-
create_new_raster_from_base(local_img, inference_image, 1, sem_seg_results)
87+
create_new_raster_from_base(local_img, inference_image, sem_seg_results)
8788
print(f"Semantic segmentation of image {img_name} completed")
8889

8990
if bucket:
@@ -102,6 +103,29 @@ def main(bucket, work_folder, img_list, weights_file_name, model, number_of_band
102103
print('Inference completed in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
103104

104105

106+
def create_new_raster_from_base(input_raster, output_raster, write_array):
107+
"""Function to use info from input raster to create new one.
108+
Args:
109+
input_raster: input raster path and name
110+
output_raster: raster name and path to be created with info from input
111+
write_array (optional): array to write into the new raster
112+
113+
Return:
114+
none
115+
"""
116+
117+
with rasterio.open(input_raster, 'r') as src:
118+
with rasterio.open( output_raster, 'w',
119+
driver=src.driver,
120+
width=src.width,
121+
height=src.height,
122+
count=1,
123+
crs=src.crs,
124+
dtype=np.uint8,
125+
transform=src.transform) as dst:
126+
dst.write(write_array[:,:,0], 1)
127+
128+
105129
def sem_seg_inference(bucket, model, image, overlay):
106130
"""Inference on images using semantic segmentation
107131
Args:

Diff for: utils.py

+23-51
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,11 @@
11
import torch
22
# import torch should be first. Unclear issue, mentionned here: https://github.com/pytorch/pytorch/issues/2083
33
import os
4-
import subprocess
54
import numpy as np
6-
# import matplotlib.pyplot as plt
7-
import gdal
5+
import rasterio
86
import warnings
97
from ruamel_yaml import YAML
10-
from osgeo import gdal, ogr
8+
import fiona
119
import csv
1210
try:
1311
import boto3
@@ -43,35 +41,6 @@ def read_parameters(param_file):
4341
return params
4442

4543

46-
def create_new_raster_from_base(input_raster, output_raster, band_count, write_array=None):
47-
"""Function to use info from input raster to create new one.
48-
Args:
49-
input_raster: input raster path and name
50-
output_raster: raster name and path to be created with info from input
51-
band_count: number of bands in the input raster
52-
write_array (optional): array to write into the new raster
53-
"""
54-
input_image = gdal.Open(input_raster)
55-
src = input_image
56-
cols = src.RasterXSize
57-
rows = src.RasterYSize
58-
projection = src.GetProjection()
59-
geotransform = src.GetGeoTransform()
60-
61-
new_raster = gdal.GetDriverByName('GTiff').Create(output_raster, cols, rows, band_count, gdal.GDT_Byte)
62-
new_raster.SetProjection(projection)
63-
new_raster.SetGeoTransform(geotransform)
64-
65-
for band_num in range(0, band_count):
66-
band = new_raster.GetRasterBand(band_num + 1)
67-
band.SetNoDataValue(-9999)
68-
# Write array if provided. If not, the image is filled with NoDataValues
69-
if write_array is not None:
70-
band.WriteArray(write_array[:, :, band_num])
71-
band.FlushCache()
72-
return new_raster
73-
74-
7544
def assert_band_number(in_image, band_count_yaml):
7645
"""Verify if provided image has the same number of bands as described in the .yaml
7746
Args:
@@ -117,36 +86,39 @@ def image_reader_as_array(file_name):
11786
"""Read an image from a file and return a 3d array (h,w,c)
11887
Args:
11988
file_name: full file path of the image
120-
"""
121-
raster = gdal.Open(file_name)
122-
band_num = raster.RasterCount
123-
band = raster.GetRasterBand(1)
124-
rows, columns = (band.XSize, band.YSize)
12589
126-
np_array = np.empty([columns, rows, band_num], dtype=np.float32)
90+
Return:
91+
numm_py_array of the image read
92+
"""
12793

128-
for i in range(0, band_num):
129-
band = raster.GetRasterBand(i + 1)
130-
arr = band.ReadAsArray()
131-
np_array[:, :, i] = arr
94+
with rasterio.open(file_name, 'r') as src:
95+
np_array = np.empty([src.height, src.width, src.count], dtype=np.float32)
96+
for i in range(src.count):
97+
band = src.read(i+1) # Bands starts at 1 in rasterio not 0
98+
np_array[:, :, i] = band
13299

133100
return np_array
134101

135102

136103
def validate_num_classes(vector_file, num_classes, value_field):
137-
"""Validate that the number of classes in the .shp corresponds to the expected number
104+
"""Validate that the number of classes in the vector file corresponds to the expected number
138105
Args:
139106
vector_file: full file path of the vector image
140107
num_classes: number of classes set in config.yaml
141108
value_field: name of the value field representing the required classes in the vector image file
109+
110+
Return:
111+
None
142112
"""
143-
source_ds = ogr.Open(vector_file)
144-
source_layer = source_ds.GetLayer()
145-
name_lyr = source_layer.GetLayerDefn().GetName()
146-
vector_classes = source_ds.ExecuteSQL("SELECT DISTINCT " + value_field + " FROM " + name_lyr).GetFeatureCount()
147-
if vector_classes + 1 != num_classes:
148-
raise ValueError('The number of classes in the yaml.config (%d) is different than the number of classes in '
149-
'the file %s (%d)' % (num_classes, vector_file, vector_classes))
113+
114+
distinct_att = set()
115+
with fiona.open(vector_file, 'r') as src:
116+
for feature in src:
117+
distinct_att.add(feature['properties'][value_field]) # Use property of set to store unique values
118+
119+
if len(distinct_att)+1 != num_classes:
120+
raise ValueError('The number of classes in the yaml.config {} is different than the number of classes in '
121+
'the file {} {}'.format (num_classes, vector_file, str(list(distinct_att))))
150122

151123

152124
def list_s3_subfolders(bucket, data_path):

0 commit comments

Comments
 (0)