Skip to content

Commit 821cc1c

Browse files
committed
added checkpoints for gen-images
1 parent e4f40da commit 821cc1c

File tree

2 files changed

+131
-26
lines changed

2 files changed

+131
-26
lines changed

generate/dataset.py

Lines changed: 130 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Generate dataset"""
22

3-
from typing import Dict, List, TypedDict, Tuple, Any
3+
from typing import Dict, List, TypedDict, Tuple, Any, Optional
4+
from enum import Enum
45
import os
56
import sys
67
from concurrent.futures import ThreadPoolExecutor, as_completed
@@ -22,7 +23,7 @@
2223
import h5py
2324

2425
from args.parser import parse_args, Arguments
25-
from utils.constants import dataset_path, dataset_file, images_h5_file
26+
from utils.constants import dataset_path, dataset_file, images_h5_file, images_gen_checkpoint_file
2627
from utils.datatypes import FilePath, df_schema, Dimensions
2728
from utils.image import transform_image
2829
from utils.colors import Colors
@@ -34,6 +35,85 @@
3435
States = List[int]
3536
Measurements = List[int]
3637

38+
class Stages(Enum):
39+
"""Enum for dataset generation stages"""
40+
GEN_IMAGES="gen"
41+
DUPLICATES="duplicates"
42+
TRANSFORM="transform"
43+
44+
class Checkpoint:
45+
"""Class to handle generate data checkpoints"""
46+
47+
def __init__(self, path:Optional[FilePath]):
48+
self._path = path
49+
50+
self._stage = Stages.GEN_IMAGES
51+
self._index = 0
52+
self._files : List[FilePath] = []
53+
54+
55+
# Check file and get the data
56+
57+
if self._path is None:
58+
print("%sNo Checkpoint was provided!%s" % (Colors.YELLOWFG, Colors.ENDC))
59+
return
60+
61+
if not os.path.exists(self._path):
62+
print("%sCheckpoint file %s doesn't exists!%s"%(Colors.YELLOWFG, self._path, Colors.ENDC))
63+
return
64+
65+
print(
66+
"%sLoading checkpoint from: %s...%s"
67+
% (Colors.MAGENTABG, self._path, Colors.ENDC)
68+
)
69+
70+
with open(self._path, "r") as file:
71+
data = json.load(file)
72+
stage = data.get("stage")
73+
self._stage = Stages.GEN_IMAGES if stage is None else Stages(stage)
74+
self._index = data.get("index") or 0
75+
self._files = data.get("files") or []
76+
77+
@property
78+
def stage(self) -> Stages:
79+
"""get checkpoint generation stage"""
80+
return self._stage
81+
82+
@stage.setter
83+
def stage(self, value:Stages):
84+
"""Update stage"""
85+
self._stage = value
86+
87+
@property
88+
def index(self) -> int:
89+
"""get checkpoint generation index"""
90+
return self._index
91+
92+
@index.setter
93+
def index(self, value:int):
94+
"""update index"""
95+
self._index = value
96+
97+
@property
98+
def files(self) -> List[FilePath]:
99+
"""get duplicated files to remove"""
100+
return self._files
101+
102+
@files.setter # type: ignore
103+
def files(self, value:List[FilePath]):
104+
"""set files to delete"""
105+
self._files = value
106+
107+
def save(self):
108+
print("%sSaving checkpoint at: %s%s" % (Colors.GREENBG, self._path, Colors.ENDC))
109+
with open(self._path, "w") as file:
110+
data = {
111+
"stage":self._stage.value,
112+
"index":self._index,
113+
"files":self._files
114+
}
115+
json.dump(file, data)
116+
37117
class CircuitResult(TypedDict):
38118
"""Type for circuit results"""
39119

@@ -132,6 +212,7 @@ def generate_images(
132212
shots: int,
133213
dataset_size: int,
134214
total_threads: int,
215+
checkpoint: Checkpoint
135216
):
136217
"""
137218
Generate multiple images and saves a dataframe with information about them.
@@ -147,7 +228,7 @@ def generate_images(
147228
base_dataset_path = dataset_path(target_folder)
148229

149230
with tqdm(total=dataset_size) as progress:
150-
index = 0
231+
index = checkpoint.index
151232
while index < dataset_size:
152233
args = []
153234

@@ -177,43 +258,54 @@ def generate_images(
177258

178259
save_df(df, dataset_file_path)
179260

180-
# remove df from memory to open avoid excess
261+
# remove df from memory to open avoid excessive
181262
# of memory usage
182263
del df
183264
gc.collect()
184265

185266
progress.update(total_threads)
186267

268+
checkpoint.index = index
269+
checkpoint.save()
270+
187271

188-
def remove_duplicated_files(target_folder: FilePath):
272+
def remove_duplicated_files(target_folder: FilePath, checkpoint:Checkpoint):
189273
"""Remove images that are duplicated based on its hash"""
274+
print("%sRemoving duplicated images%s" % (Colors.GREENBG, Colors.ENDC))
190275

191-
dataset_file_path = dataset_file(target_folder)
276+
if(not checkpoint.files): # empty list
277+
dataset_file_path = dataset_file(target_folder)
278+
279+
df = open_csv(dataset_file_path)
280+
clean_df = df.unique(maintain_order=True, subset=["hash"])
281+
save_df(clean_df, dataset_file_path)
192282

193-
df = open_csv(dataset_file_path)
194-
clean_df = df.unique(maintain_order=True, subset=["hash"])
195-
clean_df_indexes = clean_df.get_column("index")
283+
clean_df_indexes = clean_df.get_column("index")
284+
duplicated_files = df.filter(~pl.col("index").is_in(clean_df_indexes))["file"]
196285

197-
duplicated_values = df.filter(~pl.col("index").is_in(clean_df_indexes))
286+
checkpoint.files = duplicated_files.to_list()
287+
checkpoint.save()
198288

199289
print("%sDeleting duplicated files%s" % (Colors.GREENBG, Colors.ENDC))
200-
for row in tqdm(duplicated_values.iter_rows(named=True)):
201-
file = row["file"]
290+
for file in tqdm(checkpoint.files):
202291
os.remove(file)
203292

204-
save_df(clean_df, dataset_file_path)
293+
checkpoint.files.remove(file)
294+
checkpoint.save()
205295

206296

207-
def transform_images(target_folder: FilePath, new_dim: Dimensions):
297+
298+
def transform_images(target_folder: FilePath, new_dim: Dimensions, checkpoint:Checkpoint):
208299
"""Normalize images and save them into a h5 file"""
209300
print("%sTransforming images%s" % (Colors.GREENBG, Colors.ENDC))
210301

211302
df = open_csv(dataset_file(target_folder))
303+
df.slice(offset=checkpoint.index)
212304

213305
max_width, max_height = new_dim
214306

215-
image_i = 0
216-
with h5py.File(images_h5_file(target_folder), "w") as file:
307+
image_i = checkpoint.index
308+
with h5py.File(images_h5_file(target_folder), "a") as file:
217309
for row in tqdm(df.iter_rows(named=True)):
218310
image_path = row["file"]
219311

@@ -222,6 +314,8 @@ def transform_images(target_folder: FilePath, new_dim: Dimensions):
222314
file.create_dataset(f"{image_i}", data=tensor)
223315

224316
image_i += 1
317+
checkpoint.index = image_i
318+
checkpoint.save()
225319

226320

227321
def crate_dataset_folder(base_folder: FilePath):
@@ -268,18 +362,28 @@ def main(args: Arguments):
268362

269363
start_df(args.target_folder)
270364

271-
generate_images(
272-
args.target_folder,
273-
args.n_qubits,
274-
args.max_gates,
275-
args.shots,
276-
args.dataset_size,
277-
args.threads,
278-
)
365+
checkpoint = Checkpoint(images_gen_checkpoint_file(args.target_folder))
366+
367+
if(checkpoint.stage == Stages.GEN_IMAGES):
368+
generate_images(
369+
args.target_folder,
370+
args.n_qubits,
371+
args.max_gates,
372+
args.shots,
373+
args.dataset_size,
374+
args.threads,
375+
checkpoint
376+
)
377+
378+
checkpoint.stage = Stages.DUPLICATES
379+
checkpoint.index = 0
279380

280-
remove_duplicated_files(args.target_folder)
381+
if(checkpoint.stage == Stages.DUPLICATES):
382+
remove_duplicated_files(args.target_folder, checkpoint)
383+
checkpoint.stage = Stages.TRANSFORM
384+
checkpoint.index = 0
281385

282-
transform_images(args.target_folder, args.new_image_dim)
386+
transform_images(args.target_folder, args.new_image_dim, checkpoint)
283387

284388

285389
if __name__ == "__main__":

utils/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,3 +40,4 @@
4040
output_plot_file = lambda target_folder: os.path.join(
4141
target_folder, "training_progress.png"
4242
)
43+
images_gen_checkpoint_file = lambda target_folder: os.path.join(target_folder, "gen_checkpoint.json")

0 commit comments

Comments
 (0)