11"""Generate dataset"""
22
3- from typing import Dict , List , TypedDict , Tuple , Any
3+ from typing import Dict , List , TypedDict , Tuple , Any , Optional
4+ from enum import Enum
45import os
56import sys
67from concurrent .futures import ThreadPoolExecutor , as_completed
2223import h5py
2324
2425from args .parser import parse_args , Arguments
25- from utils .constants import dataset_path , dataset_file , images_h5_file
26+ from utils .constants import dataset_path , dataset_file , images_h5_file , images_gen_checkpoint_file
2627from utils .datatypes import FilePath , df_schema , Dimensions
2728from utils .image import transform_image
2829from utils .colors import Colors
3435States = List [int ]
3536Measurements = List [int ]
3637
38+ class Stages (Enum ):
39+ """Enum for dataset generation stages"""
40+ GEN_IMAGES = "gen"
41+ DUPLICATES = "duplicates"
42+ TRANSFORM = "transform"
43+
44+ class Checkpoint :
45+ """Class to handle generate data checkpoints"""
46+
47+ def __init__ (self , path :Optional [FilePath ]):
48+ self ._path = path
49+
50+ self ._stage = Stages .GEN_IMAGES
51+ self ._index = 0
52+ self ._files : List [FilePath ] = []
53+
54+
55+ # Check file and get the data
56+
57+ if self ._path is None :
58+ print ("%sNo Checkpoint was provided!%s" % (Colors .YELLOWFG , Colors .ENDC ))
59+ return
60+
61+ if not os .path .exists (self ._path ):
62+ print ("%sCheckpoint file %s doesn't exists!%s" % (Colors .YELLOWFG , self ._path , Colors .ENDC ))
63+ return
64+
65+ print (
66+ "%sLoading checkpoint from: %s...%s"
67+ % (Colors .MAGENTABG , self ._path , Colors .ENDC )
68+ )
69+
70+ with open (self ._path , "r" ) as file :
71+ data = json .load (file )
72+ stage = data .get ("stage" )
73+ self ._stage = Stages .GEN_IMAGES if stage is None else Stages (stage )
74+ self ._index = data .get ("index" ) or 0
75+ self ._files = data .get ("files" ) or []
76+
77+ @property
78+ def stage (self ) -> Stages :
79+ """get checkpoint generation stage"""
80+ return self ._stage
81+
82+ @stage .setter
83+ def stage (self , value :Stages ):
84+ """Update stage"""
85+ self ._stage = value
86+
87+ @property
88+ def index (self ) -> int :
89+ """get checkpoint generation index"""
90+ return self ._index
91+
92+ @index .setter
93+ def index (self , value :int ):
94+ """update index"""
95+ self ._index = value
96+
97+ @property
98+ def files (self ) -> List [FilePath ]:
99+ """get duplicated files to remove"""
100+ return self ._files
101+
102+ @files .setter # type: ignore
103+ def files (self , value :List [FilePath ]):
104+ """set files to delete"""
105+ self ._files = value
106+
107+ def save (self ):
108+ print ("%sSaving checkpoint at: %s%s" % (Colors .GREENBG , self ._path , Colors .ENDC ))
109+ with open (self ._path , "w" ) as file :
110+ data = {
111+ "stage" :self ._stage .value ,
112+ "index" :self ._index ,
113+ "files" :self ._files
114+ }
115+ json .dump (file , data )
116+
37117class CircuitResult (TypedDict ):
38118 """Type for circuit results"""
39119
@@ -132,6 +212,7 @@ def generate_images(
132212 shots : int ,
133213 dataset_size : int ,
134214 total_threads : int ,
215+ checkpoint : Checkpoint
135216):
136217 """
137218 Generate multiple images and saves a dataframe with information about them.
@@ -147,7 +228,7 @@ def generate_images(
147228 base_dataset_path = dataset_path (target_folder )
148229
149230 with tqdm (total = dataset_size ) as progress :
150- index = 0
231+ index = checkpoint . index
151232 while index < dataset_size :
152233 args = []
153234
@@ -177,43 +258,54 @@ def generate_images(
177258
178259 save_df (df , dataset_file_path )
179260
180- # remove df from memory to open avoid excess
261+ # remove df from memory to open avoid excessive
181262 # of memory usage
182263 del df
183264 gc .collect ()
184265
185266 progress .update (total_threads )
186267
268+ checkpoint .index = index
269+ checkpoint .save ()
270+
187271
188- def remove_duplicated_files (target_folder : FilePath ):
272+ def remove_duplicated_files (target_folder : FilePath , checkpoint : Checkpoint ):
189273 """Remove images that are duplicated based on its hash"""
274+ print ("%sRemoving duplicated images%s" % (Colors .GREENBG , Colors .ENDC ))
190275
191- dataset_file_path = dataset_file (target_folder )
276+ if (not checkpoint .files ): # empty list
277+ dataset_file_path = dataset_file (target_folder )
278+
279+ df = open_csv (dataset_file_path )
280+ clean_df = df .unique (maintain_order = True , subset = ["hash" ])
281+ save_df (clean_df , dataset_file_path )
192282
193- df = open_csv (dataset_file_path )
194- clean_df = df .unique (maintain_order = True , subset = ["hash" ])
195- clean_df_indexes = clean_df .get_column ("index" )
283+ clean_df_indexes = clean_df .get_column ("index" )
284+ duplicated_files = df .filter (~ pl .col ("index" ).is_in (clean_df_indexes ))["file" ]
196285
197- duplicated_values = df .filter (~ pl .col ("index" ).is_in (clean_df_indexes ))
286+ checkpoint .files = duplicated_files .to_list ()
287+ checkpoint .save ()
198288
199289 print ("%sDeleting duplicated files%s" % (Colors .GREENBG , Colors .ENDC ))
200- for row in tqdm (duplicated_values .iter_rows (named = True )):
201- file = row ["file" ]
290+ for file in tqdm (checkpoint .files ):
202291 os .remove (file )
203292
204- save_df (clean_df , dataset_file_path )
293+ checkpoint .files .remove (file )
294+ checkpoint .save ()
205295
206296
207- def transform_images (target_folder : FilePath , new_dim : Dimensions ):
297+
298+ def transform_images (target_folder : FilePath , new_dim : Dimensions , checkpoint :Checkpoint ):
208299 """Normalize images and save them into a h5 file"""
209300 print ("%sTransforming images%s" % (Colors .GREENBG , Colors .ENDC ))
210301
211302 df = open_csv (dataset_file (target_folder ))
303+ df .slice (offset = checkpoint .index )
212304
213305 max_width , max_height = new_dim
214306
215- image_i = 0
216- with h5py .File (images_h5_file (target_folder ), "w " ) as file :
307+ image_i = checkpoint . index
308+ with h5py .File (images_h5_file (target_folder ), "a " ) as file :
217309 for row in tqdm (df .iter_rows (named = True )):
218310 image_path = row ["file" ]
219311
@@ -222,6 +314,8 @@ def transform_images(target_folder: FilePath, new_dim: Dimensions):
222314 file .create_dataset (f"{ image_i } " , data = tensor )
223315
224316 image_i += 1
317+ checkpoint .index = image_i
318+ checkpoint .save ()
225319
226320
227321def crate_dataset_folder (base_folder : FilePath ):
@@ -268,18 +362,28 @@ def main(args: Arguments):
268362
269363 start_df (args .target_folder )
270364
271- generate_images (
272- args .target_folder ,
273- args .n_qubits ,
274- args .max_gates ,
275- args .shots ,
276- args .dataset_size ,
277- args .threads ,
278- )
365+ checkpoint = Checkpoint (images_gen_checkpoint_file (args .target_folder ))
366+
367+ if (checkpoint .stage == Stages .GEN_IMAGES ):
368+ generate_images (
369+ args .target_folder ,
370+ args .n_qubits ,
371+ args .max_gates ,
372+ args .shots ,
373+ args .dataset_size ,
374+ args .threads ,
375+ checkpoint
376+ )
377+
378+ checkpoint .stage = Stages .DUPLICATES
379+ checkpoint .index = 0
279380
280- remove_duplicated_files (args .target_folder )
381+ if (checkpoint .stage == Stages .DUPLICATES ):
382+ remove_duplicated_files (args .target_folder , checkpoint )
383+ checkpoint .stage = Stages .TRANSFORM
384+ checkpoint .index = 0
281385
282- transform_images (args .target_folder , args .new_image_dim )
386+ transform_images (args .target_folder , args .new_image_dim , checkpoint )
283387
284388
285389if __name__ == "__main__" :
0 commit comments