Skip to content

minhtcai/MLMF

Repository files navigation

3FM: Multi-modal Meta-learning for Federated Tasks

Paper Python PyTorch

Minh Tran¹, Roochi Shah², Zejun Gong³
¹Robotics Institute, Carnegie Mellon University
²Department of Statistics and Data Science, Carnegie Mellon University
³Department of Electrical and Computer Engineering, Carnegie Mellon University

📋 Overview

We present 3FM (Three-modal Federated Meta-learning), a novel meta-learning framework specifically designed for multimodal federated learning tasks. Our approach addresses the critical challenges in federated environments:

  • Modality Heterogeneity: Different clients may have access to different types of data modalities
  • Variable Modality Availability: Not all clients have the same set of modalities available
  • Missing Data: Common occurrence of incomplete multimodal data across federated clients

🎯 Key Contributions

  1. Novel Meta-Learning Framework: Integration of Model-Agnostic Meta-Learning (MAML) with Federated Averaging (FedAvg) for multimodal scenarios
  2. Robust Adaptation: Model can effectively adapt when exposed to new modalities, even with limited training data
  3. Missing Modality Handling: Framework designed to handle various missing modality scenarios across clients
  4. Comprehensive Evaluation: Extensive experiments on custom multimodal dataset with image, audio, and sign language modalities

🏗️ Architecture

Our approach combines:

  • MAML-based Client Training: Each client uses meta-learning to adapt to available modalities
  • FedAvg Global Aggregation: Server aggregates client updates using federated averaging
  • Support/Query Set Division: Data split for inner and outer loop optimization in MAML

Algorithm Overview

For each communication round:
1. Server distributes global model θ to sampled clients
2. Each client u performs local meta-learning:
   - Split data into support set D_S^u and query set D_Q^u
   - Inner loop: Update on support set (limited modalities)
   - Outer loop: Evaluate on query set (full modalities when available)
3. Server aggregates client gradients using FedAvg

📊 Dataset

We created a custom multimodal dataset by aligning three single-modality datasets:

Data Sources

  • MNIST: 60,000 handwritten digit images (0-9)
  • Free Spoken Digit Dataset: Audio recordings of spoken digits (0-9) represented as spectrograms
  • Sign Language Digits Dataset: 2,062 sign language gesture images for digits (0-9)

Final Dataset

  • Total Samples: 2,062 aligned multimodal samples
  • Labels: 10 classes (digits 0-9)
  • Modalities: Image, Audio (spectrogram), Sign language

🧪 Experimental Setup

Missing Modality Scenarios

We evaluated 6 different missing modality scenarios:

  1. img/sign: Image + Sign language (missing audio)
  2. spect/sign: Audio + Sign language (missing image)
  3. img/spect: Image + Audio (missing sign)
  4. img: Image only
  5. spect: Audio only
  6. sign: Sign language only

Hyperparameters

  • Communication Rounds: 50
  • Local Epochs: 5
  • Client Numbers: 3, 5, 10
  • Support/Query Split: 20%/80%
  • Learning Rates:
    • Outer LR: {0.001, 0.01}
    • Inner LR: {0.00001, 0.0001}

📈 Results

Best Performance Configuration

  • Client Number: 3
  • Outer Learning Rate: 0.001
  • Inner Learning Rate: 0.00001

Performance on Full Modality Testing

Training Scenario Test Accuracy
img/sign 86.407%
spect/sign 94.660%
img/spect 91.747%
img 94.174%
spect 69.417%
sign 92.718%

Key Findings

  • Significant improvement for scenarios with spectrogram/sign, spectrogram-only, and sign-only modalities
  • Smaller variance in performance across all missing modality scenarios compared to baseline
  • Best performance achieved with 3 clients and carefully tuned meta-learning rates

🚀 Getting Started

Prerequisites

pip install torch torchvision torchaudio
pip install numpy matplotlib
pip install librosa  # for audio processing

Usage

  • Check the Notebook

📁 Project Structure

MLMF/
├── models/              # Neural network architectures
├── data/               # Dataset processing and loading
├── federated/          # Federated learning implementations
├── meta_learning/      # MAML implementation
├── experiments/        # Experiment configurations
├── notebooks/          # Jupyter notebooks for analysis
└── results/           # Experimental results and plots

🔄 Limitations & Future Work

Current Limitations

  • Dataset Scope: Limited to specific multimodal MNIST variant
  • Computational Constraints: Experiments limited by available compute resources
  • Baseline Comparisons: Missing comparison with SMIL Bayesian meta-learning baseline

Future Directions

  • Scalability: Extend to more diverse and complex datasets
  • Advanced Missing Data: Explore extreme cases of modality missingness
  • Privacy Analysis: Compare privacy guarantees with other federated multimodal baselines
  • Meta-Learning Strategies: Investigate alternative meta-learning approaches

📚 Citation

If you find this work useful, please cite:

@article{tran2023_3fm,
  title={3FM: Multi-modal Meta-learning for Federated Tasks},
  author={Tran, Minh and Shah, Roochi and Gong, Zejun},
  journal={arXiv preprint arXiv:2312.10179},
  year={2023}
}

🙏 Acknowledgments

  • Federated Learning Base Code: Modified from Federated Learning PyTorch
  • Datasets: MNIST, Free Spoken Digit Dataset, Sign Language Digits Dataset
  • Carnegie Mellon University for computational resources and support

📄 License

This project is licensed under the MIT License - see the LICENSE file for details.


This work was conducted as part of research at Carnegie Mellon University. For questions or collaborations, please reach out to the authors.

About

Meta-Learning Initialization for Multimodal Federated Tasks

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Contributors 2

  •  
  •