Skip to content

Commit c1c718a

Browse files
committed
update dataset generation - generate an image for each measurement combination
1 parent 6586a75 commit c1c718a

File tree

3 files changed

+83
-75
lines changed

3 files changed

+83
-75
lines changed

Makefile

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
SHELL := /bin/bash
22
TARGET_PATH ?= "."
33

4-
clean-all: clean-dataset clean-pred clean-ghz clean-model clean-checkpoints clean-history
4+
clean-all: clean-dataset clean-pred clean-ghz clean-model clean-checkpoints clean-history clean-gen-checkpoint
55

66
clean-dataset:
77
rm -rf $(TARGET_PATH)/dataset/ $(TARGET_PATH)/dataset.csv $(TARGET_PATH)/*.h5 $(TARGET_PATH)/dataset-images.zip
@@ -18,6 +18,9 @@ clean-model:
1818
clean-checkpoints:
1919
rm -rf $(TARGET_PATH)/checkpoint_*
2020

21+
clean-gen-checkpoint:
22+
rm -rf $(TARGET_PATH)/gen_checkpoint.json
23+
2124
clean-history:
2225
rm -rf $(TARGET_PATH)/history.json
2326

dataset.py

Lines changed: 78 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,18 @@
11
"""Generate dataset"""
22

3-
from typing import Dict, List, TypedDict, Tuple, Any, Optional
3+
from typing import Dict, List, TypedDict, Any, Optional
44
from enum import Enum
55
import os
66
import sys
77
from concurrent.futures import ThreadPoolExecutor, as_completed
88
import hashlib
99
import json
10-
from itertools import product
11-
import random
10+
from itertools import product, combinations
1211
import gc
1312
import csv
1413

1514
from qiskit import QuantumCircuit, ClassicalRegister
16-
from qiskit.transpiler import generate_preset_pass_manager, StagedPassManager
15+
from qiskit.transpiler import generate_preset_pass_manager
1716

1817
from qiskit_aer import AerSimulator
1918
from qiskit_aer.primitives import Sampler
@@ -36,12 +35,13 @@
3635
from utils.image import transform_image
3736
from utils.colors import Colors
3837
from generate.random_circuit import get_random_circuit
39-
from utils.helpers import get_measurements
4038

4139
Schema = Dict[str, Any]
4240
Dist = Dict[int, float]
4341
States = List[int]
4442
Measurements = List[int]
43+
MeasurementsCombinations = List[Measurements]
44+
Rows = List[List[Any]]
4545

4646

4747
class Stages(Enum):
@@ -55,6 +55,8 @@ class Stages(Enum):
5555
class Checkpoint:
5656
"""Class to handle generate data checkpoints"""
5757

58+
__slots__ = ["_path", "_stage", "_index", "_files"]
59+
5860
def __init__(self, path: Optional[FilePath]):
5961
self._path = path
6062

@@ -150,41 +152,6 @@ class CircuitResult(TypedDict):
150152
hash: str
151153

152154

153-
def generate_circuit(
154-
circuit_image_path: FilePath, pm: StagedPassManager, n_qubits: int, total_gates: int
155-
) -> Tuple[QuantumCircuit, int, Measurements]:
156-
"""Generate circuit and return the isa version of the circuit, its depth and the qubits that were measured"""
157-
158-
# non-interactive backend
159-
matplotlib.use("Agg")
160-
161-
qc = get_random_circuit(n_qubits, total_gates)
162-
163-
type_of_meas = random.randint(0, 1)
164-
measurements = list(range(n_qubits))
165-
166-
if type_of_meas == 0:
167-
measurements = get_measurements(n_qubits)
168-
total_measurements = len(measurements)
169-
170-
classical_register = ClassicalRegister(total_measurements)
171-
qc.add_register(classical_register)
172-
qc.measure(measurements, classical_register)
173-
else:
174-
qc.measure_all()
175-
176-
drawing = qc.draw("mpl", filename=circuit_image_path)
177-
plt.close(drawing)
178-
179-
depth = qc.depth()
180-
isa_qc = pm.run(qc)
181-
182-
# release quantum circuit
183-
del qc
184-
del drawing
185-
gc.collect()
186-
187-
return isa_qc, depth, measurements
188155

189156

190157
def get_circuit_results(qc: QuantumCircuit, sampler: Sampler, shots: int) -> Dist:
@@ -200,45 +167,77 @@ def fix_dist_gaps(dist: Dist, states: States):
200167
dist[state] = 0
201168

202169

203-
def generate_image(
204-
index: int,
170+
def generate_circuit_images(
171+
base_index: int,
205172
states: States,
206-
image_path: FilePath,
173+
measurements: MeasurementsCombinations,
174+
base_image_path: FilePath,
207175
n_qubits: int,
208176
total_gates: int,
209177
shots: int,
210-
) -> CircuitResult:
211-
"""Run an experiment, save its image and return its results"""
178+
) -> List[CircuitResult]:
179+
"""
180+
Run an experiment, save its images and return its results for different, combinations
181+
of measurements.
182+
"""
212183

213184
sim = AerSimulator()
214185
pm = generate_preset_pass_manager(backend=sim, optimization_level=0)
215-
isa_qc, depth, measurements = generate_circuit(
216-
image_path, pm, n_qubits, total_gates
217-
)
186+
sampler = Sampler()
187+
results : List[CircuitResult] = []
188+
189+
# non-interactive backend
190+
matplotlib.use("Agg")
218191

219-
with open(image_path, "rb") as file:
220-
file_hash = hashlib.md5(file.read()).hexdigest()
192+
qc = get_random_circuit(n_qubits, total_gates)
193+
194+
for index, measurement in enumerate(measurements):
195+
image_index = base_index + index
196+
image_path = os.path.join(base_image_path, '%d.png'%(image_index))
221197

222-
sampler = Sampler()
223-
result = get_circuit_results(isa_qc, sampler, shots)
224-
fix_dist_gaps(result, states)
198+
qc_copy = qc.copy()
199+
total_measurements = len(measurement)
200+
classical_register = ClassicalRegister(total_measurements)
201+
qc_copy.add_register(classical_register)
202+
qc_copy.measure(measurement, list(range(total_measurements)))
203+
204+
drawing = qc_copy.draw("mpl", filename=image_path)
205+
plt.close(drawing)
206+
del drawing
207+
208+
depth = qc_copy.depth()
209+
isa_qc = pm.run(qc_copy)
210+
del qc_copy
211+
212+
with open(image_path, "rb") as file:
213+
file_hash = hashlib.md5(file.read()).hexdigest()
214+
215+
outcomes = get_circuit_results(isa_qc, sampler, shots)
216+
fix_dist_gaps(outcomes, states)
217+
218+
del isa_qc
219+
gc.collect()
220+
221+
222+
# once we have more than a few combinations, depending on how many threads we
223+
# start, it can use a lot o memory. It also depends on how many states are possible, growing
224+
# exponentially with the number of qubits (2^n).
225+
results.append({
226+
"index": image_index,
227+
"depth": depth,
228+
"file": image_path,
229+
"result": json.dumps(list(outcomes.values())),
230+
"hash": file_hash,
231+
"measurements": json.dumps(measurement),
232+
})
225233

226234
# clear data
227235
del sim
228236
del pm
229237
del sampler
230-
del isa_qc
231238
gc.collect()
232239

233-
return {
234-
"index": index,
235-
"depth": depth,
236-
"file": image_path,
237-
"result": json.dumps(list(result.values())),
238-
"hash": file_hash,
239-
"measurements": json.dumps(measurements),
240-
}
241-
240+
return results
242241

243242
def generate_images(
244243
target_folder: FilePath,
@@ -260,6 +259,14 @@ def generate_images(
260259
int("".join(comb), 2) for comb in product("01", repeat=n_qubits)
261260
]
262261

262+
# get all measurement combinations
263+
# may be expensive with a large number of qubits, but for 5,6,... it's good
264+
qubits_iter = list(range(n_qubits))
265+
measurement_combs : MeasurementsCombinations = [ qubits_iter ] # start with [[0,1,2,3,4,....,n-1]]
266+
for amount in range(1,n_qubits):
267+
measurement_combs = [ *measurement_combs, *list(combinations(qubits_iter,amount))] # type: ignore
268+
total_measurement_combs = len(measurement_combs)
269+
263270
base_dataset_path = dataset_path(target_folder)
264271

265272
index = checkpoint.index
@@ -268,14 +275,13 @@ def generate_images(
268275
args = []
269276

270277
for i in range(total_threads):
271-
filename = "%d.jpeg" % (index)
272-
circuit_image_path = os.path.join(base_dataset_path, filename)
273-
278+
base_index = index*total_measurement_combs
274279
args.append(
275280
(
276-
index,
281+
base_index,
277282
bitstrings_to_int,
278-
circuit_image_path,
283+
measurement_combs,
284+
base_dataset_path,
279285
n_qubits,
280286
total_gates,
281287
shots,
@@ -284,7 +290,7 @@ def generate_images(
284290
index += 1
285291

286292
with ThreadPoolExecutor(max_workers=total_threads) as pool:
287-
threads = [pool.submit(generate_image, *arg) for arg in args] # type:ignore
293+
threads = [pool.submit(generate_circuit_images, *arg) for arg in args] # type:ignore
288294

289295
# The best would be using the polars scan_csv and sink_csv to
290296
# write memory efficient queries easily.
@@ -297,10 +303,9 @@ def generate_images(
297303

298304
# df = open_csv(dataset_file_path)
299305

300-
rows = []
306+
rows : Rows = []
301307
for future in as_completed(threads): # type: ignore
302-
values = list(future.result().values())
303-
rows.append(values)
308+
rows = [*rows, *[list(result.values()) for result in future.result()]]
304309

305310
append_rows_to_df(dataset_file_path, rows)
306311

@@ -432,7 +437,7 @@ def save_df(df: pl.LazyFrame, file_path: FilePath):
432437
df.sink_csv(file_path)
433438

434439

435-
def append_rows_to_df(file_path: FilePath, rows: List[List[Any]]):
440+
def append_rows_to_df(file_path: FilePath, rows: Rows):
436441
"""
437442
Use pythons built-in csv library to append rows into a file
438443
without loading it into memory directly.

utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414

1515
DEFAULT_BATCH_SIZE = 10
1616

17-
DEFAULT_DATASET_SIZE = 20000
17+
DEFAULT_DATASET_SIZE = 20000 # this one doesn't reflect exactly the size of the dataset, once the dataset might get either bigger, due to the different combinations of mesurements, or smaller due to duplicated circuits
1818

1919
DEFAULT_TARGET_FOLDER = "."
2020

0 commit comments

Comments
 (0)