Skip to content

Latest commit

 

History

History
377 lines (299 loc) · 13.1 KB

File metadata and controls

377 lines (299 loc) · 13.1 KB

AirGapLite: RL Pipeline for PII Sharing Decisions

A comprehensive reinforcement learning pipeline for intelligent PII (Personally Identifiable Information) sharing decisions using policy gradient methods. The system learns domain-specific patterns to minimize PII exposure while maintaining utility.

Project Poster

View Project Poster PDF

Overview

This project implements a complete RL-based system for PII minimization that:

  • Learns domain-specific patterns (restaurant vs. bank) through reinforcement learning
  • Integrates multiple components: Context classifier, RL policy, and PII extraction
  • Compares with baseline: LLM-based minimizers for performance evaluation
  • Supports multiple RL algorithms: GRPO, GroupedPPO, and VanillaRL

Project Structure

final_project/
├── common/                    # Shared code (plug-and-play)
│   ├── config.py             # Configuration constants (PII types, groups, scenarios)
│   └── mdp.py                # MDP helpers (state building, rewards, actions)
│
├── algorithms/               # RL algorithms
│   ├── grpo/                 # GRPO: Group Relative Policy Optimization
│   │   ├── grpo_policy.py   # Policy network
│   │   └── grpo_train.py    # Training with KL regularization
│   ├── groupedppo/          # GroupedPPO: PPO with clipping
│   │   ├── grpo_policy.py   # Policy network
│   │   └── grpo_train.py    # Training with PPO clipping
│   └── vanillarl/           # VanillaRL: REINFORCE
│       ├── policy.py        # Policy network
│       └── train.py         # REINFORCE training
│
├── pipeline/                 # Unified pipelines
│   ├── algorithm_registry.py # Algorithm registry (modular algorithm addition)
│   ├── train.py             # Training pipeline (with convergence detection)
│   ├── test.py              # Testing pipeline (with directive system)
│   ├── integration_pipeline.py  # End-to-end integration (classifier → RL → extraction)
│   └── compare_baseline_vs_rl.py # Baseline vs RL comparison
│
├── MLP/                      # Context-aware domain classifier
│   ├── context_agent_classifier.py  # MLP classifier architecture
│   ├── train_context_agent.py      # Training script
│   ├── inference.py                 # Inference utilities
│   └── context_agent_mlp.pth       # Trained model
│
├── pii_extraction/          # PII extraction module
│   ├── pii_extractor.py     # Main extraction interface
│   ├── spacy_regex.py       # spaCy + regex patterns
│   ├── compute_pii_metrics.py  # Evaluation metrics
│   └── analyze_errors.py    # Error analysis
│
├── baseline/                 # Baseline LLM minimizers
│   ├── baseline_minimizer.py    # GPU-based baseline (CUDA)
│   ├── mlx_baseline_minimizer.py # MLX baseline (Apple Silicon)
│   ├── plot_baseline_results.py  # Visualization
│   └── output/              # Baseline results and plots
│
├── scripts/                  # Analysis and utility scripts
│   ├── compare_all.py       # Compare all RL algorithms
│   ├── analyze_dataset_probabilities.py  # Dataset analysis
│   ├── analyze_with_directives.py        # Directive analysis
│   ├── get_regex_by_directive.py         # Regex extraction
│   ├── endpoint.py          # API endpoint for regex
│   └── rebalance_bank_dataset.py  # Dataset rebalancing
│
├── Regex/                    # Regex patterns
│   └── PII_regex.py         # PII regex definitions
│
├── models/                   # Trained RL models
│   ├── grpo_model.pt
│   ├── groupedppo_model.pt
│   └── vanillarl_model.pt
│
├── results/                  # Evaluation results
│   ├── training_curves.png
│   ├── utility_privacy.png
│   ├── performance.png
│   └── comparison_table.csv
│
└── datasets/                 # Dataset files
    ├── 690-Project-Dataset-final.csv  # Main dataset (recommended)
    ├── 690-Project-Dataset-balanced.csv
    └── 690-Project-Dataset-bank-balanced.csv

Quick Start

1. Setup

# Create environment
conda create -n overthink python=3.10
conda activate overthink

# Install dependencies
cd final_project
pip install -r requirements.txt

# Install spaCy model
python -m spacy download en_core_web_sm

# (Optional) For Apple Silicon baseline
pip install mlx mlx-lm

2. Train RL Algorithm

cd final_project
python pipeline/train.py \
    --algorithm grpo \
    --dataset 690-Project-Dataset-final.csv \
    --num_iters 300 \
    --batch_size 64 \
    --output_dir models

3. Test Trained Model

python pipeline/test.py \
    --algorithm grpo \
    --model models/grpo_model.pt \
    --directive balanced \
    --get-regex

4. Compare All Algorithms

python scripts/compare_all.py \
    --algorithms grpo groupedppo vanillarl \
    --dataset 690-Project-Dataset-final.csv \
    --num_iters 300 \
    --batch_size 64 \
    --output_dir results

Available Algorithms

All algorithms use per-PII binary actions (0=don't share, 1=share) for each of the 11 PII types:

  • grpo: Group Relative Policy Optimization

    • Per-PII binary actions with group-based rewards
    • PPO-style updates with KL regularization
    • Best balance of performance and stability
  • groupedppo: Grouped PPO

    • Per-PII binary actions with group-based rewards
    • PPO with clipping mechanism
    • More stable than vanilla REINFORCE
  • vanillarl: Vanilla REINFORCE

    • Per-PII binary actions with group-based rewards
    • Simple REINFORCE policy gradient
    • Baseline for comparison

MDP Formulation

State Space

State: [present_mask (11), scenario_one_hot (2)] = 13 dimensions

  • present_mask: Binary vector indicating which PII types are present
    • PII types: NAME, PHONE, EMAIL, DATE/DOB, company, location, IP, SSN, CREDIT_CARD, age, sex
  • scenario_one_hot: Domain encoding
    • Restaurant: [1, 0]
    • Bank: [0, 1]
  • Important: The model NEVER sees allowed_mask in the state - it must learn domain-specific patterns from rewards

Action Space

Action: Binary vector of length 11

  • Each element: 0 (don't share) or 1 (share) for each PII type
  • Example: [0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0] means share PHONE and EMAIL only

Reward Function

Reward: Group-based reward computation

  • Formula: R = α·utility + β·privacy - complexity_penalty
  • Computed per PII group (identity, contact, financial, network, org, demographic)
  • Domain weights:
    • Restaurant: α=0.6, β=0.4 (more privacy-leaning)
    • Bank: α=0.7, β=0.3 (more utility-leaning)

Training Options

Convergence Detection (Recommended)

python pipeline/train.py \
    --algorithm grpo \
    --convergence_threshold 0.001 \  # Minimum improvement
    --patience 20 \                  # Iterations without improvement
    --max_iters 1000                 # Safety limit

Training stops when:

  • No improvement > threshold for patience evaluations, OR
  • Reaches max_iters iterations

Fixed Iterations

python pipeline/train.py \
    --algorithm grpo \
    --dataset 690-Project-Dataset-final.csv \
    --num_iters 300 \
    --batch_size 64 \
    --output_dir models

Testing & Evaluation

Directive System

Control utility-privacy tradeoff with --directive:

  • strictly: High threshold (≥0.7), lower utility, higher privacy
  • balanced: Default threshold (0.5), balanced tradeoff
  • accurately: Low threshold (≤0.3), higher utility, lower privacy

Testing Commands

# Basic evaluation
python pipeline/test.py --algorithm grpo --model models/grpo_model.pt

# With directive
python pipeline/test.py --algorithm grpo --model models/grpo_model.pt --directive strictly

# Extract learned regex patterns
python pipeline/test.py --algorithm grpo --model models/grpo_model.pt --get-regex

Integration Pipeline

End-to-end pipeline: Context Classifier → RL Policy → PII Extraction

from pipeline.integration_pipeline import minimize_data

result = minimize_data(
    third_party_prompt="I need to book a table for tonight",
    user_data="My name is John Smith, email is john@example.com, phone is 555-1234, SSN is 123-45-6789"
)

print(result['minimized_data'])  # Only EMAIL and PHONE (restaurant domain)
print(result['domain'])          # 'restaurant'
print(result['shared_pii'])       # ['EMAIL', 'PHONE']

Baseline Comparison

Compare RL-based approach with LLM baseline minimizers:

python pipeline/compare_baseline_vs_rl.py \
    --num-samples 10 \
    --domain restaurant

Metrics compared:

  • Utility: % of allowed PII correctly shared
  • Privacy: % of disallowed PII correctly NOT shared
  • Quickness: Inference time (seconds)

Outputs

Training Outputs

  • models/{algorithm}_model.pt: Trained model weights
  • models/{algorithm}_history.json: Training history (iterations, rewards)

Testing Outputs

  • evaluation_results.json: Detailed metrics (utility, privacy, domain-specific)

Comparison Outputs

  • results/training_curves.png: Learning progress across algorithms
  • results/utility_privacy.png: Utility-privacy tradeoff visualization
  • results/performance.png: Bar charts comparing algorithms
  • results/comparison_table.csv: Numerical comparison

Dataset

Recommended Dataset: 690-Project-Dataset-final.csv

  • Size: 15,805 rows
  • Purpose: Complete dataset with proper PII frequencies for learning domain patterns

Key Features:

  • EMAIL: 98.7% frequency → learned prob >0.99 (shared by all directives)
  • PHONE: 60.8% frequency → learned prob >0.99 (shared by all directives)
  • DATE/DOB: 56.7% frequency → learned prob >0.99 (shared by all directives)
  • SSN: 90.3% frequency → learned prob >0.98 (shared by all directives)
  • CREDIT_CARD: 90.3% frequency → learned prob >0.98 (shared by all directives)
  • 100% coverage: All rows with SSN/CREDIT_CARD in ground_truth also have them in allowed_bank

Expected Results (Bank Domain):

  • STRICTLY (≥0.7): Utility = 1.0, Privacy = 1.0 ✓ Perfect match
  • BALANCED (≥0.5): Utility = 1.0, Privacy = 1.0 ✓ Perfect match
  • ACCURATELY (≤0.3): Utility = 1.0, Privacy = 1.0 ✓ Perfect match

Expected Patterns:

  • Restaurant: EMAIL, PHONE (all directives)
  • Bank: EMAIL, PHONE, DATE/DOB, SSN, CREDIT_CARD (all directives)

Adding a New Algorithm

  1. Create algorithms/my_algorithm/ with:

    • policy.py: Policy network (inherits from nn.Module)
    • train.py: Training functions (rollout, update, evaluate)
    • __init__.py: Exports
  2. Register in pipeline/algorithm_registry.py:

    AlgorithmRegistry.register('my_algorithm', {
        'policy': MyPolicy,
        'config': {
            'load_dataset': load_dataset,
            'rollout': rollout_batch,
            'update': policy_gradient_update,
            'evaluate': evaluate_average_reward,
            ...
        }
    })
  3. Use: python pipeline/train.py --algorithm my_algorithm

Common Code

All algorithms share:

  • common/config.py: PII types, groups, scenarios, domain weights
  • common/mdp.py: State building, reward computation, action utilities

Update once, all algorithms benefit!

Documentation

  • HOW_TO_RUN.md: Complete guide on running all code, scripts, and workflows
  • ALGORITHM_EXPLANATION.md: Detailed explanation of MDP, algorithms, and training flow
  • FLOW_DIAGRAM.md: System architecture and flow diagrams
  • pii_extraction/README.md: PII extraction module documentation

Key Features

Exploration vs Exploitation

  • During Training: Actions are sampled stochastically from Bernoulli distributions (exploration)
  • During Testing: Actions are deterministic using threshold-based decisions (exploitation)
  • This is the standard approach for policy gradient methods (GRPO/PPO/REINFORCE)
  • No manual exploration decay needed - probabilities naturally become confident

Modular Design

  • Algorithm Registry: Easy to add new RL algorithms
  • Unified Interface: Same training/test commands for all algorithms
  • Component Integration: Context classifier, RL policy, and PII extraction work together seamlessly

Comprehensive Evaluation

  • Directive System: Control utility-privacy tradeoff
  • Domain-Specific Metrics: Separate evaluation for restaurant and bank domains
  • Baseline Comparison: Compare with LLM-based minimizers
  • Regex Extraction: Extract learned patterns as regex rules

Requirements

See requirements.txt for full dependency list. Key dependencies:

  • torch>=1.9.0: PyTorch for RL training
  • pandas>=1.3.0: Data handling
  • sentence-transformers>=2.2.0: Context classifier embeddings
  • spacy>=3.4.0: PII extraction
  • matplotlib>=3.4.0: Visualization

Optional:

  • mlx>=0.30.0, mlx-lm>=0.28.0: For Apple Silicon baseline
  • bitsandbytes>=0.41.0: For GPU baseline (CUDA)

License

This project is part of the CS 690F Security in AI course at UMass Amherst.