-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathassgin_ds.py
More file actions
70 lines (58 loc) · 2.01 KB
/
assgin_ds.py
File metadata and controls
70 lines (58 loc) · 2.01 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
Docstring for assgin_ds
划分联邦学习数据集
"""
from loguru import logger
from data.data_processing import CDDataset
def get_fed_dataset(args, ds_name: dict):
"""
Docstring for get_fed_dataset
:param args: 模型超参数
:param ds_name: 需要联邦学习的数据集信息
:type ds_name: dict
:return dict {"key" : torch.util.data.Dataset}
"""
ds_trian_dict = {name: None for name in ds_name.keys()}
ds_test_dict = {name: None for name in ds_name.keys()}
for name, info in ds_name.items():
ds_trian_dict[name] = CDDataset(
root_dir=info["path"],
split="train",
img_size=args.img_size,
label_transform="norm",
)
ds_test_dict[name] = CDDataset(
root_dir=info["path"],
split="test",
img_size=args.img_size,
is_train=False,
label_transform="norm",
)
logger.info(f"{name} train size: {len(ds_trian_dict[name])}")
logger.info(f"{name} test size: {len(ds_test_dict[name])}")
return ds_trian_dict, ds_test_dict
def get_fed_dataloaders_with_allocator(
train_datasets, test_datasets, ds_name: dict, args
):
"""
使用联邦数据分配器创建 DataLoader
Args:
train_datasets: 训练数据集字典
test_datasets: 测试数据集字典
ds_name: 数据集配置(包含分配比例和采样器配置)
args: 训练参数
Returns:
train_loaders: 训练数据加载器列表
test_loaders: 测试数据加载器列表
client_info: 客户端信息列表
"""
from data.fed_allocator import get_fed_dataloaders
train_loaders, test_loaders, client_info = get_fed_dataloaders(
train_datasets=train_datasets,
test_datasets=test_datasets,
ds_name=ds_name,
args=args,
)
logger.info(f"Total clients: {len(train_loaders)}")
logger.info(f"Client info: {client_info}")
return train_loaders, test_loaders, client_info