11"""Generate dataset"""
22
3- from typing import Dict , List , TypedDict , Tuple , Any , Optional
3+ from typing import Dict , List , TypedDict , Any , Optional
44from enum import Enum
55import os
66import sys
77from concurrent .futures import ThreadPoolExecutor , as_completed
88import hashlib
99import json
10- from itertools import product
11- import random
10+ from itertools import product , combinations
1211import gc
1312import csv
1413
1514from qiskit import QuantumCircuit , ClassicalRegister
16- from qiskit .transpiler import generate_preset_pass_manager , StagedPassManager
15+ from qiskit .transpiler import generate_preset_pass_manager
1716
1817from qiskit_aer import AerSimulator
1918from qiskit_aer .primitives import Sampler
3635from utils .image import transform_image
3736from utils .colors import Colors
3837from generate .random_circuit import get_random_circuit
39- from utils .helpers import get_measurements
4038
4139Schema = Dict [str , Any ]
4240Dist = Dict [int , float ]
4341States = List [int ]
4442Measurements = List [int ]
43+ MeasurementsCombinations = List [Measurements ]
44+ Rows = List [List [Any ]]
4545
4646
4747class Stages (Enum ):
@@ -55,6 +55,8 @@ class Stages(Enum):
5555class Checkpoint :
5656 """Class to handle generate data checkpoints"""
5757
58+ __slots__ = ["_path" , "_stage" , "_index" , "_files" ]
59+
5860 def __init__ (self , path : Optional [FilePath ]):
5961 self ._path = path
6062
@@ -150,41 +152,6 @@ class CircuitResult(TypedDict):
150152 hash : str
151153
152154
153- def generate_circuit (
154- circuit_image_path : FilePath , pm : StagedPassManager , n_qubits : int , total_gates : int
155- ) -> Tuple [QuantumCircuit , int , Measurements ]:
156- """Generate circuit and return the isa version of the circuit, its depth and the qubits that were measured"""
157-
158- # non-interactive backend
159- matplotlib .use ("Agg" )
160-
161- qc = get_random_circuit (n_qubits , total_gates )
162-
163- type_of_meas = random .randint (0 , 1 )
164- measurements = list (range (n_qubits ))
165-
166- if type_of_meas == 0 :
167- measurements = get_measurements (n_qubits )
168- total_measurements = len (measurements )
169-
170- classical_register = ClassicalRegister (total_measurements )
171- qc .add_register (classical_register )
172- qc .measure (measurements , classical_register )
173- else :
174- qc .measure_all ()
175-
176- drawing = qc .draw ("mpl" , filename = circuit_image_path )
177- plt .close (drawing )
178-
179- depth = qc .depth ()
180- isa_qc = pm .run (qc )
181-
182- # release quantum circuit
183- del qc
184- del drawing
185- gc .collect ()
186-
187- return isa_qc , depth , measurements
188155
189156
190157def get_circuit_results (qc : QuantumCircuit , sampler : Sampler , shots : int ) -> Dist :
@@ -200,45 +167,77 @@ def fix_dist_gaps(dist: Dist, states: States):
200167 dist [state ] = 0
201168
202169
203- def generate_image (
204- index : int ,
170+ def generate_circuit_images (
171+ base_index : int ,
205172 states : States ,
206- image_path : FilePath ,
173+ measurements : MeasurementsCombinations ,
174+ base_image_path : FilePath ,
207175 n_qubits : int ,
208176 total_gates : int ,
209177 shots : int ,
210- ) -> CircuitResult :
211- """Run an experiment, save its image and return its results"""
178+ ) -> List [CircuitResult ]:
179+ """
180+ Run an experiment, save its images and return its results for different, combinations
181+ of measurements.
182+ """
212183
213184 sim = AerSimulator ()
214185 pm = generate_preset_pass_manager (backend = sim , optimization_level = 0 )
215- isa_qc , depth , measurements = generate_circuit (
216- image_path , pm , n_qubits , total_gates
217- )
186+ sampler = Sampler ()
187+ results : List [CircuitResult ] = []
188+
189+ # non-interactive backend
190+ matplotlib .use ("Agg" )
218191
219- with open (image_path , "rb" ) as file :
220- file_hash = hashlib .md5 (file .read ()).hexdigest ()
192+ qc = get_random_circuit (n_qubits , total_gates )
193+
194+ for index , measurement in enumerate (measurements ):
195+ image_index = base_index + index
196+ image_path = os .path .join (base_image_path , '%d.png' % (image_index ))
221197
222- sampler = Sampler ()
223- result = get_circuit_results (isa_qc , sampler , shots )
224- fix_dist_gaps (result , states )
198+ qc_copy = qc .copy ()
199+ total_measurements = len (measurement )
200+ classical_register = ClassicalRegister (total_measurements )
201+ qc_copy .add_register (classical_register )
202+ qc_copy .measure (measurement , list (range (total_measurements )))
203+
204+ drawing = qc_copy .draw ("mpl" , filename = image_path )
205+ plt .close (drawing )
206+ del drawing
207+
208+ depth = qc_copy .depth ()
209+ isa_qc = pm .run (qc_copy )
210+ del qc_copy
211+
212+ with open (image_path , "rb" ) as file :
213+ file_hash = hashlib .md5 (file .read ()).hexdigest ()
214+
215+ outcomes = get_circuit_results (isa_qc , sampler , shots )
216+ fix_dist_gaps (outcomes , states )
217+
218+ del isa_qc
219+ gc .collect ()
220+
221+
222+ # once we have more than a few combinations, depending on how many threads we
223+ # start, it can use a lot o memory. It also depends on how many states are possible, growing
224+ # exponentially with the number of qubits (2^n).
225+ results .append ({
226+ "index" : image_index ,
227+ "depth" : depth ,
228+ "file" : image_path ,
229+ "result" : json .dumps (list (outcomes .values ())),
230+ "hash" : file_hash ,
231+ "measurements" : json .dumps (measurement ),
232+ })
225233
226234 # clear data
227235 del sim
228236 del pm
229237 del sampler
230- del isa_qc
231238 gc .collect ()
232239
233- return {
234- "index" : index ,
235- "depth" : depth ,
236- "file" : image_path ,
237- "result" : json .dumps (list (result .values ())),
238- "hash" : file_hash ,
239- "measurements" : json .dumps (measurements ),
240- }
241-
240+ return results
242241
243242def generate_images (
244243 target_folder : FilePath ,
@@ -260,6 +259,14 @@ def generate_images(
260259 int ("" .join (comb ), 2 ) for comb in product ("01" , repeat = n_qubits )
261260 ]
262261
262+ # get all measurement combinations
263+ # may be expensive with a large number of qubits, but for 5,6,... it's good
264+ qubits_iter = list (range (n_qubits ))
265+ measurement_combs : MeasurementsCombinations = [ qubits_iter ] # start with [[0,1,2,3,4,....,n-1]]
266+ for amount in range (1 ,n_qubits ):
267+ measurement_combs = [ * measurement_combs , * list (combinations (qubits_iter ,amount ))] # type: ignore
268+ total_measurement_combs = len (measurement_combs )
269+
263270 base_dataset_path = dataset_path (target_folder )
264271
265272 index = checkpoint .index
@@ -268,14 +275,13 @@ def generate_images(
268275 args = []
269276
270277 for i in range (total_threads ):
271- filename = "%d.jpeg" % (index )
272- circuit_image_path = os .path .join (base_dataset_path , filename )
273-
278+ base_index = index * total_measurement_combs
274279 args .append (
275280 (
276- index ,
281+ base_index ,
277282 bitstrings_to_int ,
278- circuit_image_path ,
283+ measurement_combs ,
284+ base_dataset_path ,
279285 n_qubits ,
280286 total_gates ,
281287 shots ,
@@ -284,7 +290,7 @@ def generate_images(
284290 index += 1
285291
286292 with ThreadPoolExecutor (max_workers = total_threads ) as pool :
287- threads = [pool .submit (generate_image , * arg ) for arg in args ] # type:ignore
293+ threads = [pool .submit (generate_circuit_images , * arg ) for arg in args ] # type:ignore
288294
289295 # The best would be using the polars scan_csv and sink_csv to
290296 # write memory efficient queries easily.
@@ -297,10 +303,9 @@ def generate_images(
297303
298304 # df = open_csv(dataset_file_path)
299305
300- rows = []
306+ rows : Rows = []
301307 for future in as_completed (threads ): # type: ignore
302- values = list (future .result ().values ())
303- rows .append (values )
308+ rows = [* rows , * [list (result .values ()) for result in future .result ()]]
304309
305310 append_rows_to_df (dataset_file_path , rows )
306311
@@ -432,7 +437,7 @@ def save_df(df: pl.LazyFrame, file_path: FilePath):
432437 df .sink_csv (file_path )
433438
434439
435- def append_rows_to_df (file_path : FilePath , rows : List [ List [ Any ]] ):
440+ def append_rows_to_df (file_path : FilePath , rows : Rows ):
436441 """
437442 Use pythons built-in csv library to append rows into a file
438443 without loading it into memory directly.
0 commit comments