1
1
import argparse
2
- import csv
3
2
import os
4
3
import numpy as np
5
4
import h5py
6
5
import warnings
7
- from osgeo import gdal , osr , ogr
8
- from utils import read_parameters , create_new_raster_from_base , assert_band_number , image_reader_as_array , \
6
+ import fiona
7
+ import rasterio
8
+ from rasterio import features
9
+
10
+ from utils import read_parameters , assert_band_number , image_reader_as_array , \
9
11
create_or_empty_folder , validate_num_classes , read_csv
10
12
11
13
try :
@@ -53,12 +55,12 @@ def resize_datasets(hdf5_file):
53
55
hdf5_file ['map_img' ].resize (new_size , axis = 0 )
54
56
55
57
56
- def samples_preparation (sat_img , ref_img , sample_size , dist_samples , samples_count , num_classes , samples_file , dataset ,
57
- background_switch ):
58
+ def samples_preparation (in_img_array , label_array , sample_size , dist_samples , samples_count , num_classes , samples_file ,
59
+ dataset , background_switch ):
58
60
"""Extract and write samples from input image and reference image
59
61
Args:
60
- sat_img: Path and name to the input image
61
- ref_img: path and name to the reference image
62
+ sat_img: num py array of to the input image
63
+ ref_img: num py array the reference image
62
64
sample_size: Size (in pixel) of the samples to create
63
65
dist_samples: Distance (in pixel) between samples in both images
64
66
samples_count: Current number of samples created (will be appended and return)
@@ -69,8 +71,6 @@ def samples_preparation(sat_img, ref_img, sample_size, dist_samples, samples_cou
69
71
"""
70
72
71
73
# read input and reference images as array
72
- in_img_array = image_reader_as_array (sat_img )
73
- label_array = image_reader_as_array (ref_img )
74
74
75
75
h , w , num_bands = in_img_array .shape
76
76
@@ -110,23 +110,38 @@ def samples_preparation(sat_img, ref_img, sample_size, dist_samples, samples_cou
110
110
return samples_count , num_classes
111
111
112
112
113
- def vector_to_raster (vector_file , attribute_name , new_raster ):
113
+ def vector_to_raster (vector_file , input_image , attribute_name ):
114
114
"""Function to rasterize vector data.
115
115
Args:
116
116
vector_file: Path and name of reference GeoPackage
117
+ input_image: Path and name of the input raster image
117
118
attribute_name: Attribute containing the pixel value to write
118
- new_raster: Raster file where the info will be written
119
+
120
+ Return
121
+ num py array of the burned image
119
122
"""
120
- source_ds = ogr .Open (vector_file )
121
- source_layer = source_ds .GetLayer ()
122
- name_lyr = source_layer .GetLayerDefn ().GetName ()
123
- rev_lyr = source_ds .ExecuteSQL ("SELECT * FROM " + name_lyr + " ORDER BY " + attribute_name + " ASC" )
124
123
125
- gdal .RasterizeLayer (new_raster , [1 ], rev_lyr , options = ["ATTRIBUTE=%s" % attribute_name ])
124
+ # Extract vector features to burn in the raster image
125
+ with fiona .open (vector_file , 'r' ) as src :
126
+ lst_vector = [vector for vector in src ]
127
+
128
+ # Sort feature in order to priorize the burning in the raster image (ex: vegetation before roads...)
129
+ lst_vector .sort (key = lambda vector : vector ['properties' ][attribute_name ])
130
+ lst_vector_tuple = [(vector ['geometry' ], int (vector ['properties' ][attribute_name ])) for vector in lst_vector ]
131
+
132
+ # Open input raster image to have access to number of rows, column, crs...
133
+ with rasterio .open (input_image , 'r' ) as src :
134
+ burned_raster = rasterio .features .rasterize ( (vector_tuple for vector_tuple in lst_vector_tuple ),
135
+ fill = 0 ,
136
+ out_shape = src .shape ,
137
+ transform = src .transform ,
138
+ dtype = np .uint8 )
126
139
140
+ return burned_raster
127
141
128
- def main (bucket_name , data_path , samples_size , num_classes , number_of_bands , csv_file , samples_dist ,
129
- remove_background , mask_input_image , mask_reference ):
142
+
143
+ def main ( bucket_name , data_path , samples_size , num_classes , number_of_bands , csv_file , samples_dist ,
144
+ remove_background , mask_input_image , mask_reference ):
130
145
gpkg_file = []
131
146
if bucket_name :
132
147
s3 = boto3 .resource ('s3' )
@@ -135,10 +150,8 @@ def main(bucket_name, data_path, samples_size, num_classes, number_of_bands, csv
135
150
list_data_prep = read_csv ('samples_prep.csv' )
136
151
if data_path :
137
152
final_samples_folder = os .path .join (data_path , "samples" )
138
- final_out_label_folder = os .path .join (data_path , "label" )
139
153
else :
140
154
final_samples_folder = "samples"
141
- final_out_label_folder = "label"
142
155
samples_folder = "samples"
143
156
out_label_folder = "label"
144
157
@@ -157,17 +170,14 @@ def main(bucket_name, data_path, samples_size, num_classes, number_of_bands, csv
157
170
val_hdf5 = h5py .File (os .path .join (samples_folder , "val_samples.hdf5" ), "w" )
158
171
159
172
trn_hdf5 .create_dataset ("sat_img" , (0 , samples_size , samples_size , number_of_bands ), np .float32 ,
160
- maxshape = (None , samples_size , samples_size , number_of_bands ))
173
+ maxshape = (None , samples_size , samples_size , number_of_bands ))
161
174
trn_hdf5 .create_dataset ("map_img" , (0 , samples_size , samples_size ), np .uint8 ,
162
- maxshape = (None , samples_size , samples_size ))
175
+ maxshape = (None , samples_size , samples_size ))
163
176
val_hdf5 .create_dataset ("sat_img" , (0 , samples_size , samples_size , number_of_bands ), np .float32 ,
164
- maxshape = (None , samples_size , samples_size , number_of_bands ))
177
+ maxshape = (None , samples_size , samples_size , number_of_bands ))
165
178
val_hdf5 .create_dataset ("map_img" , (0 , samples_size , samples_size ), np .uint8 ,
166
- maxshape = (None , samples_size , samples_size ))
179
+ maxshape = (None , samples_size , samples_size ))
167
180
for info in list_data_prep :
168
- img_name = os .path .basename (info ['tif' ]).split ('.' )[0 ]
169
- tmp_label_name = os .path .join (out_label_folder , img_name + "_label_tmp.tif" )
170
- label_name = os .path .join (out_label_folder , img_name + "_label.tif" )
171
181
172
182
if bucket_name :
173
183
bucket .download_file (info ['tif' ], "Images/" + info ['tif' ].split ('/' )[- 1 ])
@@ -178,38 +188,33 @@ def main(bucket_name, data_path, samples_size, num_classes, number_of_bands, csv
178
188
info ['gpkg' ] = info ['gpkg' ].split ('/' )[- 1 ]
179
189
assert_band_number (info ['tif' ], number_of_bands )
180
190
181
- value_field = info ['attribute_name' ]
182
- validate_num_classes (info ['gpkg' ], num_classes , value_field )
183
-
184
- # Mask zeros from input image into label raster.
185
- if mask_reference :
186
- tmp_label_raster = create_new_raster_from_base (info ['tif' ], tmp_label_name , 1 )
187
- vector_to_raster (info ['gpkg' ], info ['attribute_name' ], tmp_label_raster )
188
- tmp_label_raster = None
191
+ # Read the input raster image
192
+ np_input_image = image_reader_as_array (info ['tif' ])
189
193
190
- masked_array = mask_image ( image_reader_as_array ( info [ 'tif' ]), image_reader_as_array ( tmp_label_name ))
191
- create_new_raster_from_base (info ['tif ' ], label_name , 1 , masked_array )
194
+ # Validate the number of class in the vector file
195
+ validate_num_classes (info ['gpkg ' ], num_classes , info [ 'attribute_name' ] )
192
196
193
- os .remove (tmp_label_name )
197
+ # Burn vector file in a raster file
198
+ np_label_raster = vector_to_raster (info ['gpkg' ], info ['tif' ], info ['attribute_name' ])
194
199
195
- else :
196
- label_raster = create_new_raster_from_base (info ['tif' ], label_name , 1 )
197
- vector_to_raster (info ['gpkg' ], info ['attribute_name' ], label_raster )
198
- label_raster = None
200
+ # Mask the zeros from input image into label raster.
201
+ if mask_reference :
202
+ np_label_raster = mask_image (np_input_image , np_label_raster )
199
203
200
- # Mask zeros from label raster into input image.
204
+ # Mask zeros from label raster into input image otherwise use original image
201
205
if mask_input_image :
202
- masked_img = mask_image (image_reader_as_array (label_name ), image_reader_as_array (info ['tif' ]))
203
- create_new_raster_from_base (label_name , info ['tif' ], number_of_bands , masked_img )
206
+ np_input_image = mask_image (np_label_raster , np_input_image )
204
207
205
208
if info ['dataset' ] == 'trn' :
206
209
out_file = trn_hdf5
207
210
elif info ['dataset' ] == 'val' :
208
211
out_file = val_hdf5
209
212
210
- number_samples , number_classes = samples_preparation (info ['tif' ], label_name , samples_size , samples_dist ,
213
+ np_label_raster = np .reshape (np_label_raster , (np_label_raster .shape [0 ], np_label_raster .shape [1 ], 1 ))
214
+ number_samples , number_classes = samples_preparation (np_input_image , np_label_raster , samples_size , samples_dist ,
211
215
number_samples , number_classes , out_file , info ['dataset' ],
212
216
remove_background )
217
+
213
218
print (info ['tif' ])
214
219
print (number_samples )
215
220
out_file .flush ()
@@ -234,6 +239,10 @@ def main(bucket_name, data_path, samples_size, num_classes, number_of_bands, csv
234
239
args = parser .parse_args ()
235
240
params = read_parameters (args .ParamFile )
236
241
242
+ import time
243
+ start_time = time .time ()
244
+
245
+
237
246
main (params ['global' ]['bucket_name' ],
238
247
params ['global' ]['data_path' ],
239
248
params ['global' ]['samples_size' ],
@@ -244,3 +253,6 @@ def main(bucket_name, data_path, samples_size, num_classes, number_of_bands, csv
244
253
params ['sample' ]['remove_background' ],
245
254
params ['sample' ]['mask_input_image' ],
246
255
params ['sample' ]['mask_reference' ])
256
+
257
+ print ("Elapsed time:{}" .format (time .time () - start_time ))
258
+
0 commit comments