-
Notifications
You must be signed in to change notification settings - Fork 244
[Research] FedUMM: Federated Learning for Unified Multimodal Models #4158
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
rollingsu
wants to merge
2
commits into
NVIDIA:main
Choose a base branch
from
rollingsu:feat/fedumm
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,383 @@ | ||
| # FedUMM: Federated Learning for Unified Multimodal Models | ||
|
|
||
| > **Paper:** [FedUMM: A General Framework for Federated Learning with Unified Multimodal Models](https://arxiv.org/abs/2601.15390) | ||
| > | ||
| > Built on [NVIDIA FLARE](https://github.com/NVIDIA/NVFlare) · LoRA adapter-only aggregation · Dirichlet non-IID data partitioning | ||
|
|
||
| ## What Is This? | ||
|
|
||
| This is the code for **FedUMM** — a federated learning framework that lets multiple | ||
| clients collaboratively fine-tune vision-language models (VLMs) **without sharing | ||
| their private data**. Only lightweight LoRA adapter weights (~1-2 MB) are exchanged | ||
| each round, reducing communication by **99.7%** compared to full model fine-tuning. | ||
|
|
||
| ``` | ||
| Server (FedAvg) | ||
| aggregate LoRA deltas only | ||
| / | \ | ||
| Client 1 Client 2 Client K | ||
| LoRA on LoRA on LoRA on | ||
| VLM VLM VLM | ||
| (data (data (data | ||
| stays stays stays | ||
| local) local) local) | ||
| ``` | ||
|
|
||
| ## Supported Models | ||
|
|
||
| | Backend key | Model | LoRA target | Trainable params | | ||
| |---|---|---|---| | ||
| | `blip_vqa` | Salesforce/blip-vqa-base | text_encoder + text_decoder (q/k/v) | ~1.2 M | | ||
| | `januspro` | deepseek-ai/Janus-Pro-1B | language_model (q_proj/k_proj/v_proj) | ~2.4 M | | ||
|
|
||
| ## Project Structure | ||
|
|
||
| ``` | ||
| fedumm/ | ||
| ├── src/ | ||
| │ ├── __init__.py # Auto-registers backends | ||
| │ ├── model_registry.py # register / get backend | ||
| │ ├── common.py # Seeding, Dirichlet partition, train loop | ||
| │ ├── blip_backend.py # BLIP VQA backend | ||
| │ ├── januspro_backend.py # JanusPro backend | ||
| │ ├── fl_client.py # NVFlare FL client (all models) | ||
| │ └── local_train.py # Centralized baseline (no FL) | ||
| ├── envs/ | ||
| │ ├── env_blip.yml # Conda env for BLIP | ||
| │ └── env_januspro.yml # Conda env for JanusPro | ||
| ├── scripts/ | ||
| │ ├── setup_envs.sh # Create both conda envs | ||
| │ ├── launch_blip.sh # Wrapper: activate env -> run python | ||
| │ ├── launch_januspro.sh # Wrapper: activate env -> run python | ||
| │ └── slurm_run.sh # HPC batch job template | ||
| ├── job.py # FedJob config (FedAvg + ScriptRunner) | ||
| ├── requirements.txt # Minimal deps for single-env usage | ||
| └── README.md | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ## Quick Start (3 steps) | ||
|
|
||
| ### Step 1: Install | ||
|
|
||
| ```bash | ||
| # Option A: pip (simplest) | ||
| pip install nvflare torch transformers peft datasets Pillow numpy | ||
|
|
||
| # Option B: conda (recommended for HPC) | ||
| conda create -n fedumm python=3.10 -y | ||
| conda activate fedumm | ||
| pip install nvflare torch transformers peft datasets Pillow numpy | ||
| ``` | ||
|
|
||
| ### Step 2: Run Centralized Baseline | ||
|
|
||
| Verify everything works before starting federated training: | ||
|
|
||
| ```bash | ||
| python src/local_train.py \ | ||
| --model_backend blip_vqa \ | ||
| --batch_size 8 \ | ||
| --num_epochs 1 \ | ||
| --max_train_samples 500 \ | ||
| --max_eval_samples 100 | ||
| ``` | ||
|
|
||
| Expected output: | ||
| ``` | ||
| Train: 500, Eval: 100 | ||
| trainable: 1,179,648 / 252,562,178 (0.4672%) | ||
| Epoch 1/1 loss=2.3456 acc=0.3200 | ||
| ``` | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I got: trainable: 3,538,944 / 364,769,084 (0.9702%)
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks Ziyue, I will make a few change of the mentioned problems |
||
|
|
||
| ### Step 3: Run Federated Learning | ||
|
|
||
| ```bash | ||
| python job.py \ | ||
| --model_backend blip_vqa \ | ||
| --num_clients 2 \ | ||
| --num_rounds 3 \ | ||
| --dirichlet_alpha 0.5 \ | ||
| --max_train_samples 500 \ | ||
| --max_eval_samples 100 \ | ||
| --simulator | ||
| ``` | ||
|
|
||
| That's it! The simulator runs everything on one machine. | ||
|
|
||
| --- | ||
|
|
||
| ## Reproducing Paper Experiments | ||
|
|
||
| ### Table 1: VQA v2 with varying clients and heterogeneity | ||
|
|
||
| The paper tests K = {2, 4, 8, 16} clients with alpha = {0.1, 0.5, 1.0}. | ||
|
|
||
| ```bash | ||
| # alpha=0.1 (extreme non-IID), K=8 clients | ||
| python job.py --model_backend blip_vqa \ | ||
| --num_clients 8 --num_rounds 100 --local_epochs 5 \ | ||
| --dirichlet_alpha 0.1 \ | ||
| --batch_size 32 --lr 2e-5 \ | ||
| --simulator | ||
|
|
||
| # alpha=0.5 (moderate non-IID), K=8 clients | ||
| python job.py --model_backend blip_vqa \ | ||
| --num_clients 8 --num_rounds 100 --local_epochs 5 \ | ||
| --dirichlet_alpha 0.5 \ | ||
| --batch_size 32 --lr 2e-5 \ | ||
| --simulator | ||
|
|
||
| # alpha=1.0 (mild non-IID), K=16 clients | ||
| python job.py --model_backend blip_vqa \ | ||
| --num_clients 16 --num_rounds 100 --local_epochs 5 \ | ||
| --dirichlet_alpha 1.0 \ | ||
| --batch_size 32 --lr 2e-5 \ | ||
| --simulator | ||
| ``` | ||
|
|
||
| ### What `--dirichlet_alpha` does | ||
|
|
||
| Controls how "uneven" the data distribution is across clients: | ||
|
|
||
| ``` | ||
| alpha=0.1 (extreme) alpha=0.5 (moderate) alpha=1.0 (mild) | ||
|
|
||
| Client 0: ████████ yes/no Client 0: ████ yes/no Client 0: ███ yes/no | ||
| █ color ███ color ███ color | ||
| █ count ███ count ██ count | ||
|
|
||
| Client 1: █ yes/no Client 1: ███ yes/no Client 1: ███ yes/no | ||
| ████████ color ████ color ███ color | ||
| █ count ██ count ███ count | ||
|
|
||
| Each client specializes Clients have preferences Nearly uniform | ||
| in a few answer types but still see variety distribution | ||
| ``` | ||
|
|
||
| | alpha value | Meaning | When to use | | ||
| |---|---|---| | ||
| | `0` (default) | IID round-robin | Baseline / debugging | | ||
| | `0.1` | Extreme non-IID | Stress test, worst-case scenario | | ||
| | `0.5` | Moderate non-IID | Realistic hospital/enterprise setting | | ||
| | `1.0` | Mild non-IID | Near-IID with slight variation | | ||
|
|
||
| ### Paper hyperparameters | ||
|
|
||
| | Parameter | Paper value | CLI flag | | ||
| |---|---|---| | ||
| | LoRA rank | 16 | `--lora_r 16` | | ||
| | LoRA alpha | 32 | `--lora_alpha 32` | | ||
| | Learning rate | 2e-5 | `--lr 2e-5` | | ||
| | Batch size | 32 | `--batch_size 32` | | ||
| | Local epochs | 5 | `--local_epochs 5` | | ||
| | Communication rounds | 100 | `--num_rounds 100` | | ||
| | Clients | 2-16 | `--num_clients K` | | ||
| | Optimizer | AdamW (wd=0.05) | hardcoded | | ||
|
|
||
| --- | ||
|
|
||
| ## All CLI Options | ||
|
|
||
| ### `job.py` (Federated Learning) | ||
|
|
||
| ``` | ||
| python job.py [OPTIONS] | ||
|
|
||
| Model: | ||
| --model_backend blip_vqa | januspro (required) | ||
| --model_name_or_path HuggingFace model id (uses default if empty) | ||
|
|
||
| Federated Learning: | ||
| --num_clients Number of FL clients (default: 2) | ||
| --num_rounds Communication rounds (default: 3) | ||
| --local_epochs Local epochs per round (default: 1) | ||
| --dirichlet_alpha Non-IID level: 0=IID, 0.1/0.5/1.0 (default: 0) | ||
|
|
||
| Training: | ||
| --batch_size Per-client batch size (default: 8) | ||
| --grad_accum Gradient accumulation steps (default: 8) | ||
| --lr Learning rate (default: 5e-5) | ||
| --lora_r LoRA rank (default: 16) | ||
| --lora_alpha LoRA scaling factor (default: 32) | ||
|
|
||
| Data: | ||
| --max_train_samples Limit training data (-1 = all) (default: -1) | ||
| --max_eval_samples Limit eval data (-1 = all) (default: -1) | ||
| --data_path HuggingFace cache directory (default: "") | ||
| --seed Random seed (default: 42) | ||
|
|
||
| Execution: | ||
| --simulator Run with NVFlare simulator (single machine) | ||
| --export_dir Export job config for production deployment | ||
| --use_env_script Use conda env isolation via bash wrappers | ||
| ``` | ||
|
|
||
| ### `src/local_train.py` (Centralized Baseline) | ||
|
|
||
| ``` | ||
| python src/local_train.py [OPTIONS] | ||
|
|
||
| --model_backend blip_vqa | januspro | ||
| --num_epochs Training epochs (default: 1) | ||
| --output_dir Save model here (default: ./workspace_centralized) | ||
| (... same training/data options as job.py) | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ## HPC / Slurm Usage | ||
|
|
||
| ### Interactive testing | ||
|
|
||
| ```bash | ||
| srun --gres=gpu:1 --mem=32G --time=02:00:00 --pty bash | ||
| conda activate fedumm | ||
| python src/local_train.py --model_backend blip_vqa \ | ||
| --max_train_samples 200 --max_eval_samples 50 \ | ||
| --data_path /scratch/$USER/hf_cache | ||
| ``` | ||
|
|
||
| ### Batch job | ||
|
|
||
| ```bash | ||
| mkdir -p logs | ||
| sbatch scripts/slurm_run.sh blip_vqa | ||
| ``` | ||
|
|
||
| Edit `scripts/slurm_run.sh` to match your cluster (partition name, GPU type, | ||
| module names). | ||
|
|
||
| ### Pre-downloading data (for compute nodes without internet) | ||
|
|
||
| ```bash | ||
| # On login node (has internet): | ||
| python -c " | ||
| from datasets import load_dataset | ||
| ds = load_dataset('HuggingFaceM4/VQAv2', split='train', | ||
| cache_dir='/scratch/\$USER/hf_cache') | ||
| print(f'Downloaded {len(ds)} samples') | ||
| " | ||
|
|
||
| # Then on compute node: | ||
| python job.py --data_path /scratch/$USER/hf_cache ... | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ## Environment Isolation (for JanusPro) | ||
|
|
||
| JanusPro requires the `janus` package from GitHub which may conflict with | ||
| other dependencies. Use conda env isolation: | ||
|
|
||
| ```bash | ||
| # 1. Create isolated environments | ||
| bash scripts/setup_envs.sh | ||
|
|
||
| # 2. Run with env isolation | ||
| python job.py --model_backend januspro --use_env_script --simulator | ||
| ``` | ||
|
|
||
| How it works: NVFlare's `ScriptRunner` calls a bash wrapper that activates | ||
| the correct conda env before running the Python training script. | ||
|
|
||
| ``` | ||
| job.py --use_env_script | ||
| -> ScriptRunner(command="bash scripts/launch_januspro.sh") | ||
| -> source conda.sh && conda activate nvflare_januspro | ||
| -> exec python -u fl_client.py "$@" | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ## Adding a New Model | ||
|
|
||
| 1. **Create backend** `src/my_model_backend.py`: | ||
|
|
||
| ```python | ||
| from .model_registry import register_backend | ||
|
|
||
| class MyModelBackend: | ||
| name = "my_model" | ||
|
|
||
| def build_model_and_processor(self, model_name_or_path, lora_r, | ||
| lora_alpha, lora_dropout, device): | ||
| # Load model, apply LoRA, return (model, processor) | ||
| ... | ||
|
|
||
| def build_dataset(self, hf_ds, processor, max_q_len, max_a_len): | ||
| # Wrap HF dataset into torch Dataset | ||
| ... | ||
|
|
||
| def collate_fn(self, batch): | ||
| # Custom batching | ||
| ... | ||
|
|
||
| def train_step(self, model, batch, device): | ||
| # Forward pass, return scalar loss (no .backward()) | ||
| ... | ||
|
|
||
| def evaluate(self, model, dataloader, processor, device): | ||
| # Return metric (e.g. VQA accuracy float) | ||
| ... | ||
|
|
||
| def hf_dataset_name(self): return "HuggingFaceM4/VQAv2" | ||
| def hf_train_split(self): return "train" | ||
| def hf_eval_split(self): return "validation[:50%]" | ||
| def keep_columns(self): return ["image", "question", | ||
| "multiple_choice_answer", "answers"] | ||
|
|
||
| register_backend("my_model", MyModelBackend()) | ||
| ``` | ||
|
|
||
| 2. **Register** in `src/__init__.py`: | ||
| ```python | ||
| try: | ||
| from . import my_model_backend # noqa: F401 | ||
| except ImportError: | ||
| pass | ||
| ``` | ||
|
|
||
| 3. **Use it**: | ||
| ```bash | ||
| python job.py --model_backend my_model --simulator | ||
| ``` | ||
|
|
||
| --- | ||
|
|
||
| ## Troubleshooting | ||
|
|
||
| | Problem | Solution | | ||
| |---|---| | ||
| | `Unknown backend 'blip_vqa'. Available: ` | Files not in `src/` package. Check `src/__init__.py` exists. | | ||
| | `EnvironmentFileNotFound` in `setup_envs.sh` | Run from project root: `bash scripts/setup_envs.sh` | | ||
| | CUDA out of memory | Reduce `--batch_size` to 2-4, increase `--grad_accum` | | ||
| | `trust_remote_code` blocked (JanusPro) | `export HF_HUB_TRUST_REMOTE_CODE=1` | | ||
| | No internet on compute node | Pre-download data, then use `--data_path` | | ||
| | Very slow first run | HuggingFace is downloading VQAv2 (~50 GB). Use `--max_train_samples 500` for testing. | | ||
|
|
||
| --- | ||
|
|
||
| ## Citation | ||
|
|
||
| ```bibtex | ||
| @article{su2026fedumm, | ||
| title={FedUMM: A General Framework for Federated Learning | ||
| with Unified Multimodal Models}, | ||
| author={Su, Zhaolong and Zhao, Leheng and Wu, Xiaoying | ||
| and Xu, Ziyue and Wang, Jindong}, | ||
| journal={arXiv preprint arXiv:2601.15390}, | ||
| year={2026} | ||
| } | ||
| ``` | ||
|
|
||
| ## License | ||
|
|
||
| This example is licensed under the Apache License 2.0. | ||
| Model weights are subject to their respective licenses | ||
| (Salesforce for BLIP, DeepSeek Model License for JanusPro). | ||
|
|
||
| ## Acknowledgments | ||
|
|
||
| This work was partially supported by the NVIDIA Academic Grant Program. | ||
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
running it locally, I got: