Skip to content

Commit 4b16e10

Browse files
committed
init
0 parents  commit 4b16e10

17 files changed

Lines changed: 1350 additions & 0 deletions

.gitignore

Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
# project related
2+
dataset/
3+
models*/
4+
runs*/
5+
figs/
6+
.vscode/
7+
8+
*.DS_Store
9+
10+
# Byte-compiled / optimized / DLL files
11+
__pycache__/
12+
*.py[cod]
13+
*$py.class
14+
15+
# C extensions
16+
*.so
17+
18+
# Distribution / packaging
19+
.Python
20+
build/
21+
develop-eggs/
22+
dist/
23+
downloads/
24+
eggs/
25+
.eggs/
26+
lib/
27+
lib64/
28+
parts/
29+
sdist/
30+
var/
31+
wheels/
32+
*.egg-info/
33+
.installed.cfg
34+
*.egg
35+
MANIFEST
36+
37+
# PyInstaller
38+
# Usually these files are written by a python script from a template
39+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
40+
*.manifest
41+
*.spec
42+
43+
# Installer logs
44+
pip-log.txt
45+
pip-delete-this-directory.txt
46+
47+
# Unit test / coverage reports
48+
htmlcov/
49+
.tox/
50+
.coverage
51+
.coverage.*
52+
.cache
53+
nosetests.xml
54+
coverage.xml
55+
*.cover
56+
.hypothesis/
57+
.pytest_cache/
58+
59+
# Translations
60+
*.mo
61+
*.pot
62+
63+
# Django stuff:
64+
*.log
65+
local_settings.py
66+
db.sqlite3
67+
68+
# Flask stuff:
69+
instance/
70+
.webassets-cache
71+
72+
# Scrapy stuff:
73+
.scrapy
74+
75+
# Sphinx documentation
76+
docs/_build/
77+
78+
# PyBuilder
79+
target/
80+
81+
# Jupyter Notebook
82+
.ipynb_checkpoints
83+
84+
# pyenv
85+
.python-version
86+
87+
# celery beat schedule file
88+
celerybeat-schedule
89+
90+
# SageMath parsed files
91+
*.sage.py
92+
93+
# Environments
94+
.env
95+
.venv
96+
env/
97+
venv/
98+
ENV/
99+
env.bak/
100+
venv.bak/
101+
102+
# Spyder project settings
103+
.spyderproject
104+
.spyproject
105+
106+
# Rope project settings
107+
.ropeproject
108+
109+
# mkdocs documentation
110+
/site
111+
112+
# mypy
113+
.mypy_cache/

LICENSE

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2020 THU Media
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

README.md

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Continual Local Training for Better Initialization of Federated Models
2+
3+
The implementation of "Continual Local Training for Better Initialization of Federated Models" (ICIP 2020).
4+
[[Conference Version]](#)[[arXiv Version]](https://arxiv.org/abs/2005.12657)
5+
6+
## Introduction
7+
8+
Federated learning (FL) refers to the learning paradigm that trains machine learning models directly in the decentralized systems consisting of smart edge devices without transmitting the raw data, which avoids the heavy communication costs and privacy concerns.
9+
Given the typical heterogeneous data distributions in such situations, the popular FL algorithm *Federated Averaging* (FedAvg) suffers from weight divergence and thus cannot achieve a competitive performance for the global model (denoted as the *initial performance* in FL) compared to centralized methods.
10+
11+
In this paper, we propose the local continual training strategy to address this problem.
12+
Importance weights are evaluated on a small proxy dataset on the central server and then used to constrain the local training.
13+
With this additional term, we alleviate the weight divergence and continually integrate the knowledge on different local clients into the global model, which ensures a better generalization ability.
14+
Experiments on various FL settings demonstrate that our method significantly improves the initial performance of federated models with few extra communication costs.
15+
16+
<div align="center">
17+
<img src="./overview.png" width = "70%" height = "70%" alt="overview" />
18+
</div>
19+
20+
## Dependency
21+
22+
```
23+
python==3.7
24+
pytorch==1.4
25+
prefetch_generator
26+
tensorboardx
27+
```
28+
29+
## How To Run
30+
31+
1. Download the `dataset.tar.gz` in the [release page](https://github.com/thu-media/FedCL/releases/tag/v1.0), and unzip it to the root of the repository.
32+
33+
2. Then you can start with
34+
```shell
35+
python cifar_main.py
36+
```
37+
or
38+
```shell
39+
python mnist_main.py
40+
```
41+
The hyperparameters are defined in standalone file `config.py`.
42+
43+
## Code Structure
44+
45+
```
46+
-- mnist_main.py # the main train file for experiments on split MNIST
47+
-- cifar_main.py # the main train file for experiments on split CIFAR10
48+
-- config.py # the global config file
49+
-- model/
50+
|_ cifar_model.py # the model file for CIFAR10
51+
|_ mnist_model.py # the model file for MNIST
52+
-- data/
53+
|_ cifar_data.py # define data loader and allocator for CIFAR10
54+
|_ mnist_data.py # define data loader and allocator for MNIST
55+
-- core/
56+
|_ agent.py # core functions for FL clients, e.g., train/test/estimate importance weights
57+
|_ trainer.py # core functions for FL server, e.g., model aggregation/initialize clients
58+
|_ utils.py # define some utils
59+
```
60+
61+
## Cite
62+
63+
If you find this work useful to you, please cite [the conference version](#):
64+
65+
```
66+
To be published
67+
```
68+
or [the arXiv version](https://arxiv.org/abs/2005.12657):
69+
```
70+
@article{yao2020continual,
71+
title={Continual Local Training for Better Initialization of Federated Models},
72+
author={Yao, Xin and Sun, Lifeng},
73+
journal={arXiv preprint arXiv:2005.12657},
74+
year={2020}
75+
}
76+
```

cifar_main.py

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
import os
4+
import numpy as np
5+
import torch
6+
from torch import nn
7+
import torch.multiprocessing as mp
8+
9+
import config
10+
from core import Agent, Trainer, train_local_mp
11+
from model import CifarModel
12+
from data import CifarData
13+
14+
15+
class CIFARAgent(Agent):
16+
"""CIFARAgent for CIFAR10 and CIFAR100."""
17+
def __init__(self, global_args, subset=tuple(range(10)), fine='CIFAR10'):
18+
super().__init__(global_args, subset, fine)
19+
self.distr_type = global_args.distr_type
20+
if self.distr_type == 'uniform':
21+
self.distribution = np.array([0.1] * 10)
22+
elif self.distr_type == 'dirichlet':
23+
self.distribution = np.random.dirichlet([global_args.alpha] * 10)
24+
else:
25+
raise ValueError(f'Invalid distribution type: {self.distr_type}.')
26+
27+
def load_data(self, data_alloc, center=False):
28+
print("=> loading data")
29+
if center:
30+
self.train_loader, self.test_loader, self.num_train = \
31+
data_alloc.create_dataset_for_center(self.batch_size, self.num_workers)
32+
else:
33+
self.train_loader, self.test_loader, self.num_train = \
34+
data_alloc.create_dataset_for_client(self.distribution, self.batch_size,
35+
self.num_workers, self.subset)
36+
37+
def build_model(self):
38+
print("=> building model")
39+
if self.fine == 'CIFAR10':
40+
num_class = 10
41+
elif self.fine == 'CIFAR100':
42+
num_class = 100
43+
else:
44+
raise ValueError('Invalid dataset choice.')
45+
self.model = CifarModel(num_class).to(self.device)
46+
self.criterion = nn.CrossEntropyLoss().to(self.device)
47+
self.optimizer = torch.optim.SGD(self.model.parameters(), lr=self.lr,
48+
momentum=0.9, weight_decay=1e-4)
49+
50+
51+
class CIFARTrainer(Trainer):
52+
"""CIFAR Trainer."""
53+
def __init__(self, global_args):
54+
super().__init__(global_args)
55+
self.data_alloc = CifarData(self.num_locals, self.sample_rate, fine=self.fine)
56+
57+
# init the global model
58+
self.global_agent = CIFARAgent(global_args, fine=self.fine)
59+
self.global_agent.load_data(self.data_alloc, center=True)
60+
self.global_agent.build_model()
61+
self.global_agent.resume_model(self.resume)
62+
63+
def build_local_models(self, global_args):
64+
self.nets_pool = list()
65+
for _ in range(self.num_locals):
66+
self.nets_pool.append(CIFARAgent(global_args, fine=self.fine))
67+
self.init_local_models()
68+
69+
def train(self):
70+
for rnd in range(self.rounds):
71+
np.random.shuffle(self.nets_pool)
72+
pool = mp.Pool(self.num_per_rnd)
73+
self.q = mp.Manager().Queue()
74+
dict_new = self.global_agent.model.state_dict()
75+
if self.estimate_weights_in_center and rnd % self.interval == 0:
76+
w_d = self.global_agent.estimate_weights(self.policy)
77+
else:
78+
w_d = None
79+
for net in self.nets_pool[:self.num_per_rnd]:
80+
net.model.load_state_dict(dict_new)
81+
net.set_lr(self.global_agent.lr)
82+
pool.apply_async(train_local_mp, (net, self.local_epochs, rnd, self.q, self.policy, w_d))
83+
pool.close()
84+
pool.join()
85+
self.update_global(rnd)
86+
87+
def main():
88+
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
89+
torch.manual_seed(args.seed)
90+
torch.cuda.manual_seed(args.seed)
91+
np.random.seed(args.seed)
92+
mp.set_start_method('forkserver')
93+
94+
cifar_trainer = CIFARTrainer(args)
95+
96+
# test
97+
if args.mode == 'test':
98+
cifar_trainer.test()
99+
return
100+
101+
cifar_trainer.build_local_models(args)
102+
cifar_trainer.train()
103+
104+
if __name__ == '__main__':
105+
args = config.get_args()
106+
args.fine = 'CIFAR10'
107+
main()

config.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/usr/bin/env python3
2+
# -*- coding: utf-8 -*-
3+
#pylint: disable=C0301,C0326
4+
import argparse
5+
6+
def get_args():
7+
parser = argparse.ArgumentParser()
8+
9+
parser.add_argument('--model_file', type=str, default='model.pth.tar', help='File to save model.')
10+
parser.add_argument('--model_dir', type=str, default='models', help='Directory for storing checkpoint file.')
11+
parser.add_argument('--resume', type=str, default='', metavar='PATH', help='path to resume checkpoint (default: none)')
12+
parser.add_argument('--mode', type=str, default='train', choices=('train', 'test'), help='train or test.')
13+
parser.add_argument('--log_dir', type=str, default='runs_attn', help='Directory for logging.')
14+
parser.add_argument('--gpu', type=str, default='0', help='Number of gpu to use')
15+
parser.add_argument('--seed', type=int, default=1234, help='Random seed')
16+
17+
# hyper parameter for local data and training
18+
parser.add_argument('--distr_type', type=str, default='uniform', choices=('uniform', 'dirichlet'), help='Distribution to construct local data.')
19+
parser.add_argument('--alpha', type=float, default=1., help='alpha for dirichlet distribution. Must > 0 if dirichlet distribution is chosen.')
20+
parser.add_argument('--lr', type=float, default=5e-3, help='learning rate.')
21+
parser.add_argument('--min_lr', type=float, default=1e-4, help='minimum learning rate.')
22+
parser.add_argument('--decay_rate', type=float, default=0.99, help='lr decay rate.')
23+
parser.add_argument('--batch_size', type=int, default=64, help='Batch size. (B)')
24+
parser.add_argument('--local_epochs', type=int, default=2, help='Number of epoch in local. (E)')
25+
parser.add_argument('--num_workers', type=int, default=0, help='number of workers to preprocess data, must be 0 for mp agents.')
26+
27+
# hyper parameters for central server
28+
parser.add_argument('--num_locals', type=int, default=10, help='number of local agents.')
29+
parser.add_argument('--num_per_rnd', type=int, default=2, help='number of local agents to train per round.')
30+
parser.add_argument('--rounds', type=int, default=500, help='number of communication rounds.')
31+
parser.add_argument('--sample_rate', type=float, default=-1., help='sample rate of central data.')
32+
parser.add_argument('--policy', type=str, default='avg', choices=('avg', 'ewc', 'mas'), help='Policy for estimating parameter importance.')
33+
parser.add_argument('--estimate_weights_in_center', action='store_true', help='Estimate parameter importance in central server.')
34+
35+
# hyper parameters for ewc train
36+
parser.add_argument('--coe', type=float, default=0.5, help='The coefficient for local additional constraint.')
37+
parser.add_argument('--interval', type=float, default=1, help='The interval for weight estimation.')
38+
39+
args = parser.parse_args()
40+
return args

core/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .agent import Agent
2+
from .trainer import Trainer, train_local_mp, test_local_mp

0 commit comments

Comments
 (0)