Skip to content

Commit 2b596aa

Browse files
committed
split main.py into two files
1 parent 47dd4b5 commit 2b596aa

4 files changed

Lines changed: 349 additions & 348 deletions

File tree

hadamard_random_forest/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,14 @@
1414
__license__ = "MIT"
1515

1616
# Core functionality from main module
17-
from .main import (
17+
from .random_forest import (
1818
fix_random_seed,
1919
optimized_uniform_spanning_tree,
2020
generate_hypercube_tree,
21-
generate_random_forest,
21+
generate_random_forest
22+
)
23+
24+
from .sample import (
2225
get_circuits,
2326
get_circuits_hardware,
2427
get_samples,
Lines changed: 3 additions & 342 deletions
Original file line numberDiff line numberDiff line change
@@ -9,46 +9,15 @@
99
import logging
1010
import warnings
1111
from pathlib import Path
12-
from typing import Dict, List, Optional, Tuple
12+
from typing import List, Optional, Tuple
1313

1414
import numpy as np
1515
import networkx as nx
1616
import treelib
1717
from math import comb
18-
from scipy import sparse
1918
from scipy.sparse import coo_matrix
2019
import matplotlib.pyplot as plt
2120

22-
import qiskit
23-
import mthree
24-
from qiskit.providers import Backend
25-
from qiskit.transpiler import generate_preset_pass_manager
26-
from qiskit_ibm_runtime import Session, SamplerV2 as Sampler
27-
from qiskit_aer.primitives import Sampler as Aer_Sampler
28-
from mthree import M3Mitigation
29-
import mthree.utils as mthree_utils
30-
31-
# Public API
32-
__all__ = [
33-
"fix_random_seed",
34-
"hamming_weight",
35-
"pascal_layer",
36-
"optimized_uniform_spanning_tree",
37-
"generate_hypercube_tree",
38-
"find_global_roots_and_leafs",
39-
"get_path",
40-
"get_path_sparse_matrix",
41-
"get_weight",
42-
"get_signs",
43-
"majority_voting",
44-
"get_circuits",
45-
"get_samples",
46-
"get_samples_noisy",
47-
"get_circuits_hardware",
48-
"get_samples_hardware",
49-
"generate_random_forest",
50-
"get_statevector",
51-
]
5221

5322
def fix_random_seed(seed: int) -> None:
5423
"""
@@ -327,279 +296,14 @@ def majority_voting(votes: np.ndarray) -> np.ndarray:
327296
return result
328297

329298

330-
def get_circuits(
331-
num_qubits: int,
332-
base_circuit: qiskit.QuantumCircuit
333-
) -> List[qiskit.QuantumCircuit]:
334-
"""
335-
Generate a list of circuits each with a single Hadamard on one qubit appended.
336-
337-
Args:
338-
num_qubits: Total number of qubits.
339-
base_circuit: A QuantumCircuit to which measurements and H gates are appended.
340-
341-
Returns:
342-
List of QuantumCircuit objects including the base circuit with measure_all
343-
and one variant with an H applied to each qubit.
344-
"""
345-
circuits: List[qiskit.QuantumCircuit] = []
346-
# Base circuit with measurements
347-
circuits.append(base_circuit.measure_all(inplace=False))
348-
# Variants with extra Hadamard on each qubit
349-
for iq in range(num_qubits):
350-
qc = qiskit.QuantumCircuit(num_qubits)
351-
qc.compose(base_circuit, inplace=True)
352-
qc.h(iq)
353-
circuits.append(qc.measure_all(inplace=False))
354-
return circuits
355-
356-
357-
def get_samples(
358-
num_qubits: int,
359-
sampler: Aer_Sampler | Sampler,
360-
circuits: List[qiskit.QuantumCircuit],
361-
parameters: np.ndarray
362-
) -> List[np.ndarray]:
363-
"""
364-
Execute circuits and collect probability distributions using a noiseless sampler.
365-
366-
Args:
367-
num_qubits: Number of qubits (defines statevector size 2**num_qubits).
368-
sampler: Sampler object providing run().result().quasi_dists.
369-
circuits: List of QuantumCircuit to execute.
370-
parameters: 1D array of parameter values to bind to each circuit.
371-
372-
Returns:
373-
List of 1D numpy arrays of length 2**num_qubits representing probabilities.
374-
"""
375-
n = len(circuits)
376-
results = sampler.run(circuits, [parameters] * n).result().quasi_dists
377-
samples: List[np.ndarray] = []
378-
for res in results:
379-
proba = np.zeros(2**num_qubits, dtype=float)
380-
for idx, val in res.items():
381-
proba[idx] = val
382-
samples.append(proba)
383-
return samples
384-
385-
386-
def get_samples_noisy(
387-
num_qubits: int,
388-
circuits: List[qiskit.QuantumCircuit],
389-
shots: int,
390-
parameters: np.ndarray,
391-
backend_sim: Backend,
392-
error_mitigation: bool = False
393-
) -> List[np.ndarray]:
394-
"""
395-
Transpile and run circuits with optional M3 error mitigation.
396-
397-
Args:
398-
num_qubits: Number of qubits.
399-
circuits: List of QuantumCircuit to transpile and run.
400-
shots: Number of shots per circuit execution.
401-
parameters: Parameter values to assign.
402-
backend_sim: Qiskit backend to run circuits on.
403-
error_mitigation: If True, perform M3 calibration and mitigation.
404-
405-
Returns:
406-
List of numpy arrays of length 2**num_qubits with (mitigated) probabilities.
407-
"""
408-
409-
# Generate a preset pass manager.
410-
pm = generate_preset_pass_manager(
411-
optimization_level=3,
412-
backend=backend_sim,
413-
layout_method="default",
414-
routing_method="sabre",
415-
seed_transpiler=999
416-
)
417-
samples: List[np.ndarray] = []
418-
419-
if error_mitigation:
420-
# Dictionary to store unique mapping keys and their M3Mitigation objects.
421-
mapping_mit: Dict[str, M3Mitigation] = {}
422-
# Measurement results.
423-
counts_data: List[Tuple[Dict[str, int], Any, str]] = []
424-
425-
# Transpile and run each circuit.
426-
for circuit in circuits:
427-
transpiled = pm.run(circuit)
428-
mapping = mthree_utils.final_measurement_mapping(transpiled)
429-
430-
# Create a key for the mapping.
431-
key = str(mapping)
432-
433-
# If this mapping hasn't been seen, calibrate a new mitigation object.
434-
if key not in mapping_mit:
435-
# print("=========== New M3 calibration detected ===========")
436-
mit = M3Mitigation(backend_sim)
437-
mit.cals_from_system(mapping)
438-
mapping_mit[key] = mit
439-
440-
# Assign parameters and execute the circuit.
441-
transpiled.assign_parameters(parameters, inplace=True)
442-
counts = backend_sim.run(transpiled, shots=shots).result().get_counts()
443-
counts_data.append((counts, mapping, key))
444-
445-
# Apply error mitigation to each result.
446-
for counts, mapping, key in counts_data:
447-
mit = mapping_mit[key]
448-
# print(f"Applying M3 error mitigation with mapping: {mapping}")
449-
quasi = mit.apply_correction(counts, mapping)
450-
451-
# Convert counts to a probability distribution.
452-
probs = quasi.nearest_probability_distribution()
453-
dist = {k: v / shots for k, v in qiskit.result.ProbDistribution(probs, shots=shots).items()}
454-
455-
# Build a probability vector.
456-
proba = np.zeros(2**num_qubits, dtype=float)
457-
for idx, val in dist.items():
458-
proba[idx] = val
459-
samples.append(proba)
460-
else:
461-
for circuit in circuits:
462-
transpiled = pm.run(circuit)
463-
transpiled.assign_parameters(parameters, inplace=True)
464-
counts = backend_sim.run(transpiled, shots=shots).result().get_counts()
465-
proba = np.zeros(2**num_qubits, dtype=float)
466-
for bitstr, count in counts.items():
467-
idx = int(bitstr, 2)
468-
proba[idx] = count / shots
469-
samples.append(proba)
470-
return samples
471-
472-
473-
def get_circuits_hardware(
474-
num_qubits: int,
475-
base_circuit: qiskit.QuantumCircuit,
476-
device: Backend
477-
) -> List[qiskit.QuantumCircuit]:
478-
"""
479-
Transpile a base circuit for hardware and generate variants with an appended Hadamard gate.
480-
481-
Args:
482-
num_qubits: Total number of qubits.
483-
base_circuit: The original QuantumCircuit to transpile and append to.
484-
device: Qiskit backend or simulator to target for transpilation.
485-
486-
Returns:
487-
A list of transpiled QuantumCircuit objects:
488-
- The first is the base circuit with measurements.
489-
- Each subsequent circuit has an additional H gate on qubit i before measurement.
490-
"""
491-
# Create a pass manager for transpilation
492-
pm = generate_preset_pass_manager(
493-
optimization_level=3,
494-
backend=device,
495-
layout_method="default",
496-
routing_method="sabre",
497-
seed_transpiler=999
498-
)
499-
500-
circuits: List[qiskit.QuantumCircuit] = []
501-
# Base circuit: add measurements and transpile
502-
qc_base = base_circuit.measure_all(inplace=False)
503-
circuits.append(pm.run(qc_base))
504-
505-
# Variants: apply Hadamard on each qubit, then measure and transpile
506-
for qubit in range(num_qubits):
507-
qc = qiskit.QuantumCircuit(num_qubits)
508-
qc.compose(base_circuit, inplace=True)
509-
qc.h(qubit)
510-
qc.measure_all(inplace=True)
511-
circuits.append(pm.run(qc))
512-
513-
return circuits
514-
515-
516-
def get_samples_hardware(
517-
num_qubits: int,
518-
shots: int,
519-
circuits: List[qiskit.QuantumCircuit],
520-
parameters: np.ndarray,
521-
device: Backend,
522-
error_mitigation: bool = True
523-
) -> Tuple[List[np.ndarray], List[np.ndarray], List[str], List[float]]:
524-
"""
525-
Execute circuits on hardware with optional M3 error mitigation and record raw and mitigated samples.
526-
527-
Args:
528-
num_qubits: Number of qubits (defines vector size 2**num_qubits).
529-
shots: Number of shots per circuit execution.
530-
circuits: List of transpiled QuantumCircuit objects.
531-
parameters: 1D array of parameter values to bind to each circuit.
532-
device: Qiskit backend to run circuits on.
533-
error_mitigation: If True, perform M3 calibration and apply measurement mitigation.
534-
535-
Returns:
536-
A tuple of four items:
537-
mitigated_samples: List of numpy arrays (length 2**num_qubits) after mitigation.
538-
raw_samples: List of numpy arrays without mitigation.
539-
job_ids: List of job ID strings for each circuit execution.
540-
quantum_times: List of quantum execution times (in seconds).
541-
"""
542-
# Prepare sampler for hardware
543-
sampler = Sampler(device)
544-
sampler.options.default_shots = shots
545-
546-
mapping_mit: dict = {}
547-
results = [] # List of tuples: (counts, mapping_key)
548-
job_ids: List[str] = []
549-
quantum_times: List[float] = []
550-
551-
# Submit jobs and collect raw counts
552-
for idx, circ in enumerate(circuits):
553-
# Measurement mitigation setup
554-
mapping = mthree_utils.final_measurement_mapping(circ)
555-
key = str(mapping)
556-
if error_mitigation and key not in mapping_mit:
557-
# print("=========== New M3 calibration detected ===========")
558-
mit = mthree.M3Mitigation(device)
559-
mit.cals_from_system(mapping)
560-
mapping_mit[key] = mit
561-
562-
# Run circuit on hardware
563-
job = sampler.run([(circ, parameters)])
564-
result = job.result()[0]
565-
counts = result.data.meas.get_counts()
566-
results.append((counts, key))
567-
568-
job_ids.append(job.job_id())
569-
quantum_times.append(job.usage_estimation.get('quantum_seconds', 0.0))
570-
571-
# Process raw samples
572-
raw_samples: List[np.ndarray] = []
573-
for counts, _ in results:
574-
vec = np.zeros(2**num_qubits, dtype=float)
575-
for bitstr, cnt in counts.items():
576-
idx = int(bitstr, 2)
577-
vec[idx] = cnt / shots
578-
raw_samples.append(vec)
579-
580-
# Apply mitigation if requested
581-
mitigated_samples: List[np.ndarray] = []
582-
for (counts, key), raw in zip(results, raw_samples):
583-
if error_mitigation:
584-
mit = mapping_mit[key]
585-
quasi = mit.apply_correction(counts, mthree_utils.final_measurement_mapping(circuits[0]))
586-
probs = quasi.nearest_probability_distribution()
587-
vec = np.zeros(2**num_qubits, dtype=float)
588-
for bitstr, p in probs.items():
589-
vec[int(bitstr, 2)] = p
590-
mitigated_samples.append(vec)
591-
else:
592-
mitigated_samples.append(raw.copy())
593-
594-
return mitigated_samples, raw_samples, job_ids, quantum_times
595-
596299

597300
def generate_random_forest(
598301
num_qubits: int,
599302
num_trees: int,
600303
samples: List[np.ndarray],
601304
save_tree: bool = True,
602-
show_tree: bool = False
305+
show_tree: bool = False,
306+
show_first: bool = False
603307
) -> np.ndarray:
604308
"""
605309
Build multiple random spanning trees on a hypercube and aggregate signs by majority voting.
@@ -688,46 +392,3 @@ def generate_random_forest(
688392
assert signs_stack is not None
689393
return majority_voting(signs_stack)
690394

691-
def get_statevector(
692-
num_qubits: int,
693-
num_trees: int,
694-
samples: List[np.ndarray],
695-
save_tree: bool = True,
696-
show_tree: bool = False
697-
) -> np.ndarray:
698-
"""
699-
Construct the estimated statevector from measured samples and sign forest.
700-
Allows passing save_tree flag to control tree visualization.
701-
702-
Args:
703-
num_qubits: Cube dimension (log2 of state size).
704-
num_trees: Number of trees in the random forest.
705-
samples: List of sample probability arrays.
706-
save_tree: If True, save the first 10 forest tree visualizations.
707-
708-
Returns:
709-
A 1D numpy array of length 2**num_qubits representing the statevector.
710-
"""
711-
# Compute amplitudes
712-
base = samples[0]
713-
if np.any(base < 0):
714-
import warnings
715-
warnings.warn("Negative sample probabilities found; using absolute values.")
716-
amplitudes = np.sqrt(np.abs(base))
717-
else:
718-
amplitudes = np.sqrt(base)
719-
720-
# Generate signs (with optional save_tree)
721-
signs = generate_random_forest(
722-
num_qubits=num_qubits,
723-
num_trees=num_trees,
724-
samples=samples,
725-
save_tree=save_tree,
726-
show_tree=show_tree
727-
)
728-
729-
# Normalization
730-
statevector = amplitudes * signs
731-
statevector = statevector/np.linalg.norm(statevector)
732-
733-
return statevector

0 commit comments

Comments
 (0)