Skip to content

Commit b6f89f7

Browse files
committed
added model retraining workflow when drift is detected
1 parent dfe9b33 commit b6f89f7

File tree

3 files changed

+132
-96
lines changed

3 files changed

+132
-96
lines changed

.github/workflows/main.yml

Lines changed: 47 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,84 +1,76 @@
1-
# This is the CI/CD pipeline for the Quantum MLOps Project
2-
# It automates testing, validation, and artifact storage.
1+
name: Quantum MLOps CI/CD and Validation Pipeline
32

4-
name: Quantum MLOps CI/CD Pipeline
5-
6-
# --- TRIGGERS ---
7-
# This workflow runs on:
8-
# 1. Pushes to any branch
9-
# 2. Pull Requests to any branch
10-
# 3. Manual triggers from the GitHub Actions tab
113
on:
124
push:
135
pull_request:
146
workflow_dispatch:
157

168
jobs:
17-
# --- JOB 1: Continuous Integration (Fast Checks) ---
9+
# --- JOB 1: Lint and Unit Test (remains unchanged) ---
1810
lint-and-test:
1911
name: Lint and Unit Test
20-
runs-on: ubuntu-latest # Use a standard Linux runner
21-
12+
runs-on: ubuntu-latest
2213
steps:
23-
- name: 1. Check out code
24-
uses: actions/checkout@v4
25-
26-
- name: 2. Set up Python environment
27-
uses: actions/setup-python@v5
28-
with:
29-
python-version: '3.11' # Specify a Python version
30-
31-
- name: 3. Install dependencies
14+
- uses: actions/checkout@v4
15+
- uses: actions/setup-python@v5
16+
with: { python-version: '3.11' }
17+
- name: Install dependencies
3218
run: |
3319
pip install -r requirements.txt
34-
pip install qiskit-optimization # Install the extra dependency
35-
pip install flake8 # Install the linter
36-
37-
- name: 4. Lint with flake8
20+
pip install qiskit-optimization flake8
21+
- name: Lint with flake8
3822
run: |
39-
# Stop the build if there are Python syntax errors or undefined names
4023
flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics
41-
# Exit-zero treats all errors as warnings. The GitHub editor is 127 chars wide
4224
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
25+
- name: Run Unit Tests
26+
run: python -m unittest discover tests
4327

44-
- name: 5. Run Unit Tests
45-
run: |
46-
python -m unittest discover tests
47-
48-
# --- JOB 2: Continuous Delivery (Full Pipeline Validation) ---
49-
run-full-pipeline:
50-
name: Run Full Pipeline & Save Artifacts
51-
needs: lint-and-test # This job will only start if the 'lint-and-test' job succeeds
28+
# --- JOB 2: Full Pipeline Validation & Automated Retraining Trigger ---
29+
run-full-pipeline-and-monitor:
30+
name: Run Full Pipeline & Trigger Retraining on Drift
31+
needs: lint-and-test
5232
runs-on: ubuntu-latest
53-
54-
# This job will only run on pushes to the 'main' branch OR manual triggers
5533
if: (github.event_name == 'push' && github.ref == 'refs/heads/main') || (github.event_name == 'workflow_dispatch')
5634

5735
steps:
58-
- name: 1. Check out code
59-
uses: actions/checkout@v4
60-
61-
- name: 2. Set up Python environment
62-
uses: actions/setup-python@v5
63-
with:
64-
python-version: '3.11'
65-
66-
- name: 3. Install dependencies
36+
- uses: actions/checkout@v4
37+
- uses: actions/setup-python@v5
38+
with: { python-version: '3.11' }
39+
- name: Install dependencies
6740
run: |
6841
pip install -r requirements.txt
6942
pip install qiskit-optimization
7043
71-
- name: 4. Run the end-to-end MLOps pipeline
72-
run: |
73-
python run_pipeline.py
44+
- name: Run the end-to-end MLOps pipeline
45+
run: python run_pipeline.py
7446

75-
- name: 5. Archive results
47+
- name: Archive initial pipeline results
7648
uses: actions/upload-artifact@v4
7749
with:
78-
name: pipeline-results
50+
name: initial-pipeline-results
7951
path: |
8052
saved_models/
81-
visualization_stage_1_feature_space.png
82-
visualization_stage_2_hpo_search.png
83-
visualization_stage_3_drift_boundary.png
84-
visualization_stage_3_confusion_matrix.png
53+
mlruns/
54+
visualization*.png
55+
drift_status.txt
56+
57+
# Step to check for data drift and trigger retraining
58+
- name: Check for data drift and trigger retraining if needed
59+
id: check_drift
60+
run: |
61+
# Read the status file created by our Python monitoring script
62+
STATUS=$(cat drift_status.txt)
63+
if [ "$STATUS" == "DRIFT_DETECTED" ]; then
64+
echo "Drift detected. Triggering the automated retraining workflow."
65+
# Make a secure API call to start the 'retrain.yml' workflow
66+
# This uses the built-in GITHUB_TOKEN for authentication.
67+
curl -L \
68+
-X POST \
69+
-H "Accept: application/vnd.github+json" \
70+
-H "Authorization: Bearer ${{ secrets.GITHUB_TOKEN }}" \
71+
-H "X-GitHub-Api-Version: 2022-11-28" \
72+
https://api.github.com/repos/${{ github.repository }}/actions/workflows/retrain.yml/dispatches \
73+
-d '{"ref":"main"}'
74+
else
75+
echo "No significant drift detected. No retraining needed."
76+
fi

.github/workflows/retrain.yml

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Workflow for closing the loop.
2+
# Its only job is to run the full training pipeline from end-to-end.
3+
# It is triggered ONLY by an API call from our main validation workflow when it detects significant data drift.
4+
5+
name: Automated Retraining Pipeline
6+
7+
on:
8+
# This workflow can only be started via an API call (or manually from the Actions tab).
9+
# It will NOT run on a normal 'push' or 'pull_request'.
10+
workflow_dispatch:
11+
12+
jobs:
13+
run-full-training:
14+
name: Execute Full MLOps Pipeline
15+
runs-on: ubuntu-latest
16+
17+
steps:
18+
- name: 1. Check out code from repository
19+
uses: actions/checkout@v4
20+
21+
- name: 2. Set up Python environment
22+
uses: actions/setup-python@v5
23+
with:
24+
python-version: '3.11'
25+
26+
- name: 3. Install all project dependencies
27+
run: |
28+
pip install -r requirements.txt
29+
pip install qiskit-optimization
30+
31+
- name: 4. Run the end-to-end MLOps training pipeline
32+
run: |
33+
# This command runs the exact same training process as the main pipeline,
34+
# creating a new, retrained model and logging it to a new MLflow run.
35+
python run_pipeline.py
36+
37+
- name: 5. Archive results from the retraining run
38+
# This saves all the new models and plots from the successful retraining.
39+
uses: actions/upload-artifact@v4
40+
with:
41+
name: retrained-pipeline-results
42+
path: |
43+
saved_models/
44+
mlruns/
45+
visualization*.png
Lines changed: 40 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import torch # The main PyTorch library.
2-
import numpy as np # For numerical operations and creating noisy data.
2+
import numpy as np # For numerical operations.
33
import os # For handling file paths.
4-
import mlflow # For logging metrics from the monitoring stage.
4+
import mlflow # For logging metrics.
5+
import sys # Used here to interact with the system (for the trigger).
56

6-
from qiskit.circuit.library import ZFeatureMap # A standard circuit for encoding classical data into a quantum state.
7+
from qiskit.circuit.library import ZFeatureMap # A standard circuit for encoding classical data.
78
from qiskit_aer.primitives import SamplerV2 # The fast, local quantum simulator.
8-
from qiskit_machine_learning.kernels import FidelityQuantumKernel # A method to calculate the "similarity" between quantum states.
9+
from qiskit_machine_learning.kernels import FidelityQuantumKernel # A method to calculate "quantum similarity".
910
from sklearn.svm import OneClassSVM # The classical SVM algorithm used for anomaly detection.
1011

1112
# --- Local Project Imports ---
@@ -14,26 +15,21 @@
1415
from src.feature_engineering.quantum_circuits import get_quantum_torch_layer
1516
from src.hyperparameter_tuning.tune_with_qaoa import generate_quantum_features
1617

17-
# The function now accepts the master config object
1818
def run_drift_detection(config):
1919
print("\n--- MLOps Stage 3: Quantum-Enhanced Production Monitoring ---")
20-
21-
# --- Read parameters from the config file ---
22-
cfg1 = config['stage_1_feature_engineering']
23-
cfg3 = config['stage_3_production_monitoring']
20+
cfg1 = config['stage_1_feature_engineering'] # Get Stage 1 parameters from the config.
21+
cfg3 = config['stage_3_production_monitoring'] # Get Stage 3 parameters from the config.
2422
device = torch.device("cpu") # Set the device to CPU.
2523

26-
# --- Load the trained feature extractor from Stage 1 ---
2724
print("Loading feature extractor from Stage 1...")
28-
encoder = Encoder(cfg1['stage_1_latent_dim'], cfg1['stage_1_img_size'])
29-
encoder.load_state_dict(torch.load("saved_models/feature_extractor/hae_encoder.pth")) # Load saved encoder weights.
25+
encoder = Encoder(cfg1['stage_1_latent_dim'], cfg1['stage_1_img_size']) # Initialize the encoder architecture.
26+
encoder.load_state_dict(torch.load("saved_models/feature_extractor/hae_encoder.pth")) # Load its saved weights.
3027
encoder.to(device)
31-
quantum_layer = get_quantum_torch_layer(cfg1['stage_1_latent_dim'])
32-
pqc_weights = np.load("saved_models/feature_extractor/hae_pqc_weights.npy") # Load saved PQC weights.
33-
quantum_layer.weight = torch.nn.Parameter(torch.Tensor(pqc_weights))
28+
quantum_layer = get_quantum_torch_layer(cfg1['stage_1_latent_dim']) # Initialize the quantum layer.
29+
pqc_weights = np.load("saved_models/feature_extractor/hae_pqc_weights.npy") # Load its saved weights.
30+
quantum_layer.weight = torch.nn.Parameter(torch.Tensor(pqc_weights)) # Assign the weights.
3431
quantum_layer.to(device)
3532

36-
# --- Generate a dataset of "normal" data to train the monitor ---
3733
print("Generating quantum features from 'normal' production data to train the monitor...")
3834
train_loader, _ = get_data_loaders(
3935
batch_size=cfg3['stage_3_n_samples'],
@@ -42,41 +38,44 @@ def run_drift_detection(config):
4238
)
4339
normal_features, _ = generate_quantum_features(encoder, quantum_layer, train_loader, device)
4440

45-
# --- Configure the One-Class Quantum SVM ---
46-
print("Configuring One-Class QSVM for drift detection...")
47-
feature_map = ZFeatureMap(feature_dimension=cfg1['stage_1_latent_dim'], reps=2) # The circuit to encode data.
48-
sampler = SamplerV2() # Initialize the local simulator.
49-
quantum_kernel = FidelityQuantumKernel(feature_map=feature_map) # Define the quantum kernel.
50-
quantum_kernel.sampler = sampler # Assign the simulator to the kernel.
41+
# This line was missing in the original file but is crucial. It configures the quantum kernel.
42+
quantum_kernel = FidelityQuantumKernel(feature_map=ZFeatureMap(feature_dimension=cfg1['stage_1_latent_dim'], reps=2))
43+
quantum_kernel.sampler = SamplerV2()
5144

52-
# Initialize scikit-learn's OneClassSVM, but tell it to use our quantum kernel as the similarity function.
53-
qsvm_monitor = OneClassSVM(kernel=quantum_kernel.evaluate, nu=cfg3['stage_3_nu_param'])
45+
qsvm_monitor = OneClassSVM(kernel=quantum_kernel.evaluate, nu=cfg3['stage_3_nu_param']) # Initialize the SVM with the quantum kernel.
5446

55-
# --- Train the monitor on only the "normal" data ---
5647
print("Training the QSVM monitor...")
57-
qsvm_monitor.fit(normal_features) # The SVM learns the boundary of the normal data.
48+
qsvm_monitor.fit(normal_features) # Train the monitor on only "good" data.
5849
print("QSVM monitor training complete.")
5950

60-
# --- Simulate a live data stream containing both normal and drifted data ---
6151
print("\nSimulating a production data stream with potential data drift...")
6252
noise = np.random.normal(0, 0.8, normal_features.shape) # Create some random noise.
6353
anomalous_features = normal_features + noise # Create "drifted" data by adding noise.
64-
production_stream = np.concatenate([normal_features[:10], anomalous_features[:10]]) # Create a small test stream.
6554

66-
# --- Use the trained monitor to make predictions on the new data ---
67-
predictions = qsvm_monitor.predict(production_stream)
55+
# --- Drift Detection Logic ---
56+
production_stream = np.concatenate([normal_features, anomalous_features]) # Combine good and bad data for testing.
57+
predictions = qsvm_monitor.predict(production_stream) # Get the monitor's predictions.
58+
59+
true_labels = np.array([1]*len(normal_features) + [-1]*len(anomalous_features)) # Create ground truth labels.
6860

69-
# --- Log the results of the monitoring test to MLflow ---
70-
num_anomalies_detected = np.sum(predictions == -1) # Count how many data points were flagged as anomalous.
71-
mlflow.log_metric("stage_3_anomalies_detected_in_stream", num_anomalies_detected) # Log this count to MLflow.
72-
print(f"Logged metric to MLflow: Detected {num_anomalies_detected} anomalies in the test stream.")
61+
anomalies_missed = np.sum((predictions == 1) & (true_labels == -1)) # Count how many anomalies were missed.
62+
total_anomalies = len(anomalous_features) # Get the total number of anomalies.
63+
drift_rate = anomalies_missed / total_anomalies if total_anomalies > 0 else 0 # Calculate the miss rate.
7364

74-
# --- Display the results in the terminal ---
75-
print("\n--- Data Drift Detection Results ---")
76-
print("Prediction key: 1 = Inlier (Normal), -1 = Outlier (Anomaly/Drift)")
77-
for i, p in enumerate(predictions):
78-
data_type = "Normal" if i < 10 else "Anomalous" # Check if it was a normal or anomalous point.
79-
status = "Normal" if p == 1 else "ANOMALY DETECTED" # Check the SVM's prediction.
80-
print(f"Data point {i+1} (True type: {data_type}) -> Prediction: {status}")
65+
print(f"\nDrift Analysis: The monitor missed {anomalies_missed} out of {total_anomalies} anomalous data points.")
66+
print(f"Calculated Drift Rate: {drift_rate:.2%}")
67+
mlflow.log_metric("stage_3_drift_rate", drift_rate) # Log the result to MLflow.
8168

69+
# The threshold to decide if retraining is needed.
70+
drift_threshold = 0.5 # This can be adjusted based on acceptable risk levels.
71+
72+
# Create a simple text file to signal the status to the CI/CD pipeline.
73+
with open("drift_status.txt", "w") as f:
74+
if drift_rate > drift_threshold: # Check if the miss rate is too high.
75+
print(f"[ALERT] Drift rate ({drift_rate:.2%}) exceeds threshold ({drift_threshold:.2%}). Signaling for retraining.")
76+
f.write("DRIFT_DETECTED") # Write the "emergency" signal.
77+
else:
78+
print(f"Drift rate ({drift_rate:.2%}) is within acceptable limits.")
79+
f.write("NO_DRIFT") # Write the "all clear" signal.
80+
8281
print("\n--- Production Monitoring Stage Complete ---")

0 commit comments

Comments
 (0)