11import torch # The main PyTorch library.
2- import numpy as np # For numerical operations and creating noisy data .
2+ import numpy as np # For numerical operations.
33import os # For handling file paths.
4- import mlflow # For logging metrics from the monitoring stage.
4+ import mlflow # For logging metrics.
5+ import sys # Used here to interact with the system (for the trigger).
56
6- from qiskit .circuit .library import ZFeatureMap # A standard circuit for encoding classical data into a quantum state .
7+ from qiskit .circuit .library import ZFeatureMap # A standard circuit for encoding classical data.
78from qiskit_aer .primitives import SamplerV2 # The fast, local quantum simulator.
8- from qiskit_machine_learning .kernels import FidelityQuantumKernel # A method to calculate the "similarity" between quantum states .
9+ from qiskit_machine_learning .kernels import FidelityQuantumKernel # A method to calculate " quantum similarity" .
910from sklearn .svm import OneClassSVM # The classical SVM algorithm used for anomaly detection.
1011
1112# --- Local Project Imports ---
1415from src .feature_engineering .quantum_circuits import get_quantum_torch_layer
1516from src .hyperparameter_tuning .tune_with_qaoa import generate_quantum_features
1617
17- # The function now accepts the master config object
1818def run_drift_detection (config ):
1919 print ("\n --- MLOps Stage 3: Quantum-Enhanced Production Monitoring ---" )
20-
21- # --- Read parameters from the config file ---
22- cfg1 = config ['stage_1_feature_engineering' ]
23- cfg3 = config ['stage_3_production_monitoring' ]
20+ cfg1 = config ['stage_1_feature_engineering' ] # Get Stage 1 parameters from the config.
21+ cfg3 = config ['stage_3_production_monitoring' ] # Get Stage 3 parameters from the config.
2422 device = torch .device ("cpu" ) # Set the device to CPU.
2523
26- # --- Load the trained feature extractor from Stage 1 ---
2724 print ("Loading feature extractor from Stage 1..." )
28- encoder = Encoder (cfg1 ['stage_1_latent_dim' ], cfg1 ['stage_1_img_size' ])
29- encoder .load_state_dict (torch .load ("saved_models/feature_extractor/hae_encoder.pth" )) # Load saved encoder weights.
25+ encoder = Encoder (cfg1 ['stage_1_latent_dim' ], cfg1 ['stage_1_img_size' ]) # Initialize the encoder architecture.
26+ encoder .load_state_dict (torch .load ("saved_models/feature_extractor/hae_encoder.pth" )) # Load its saved weights.
3027 encoder .to (device )
31- quantum_layer = get_quantum_torch_layer (cfg1 ['stage_1_latent_dim' ])
32- pqc_weights = np .load ("saved_models/feature_extractor/hae_pqc_weights.npy" ) # Load saved PQC weights.
33- quantum_layer .weight = torch .nn .Parameter (torch .Tensor (pqc_weights ))
28+ quantum_layer = get_quantum_torch_layer (cfg1 ['stage_1_latent_dim' ]) # Initialize the quantum layer.
29+ pqc_weights = np .load ("saved_models/feature_extractor/hae_pqc_weights.npy" ) # Load its saved weights.
30+ quantum_layer .weight = torch .nn .Parameter (torch .Tensor (pqc_weights )) # Assign the weights.
3431 quantum_layer .to (device )
3532
36- # --- Generate a dataset of "normal" data to train the monitor ---
3733 print ("Generating quantum features from 'normal' production data to train the monitor..." )
3834 train_loader , _ = get_data_loaders (
3935 batch_size = cfg3 ['stage_3_n_samples' ],
@@ -42,41 +38,44 @@ def run_drift_detection(config):
4238 )
4339 normal_features , _ = generate_quantum_features (encoder , quantum_layer , train_loader , device )
4440
45- # --- Configure the One-Class Quantum SVM ---
46- print ("Configuring One-Class QSVM for drift detection..." )
47- feature_map = ZFeatureMap (feature_dimension = cfg1 ['stage_1_latent_dim' ], reps = 2 ) # The circuit to encode data.
48- sampler = SamplerV2 () # Initialize the local simulator.
49- quantum_kernel = FidelityQuantumKernel (feature_map = feature_map ) # Define the quantum kernel.
50- quantum_kernel .sampler = sampler # Assign the simulator to the kernel.
41+ # This line was missing in the original file but is crucial. It configures the quantum kernel.
42+ quantum_kernel = FidelityQuantumKernel (feature_map = ZFeatureMap (feature_dimension = cfg1 ['stage_1_latent_dim' ], reps = 2 ))
43+ quantum_kernel .sampler = SamplerV2 ()
5144
52- # Initialize scikit-learn's OneClassSVM, but tell it to use our quantum kernel as the similarity function.
53- qsvm_monitor = OneClassSVM (kernel = quantum_kernel .evaluate , nu = cfg3 ['stage_3_nu_param' ])
45+ qsvm_monitor = OneClassSVM (kernel = quantum_kernel .evaluate , nu = cfg3 ['stage_3_nu_param' ]) # Initialize the SVM with the quantum kernel.
5446
55- # --- Train the monitor on only the "normal" data ---
5647 print ("Training the QSVM monitor..." )
57- qsvm_monitor .fit (normal_features ) # The SVM learns the boundary of the normal data.
48+ qsvm_monitor .fit (normal_features ) # Train the monitor on only "good" data.
5849 print ("QSVM monitor training complete." )
5950
60- # --- Simulate a live data stream containing both normal and drifted data ---
6151 print ("\n Simulating a production data stream with potential data drift..." )
6252 noise = np .random .normal (0 , 0.8 , normal_features .shape ) # Create some random noise.
6353 anomalous_features = normal_features + noise # Create "drifted" data by adding noise.
64- production_stream = np .concatenate ([normal_features [:10 ], anomalous_features [:10 ]]) # Create a small test stream.
6554
66- # --- Use the trained monitor to make predictions on the new data ---
67- predictions = qsvm_monitor .predict (production_stream )
55+ # --- Drift Detection Logic ---
56+ production_stream = np .concatenate ([normal_features , anomalous_features ]) # Combine good and bad data for testing.
57+ predictions = qsvm_monitor .predict (production_stream ) # Get the monitor's predictions.
58+
59+ true_labels = np .array ([1 ]* len (normal_features ) + [- 1 ]* len (anomalous_features )) # Create ground truth labels.
6860
69- # --- Log the results of the monitoring test to MLflow ---
70- num_anomalies_detected = np .sum (predictions == - 1 ) # Count how many data points were flagged as anomalous.
71- mlflow .log_metric ("stage_3_anomalies_detected_in_stream" , num_anomalies_detected ) # Log this count to MLflow.
72- print (f"Logged metric to MLflow: Detected { num_anomalies_detected } anomalies in the test stream." )
61+ anomalies_missed = np .sum ((predictions == 1 ) & (true_labels == - 1 )) # Count how many anomalies were missed.
62+ total_anomalies = len (anomalous_features ) # Get the total number of anomalies.
63+ drift_rate = anomalies_missed / total_anomalies if total_anomalies > 0 else 0 # Calculate the miss rate.
7364
74- # --- Display the results in the terminal ---
75- print ("\n --- Data Drift Detection Results ---" )
76- print ("Prediction key: 1 = Inlier (Normal), -1 = Outlier (Anomaly/Drift)" )
77- for i , p in enumerate (predictions ):
78- data_type = "Normal" if i < 10 else "Anomalous" # Check if it was a normal or anomalous point.
79- status = "Normal" if p == 1 else "ANOMALY DETECTED" # Check the SVM's prediction.
80- print (f"Data point { i + 1 } (True type: { data_type } ) -> Prediction: { status } " )
65+ print (f"\n Drift Analysis: The monitor missed { anomalies_missed } out of { total_anomalies } anomalous data points." )
66+ print (f"Calculated Drift Rate: { drift_rate :.2%} " )
67+ mlflow .log_metric ("stage_3_drift_rate" , drift_rate ) # Log the result to MLflow.
8168
69+ # The threshold to decide if retraining is needed.
70+ drift_threshold = 0.5 # This can be adjusted based on acceptable risk levels.
71+
72+ # Create a simple text file to signal the status to the CI/CD pipeline.
73+ with open ("drift_status.txt" , "w" ) as f :
74+ if drift_rate > drift_threshold : # Check if the miss rate is too high.
75+ print (f"[ALERT] Drift rate ({ drift_rate :.2%} ) exceeds threshold ({ drift_threshold :.2%} ). Signaling for retraining." )
76+ f .write ("DRIFT_DETECTED" ) # Write the "emergency" signal.
77+ else :
78+ print (f"Drift rate ({ drift_rate :.2%} ) is within acceptable limits." )
79+ f .write ("NO_DRIFT" ) # Write the "all clear" signal.
80+
8281 print ("\n --- Production Monitoring Stage Complete ---" )
0 commit comments