Kao TW et al. Demographic Differences in Tissue Architecture Drive Disparities in Pathology AI Performance.
This repository implements UNVEIL, a framework for identifying, quantifying, and mitigating demographic-associated signals in pathology foundation model representations. By integrating demographic classification, nuclear morphometric analysis, and demographic signal-aware agentic scheduling, UNVEIL addresses performance disparities in computational pathology tasks.
Pathology foundation models encode demographic-linked morphological variations during training, which can contribute to performance disparities across population groups. UNVEIL provides a systematic approach to:
- Detect demographic signals in learned representations through demographic attribute classification
- Quantify the contribution of demographic-associated features to downstream task performance disparities
- Mitigate disparities using demographic signal-aware agentic scheduling that adaptively modulates patch contributions
WSI Patches → Foundation Model Encoding → Attention-Based MIL Aggregation
↓
┌─────────────────────┴──────────────────────┐
↓ ↓
Demographic Classifier Downstream Task Model
(Age, Race, Sex prediction) (Mutation prediction, etc.)
| ↓
| Fairness Evaluation
| ↓
└────────────────→ Demographic Signal-Aware Agentic Scheduling
(Adaptive patch contribution modulation)
The framework operates in three stages:
-
Demographic Signal Detection: Train classifiers to predict demographic attributes (age, race, sex) from WSI representations, quantifying the extent of demographic information encoded in foundation model features.
-
Fairness Assessment: Train downstream task models (mutation prediction as exemplar) and evaluate performance disparities across demographic groups, linking disparities to demographic-predictive signals through attention mapping and nuclear morphometric analysis.
-
Disparity Mitigation: Apply demographic signal-aware agentic scheduling that uses multi-factor routing to selectively modulate patches most strongly associated with demographic signals, reducing performance gaps while maintaining overall predictive accuracy.
.
├── demographic_classifier/ # Demographic attribute prediction
│ ├── configs/ # Configuration examples
│ ├── run.py # Main entry point
│ ├── run.sh # Execution wrapper
│ ├── experiment_runner.py # Training orchestration
│ ├── dataset.py # Feature dataset loaders
│ └── model.py # MLP classifier architectures
│
├── mutation_prediction/ # Downstream task pipeline (mutation prediction exemplar)
│ ├── configs/ # Dataset and attribute mapping configurations
│ ├── main_genetic.py # Main training script
│ ├── dataset.py # Dataset loading and preprocessing
│ ├── network.py # Attention-based MIL architectures
│ ├── demographic_agent.py # Base demographic-aware filtering agent
│ ├── unified_demographic_agent.py # Multi-factor routing implementation
│ ├── run_baseline.sh # Standard training without demographic awareness
│ └── run_agentic.sh # Training with demographic signal-aware scheduling
│
├── example_data/ # Mock data demonstrating required formats
│ ├── demographic_classifier/ # Demographic classification data
│ ├── mutation_prediction/ # Exemplar downstream task data
│ └── README.md # Detailed data format specifications
│
├── requirements.txt # Python dependencies for all components
└── LICENSE # GNU AFFERO GENERAL PUBLIC LICENSE v3.0
Both components require:
-
WSI Feature Embeddings (
.ptfiles): Pre-extracted patch-level features from foundation models- Format:
{'features': torch.Tensor}with shape(N_patches, feature_dim) - Supported dimensions: CHIEF (768), UNI (1024), GIGAPATH (1536), VIRCHOW2 (2560)
- Format:
-
Metadata Files:
- Demographic Classifier: JSON files mapping slide IDs to demographic categories
- Mutation Prediction: TCGA Pan-Cancer Atlas structure with clinical metadata and mutation status
Mock example data is provided in example_data/ to demonstrate required formats:
- 5 mock WSI feature files with CHIEF embeddings (768-dim)
- Clinical metadata and mutation status files for BRCA
- Demographic label files (age, race, sex)
- Example configuration file for demographic classifier
See example_data/README.md for detailed specifications and instructions for generating mock data.
- Python 3.8+
- CUDA-capable GPU (recommended)
- SLURM cluster environment (for batch jobs)
# Clone repository
git clone <repository_url>
cd HiddenFeature_Bao_new_github
# Create environment
conda create -n unveil python=3.8
conda activate unveil
# Install dependencies
pip install -r requirements.txt
# Create output directories
mkdir -p output/{mutation_models,mutation_models_agentic,demographic_classifier}
mkdir -p logs/{train_TCGA_mutation,demographic_agentic}Train classifiers to detect demographic signals in foundation model representations:
cd demographic_classifier
./run.sh configs/example_train.jsonKey configuration parameters:
train_targets_file_path_list: Paths to demographic label JSON filesfeatures_dir_path_list: Directories containing WSI feature.ptfilesmodel_init_args.input_dim: Feature dimension matching foundation modelmodel_init_args.output_dim: Number of demographic categories
Output: Trained classifiers saved in specified save_dir with performance metrics.
Train mutation prediction models (or other pathology tasks) and evaluate performance across demographic groups:
cd mutation_prediction
sbatch run_baseline.shThis runs standard attention-based MIL training across all 31 cancer types without demographic-aware filtering.
Key parameters (edit in run_baseline.sh):
FOUNDATION_MODEL: Feature extractor (CHIEF, UNI, GIGAPATH, VIRCHOW2)SENSITIVE: Demographic attribute to track (age, race, sex)cancer: Cancer type (lowercase abbreviation)
Output: Models and performance metrics saved in ./output/mutation_models/{ATTRIBUTE}/{DATA_SOURCE}/{FOUNDATION_MODEL}/{SLIDE_TYPE}/
Mitigate performance disparities using adaptive patch filtering:
cd mutation_prediction
sbatch run_agentic.shThe demographic classifier from Stage I provides per-patch attention scores indicating demographic signal strength. During mutation model training, the agent dynamically selects between signal-leveraging filtering (filters patches with high demographic signals) and random filtering (demographic signal unreliable) based on demographic prediction correctness, model reliability, group imbalance, and training progress. All decisions use only training data statistics, ensuring no data leakage.
Output: Models saved in ./output/mutation_models_agentic/{ATTRIBUTE}/TCGA/{FOUNDATION_MODEL}/FS/agent_demographic_agentic/ with agent decision logs.
output/
├── mutation_models/ # Baseline models
│ └── {ATTRIBUTE}/{DATA_SOURCE}/{FOUNDATION_MODEL}/{SLIDE_TYPE}/
│ └── {CANCER}_{GENE}/
│ ├── checkpoint_best.pt # Best model weights
│ ├── results.csv # Performance metrics
│ └── inference_results_fold{N}.csv # Per-fold predictions
│
├── mutation_models_agentic/ # Agentic scheduling models
│ └── {ATTRIBUTE}/TCGA/{FOUNDATION_MODEL}/FS/agent_demographic_agentic/
│ └── {CANCER}_{GENE}/
│ ├── checkpoint_best.pt
│ ├── results.csv
│ ├── inference_results_fold{N}.csv
│ └── agent_logs/ # Routing decision logs
│
└── demographic_classifier/
└── {save_dir}/
├── configs.json # Configuration
├── best_model.pt # Trained classifier
└── metrics/ # Performance metrics
This project is licensed under the GNU Affero General Public License Version 3.0. See the LICENSE file for details.