diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..7ad6035 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +__pycache__ +wandb/ +*.log +*-best_model.pt +*-last_model.pt +coco_subset_idx_* +*.bbl +*.synctex.gz diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..4bf9e06 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,52 @@ +{ + "cSpell.words": [ + "adamp", + "allimages", + "batchidx", + "Bstdv", + "cifa", + "cifar", + "CIFAR", + "crossfold", + "crossfolds", + "cudnn", + "CVPR", + "dset", + "flickr", + "gpuid", + "Graphcore", + "idxs", + "imagenet", + "inferencing", + "interintra", + "Jingjing", + "keepdim", + "mmdata", + "MMFL", + "MSCOCO", + "multimodal", + "Multimodal", + "multistep", + "noniid", + "optim", + "PCME", + "pycocotools", + "Qiying", + "rsum", + "subconfig", + "svhn", + "testclasses", + "tqdm", + "trainclasses", + "trainloader", + "trainloaders", + "trainval", + "trainvalclasses", + "ujson", + "unsqueeze", + "valclasses", + "vocabs", + "wandb", + "Yimu" + ] +} \ No newline at end of file diff --git a/README.md b/README.md index 9c198b2..e7e6f4a 100644 --- a/README.md +++ b/README.md @@ -1,3 +1,83 @@ +# Networking for CreamFL + +## Tasks + +* get code base running locally. + * figure out how to run though the entire code quickly. + * quick test run: `python src/main.py --name quick --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --max_size 64 --pub_data_num 2 --feature_dim 2 --num_img_clients 2 --num_txt_clients 2 --num_mm_clients 3 --client_num_per_round 2 --local_epochs 2 --comm_rounds 2 --not_bert` + * `--contrast_local_inter --contrast_local_intra --interintra_weight 0.5` Cream options. + * `--max_size` added by xiegeo, 0 or 10000 for old behavior, client training data count, per client. + * `--pub_data_num` public training data size (default 50000), proportional to communication cost (memory for local simulation) cost. + * `--feature_dim` number of public features (default 256), proportional to communication cost. + * `--num_img_clients 2 --num_txt_clients 2 --num_mm_clients 3 --client_num_per_round 2` number of max client of each type, and total number of client per round. + * `--local_epochs 2 --comm_rounds 2` local and global rounds. + * `--not_bert` use a simpler model + +* get code to run in a network + * see the "How to run the network" section. + +## Goals + +* Learn: + * Transformers + * Transformer [Attention Is All You Need 2017 v7(2023)](https://arxiv.org/abs/1706.03762) + * [An Introduction to Transformers 2023 v5(2024)](https://arxiv.org/abs/2304.10557) + * Multimodal + * [DeViSE: A Deep Visual-Semantic Embedding Model 2013](https://research.google.com/pubs/archive/41473.pdf) + * PCME [Probabilistic Embeddings for Cross-Modal Retrieval 2021 v2](https://arxiv.org/abs/2101.05068) + * Federated Learning + * Federated Averaging [Communication-Efficient Learning of Deep Networks from Decentralized Data 2016 v4(2023)](https://arxiv.org/abs/1602.05629) + +* Implement networking + * try FedML? (to much rewrite for fedML to do it properly, otherwise too hacky.) + * try custom network? (do a quick demo version) + +## How to run the network + +### Configuration + +* flags: the same as local version. +* fed_config: setup server and client options. + +### Run + +A network requires n + 2 processes. Where n is the number of clients, +plus a command server over http, and a global round computation provider. + +#### Command server + +```bash +python src/federation/server.py --name test +``` + +#### Global round computation provider + +```bash +python src/federation/global.py --name test --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --max_size 64 --pub_data_num 2 --feature_dim 2 --not_bert +``` + +#### Clients + +Replace txt0 with the client to start. + +```bash +python src/federation/client.py --name test --client_name txt_0 --max_size 64 --pub_data_num 2 --feature_dim 2 --not_bert +``` + +### File sharing + +The network has to share the learned features. This could be through a file server, +a CDN, or shared network storage, ex. Directly accessing the same files is the +easies to implement and easily extends to shared network storage, so this is implemented +first for ease of local testing without lose of generality. + +## Prove of Concept + +see [report/poc.pdf](report/poc.pdf) + +------------------------ +Begin original readme + # Multimodal Federated Learning via Contrastive Representation Ensemble This repo contains a PyTorch implementation of the paper [Multimodal Federated Learning via Contrastive Representation Ensemble](https://arxiv.org/abs/2302.08888) (ICLR 2023). @@ -42,6 +122,24 @@ To reproduce CreamFL with BERT and ResNet101 as server models, run the following python src/main.py --name CreamFL --server_lr 1e-5 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 ``` +## Run CreamFL retrieval parallely +[1] Run global server +```shell +bash retri_center.sh +``` +[2] Run txt client +```shell +bash client_txt_retri.sh +``` +[3] Run img client +```shell +bash client_img_retri.sh +``` +[4] Run mm client +```shell +bash client_mm_retri.sh +``` + ## Citation If you find the paper provides some insights into multimodal FL or our code useful 🤗, please consider citing: diff --git a/client_img_retri.sh b/client_img_retri.sh new file mode 100644 index 0000000..3f9fa9d --- /dev/null +++ b/client_img_retri.sh @@ -0,0 +1,4 @@ +export HF_ENDPOINT=https://hf-mirror.com +export HF_DATASETS_CACHE="/shared/.cache/huggingface/datasets" + +nohup python src/retri_client_img.py --name retri_client_img --server_lr 1e-5 --seed 0 --feature_dim 256 --pub_data_num 50000 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --local_epochs 5 --client_num_per_round 1 --num_img_clients 1 --num_txt_clients 0 --num_mm_clients 0 > retri_client_img.log 2>&1 & \ No newline at end of file diff --git a/client_mm_retri.sh b/client_mm_retri.sh new file mode 100644 index 0000000..df45b5a --- /dev/null +++ b/client_mm_retri.sh @@ -0,0 +1,4 @@ +export HF_ENDPOINT=https://hf-mirror.com +export HF_DATASETS_CACHE="/shared/.cache/huggingface/datasets" + +nohup python src/retri_client_mm.py --name retri_client_mm --server_lr 1e-5 --seed 0 --feature_dim 256 --pub_data_num 50000 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --local_epochs 5 --client_num_per_round 1 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 1 > retri_client_mm.log 2>&1 & \ No newline at end of file diff --git a/client_txt_retri.sh b/client_txt_retri.sh new file mode 100644 index 0000000..1df0b0e --- /dev/null +++ b/client_txt_retri.sh @@ -0,0 +1,4 @@ +export HF_ENDPOINT=https://hf-mirror.com +export HF_DATASETS_CACHE="/shared/.cache/huggingface/datasets" + +nohup python src/retri_client_txt.py --name retri_client_txt --server_lr 1e-5 --seed 0 --feature_dim 256 --pub_data_num 50000 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --local_epochs 5 --client_num_per_round 1 --num_img_clients 0 --num_txt_clients 1 --num_mm_clients 0 > retri_client_txt.log 2>&1 & \ No newline at end of file diff --git a/coco_subset_idx_file b/coco_subset_idx_file deleted file mode 100644 index a3a5425..0000000 Binary files a/coco_subset_idx_file and /dev/null differ diff --git a/data_partition/client_AG_NEWS_noniid.pkl b/data_partition/client_AG_NEWS_10_nets_120000_samples_hetero_0.1.pkl similarity index 100% rename from data_partition/client_AG_NEWS_noniid.pkl rename to data_partition/client_AG_NEWS_10_nets_120000_samples_hetero_0.1.pkl diff --git a/data_partition/client_AG_NEWS_1_nets_120000_samples_hetero_0.1.pkl b/data_partition/client_AG_NEWS_1_nets_120000_samples_hetero_0.1.pkl new file mode 100644 index 0000000..f1a7134 Binary files /dev/null and b/data_partition/client_AG_NEWS_1_nets_120000_samples_hetero_0.1.pkl differ diff --git a/data_partition/client_AG_NEWS_2_nets_120000_samples_hetero_0.1.pkl b/data_partition/client_AG_NEWS_2_nets_120000_samples_hetero_0.1.pkl new file mode 100644 index 0000000..c03eb81 Binary files /dev/null and b/data_partition/client_AG_NEWS_2_nets_120000_samples_hetero_0.1.pkl differ diff --git a/data_partition/client_AG_NEWS_4_nets_120000_samples_hetero_0.1.pkl b/data_partition/client_AG_NEWS_4_nets_120000_samples_hetero_0.1.pkl new file mode 100644 index 0000000..48c85ad Binary files /dev/null and b/data_partition/client_AG_NEWS_4_nets_120000_samples_hetero_0.1.pkl differ diff --git a/data_partition/client_cifar100_10_nets_50000_samples_hetero_0.1.pkl b/data_partition/client_cifar100_10_nets_50000_samples_hetero_0.1.pkl new file mode 100644 index 0000000..a4fd2ca Binary files /dev/null and b/data_partition/client_cifar100_10_nets_50000_samples_hetero_0.1.pkl differ diff --git a/data_partition/client_cifar100_1_nets_50000_samples_hetero_0.1.pkl b/data_partition/client_cifar100_1_nets_50000_samples_hetero_0.1.pkl new file mode 100644 index 0000000..93a1a57 Binary files /dev/null and b/data_partition/client_cifar100_1_nets_50000_samples_hetero_0.1.pkl differ diff --git a/data_partition/client_cifar100_2_nets_50000_samples_hetero_0.1.pkl b/data_partition/client_cifar100_2_nets_50000_samples_hetero_0.1.pkl new file mode 100644 index 0000000..a206e8d Binary files /dev/null and b/data_partition/client_cifar100_2_nets_50000_samples_hetero_0.1.pkl differ diff --git a/data_partition/client_cifar100_noniid.pkl b/data_partition/client_cifar100_noniid.pkl deleted file mode 100644 index cae3d62..0000000 Binary files a/data_partition/client_cifar100_noniid.pkl and /dev/null differ diff --git a/fed_config.yaml b/fed_config.yaml new file mode 100644 index 0000000..ac6298e --- /dev/null +++ b/fed_config.yaml @@ -0,0 +1,45 @@ +wandb: + name: "cream_api" + +feature_store: "/tmp/cream_api" # the path to the feature store where client and global features are shared. + +# server configuration +server: + api_url: "http://localhost:2323/cream_api" + min_clients: 3 # the number of required clients reporting to start global distillation. + max_clients: 3 # the number of clients reached to start global distillation immediately. + wait_duration: 10m # the duration to wait for clients to report before starting global distillation. + + +# clients configuration +clients: + - name: "txt_0" + data_type: txt # img, txt, or mm: the type of data the client is handling. + local_epochs: 5 + data_partition: "client_AG_NEWS_10_nets_120000_samples_hetero_0.1.pkl" + data_partition_index: 0 # This is only for testing purposes. In a real-world scenario, the data will not be loaded from the same dataset + - name: "txt_1" + data_type: txt # img, txt, or mm: the type of data the client is handling. + local_epochs: 5 + data_partition: "client_AG_NEWS_10_nets_120000_samples_hetero_0.1.pkl" + data_partition_index: 1 # This is only for testing purposes. In a real-world scenario, the data will not be loaded from the same dataset + - name: "img_0" + data_type: img # img, txt, or mm: the type of data the client is handling. + local_epochs: 5 + data_partition_index: 0 # This is only for testing purposes. In a real-world scenario, the data will not be loaded from the same dataset + - name: "img_1" + data_type: img # img, txt, or mm: the type of data the client is handling. + local_epochs: 5 + data_partition_index: 1 # This is only for testing purposes. In a real-world scenario, the data will not be loaded from the same dataset + - name: "mm_0" + data_type: mm + local_epochs: 5 + data_partition_index: 0 + - name: "mm_1" + data_type: mm + local_epochs: 5 + data_partition_index: 1 + - name: "mm_2" + data_type: mm + local_epochs: 5 + data_partition_index: 2 \ No newline at end of file diff --git a/report/commands.sh b/report/commands.sh new file mode 100644 index 0000000..98e4625 --- /dev/null +++ b/report/commands.sh @@ -0,0 +1,123 @@ + +### server setup +conda activate creamfl +export HF_ENDPOINT=https://hf-mirror.com +export HF_DATASETS_CACHE="/shared/.cache/huggingface/datasets" + +cd /shared/project/xiegeo-dev/CreamFL + +git pull && python src/vqa_exp.py --seed 0 --vqa_hidden_sizes 1024 --vqa_unfreeze_base_epoch 25 --vqa_weight_decay 0.00001 --vqa_epochs 30 --batch_size 64 + + +#vqa_pretrained_eval: +#test scores {'test': { +# 'mean_log_image_sigma': 0.0, 'mean_log_caption_sigma': 0.0, 'n_fold': { +# 'i2t': { +# 'recall_1': 50.339999999999996, 'recall_5': 80.30000000000001, 'recall_10': 90.12, 'rsum': 220.76, 'medr': 1.4, 'meanr': 5.4586 +# }, 't2i': { +# 'recall_1': 38.620000000000005, 'recall_5': 75.02, 'recall_10': 87.124, 'rsum': 200.764, 'medr': 2.0, 'meanr': 7.05468 +# }}, 'i2t': { +# 'recall_1': 26.32, 'recall_5': 54.12, 'recall_10': 67.64, 'rsum': 148.07999999999998, 'medr': 5.0, 'meanr': 22.8744 +# }, 't2i': { +# 'recall_1': 18.348, 'recall_5': 44.296, 'recall_10': 58.544, 'rsum': 121.18799999999999, 'medr': 7.0, 'meanr': 30.7024 +# }, 'rsum': 269.268, 'medr': 12.0, 'meanr': 53.576800000000006 +#}} + +#test +git pull && python src/vqa.py --name test --server_lr 1e-5 --feature_dim 256 --pub_data_num 1000 --client_num_per_round 1 --num_img_clients 1 --num_txt_clients 1 --num_mm_clients 1 --local_epochs 1 --comm_rounds 5 --client_init_local_epochs 2 + + +git pull && python src/vqa.py --name vqa_0c_100k --server_lr 1e-5 --seed 0 --feature_dim 1024 --pub_data_num 100000 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 + +git pull && python src/main.py --name base_intra_full_clients --server_lr 1e-5 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --seed 0 --feature_dim 1024 +git pull && python src/main.py --name base_0_clients --server_lr 1e-5 --seed 0 --feature_dim 1024 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 +git pull && python src/main.py --name base_0c200k --server_lr 1e-5 --seed 0 --feature_dim 1024 --pub_data_num 200000 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 +git pull && python src/main.py --name base_0c800k --server_lr 1e-5 --pretrained_model base_0c400k_best_1024_net.pt --seed 0 --feature_dim 1024 --pub_data_num 800000 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 +git pull && python src/main.py --name full_200k --server_lr 1e-5 --pretrained_model base_0c800k_best_1024_net.pt --seed 0 --feature_dim 1024 --pub_data_num 200000 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 +git pull && python src/main.py --name full_200k --server_lr 1e-5 --pretrained_model full_200k_best_1024_net.pt --seed 0 --feature_dim 1024 --pub_data_num 200000 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 + + +git pull && python src/vqa.py --name vqa_0c1_pre400k --server_lr 1e-5 --pretrained_model vqa_0c1_pre400k_best_1024_vqa.pt --seed 0 --feature_dim 1024 --pub_data_num 1 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 +git pull && python src/vqa.py --name vqa_allc50k_nort --server_lr 1e-5 --pretrained_model vqa_0c1_pre400k_best_1024_vqa.pt --seed 0 --feature_dim 1024 --pub_data_num 50000 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --no_retrieval_training +git pull && python src/vqa.py --name vqa_allc50k_nortv3 --server_lr 1e-5 --pretrained_model vqa_allc50k_nortv2_best_1024_vqa.pt --seed 0 --feature_dim 1024 --pub_data_num 100000 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --no_retrieval_training --client_init_local_epochs 10 --client_num_per_round 5 --num_img_clients 2 --num_txt_clients 2 --num_mm_clients 3 +git pull && python src/vqa.py --name vqa_2c50k_nort --server_lr 1e-5 --pretrained_model vqa_allc50k_nortv3_best_1024_vqa.pt --seed 0 --feature_dim 1024 --pub_data_num 50000 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --no_retrieval_training --client_init_local_epochs 5 --client_num_per_round 2 --num_img_clients 2 --num_txt_clients 2 --num_mm_clients 3 +git pull && python src/vqa.py --name vqa_2cimg50k_nort --server_lr 1e-5 --pretrained_model vqa_2cimg50k_nort_best_1024_vqa.pt --seed 0 --feature_dim 1024 --pub_data_num 50000 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --no_retrieval_training --client_init_local_epochs 5 --client_num_per_round 3 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 3 --comm_rounds 100 +git pull && python src/vqa.py --name vqaall_allc50k --server_lr 1e-5 --pretrained_model vqa_allc50k_nortv2_best_1024_vqa.pt --seed 0 --feature_dim 1024 --vqa_data_size_per_epoch -1 --pub_data_num 50000 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --client_init_local_epochs 10 + +git pull && python src/vqa.py --name vqa100kd01_2c50k --server_lr 1e-5 --pretrained_model vqaall_allc50k_best_1024_vqa.pt --seed 0 --feature_dim 1024 --vqa_data_size_per_epoch 100000 --pub_data_num 50000 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --client_init_local_epochs 10 --local_epochs 10 --client_num_per_round 2 --comm_rounds 100 --vqa_dropout 0.1 +git pull && python src/vqa.py --name vqaall_d01_allc50k_nort --server_lr 1e-5 --pretrained_model vqa100kd01_2c50k_last_1024_vqa.pt --seed 0 --feature_dim 1024 --vqa_data_size_per_epoch -1 --pub_data_num 50000 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --client_init_local_epochs 20 --no_retrieval_training --comm_roucnds 100 --vqa_dropout 0.1 +git pull && python src/vqa.py --name vqaall_d05_0c_nort --server_lr 1e-5 --pretrained_model vqa100kd01_2c50k_last_1024_vqa.pt --seed 0 --feature_dim 1024 --vqa_data_size_per_epoch -1 --pub_data_num 50000 --no_retrieval_training --comm_rounds 100 --vqa_dropout 0.5 --pub_data_num 1 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 +git pull && python src/vqa.py --name vqa100k_0c50k --server_lr 1e-5 --pretrained_model vqa100kd01_2c50k_last_1024_vqa.pt --seed 0 --feature_dim 1024 --vqa_data_size_per_epoch 100000 --pub_data_num 50000 --comm_rounds 100 --vqa_dropout 0 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 + + + +git pull && python src/vqa.py --name fd1k_fte0_nort_0c --pretrained_model 0c1f_best_1024_vqa.pt --server_lr 1e-5 --seed 0 --feature_dim 1024 --pub_data_num 1 --vqa_full_training_epoch 0 --no_retrieval_training --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 + +git pull && python src/vqa.py --name vqa_0c50k_nort_d05 --server_lr 1e-5 --pretrained_model vqa_0c1_pre400k_best_1024_vqa.pt --seed 0 --feature_dim 1024 --pub_data_num 50000 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 --no_retrieval_training --vqa_dropout 0.5 --comm_rounds 100 +git pull && python src/vqa.py --name vqa_0c50k --server_lr 1e-5 --pretrained_model vqa_0c1_pre400k_best_1024_vqa.pt --seed 0 --feature_dim 1024 --pub_data_num 50000 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 --comm_rounds 100 +git pull && python src/vqa.py --name vqa100k_2c10t50k --server_lr 1e-5 --pretrained_model vqa_0c50k_best_1024_vqa.pt --seed 0 --feature_dim 1024 --vqa_data_size_per_epoch 100000 --pub_data_num 50000 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --client_num_per_round 2 --num_img_clients 0 --num_txt_clients 10 --num_mm_clients 0 --comm_rounds 100 --client_init_local_epochs 10 --local_epochs 10 + +git pull && python src/vqa.py --name vqa_cw1k_0c50k --vqa_cat_weight count+1000 --server_lr 1e-5 --pretrained_model vqa_0c50k_best_1024_vqa.pt --seed 0 --feature_dim 1024 --pub_data_num 50000 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 --comm_rounds 100 + +git pull && python src/vqa.py --name vqa_full50k_nort --server_lr 1e-5 --seed 0 --feature_dim 1024 --pub_data_num 50000 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --no_retrieval_training + + +git pull && python src/vqa.py --name f_all_0c_nort --server_lr 1e-5 --pretrained_model vqa100kd01_2c50k_last_1024_vqa.pt --seed 0 --feature_dim 1024 --vqa_filter_unknown --vqa_data_size_per_epoch -1 --pub_data_num 1 --no_retrieval_training --comm_rounds 10 --vqa_dropout 0 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 + + +git pull && python src/vqa.py --name f_100k_0c50k --server_lr 1e-5 --pretrained_model vqa100k_0c50k_best_1024_vqa.pt --seed 0 --feature_dim 1024 --vqa_filter_unknown --vqa_data_size_per_epoch 100000 --pub_data_num 50000 --comm_rounds 100 --vqa_dropout 0 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 +git pull && python src/vqa.py --name f_100k_0c_nort --server_lr 1e-5 --pretrained_model vqa100k_0c50k_best_1024_vqa.pt --seed 0 --feature_dim 1024 --vqa_filter_unknown --vqa_data_size_per_epoch 100000 --pub_data_num 1 --no_retrieval_training --comm_rounds 100 --vqa_dropout 0 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 +git pull && python src/vqa.py --name f_100k_0c_nort --server_lr 1e-5 --pretrained_model f_100k_0c_nort_best_1024_vqa.pt --seed 0 --feature_dim 1024 --vqa_filter_unknown --vqa_data_size_per_epoch 100000 --pub_data_num 1 --no_retrieval_training --comm_rounds 100 --vqa_dropout 0 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 + + +!git pull && python src/vqa.py --name f0_all_allc50k_nort --server_lr 1e-5 --seed 0 --feature_dim 1024 --vqa_data_size_per_epoch -1 --pub_data_num 50000 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --no_retrieval_training --comm_rounds 30 +git pull && python src/vqa.py --name f0_all_allc50k_nort_run2 --server_lr 1e-5 --pretrained_model f0_all_allc50k_nort_best_1024_vqa.pt --client_init_local_epochs 5 --seed 0 --feature_dim 1024 --vqa_filter_unknown --vqa_data_size_per_epoch -1 --pub_data_num 50000 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --no_retrieval_training --comm_rounds 30 + +git pull && python src/vqa.py --name f0_all_0c_nort --server_lr 1e-5 --seed 0 --feature_dim 1024 --vqa_filter_unknown --vqa_data_size_per_epoch -1 --pub_data_num 1 --no_retrieval_training --comm_rounds 30 --vqa_dropout 0 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 +git pull && python src/vqa.py --name f1_100k_0c_nort --server_lr 1e-5 --pretrained_model f0_all_0c_nort_best_1024_vqa.pt --seed 0 --feature_dim 1024 --vqa_filter_unknown --vqa_data_size_per_epoch 100000 --pub_data_num 1 --no_retrieval_training --comm_rounds 300 --vqa_dropout 0 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 + + +git pull && python src/vqa.py --name f0_100k_0c50k --server_lr 1e-5 --seed 0 --feature_dim 1024 --vqa_filter_unknown --vqa_data_size_per_epoch 100000 --pub_data_num 50000 --comm_rounds 300 --vqa_dropout 0 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 +git pull && python src/vqa.py --name f2_100k_0c_nort_from50k --server_lr 1e-5 --pretrained_model f0_100k_0c50k_best_1024_vqa.pt --seed 0 --feature_dim 1024 --vqa_filter_unknown --vqa_data_size_per_epoch 100000 --pub_data_num 1 --no_retrieval_training --comm_rounds 300 --vqa_dropout 0 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 + + +git pull && python src/vqa.py --name f3all_0c_nort_pre400k --server_lr 1e-5 --pretrained_model base_0c400k_best_1024_net.pt --seed 0 --feature_dim 1024 --pub_data_num 1 --vqa_filter_unknown --vqa_data_size_per_epoch -1 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 --vqa_full_training_epoch 10 --no_retrieval_training --comm_rounds 30 + +git pull && python src/vqa.py --name f3all_0c_nort_pre_raw --server_lr 1e-5 --seed 0 --feature_dim 1024 --pub_data_num 1 --vqa_filter_unknown --vqa_data_size_per_epoch -1 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 --vqa_full_training_epoch 10 --no_retrieval_training --comm_rounds 30 +git pull && python src/vqa.py --name f3all_allc_nort_pre_raw --server_lr 1e-5 --seed 0 --feature_dim 1024 --vqa_filter_unknown --vqa_data_size_per_epoch -1 --pub_data_num 50000 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --vqa_full_training_epoch 10 --no_retrieval_training --comm_rounds 30 + + +================================================================================================ +10x1: pretrain retrival for 10 rounds, 0 clients +10x2c: creamfl retrival for 10 rounds, all clients +10x3c: pretrain vqa fussion only, 10 rounds, 0 clients +10x4c: creamfl vqa, 10 rounds, all clients + +10x2b: global only retrival for 10 rounds, 0 clients +10x3b: pretrain vqa fussion only, 10 rounds, 0 clients +10x4b: global only vqa, 10 rounds, 0 clients + + +10x2c234: creamfl retrival for 10 rounds, 2/3/4 clients +10x3c234: pretrain vqa fussion only, 10 rounds, 2/3/4 clients +10x4c234: creamfl vqa, 10 rounds, 2/3/4 clients + +git pull +python src/main.py --name 10x1 --server_lr 1e-5 --seed 0 --feature_dim 1024 --pub_data_num 50000 --comm_rounds 10 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 +python src/main.py --name 10x2c --server_lr 1e-5 --pretrained_model 10x1_last_1024_net.pt --seed 0 --feature_dim 1024 --pub_data_num 50000 --comm_rounds 10 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --client_init_local_epochs 10 +python src/vqa.py --name 10x3c --server_lr 1e-5 --pretrained_model 10x2c_last_1024_net.pt --seed 0 --feature_dim 1024 --pub_data_num 1 --vqa_filter_unknown --vqa_data_size_per_epoch -1 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 --vqa_full_training_epoch 99 --no_retrieval_training --comm_rounds 10 +python src/vqa.py --name 10x4c --server_lr 1e-5 --pretrained_model 10x3c_last_1024_vqa.pt --seed 0 --feature_dim 1024 --pub_data_num 50000 --vqa_filter_unknown --vqa_data_size_per_epoch -1 --comm_rounds 10 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --no_retrieval_training --client_init_local_epochs 10 + +python src/main.py --name 10x2b --server_lr 1e-5 --pretrained_model 10x1_last_1024_net.pt --seed 0 --feature_dim 1024 --pub_data_num 50000 --comm_rounds 10 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 +python src/vqa.py --name 10x3b --server_lr 1e-5 --pretrained_model 10x2b_last_1024_net.pt --seed 0 --feature_dim 1024 --pub_data_num 1 --vqa_filter_unknown --vqa_data_size_per_epoch -1 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 --vqa_full_training_epoch 99 --no_retrieval_training --comm_rounds 10 +python src/vqa.py --name 10x4b --server_lr 1e-5 --pretrained_model 10x3b_last_1024_vqa.pt --seed 0 --feature_dim 1024 --pub_data_num 1 --vqa_filter_unknown --vqa_data_size_per_epoch -1 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 --comm_rounds 10 --no_retrieval_training + +python src/main.py --name 10x2c234 --server_lr 1e-5 --pretrained_model 10x1_last_1024_net.pt --seed 0 --feature_dim 1024 --pub_data_num 50000 --comm_rounds 10 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --client_init_local_epochs 10 --client_num_per_round 4 --num_img_clients 2 --num_txt_clients 2 --num_mm_clients 3 +python src/vqa.py --name 10x3c234 --server_lr 1e-5 --pretrained_model 10x2c234_last_1024_net.pt --seed 0 --feature_dim 1024 --pub_data_num 1 --vqa_filter_unknown --vqa_data_size_per_epoch -1 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 --vqa_full_training_epoch 99 --no_retrieval_training --comm_rounds 10 +python src/vqa.py --name 10x4c234 --server_lr 1e-5 --pretrained_model 10x3c234_last_1024_vqa.pt --seed 0 --feature_dim 1024 --pub_data_num 50000 --vqa_filter_unknown --vqa_data_size_per_epoch -1 --comm_rounds 10 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --no_retrieval_training --client_init_local_epochs 10 --client_num_per_round 4 --num_img_clients 2 --num_txt_clients 2 --num_mm_clients 3 + + +python src/main.py --name 30x1 --server_lr 1e-5 --seed 0 --feature_dim 1024 --pub_data_num 50000 --comm_rounds 30 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 +python src/main.py --name 30x2c237 --server_lr 1e-5 --pretrained_model 30x1_last_1024_net.pt --seed 0 --feature_dim 1024 --pub_data_num 50000 --comm_rounds 30 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --client_init_local_epochs 20 --client_num_per_round 7 --num_img_clients 2 --num_txt_clients 2 --num_mm_clients 3 +python src/vqa.py --name 30x3c237 --server_lr 1e-5 --pretrained_model 30x2c237_last_1024_net.pt --seed 0 --feature_dim 1024 --pub_data_num 1 --vqa_filter_unknown --vqa_data_size_per_epoch -1 --disable_distill --client_num_per_round 0 --num_img_clients 0 --num_txt_clients 0 --num_mm_clients 0 --vqa_full_training_epoch 99 --no_retrieval_training --comm_rounds 30 +python src/vqa.py --name 30x4c237 --server_lr 1e-5 --pretrained_model 30x3c237_last_1024_vqa.pt --seed 0 --feature_dim 1024 --pub_data_num 50000 --vqa_filter_unknown --vqa_data_size_per_epoch -1 --comm_rounds 30 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --no_retrieval_training --client_init_local_epochs 20 --client_num_per_round 7 --num_img_clients 2 --num_txt_clients 2 --num_mm_clients 3 \ No newline at end of file diff --git a/report/poc-r1x10.csv b/report/poc-r1x10.csv new file mode 100644 index 0000000..372fa50 --- /dev/null +++ b/report/poc-r1x10.csv @@ -0,0 +1,6 @@ +"Name","Runtime","ID","Server i2t_r1","Server n_fold_i2t_r1","Server n_fold_t2i_r1","Server rsum_r1","Server t2i_r1" +"poc-r1x10-global","21304","9eebn4qd","1.82","6.860000000000001","6.587999999999999","17.060000000000002","1.792" +"poc-r1x10-global","8255","goljw32i","1.76","7.720000000000001","6.587999999999999","17.848000000000003","1.78" +"poc-r1x10","8054","3u0urus7","1.76","6.660000000000001","6.256","16.408","1.732" +"poc-r1x10","7334","bkiu5ffi","2.12","8.120000000000001","7.803999999999999","20.084","2.04" +"poc-r1x10","7101","0g6ia4ds","1.44","6.58","6.452","16.192","1.72" \ No newline at end of file diff --git a/report/poc-r1x10.png b/report/poc-r1x10.png new file mode 100644 index 0000000..8206b07 Binary files /dev/null and b/report/poc-r1x10.png differ diff --git a/report/poc.bib b/report/poc.bib new file mode 100644 index 0000000..44c7893 --- /dev/null +++ b/report/poc.bib @@ -0,0 +1,6 @@ +@article{yu2023multimodal, + title={Multimodal Federated Learning via Contrastive Representation Ensemble}, + author={Yu, Qiying and Liu, Yang and Wang, Yimu and Xu, Ke and Liu, Jingjing}, + journal={arXiv preprint arXiv:2302.08888}, + year={2023} +} \ No newline at end of file diff --git a/report/poc.pdf b/report/poc.pdf new file mode 100644 index 0000000..e6c602a Binary files /dev/null and b/report/poc.pdf differ diff --git a/report/poc.tex b/report/poc.tex new file mode 100644 index 0000000..de7e8d5 --- /dev/null +++ b/report/poc.tex @@ -0,0 +1,104 @@ +\documentclass{article} + +\usepackage{graphicx} +\usepackage{rotating} % support sidewaystable +\usepackage{listings} +\lstset{ + breaklines=true, + basicstyle=\ttfamily, +} + +\title{Experimental Report: A Prove of Concept Implementation of Networking for CreamFL} +\author{Xie Yu Guang} +\date{\today} + +\begin{document} + +\maketitle + +\section{Introduction} +The purpose of this report is to present the results of modifying CreamFL\cite{yu2023multimodal} to run in an distributed faction. The report will provide details of replicating the results of running CreamFL using centralized and distributed execution. The goal is prove that the refactored code supports running CreamFL in a federated learning environment. + +\section{Experimental Setup} + +\subsection{Hardware} +Because of limited resources, the experiment was conducted on a single machine. The machine has 6 cores, 16GB of RAM and a single NVIDIA GeForce GTX 1050 Ti. As such, fulling replicating the results of the original paper is not possible. Instead, we will run smaller experiments that require a smaller communication/memory size and compute power. + +\subsection{Software modifications} +To allow fast development cycles, we first modified the code base to allow easy limiting client training data size by adding a "--max-size" flag. + +A bug was also fixed in the code base where training crashed when a type of client was configured but not selected for training in a round, this bug was more prominent with few clients. + +The random initialization process has been fixed to enable identical runs to utilize a predefined random seed set in the flag. An debugging line was left in the original code to always set the seed to 2021. Additionally, it can now operate with a random seed. If the seed is set to 0, the system defaults to a random seed, which is determined based on the current time. In this case, the \texttt{cudnn.deterministic} attribute is set to \texttt{False}, and \texttt{cudnn.benchmark} is set to \texttt{True} as a performance improvement. + +This setup of setting the seed to current time is used in all the experiments in this report. This is because the pass of execution is different in the centralized and distributed runs. Instead of comparing the results of two runs for equality, which is not expected to work. We can only comparing the results of two distributions for equality by running each setup multiple times. + +\subsection{Distributed Architecture} +The distributed architecture is split to three main components: the state server, the clients, and the global model computation provider. The server is responsible for sending the learned features to the clients and collecting the results. The clients are responsible for client side training and sending the results back to the server. The global model computation provider is responsible for computing the global model from the results of the clients, computing the learned features, and evaluating the global model + +We split the server into a state reporting server and a model computation provider. This is mainly because of the single threaded nature of python would otherwise make the server non-responsive when the model computation is running. + +The clients and global model only communicate with the server and each other over http and file io. The http server uses json and client wait for updates by polling the server in regular intervals. Features are shared between the clients and the global model computation provider using files and never uploaded to the server. However, the server distributes hashes of the features that is use to find the feature files. This design can be easily extended to use a CDN. + +\subsection{Client Training Data} +To keep the experiment as close as possible to the original code base. The client training data is generated using the same method as the original code base, including how the data is loaded on to the client. Both the centralized and distributed client use the same data splitting indexes, when testing under the same parameters. + +In a real world scenario, each client would have their own data. We skipped this step to simplify the experimental setup, at a slight cost to client start up time, where all data is loaded and only the data indexed are kept. We expect different real world scenarios to require different data preparation and loading methods. There for, build one just for this experiment is not worth the cost. + + +\section{Results and Analysis} +% Present the results obtained from the experiment and analyze them. Discuss the task convergence, model effect, and any other observations or insights. + +\begin{figure}[ht] + \centering + \includegraphics[width=0.8\textwidth]{poc-r1x10.png} + \caption{Train with 1 client over 10 communication rounds. "poc-r1x10" are centralized runs and "poc-r1x10-global" are distributed runs.} + \label{fig:r1x10} +\end{figure} + +In Figure \ref{fig:r1x10}, the parameters used were: +\begin{lstlisting} + --contrast_local_inter --contrast_local_intra + --interintra_weight 0.5 --max_size 50000 + --pub_data_num 4000 --feature_dim 64 + --num_img_clients 0 --num_txt_clients 1 + --num_mm_clients 0 --client_num_per_round 1 + --local_epochs 5 --comm_rounds 10 --not_bert + --seed 0 +\end{lstlisting} + +This was run 3 times centralized and 2 times distributed to compare the performance of the two. The graph shows that the model performance become highly varied in the first few communication rounds, but coverages to similar results at the end. The end results are shown in Table \ref{table:r1x10}. + +\begin{sidewaystable} + \centering + \begin{tabular}{|l|l|l|l|l|l|l|l|} + \hline + Name & Runtime & ID & i2t\_r1 & n\_fold\_i2t\_r1 & n\_fold\_t2i\_r1 & rsum\_r1 & t2i\_r1 \\ + \hline + poc-r1x10-global & 5.9h\footnote{computer went to sleep} & 9eebn4qd & 1.82 & 6.86 & 6.588 & 17.06 & 1.792 \\ + poc-r1x10-global & 2.3h & goljw32i & 1.76 & 7.72 & 6.588 & 17.848 & 1.78 \\ + poc-r1x10 & 2.2h & 3u0urus7 & 1.76 & 6.66 & 6.256 & 16.408 & 1.732 \\ + poc-r1x10 & 2.0h & bkiu5ffi & 2.12 & 8.12 & 7.804 & 20.084 & 2.04 \\ + poc-r1x10 & 2.0h & 0g6ia4ds & 1.44 & 6.58 & 6.452 & 16.192 & 1.72 \\ + \hline + \end{tabular} + \caption{Comparing the results of centralized vs distributed inference with 1 client over 10 communication rounds. The "poc-r1x10" are centralized runs and "poc-r1x10-global" are distributed runs. The runtime is the time taken to complete the experiment. The ID is the experiment id, this id can be appended to to view the experiment details. The rest of the columns are the results of the experiment.} + \label{table:r1x10} +\end{sidewaystable} + + +\section{Future Work} + +There are many possible improvements that can be made, both to the code base and the experimental setup. Capable of fulling replicating the results of the original paper is ideal. + +However, for the interest of usability and long time support, perhaps we should also turn to another direction ---\ wrapping CreamFL as a component of an established federated learning framework, such as FATE or FedML. This would save us from reinventing the wheel and allow existing users of these frameworks to easily integrate CreamFL as a component of their federated learning system. + +\section{Conclusion} +% Summarize the findings of the experiment and draw conclusions about the effectiveness of the distributed framework. + +The results of the experiment show that the distributed framework is capable of running CreamFL. However, due to the limited resources the results are not as conclusive as we would like. + +\bibliography{poc} +\bibliographystyle{plain} + +\end{document} \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index fa8f408..fedf847 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,17 +1,21 @@ -adamp==0.3.0 +transformers>=4.37.1 +portalocker>=2.0.0 +wheel>=0.42.0 +fire>=0.5.0 +adamp>=0.3.0 apex==0.1 -matplotlib==3.5.0 -munch==2.5.0 -nltk==3.7 -opencv-python==4.5.4.58 -pandas==1.3.4 -pycocotools==2.0.4 -scipy==1.7.2 -seaborn==0.11.2 -sentencepiece==0.1.96 -sklearn==0.0 -torch==1.10.0+cu113 -torchtext==0.11.0 -torchvision==0.11.1+cu113 -tqdm==4.62.3 -wandb==0.13.1 +matplotlib>=3.5.0 +munch>=2.5.0 +nltk>=3.7 +opencv-python>=4.5.4.58 +pandas>=1.3.4 +pycocotools>=2.0.4 +scipy>=1.7.2 +seaborn>=0.11.2 +sentencepiece>=0.1.96 +scikit-learn>=0.0 +torch>=1.10.0+cu113 +torchtext>=0.11.0 +torchvision>=0.11.1+cu113 +tqdm>=4.62.3 +wandb>=0.13.1 diff --git a/retri_center.sh b/retri_center.sh new file mode 100644 index 0000000..1793d84 --- /dev/null +++ b/retri_center.sh @@ -0,0 +1,4 @@ +export HF_ENDPOINT=https://hf-mirror.com +export HF_DATASETS_CACHE="/shared/.cache/huggingface/datasets" + +nohup python src/retrivel.py --name retri_center --server_lr 1e-5 --seed 0 --feature_dim 256 --pub_data_num 50000 --agg_method con_w --contrast_local_inter --contrast_local_intra --interintra_weight 0.5 --local_epochs 5 --client_num_per_round 3 --num_img_clients 1 --num_txt_clients 1 --num_mm_clients 1 > retri_center.log 2>&1 & \ No newline at end of file diff --git a/src/accuracy.txt b/src/accuracy.txt new file mode 100644 index 0000000..21a86c8 --- /dev/null +++ b/src/accuracy.txt @@ -0,0 +1,20 @@ +20:25.000,100.000 +40:25.000,100.000 +60:25.000,100.000 +80:25.000,100.000 +100:24.579,100.000 +120:25.000,100.000 +140:25.000,100.000 +160:25.000,100.000 +180:25.408,100.000 +200:25.000,100.000 +220:25.224,100.000 +240:25.000,100.000 +260:25.000,100.000 +280:25.197,100.000 +300:25.013,100.000 +320:25.026,100.000 +340:25.605,100.000 +360:25.000,100.000 +380:25.145,100.000 +400:25.013,100.000 diff --git a/src/algorithms/ClientTrainer.py b/src/algorithms/ClientTrainer.py index 34e6d8a..b510cb9 100644 --- a/src/algorithms/ClientTrainer.py +++ b/src/algorithms/ClientTrainer.py @@ -10,8 +10,8 @@ from sklearn.metrics import pairwise_distances from src import losses -from src.datasets.cifar import Cifar -from src.datasets.dataset_L import caption_collate_fn, Language +from src.custom_datasets.cifar import Cifar +from src.custom_datasets.dataset_L import caption_collate_fn, Language from src.networks.language_model import EncoderText from src.networks.resnet_client import resnet18_client from src.utils.Reader import ImageReader @@ -33,6 +33,7 @@ def seed_torch(seed=2021): + print(f'ClientTrainer.seed_torch called seed={seed}') random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) @@ -136,8 +137,8 @@ def accuracy(output, target, topk=(1,)): class ClientTrainer: def __init__(self, args, dataset, dst, RGBmean, RGBstdv, data_dict, logger, global_test_set, inter_distance=4, loss='softmax', gpuid='cuda:0', num_epochs=30, init_lr=0.0001, decay=0.1, batch_size=512, - imgsize=256, num_workers=4, print_freq=10, save_step=10, scale=128, pool_type='max_avg', client_id=-1, wandb=None): - seed_torch() + imgsize=256, num_workers=12, print_freq=10, save_step=10, scale=128, pool_type='max_avg', client_id=-1, wandb=None): + # seed_torch() self.args = args if dataset == 'Flickr30k': init_lr = 0.0002 @@ -197,6 +198,11 @@ def run(self, global_img_feature, global_txt_feature, distill_index, global_trai self.old_model.cuda() self.lr_scheduler(self.cur_epoch) + + if self.local_epoch == 0: + for i in range(self.args.client_init_local_epochs): + self.local_epoch += 1 + self.tra(global_img_feature, global_txt_feature, distill_index, global_train_loader) for i in range(self.local_epochs): self.local_epoch += 1 @@ -264,7 +270,7 @@ def loadData(self): self.classSize = len(self.data_dict) assert False, 'Dataset Not Supported!' self.class_label = torch.Tensor(np.array(range(self.classSize))) - print('output size: {}'.format(self.classSize)) + print(f'ClientTrainer loadData dataset name: {self.dset_name} class size: {self.classSize}') return @@ -319,7 +325,7 @@ def printnreset(name): # Set model to training mode self.model.train() - for i, data in enumerate(self.train_loader): + for i, data in tqdm(enumerate(self.train_loader), total=len(self.train_loader)): self.optimizer.zero_grad() with torch.set_grad_enabled(True): center_labels_var = torch.autograd.Variable(self.class_label.to(torch.long)).to(self.gpuid) @@ -548,6 +554,9 @@ def printnreset(name): self.test_top1.update(prec1[0], inputs_bt.size(0)) self.test_top5.update(prec5[0], inputs_bt.size(0)) + current_path = os.path.dirname(os.path.dirname(__file__)) + with open(os.path.join(current_path, 'accuracy.txt'), 'a') as f: + f.write(f'{self.local_epoch}:{self.test_top1.avg:.3f},{self.test_top5.avg:.3f}\n') printnreset(self.dset_name) self.model.train() @@ -670,5 +679,5 @@ def to_half(self): opt_level='O2') def __getattr__(self, k): - if k.startwith("__"): + if k.startswith("__"): raise AttributeError diff --git a/src/algorithms/MMClientTrainer.py b/src/algorithms/MMClientTrainer.py index 00abdc0..5c6d570 100644 --- a/src/algorithms/MMClientTrainer.py +++ b/src/algorithms/MMClientTrainer.py @@ -15,7 +15,7 @@ import torch.nn as nn -from src.algorithms.base import EngineBase +from src.algorithms.base import EngineBase2 from tqdm import tqdm import torch @@ -28,6 +28,7 @@ def seed_torch(seed=2021): + print(f'MMClientTrainer.seed_torch called seed={seed}') random.seed(seed) os.environ['PYTHONHASHSEED'] = str(seed) np.random.seed(seed) @@ -86,8 +87,8 @@ def update(self, val, n=1): is_test = False -class MMClientTrainer(EngineBase): - +class MMClientTrainer(EngineBase2): + def run(self, global_img_feature, global_txt_feature, distill_index, global_train_loader, prefix=''): self.old_model = copy.deepcopy(self.model) self.old_model.eval().cuda() @@ -96,6 +97,11 @@ def run(self, global_img_feature, global_txt_feature, distill_index, global_trai self.model, self.optimizer = amp.initialize(self.model, self.optimizer, opt_level='O2') self.model.train() + + if self.local_epoch == 0: + for i in range(self.args.client_init_local_epochs): + self.local_epoch += 1 + self.train_epoch(global_img_feature, global_txt_feature, distill_index, global_train_loader, prefix='') for i in range(self.local_epochs): self.local_epoch += 1 diff --git a/src/algorithms/MMFL.py b/src/algorithms/MMFL.py index 4c26c8c..d64ed37 100644 --- a/src/algorithms/MMFL.py +++ b/src/algorithms/MMFL.py @@ -8,24 +8,28 @@ import torch import torch.nn as nn -from torch.distributions import Categorical +import datasets from tqdm import tqdm +import wandb sys.path.append("./") sys.path.append("../") sys.path.append("../../") sys.path.append("../../../") -from src.datasets.load_FL_datasets import get_FL_trainloader, get_dataloader +from src.custom_datasets.load_FL_datasets import get_FL_trainloader from src.algorithms.ClientTrainer import ClientTrainer from src.algorithms.MMClientTrainer import MMClientTrainer from src.utils.color_lib import RGBmean, RGBstdv from src.algorithms.eval_coco import COCOEvaluator -from src.algorithms.retrieval_trainer import TrainerEngine, rawTrainerEngine +from src.algorithms.retrieval_trainer import TrainerEngine +from src.algorithms.vqa_meta import VQAMetaData +from src.algorithms.vqa_trainer import VQAEngine, vqa_validation from src.utils.config import parse_config -from src.utils.load_datasets import prepare_coco_dataloaders +from src.utils.load_datasets import prepare_coco_dataloaders, vqa2_dataloader from src.utils.logger import PythonLogger +from src.utils.util import print_model_tree try: from apex import amp @@ -37,7 +41,7 @@ class MMFL(object): - def __init__(self, args, wandb=None): + def __init__(self, args, wandb:wandb): self.args = args self.wandb = wandb @@ -46,6 +50,7 @@ def __init__(self, args, wandb=None): self.txt_local_trainers = None self.mm_local_trainers = None self.engine = None + self.vqa_engine = None self.best_score = 0 self.cur_epoch = 0 @@ -54,8 +59,11 @@ def __init__(self, args, wandb=None): # coco global dataloaders self.dataloaders_global = None + self.vqa_dataloader = None + self.vqa_meta = None # universal test dataloader self.test_loader = None + self.vqa_test_loader = None self.config = None self.set_config() @@ -87,22 +95,50 @@ def set_config(self, img='cifa100', txt='AG_NEWS'): self.config.model.not_bert = False self.config.model.cnn_type = 'resnet101' - def load_dataset(self, args): + def load_dataset(self, args, is_vqa=False): dataset_root = os.environ['HOME'] + '/data/mmdata/MSCOCO/2014' - vocab_path = './src/datasets/vocabs/coco_vocab.pkl' - self.dataloaders_global, self.vocab = prepare_coco_dataloaders(self.config.dataloader, dataset_root, vocab_path) + vocab_path = './src/custom_datasets/vocabs/coco_vocab.pkl' + self.dataloaders_global, self.vocab = prepare_coco_dataloaders(self.config.dataloader, dataset_root, args.pub_data_num, args.max_size, vocab_path) self.engine = TrainerEngine() self.engine.set_logger(self.logger) + + if is_vqa: + self.vqa_engine = VQAEngine(args,self.engine, self.wandb) + self.config.vqa_dropout = self.args.vqa_dropout self.config.optimizer.learning_rate = self.args.server_lr self._dataloaders = self.dataloaders_global.copy() self.evaluator = COCOEvaluator(eval_method='matmul', - verbose=False, + verbose=True, eval_device='cuda', n_crossfolds=5) - self.engine.create(self.config, self.vocab.word2idx, self.evaluator, self.args.mlp_local) + if is_vqa: + vqa_dataset = datasets.load_dataset("HuggingFaceM4/VQAv2", split="train") + meta = VQAMetaData() + meta.build_or_load_categories_top() + self.vqa_meta = meta + self.vqa_dataloader = vqa2_dataloader(vqa_dataset, train=True, filter_unknown=args.vqa_filter_unknown, meta=meta) + test_dataset = datasets.load_dataset("HuggingFaceM4/VQAv2", split="validation") + self.vqa_test_loader = vqa2_dataloader(test_dataset, filter_unknown=args.vqa_filter_unknown, meta=meta) + self.vqa_engine.create(self.config, self.vocab.word2idx, self.evaluator, self.args.mlp_local, meta) + #print_model_tree(self.vqa_engine.fusion_model) + if args.pretrained_model.endswith('_vqa.pt'): + print(f"Loading pretrained model as VQAEngine {args.pretrained_model}") + checkpoint = torch.load(args.pretrained_model) + self.vqa_engine.fusion_model.load_state_dict(checkpoint['vqa']) + self.best_score = getattr(checkpoint, 'score', self.best_score) + else: + self.engine.create(self.config, self.vocab.word2idx, self.evaluator, self.args.mlp_local) + if args.pretrained_model.endswith('_net.pt'): + print(f"Loading pretrained model as TrainerEngine {args.pretrained_model}") + checkpoint = torch.load(args.pretrained_model) + self.engine.model.load_state_dict(checkpoint['net']) + if not is_vqa: + self.best_score = getattr(checkpoint, 'score', self.best_score) + + #print_model_tree(self.engine.model) self.train_eval_dataloader = self._dataloaders.pop( 'train_subset_eval' + f'_{self.args.pub_data_num}') if self._dataloaders is not None else None @@ -111,13 +147,21 @@ def load_dataset(self, args): torch.backends.cudnn.enabled = True if self.config.train.get('use_fp16'): self.engine.logger.log('Train with half precision') - self.engine.to_half() + if is_vqa: + self.vqa_engine.to_half() + else: + self.engine.to_half() + def create_model(self, args): self.logger.log('start creating model and partition datasets') self.device = torch.device("cuda:%d" % args.device) os.makedirs(os.environ['HOME'] + f'/data/yClient', exist_ok=True) + + alpha = args.alpha # was hard-coded to 0.1 + batch_size = args.batch_size # was hard-coded to 512 + max_size = args.max_size # introduced by xiegeo # Create Client Models self.img_local_trainers, self.txt_local_trainers, self.mm_local_trainers = [], [], [] @@ -125,7 +169,7 @@ def create_model(self, args): if args.num_img_clients > 0: dataset = 'cifar100' self.img_trainloaders, test_set = get_FL_trainloader(dataset, os.environ['HOME'] + "/data/cifar100", - args.num_img_clients, "hetero", 0.1, 512) + args.num_img_clients, "hetero", alpha, batch_size, max_size) dataset = 'Cifar100' dst = os.environ['HOME'] + f'/data/yClient/{dataset}' self.img_local_trainers = [] @@ -140,7 +184,7 @@ def create_model(self, args): if args.num_txt_clients > 0: dataset = 'AG_NEWS' self.txt_trainloaders, test_set = get_FL_trainloader(dataset, os.environ['HOME'] + "/data", - args.num_txt_clients, "hetero", 0.1, 512) + args.num_txt_clients, "hetero", alpha, batch_size, max_size) client_id = 1 dst = os.environ['HOME'] + f'/data/yClient/{dataset}-{client_id}' self.txt_local_trainers = [] @@ -166,11 +210,11 @@ def create_model(self, args): self.mm_local_trainers.append( MMClientTrainer(args, config, self.logger, client=client_id, dset_name="flicker30k", device='cuda', - vocab_path='./src/datasets/vocabs/coco_vocab.pkl', + vocab_path='./src/custom_datasets/vocabs/coco_vocab.pkl', mlp_local=self.args.mlp_local)) if is_test and client_id == 0: break - print(f"Samples Num: {[len(i.train_loader.dataset) for i in self.mm_local_trainers]}") + print(f"MM Clients Samples Num: {[len(i.train_loader.dataset) for i in self.mm_local_trainers]}") self.total_local_trainers = self.img_local_trainers + self.txt_local_trainers + self.mm_local_trainers @@ -181,21 +225,25 @@ def train(self, round_n): self.cur_epoch = round_n self.cur_trainers = self.total_local_trainers + + self.logger.log(f"Round {round_n + 1}!") - if not is_test: + if not is_test and not self.args.no_retrieval_training: # global training - self.logger.log(f"Round {round_n + 1}!") self.engine.train( tr_loader=self._dataloaders['train_subset' + f'_{self.args.pub_data_num}']) # global train - if len(self.total_local_trainers) != 0: - self.cur_trainers = random.sample(self.total_local_trainers, self.args.client_num_per_round) + if len(self.total_local_trainers) != 0: + self.cur_trainers = random.sample(self.total_local_trainers, self.args.client_num_per_round) # global representations - if self.args.agg_method == "con_w" or self.args.contrast_local_intra or self.args.contrast_local_inter: + if len(self.cur_trainers) == 0: + print("No clients to train, skipping global representations") + elif self.args.agg_method == "con_w" or self.args.contrast_local_intra or self.args.contrast_local_inter: img_feature, txt_feature = [], [] distill_index = [] for idx, (images, captions, captions_word, caption_lens, _, _, index) in tqdm( enumerate(self.dataloaders_global['train_subset_eval' + f'_{self.args.pub_data_num}']), + desc="Global Representations", total=len(self.dataloaders_global['train_subset_eval' + f'_{self.args.pub_data_num}'])): with torch.no_grad(): images = images.to(self.engine.device) # [bs, 3, 224, 224] @@ -219,6 +267,8 @@ def train(self, round_n): self.distill_index = distill_index del img_feature, txt_feature gc.collect() + else: + print("No agg_method or contrast, skipping global representations") # local training and generated representations img_vec, img_num = [], [] @@ -254,70 +304,113 @@ def get_lr(optimizer): for param_group in optimizer.param_groups: return param_group['lr'] + # assert round_n + 1 == self.cur_epoch, "inconstant round_n vs cur_epoch, added to check that code clean up does not change logic." + # record after each epoch training metadata = self.engine.metadata.copy() metadata['cur_epoch'] = round_n + 1 metadata['lr'] = get_lr(self.engine.optimizer) - + + score = 0 + test_scores = self.engine.evaluate({'test': self._dataloaders['test']}) self.engine.report_scores(step=round_n + 1, - scores=test_scores, - metadata=metadata, - prefix=self.engine.eval_prefix) + scores=test_scores, + metadata=metadata, + prefix=self.engine.eval_prefix) rsum = test_scores['test']['n_fold']['i2t']['recall_1'] + test_scores['test']['n_fold']['t2i']['recall_1'] + \ - test_scores['test']['i2t']['recall_1'] + test_scores['test']['t2i']['recall_1'] + test_scores['test']['i2t']['recall_1'] + test_scores['test']['t2i']['recall_1'] self.wandb.log({"Server rsum_r1": rsum}, step=self.cur_epoch) + self.wandb.log({"Server rsum": test_scores['test']['rsum']}, step=self.cur_epoch) self.wandb.log({"Server n_fold_i2t_r1": test_scores['test']['n_fold']['i2t']['recall_1']}, step=self.cur_epoch) self.wandb.log({"Server n_fold_t2i_r1": test_scores['test']['n_fold']['t2i']['recall_1']}, step=self.cur_epoch) self.wandb.log({"Server i2t_r1": test_scores['test']['i2t']['recall_1']}, step=self.cur_epoch) self.wandb.log({"Server t2i_r1": test_scores['test']['t2i']['recall_1']}, step=self.cur_epoch) - - if self.best_score < rsum: - best_score = rsum + score = rsum + + if self.vqa_engine is not None: + test_loader = None + if round_n == 0: + test_loader = self.vqa_test_loader # only test during training in the first round + self.vqa_engine.train_vqa(self.cur_epoch, self.vqa_dataloader, vqa2_test_dataloader=test_loader) + test_scores = vqa_validation(10000, self.vqa_engine.fusion_model, self.vqa_meta, self.vqa_test_loader) + #test_scores = vqa_validation(100000, self.vqa_engine.fusion_model, self.vqa_meta, self.vqa_test_loader) + self.wandb.log(test_scores, step=self.cur_epoch) + score = test_scores['accuracy'] + + def save_model(type_name, score=score): + prefix = f'{self.args.name}_{type_name}_{self.args.feature_dim}' + if self.vqa_engine is not None: + torch.save({'vqa': self.vqa_engine.fusion_model.state_dict(), + 'score':score}, f'{prefix}_vqa.pt') + else: + torch.save({'net': self.engine.model.state_dict(), + 'score':score}, f'{prefix}_net.pt') + + + if self.best_score < score: + best_score = score metadata['best_score'] = best_score metadata['best_epoch'] = round_n + 1 self.best_metadata, self.best_scores = metadata, test_scores - - torch.save({'net': self.engine.model.state_dict()}, self.args.name + '-best_model.pt') + save_model("best") if round_n == self.args.comm_rounds - 1: - torch.save({'net': self.engine.model.state_dict()}, self.args.name + '-last_model.pt') + save_model("last") self.engine.lr_scheduler.step() + if self.vqa_engine is not None: + self.vqa_engine.vqa_lr_scheduler.step() del img_vec, txt_vec gc.collect() def distill(self, round_n, img_vec, txt_vec, img_num, txt_num, distill_index): + + if len(img_vec) == 0 and len(txt_vec) == 0: + print("No img_vec and txt_vec to distill (no clients)") + return self.engine.model.train() if self.config.model.use_img_client or self.config.model.use_txt_client or self.config.model.use_mm_client: client_loss_cri = nn.MSELoss() - def aggregation(i_vec=img_vec, t_vec=txt_vec, i_num=img_num, t_num=txt_num): + def aggregation(i_vec=img_vec, t_vec=txt_vec): if self.args.agg_method == "con_w": - contrastive_w = [] - for vec in i_vec: # vec: [50000, n_feature], global_txt_feature: [50000, n_feature] - logits = torch.matmul(vec, self.global_txt_feature.T) # [50000, 50000] - exp_logits = torch.exp(logits) - log_prob = logits - torch.log(torch.sum(exp_logits, dim=1, keepdim=True)) - contrastive_w.append(torch.diagonal(log_prob).reshape(1, -1)) - contrastive_w = torch.softmax(torch.cat(contrastive_w, dim=0), dim=0) - for i in range(len(i_vec)): - i_vec[i] = (i_vec[i] * contrastive_w[i].reshape(-1, 1)).unsqueeze(0) - i_vec = torch.sum(torch.cat(i_vec, dim=0), dim=0) # aggregated image vectors - - contrastive_w = [] - for vec in t_vec: # vec: [50000, n_feature], global_txt_feature: [50000, n_feature] - logits = torch.matmul(vec, self.global_img_feature.T) # [50000, 50000] - exp_logits = torch.exp(logits) - log_prob = logits - torch.log(torch.sum(exp_logits, dim=1, keepdim=True)) - contrastive_w.append(torch.diagonal(log_prob).reshape(1, -1)) - contrastive_w = torch.softmax(torch.cat(contrastive_w, dim=0), dim=0) - for i in range(len(t_vec)): - t_vec[i] = (t_vec[i] * contrastive_w[i].reshape(-1, 1)).unsqueeze(0) - t_vec = torch.sum(torch.cat(t_vec, dim=0), dim=0) # aggregated text vectors + if not i_vec: + self.logger.log("distill.aggregation i_vec is empty") + else: + contrastive_w = [] + for vec in i_vec: # vec: [50000, n_feature], global_txt_feature: [50000, n_feature] + logits = torch.matmul(vec, self.global_txt_feature.T) # [50000, 50000] + exp_logits = torch.exp(logits) + log_prob = logits - torch.log(torch.sum(exp_logits, dim=1, keepdim=True)) + contrastive_w.append(torch.diagonal(log_prob).reshape(1, -1)) + if not contrastive_w: + self.logger.log("distill.aggregation No tensors were added to contrastive_w for images") + else: + contrastive_w = torch.softmax(torch.cat(contrastive_w, dim=0), dim=0) + for i in range(len(i_vec)): + i_vec[i] = (i_vec[i] * contrastive_w[i].reshape(-1, 1)).unsqueeze(0) + i_vec = torch.sum(torch.cat(i_vec, dim=0), dim=0) # aggregated image vectors + + if not t_vec: + self.logger.log("distill.aggregation t_vec is empty") + else: + contrastive_w = [] + for vec in t_vec: # vec: [50000, n_feature], global_txt_feature: [50000, n_feature] + logits = torch.matmul(vec, self.global_img_feature.T) # [50000, 50000] + exp_logits = torch.exp(logits) + log_prob = logits - torch.log(torch.sum(exp_logits, dim=1, keepdim=True)) + contrastive_w.append(torch.diagonal(log_prob).reshape(1, -1)) + if not contrastive_w: + self.logger.log("distill.aggregation No tensors were added to contrastive_w for texts") + else: + contrastive_w = torch.softmax(torch.cat(contrastive_w, dim=0), dim=0) + for i in range(len(t_vec)): + t_vec[i] = (t_vec[i] * contrastive_w[i].reshape(-1, 1)).unsqueeze(0) + t_vec = torch.sum(torch.cat(t_vec, dim=0), dim=0) # aggregated text vectors else: raise NotImplementedError @@ -347,17 +440,17 @@ def code_sim(output, target, config): return client_loss_cri(output, target.type_as(output)) - if self.args.num_img_clients > 0: + if self.args.num_img_clients > 0 and len(img_num)> 0: out_img = output['image_features'] d_idx = operator.itemgetter(*index)(distill_dict) # idx of the current batch target_img = self.img_vec[d_idx, :].type_as(out_img) loss += self.args.kd_weight * code_sim(out_img, target_img, self.config) - if self.args.num_txt_clients > 0: + if self.args.num_txt_clients > 0 and len(txt_num) > 0: out_txt = output['caption_features'] d_idx = operator.itemgetter(*index)(distill_dict) # idx of the current batch target_txt = self.txt_vec[d_idx, :].type_as(out_txt) loss += self.args.kd_weight * code_sim(out_txt, target_txt, self.config) - if self.args.num_mm_clients > 0: + if self.args.num_mm_clients > 0 and len(img_num) > 0 and len(txt_num) > 0: out_img = output['image_features'] d_idx = operator.itemgetter(*index)(distill_dict) # idx of the current batch target_img = self.img_vec[d_idx, :].type_as(out_img) diff --git a/src/algorithms/MMFL_paral.py b/src/algorithms/MMFL_paral.py new file mode 100644 index 0000000..7a5bfb8 --- /dev/null +++ b/src/algorithms/MMFL_paral.py @@ -0,0 +1,825 @@ +import gc +import random + +import operator +import os +from copy import deepcopy +import sys + +import torch +import torch.nn as nn +import datasets +from tqdm import tqdm +import wandb + +sys.path.append("./") +sys.path.append("../") +sys.path.append("../../") +sys.path.append("../../../") + +from src.custom_datasets.load_FL_datasets import get_FL_trainloader +from src.algorithms.ClientTrainer import ClientTrainer +from src.algorithms.MMClientTrainer import MMClientTrainer +from src.utils.color_lib import RGBmean, RGBstdv + +from src.algorithms.eval_coco import COCOEvaluator +from src.algorithms.retrieval_trainer import TrainerEngine +from src.algorithms.vqa_meta import VQAMetaData +from src.algorithms.vqa_trainer import VQAEngine, vqa_validation +from src.utils.config import parse_config +from src.utils.load_datasets import prepare_coco_dataloaders, vqa2_dataloader +from src.utils.logger import PythonLogger +from src.utils.util import print_model_tree +# from src.networks.zmq_client import Node + +import zmq +import time +import logging +import threading +import pickle + + +logging.basicConfig( + level=logging.DEBUG, # 确保捕获所有日志信息 + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) + +import zmq +import threading +import pickle + +class Logger: + def log(self, message): + print(message) + +class Node: + def __init__(self, node_id, router_port, peers, logger=Logger()): + self.context = zmq.Context() + + # ROUTER socket to receive messages from other nodes + self.router = self.context.socket(zmq.ROUTER) + self.router.bind(f"tcp://{node_id}:{router_port}") + + # DEALER socket to send messages to other nodes + self.dealer = self.context.socket(zmq.DEALER) + + # Connect to all peers' ROUTER sockets + for peer_id, peer_port in peers: + self.dealer.connect(f"tcp://{peer_id}:{peer_port}") + + self.node_id = node_id + self.peers = peers + self.logger = logger + self.running = True + + # Start a background thread to handle incoming messages + self.recv_thread = threading.Thread(target=self.recv_msg_loop, daemon=True) + self.recv_thread.start() + self.recv_dict = {} + + def send_msg(self, message): + """Send a message to a specific peer.""" + print("start send msg") + data = pickle.dumps({self.node_id: message}) + for _ in range(len(self.peers)): + self.dealer.send_multipart([data]) + self.logger.log(f"Node {self.node_id} sent message to peers") + + def recv_msg_loop(self): + """Loop to receive messages from other nodes.""" + while self.running: + try: + _, message = self.router.recv_multipart() + data = pickle.loads(message) + # self.logger.log() + self.logger.log(f"Node {self.node_id} received message: {data}") + # self.recv_dict[address.decode('utf-8')] = data + for key, val in data.items(): + self.recv_dict[key] = val + + except zmq.ZMQError as e: + self.logger.log(f"ZMQ Error: {e}") + break + + def get_from(self, from_node="", retries=1000, time_interal=1): + # tag = from_node + "_" + key + for _ in range(retries): + val = self.recv_dict.get(from_node, None) + + if val is not None: + del self.recv_dict[from_node] + self.logger.log(f"Get val from node {from_node}") + # logging.debug(f"Get val with tag {val}") + return val + else: + time.sleep(time_interal) + + logging.info("Max retry has exceed and result is none.") + return None + + def stop(self): + """Stop the node's operations.""" + self.running = False + self.recv_thread.join() + self.router.close() + self.dealer.close() + self.context.term() + + +try: + from apex import amp +except ImportError: + print('failed to import apex') + +# TODO: test +is_test = False + +class MMFL_Client(object): + def __init__(self, args, wandb:wandb, node_id, router_port, peers): + self.args = args + self.wandb = wandb + + self.device = None + self.img_local_trainers = None + self.txt_local_trainers = None + self.mm_local_trainers = None + self.engine = None + self.vqa_engine = None + self.best_score = 0 + self.cur_epoch = 0 + + # img & txt local dataloaders + self.img_train_loaders, self.txt_train_loaders = None, None + + # coco global dataloaders + self.dataloaders_global = None + self.vqa_dataloader = None + self.vqa_meta = None + # universal test dataloader + self.test_loader = None + self.vqa_test_loader = None + + self.config = None + self.set_config() + + self.logger = PythonLogger(output_file=self.config.train.output_file) + + self.img_vec, self.txt_vec = None, None + self.global_img_feature = None + self.global_txt_feature = None + self.distill_index = None + self.client_node = Node(node_id=node_id, router_port=router_port, peers=peers) + self.peers = peers + + def set_config(self, img='cifa100', txt='AG_NEWS'): + self.config = parse_config("./src/coco.yaml", strict_cast=False) + self.config.train.model_save_path = 'model_last_no_prob' + self.config.train.best_model_save_path = 'model_best_no_prob' + self.config.train.output_file = 'model_noprob' + self.config.model.img_client = img + self.config.model.txt_client = txt + self.config.train.model_save_path = self.config.train.model_save_path + '.pth' + self.config.train.best_model_save_path = self.config.train.best_model_save_path + '.pth' + self.config.train.output_file = self.config.train.output_file + '.log' + + self.config.model.embed_dim = self.args.feature_dim # set global model dim + + if self.args.not_bert: + self.config.model.not_bert = True + self.config.model.cnn_type = 'resnet50' + else: + self.config.model.not_bert = False + self.config.model.cnn_type = 'resnet101' + + def load_dataset(self, args, is_vqa=False): + dataset_root = os.environ['HOME'] + '/data/mmdata/MSCOCO/2014' + vocab_path = './src/custom_datasets/vocabs/coco_vocab.pkl' + self.dataloaders_global, self.vocab = prepare_coco_dataloaders(self.config.dataloader, dataset_root, args.pub_data_num, args.max_size, vocab_path) + + self.engine = TrainerEngine() + self.engine.set_logger(self.logger) + + if is_vqa: + self.vqa_engine = VQAEngine(args,self.engine, self.wandb) + self.config.vqa_dropout = self.args.vqa_dropout + + self.config.optimizer.learning_rate = self.args.server_lr + + self._dataloaders = self.dataloaders_global.copy() + self.evaluator = COCOEvaluator(eval_method='matmul', + verbose=True, + eval_device='cuda', + n_crossfolds=5) + if is_vqa: + vqa_dataset = datasets.load_dataset("HuggingFaceM4/VQAv2", split="train") + meta = VQAMetaData() + meta.build_or_load_categories_top() + self.vqa_meta = meta + self.vqa_dataloader = vqa2_dataloader(vqa_dataset, train=True, filter_unknown=args.vqa_filter_unknown, meta=meta) + test_dataset = datasets.load_dataset("HuggingFaceM4/VQAv2", split="validation") + self.vqa_test_loader = vqa2_dataloader(test_dataset, filter_unknown=args.vqa_filter_unknown, meta=meta) + self.vqa_engine.create(self.config, self.vocab.word2idx, self.evaluator, self.args.mlp_local, meta) + #print_model_tree(self.vqa_engine.fusion_model) + if args.pretrained_model.endswith('_vqa.pt'): + print(f"Loading pretrained model as VQAEngine {args.pretrained_model}") + checkpoint = torch.load(args.pretrained_model) + self.vqa_engine.fusion_model.load_state_dict(checkpoint['vqa']) + self.best_score = getattr(checkpoint, 'score', self.best_score) + else: + self.engine.create(self.config, self.vocab.word2idx, self.evaluator, self.args.mlp_local) + if args.pretrained_model.endswith('_net.pt'): + print(f"Loading pretrained model as TrainerEngine {args.pretrained_model}") + checkpoint = torch.load(args.pretrained_model) + self.engine.model.load_state_dict(checkpoint['net']) + if not is_vqa: + self.best_score = getattr(checkpoint, 'score', self.best_score) + + #print_model_tree(self.engine.model) + + self.train_eval_dataloader = self._dataloaders.pop( + 'train_subset_eval' + f'_{self.args.pub_data_num}') if self._dataloaders is not None else None + + self.engine.model_to_device() + torch.backends.cudnn.enabled = True + if self.config.train.get('use_fp16'): + self.engine.logger.log('Train with half precision') + if is_vqa: + self.vqa_engine.to_half() + else: + self.engine.to_half() + + + def create_model(self, args): + self.logger.log('start creating model and partition datasets') + self.device = torch.device("cuda:%d" % args.device) + self.global_round = args.comm_rounds + + os.makedirs(os.environ['HOME'] + f'/data/yClient', exist_ok=True) + + alpha = args.alpha # was hard-coded to 0.1 + batch_size = args.batch_size # was hard-coded to 512 + max_size = args.max_size # introduced by xiegeo + + # Create Client Models + self.img_local_trainers, self.txt_local_trainers, self.mm_local_trainers = [], [], [] + # img clients + if args.num_img_clients > 0: + dataset = 'cifar100' + self.img_trainloaders, test_set = get_FL_trainloader(dataset, os.environ['HOME'] + "/data/cifar100", + args.num_img_clients, "hetero", alpha, batch_size, max_size) + dataset = 'Cifar100' + dst = os.environ['HOME'] + f'/data/yClient/{dataset}' + self.img_local_trainers = [] + for i in range(args.num_img_clients): + self.img_local_trainers.append( + ClientTrainer(args, dataset, dst, RGBmean['Cifar100'], RGBstdv['Cifar100'], None, self.logger, + global_test_set=test_set, inter_distance=4, client_id=i, wandb=self.wandb)) + self.img_local_trainers[i].train_loader = self.img_trainloaders[i] + if is_test and i == 0: + break + self.cur_trainers = self.img_local_trainers + self.cur_type = "img" + + # txt clients + if args.num_txt_clients > 0: + dataset = 'AG_NEWS' + self.txt_trainloaders, test_set = get_FL_trainloader(dataset, os.environ['HOME'] + "/data", + args.num_txt_clients, "hetero", alpha, batch_size, max_size) + client_id = 1 + dst = os.environ['HOME'] + f'/data/yClient/{dataset}-{client_id}' + self.txt_local_trainers = [] + for i in range(args.num_txt_clients): + self.txt_local_trainers.append( + ClientTrainer(args, dataset, dst, RGBmean['Cifar100'], RGBstdv['Cifar100'], None, self.logger, + global_test_set=test_set, inter_distance=4, client_id=i, wandb=self.wandb)) + self.txt_local_trainers[i].train_loader = self.txt_trainloaders[i] + if is_test and i == 0: + break + + self.cur_trainers = self.txt_local_trainers + self.cur_type = "txt" + # mm clients + if args.num_mm_clients > 0: + # mm img models + config = parse_config("./src/f30k.yaml", strict_cast=False) + config.model.cache_dir = config.model.cache_dir + '-' + config.train.server_dataset + config.train.output_file = os.path.join(config.model.cache_dir, config.train.output_file) + config.train.best_model_save_path = os.path.join(config.model.cache_dir, config.train.best_model_save_path) + config.train.model_save_path = os.path.join(config.model.cache_dir, config.train.model_save_path) + config.model.embed_dim = self.args.feature_dim + config.model.not_bert = True + self.mm_local_trainers = [] + for client_id in range(args.num_mm_clients): + self.mm_local_trainers.append( + MMClientTrainer(args, config, self.logger, client=client_id, dset_name="flicker30k", + device='cuda', + vocab_path='./src/custom_datasets/vocabs/coco_vocab.pkl', + mlp_local=self.args.mlp_local)) + if is_test and client_id == 0: + break + print(f"MM Clients Samples Num: {[len(i.train_loader.dataset) for i in self.mm_local_trainers]}") + self.cur_trainers = self.mm_local_trainers + self.cur_type = "mm" + + self.total_local_trainers = self.img_local_trainers + self.txt_local_trainers + self.mm_local_trainers + + for i in range(len(self.total_local_trainers)): + self.total_local_trainers[i].client_idx = i + 1 + + def train(self, round_n): + self.cur_epoch = round_n + # distill_index = data['msg']['distill_index'] + + global_items = self.client_node.get_from(from_node=self.peers[0][0], time_interal=20) + + global_img_feature = global_items["global_img_feature"] + global_txt_feature = global_items["global_txt_feature"] + distill_index = global_items["distill_index"] + + # local training and generated representations + img_vec, img_num = [], [] + txt_vec, txt_num = [], [] + for idx, trainer in enumerate(self.cur_trainers): + self.logger.log(f"Training Client {trainer.client_idx} in gourds {round_n}!") + trainer.cur_epoch = round_n + trainer.run(global_img_feature, global_txt_feature, distill_index, + self._dataloaders['train_subset' + f'_{self.args.pub_data_num}']) + self.logger.log("Generate Local Representations") + _vec, i = trainer.generate_logits( + self.dataloaders_global[ + 'train_subset_eval' + f'_{self.args.pub_data_num}']) # {'img': img_vec, 'txt': txt_vec} + # if not is_test: + if distill_index is None: + distill_index = i + elif distill_index is not None: + assert i == distill_index + if _vec['img'] is not None: + img_vec.append(_vec['img']) + img_num.append(len(trainer.train_loader.dataset)) + print(f'img_vec {_vec["img"].shape}') + if _vec['txt'] is not None: + txt_vec.append(_vec['txt']) + txt_num.append(len(trainer.train_loader.dataset)) + print(f'txt_vec {_vec["txt"].shape}') + + # send local item to global + self.logger.log("start send local msg") + self.client_node.send_msg({"img_vec":img_vec, "txt_vec":txt_vec, "img_num":img_num, "txt_num":txt_num}) + + del img_vec, txt_vec + gc.collect() + +class MMFL_Global(object): + def __init__(self, args, wandb:wandb, node_id, router_port, peers): + self.args = args + self.wandb = wandb + + self.device = None + self.img_local_trainers = None + self.txt_local_trainers = None + self.mm_local_trainers = None + self.engine = None + self.vqa_engine = None + self.best_score = 0 + self.cur_epoch = 0 + + # img & txt local dataloaders + self.img_train_loaders, self.txt_train_loaders = None, None + + # coco global dataloaders + self.dataloaders_global = None + self.vqa_dataloader = None + self.vqa_meta = None + # universal test dataloader + self.test_loader = None + self.vqa_test_loader = None + + self.config = None + self.set_config() + + self.logger = PythonLogger(output_file=self.config.train.output_file) + + self.img_vec, self.txt_vec = None, None + self.global_img_feature = None + self.global_txt_feature = None + self.distill_index = None + self.global_node = Node(node_id=node_id, router_port=router_port, peers=peers) + self.peers = peers + + def set_config(self, img='cifa100', txt='AG_NEWS'): + self.config = parse_config("./src/coco.yaml", strict_cast=False) + self.config.train.model_save_path = 'model_last_no_prob' + self.config.train.best_model_save_path = 'model_best_no_prob' + self.config.train.output_file = 'model_noprob' + self.config.model.img_client = img + self.config.model.txt_client = txt + self.config.train.model_save_path = self.config.train.model_save_path + '.pth' + self.config.train.best_model_save_path = self.config.train.best_model_save_path + '.pth' + self.config.train.output_file = self.config.train.output_file + '.log' + + self.config.model.embed_dim = self.args.feature_dim # set global model dim + + if self.args.not_bert: + self.config.model.not_bert = True + self.config.model.cnn_type = 'resnet50' + else: + self.config.model.not_bert = False + self.config.model.cnn_type = 'resnet101' + + def load_dataset(self, args, is_vqa=False): + dataset_root = os.environ['HOME'] + '/data/mmdata/MSCOCO/2014' + vocab_path = './src/custom_datasets/vocabs/coco_vocab.pkl' + self.dataloaders_global, self.vocab = prepare_coco_dataloaders(self.config.dataloader, dataset_root, args.pub_data_num, args.max_size, vocab_path) + + self.engine = TrainerEngine() + self.engine.set_logger(self.logger) + + if is_vqa: + self.vqa_engine = VQAEngine(args,self.engine, self.wandb) + self.config.vqa_dropout = self.args.vqa_dropout + + self.config.optimizer.learning_rate = self.args.server_lr + + self._dataloaders = self.dataloaders_global.copy() + self.evaluator = COCOEvaluator(eval_method='matmul', + verbose=True, + eval_device='cuda', + n_crossfolds=5) + if is_vqa: + train_f = "/root/xus/CreamFL-main/vq_av2-train.arrow" + vqa_dataset = datasets.Dataset.from_file(train_f) + + test_f = "/root/xus/CreamFL-main/vq_av2-validation.arrow" + test_dataset = datasets.Dataset.from_file(test_f) + meta = VQAMetaData() + meta.build_or_load_categories_top() + self.vqa_meta = meta + self.vqa_dataloader = vqa2_dataloader(vqa_dataset, train=True, filter_unknown=args.vqa_filter_unknown, meta=meta) + + self.vqa_test_loader = vqa2_dataloader(test_dataset, filter_unknown=args.vqa_filter_unknown, meta=meta) + self.vqa_engine.create(self.config, self.vocab.word2idx, self.evaluator, self.args.mlp_local, meta) + #print_model_tree(self.vqa_engine.fusion_model) + if args.pretrained_model.endswith('_vqa.pt'): + print(f"Loading pretrained model as VQAEngine {args.pretrained_model}") + checkpoint = torch.load(args.pretrained_model) + self.vqa_engine.fusion_model.load_state_dict(checkpoint['vqa']) + self.best_score = getattr(checkpoint, 'score', self.best_score) + else: + self.engine.create(self.config, self.vocab.word2idx, self.evaluator, self.args.mlp_local) + if args.pretrained_model.endswith('_net.pt'): + print(f"Loading pretrained model as TrainerEngine {args.pretrained_model}") + checkpoint = torch.load(args.pretrained_model) + self.engine.model.load_state_dict(checkpoint['net']) + if not is_vqa: + self.best_score = getattr(checkpoint, 'score', self.best_score) + + #print_model_tree(self.engine.model) + + self.train_eval_dataloader = self._dataloaders.pop( + 'train_subset_eval' + f'_{self.args.pub_data_num}') if self._dataloaders is not None else None + + self.engine.model_to_device() + torch.backends.cudnn.enabled = True + if self.config.train.get('use_fp16'): + self.engine.logger.log('Train with half precision') + if is_vqa: + self.vqa_engine.to_half() + else: + self.engine.to_half() + + + def create_model(self, args): + self.logger.log('start creating model and partition datasets') + self.device = torch.device("cuda:%d" % args.device) + self.global_round = args.comm_rounds + + os.makedirs(os.environ['HOME'] + f'/data/yClient', exist_ok=True) + + alpha = args.alpha # was hard-coded to 0.1 + batch_size = args.batch_size # was hard-coded to 512 + max_size = args.max_size # introduced by xiegeo + + # Create Client Models + self.img_local_trainers, self.txt_local_trainers, self.mm_local_trainers = [], [], [] + # img clients + if args.num_img_clients > 0: + dataset = 'cifar100' + self.img_trainloaders, test_set = get_FL_trainloader(dataset, os.environ['HOME'] + "/data/cifar100", + args.num_img_clients, "hetero", alpha, batch_size, max_size) + dataset = 'Cifar100' + dst = os.environ['HOME'] + f'/data/yClient/{dataset}' + self.img_local_trainers = [] + for i in range(args.num_img_clients): + self.img_local_trainers.append( + ClientTrainer(args, dataset, dst, RGBmean['Cifar100'], RGBstdv['Cifar100'], None, self.logger, + global_test_set=test_set, inter_distance=4, client_id=i, wandb=self.wandb)) + self.img_local_trainers[i].train_loader = self.img_trainloaders[i] + if is_test and i == 0: + break + # txt clients + if args.num_txt_clients > 0: + dataset = 'AG_NEWS' + self.txt_trainloaders, test_set = get_FL_trainloader(dataset, os.environ['HOME'] + "/data", + args.num_txt_clients, "hetero", alpha, batch_size, max_size) + client_id = 1 + dst = os.environ['HOME'] + f'/data/yClient/{dataset}-{client_id}' + self.txt_local_trainers = [] + for i in range(args.num_txt_clients): + self.txt_local_trainers.append( + ClientTrainer(args, dataset, dst, RGBmean['Cifar100'], RGBstdv['Cifar100'], None, self.logger, + global_test_set=test_set, inter_distance=4, client_id=i, wandb=self.wandb)) + self.txt_local_trainers[i].train_loader = self.txt_trainloaders[i] + if is_test and i == 0: + break + # mm clients + if args.num_mm_clients > 0: + # mm img models + config = parse_config("./src/f30k.yaml", strict_cast=False) + config.model.cache_dir = config.model.cache_dir + '-' + config.train.server_dataset + config.train.output_file = os.path.join(config.model.cache_dir, config.train.output_file) + config.train.best_model_save_path = os.path.join(config.model.cache_dir, config.train.best_model_save_path) + config.train.model_save_path = os.path.join(config.model.cache_dir, config.train.model_save_path) + config.model.embed_dim = self.args.feature_dim + config.model.not_bert = True + self.mm_local_trainers = [] + for client_id in range(args.num_mm_clients): + self.mm_local_trainers.append( + MMClientTrainer(args, config, self.logger, client=client_id, dset_name="flicker30k", + device='cuda', + vocab_path='./src/custom_datasets/vocabs/coco_vocab.pkl', + mlp_local=self.args.mlp_local)) + if is_test and client_id == 0: + break + print(f"MM Clients Samples Num: {[len(i.train_loader.dataset) for i in self.mm_local_trainers]}") + + self.total_local_trainers = self.img_local_trainers + self.txt_local_trainers + self.mm_local_trainers + + for i in range(len(self.total_local_trainers)): + self.total_local_trainers[i].client_idx = i + 1 + + def train(self, round_n): + # for round_n in range(self.global_round): + self.cur_epoch = round_n + + self.cur_trainers = self.total_local_trainers + + self.logger.log(f"Round {round_n}!") + + if not is_test and not self.args.no_retrieval_training: + # global training + print("start retrieval training") + self.engine.train(tr_loader=self._dataloaders['train_subset' + f'_{self.args.pub_data_num}']) + + # global representations + # if len(self.cur_trainers) == 0: + # print("No clients to train, skipping global representations") + if self.args.agg_method == "con_w" or self.args.contrast_local_intra or self.args.contrast_local_inter: + img_feature, txt_feature = [], [] + distill_index = [] + for idx, (images, captions, captions_word, caption_lens, _, _, index) in tqdm( + enumerate(self.dataloaders_global['train_subset_eval' + f'_{self.args.pub_data_num}']), + desc="Global Representations", + total=len(self.dataloaders_global['train_subset_eval' + f'_{self.args.pub_data_num}'])): + with torch.no_grad(): + images = images.to(self.engine.device) # [bs, 3, 224, 224] + captions = captions.to(self.engine.device) # [bs, seq_len] + caption_lens = caption_lens.to(self.engine.device) + + output = self.engine.model(images, captions, captions_word, caption_lens) + out_img = output['image_features'] + out_txt = output['caption_features'] + + out_img = out_img.cpu().detach() + out_txt = out_txt.cpu().detach() + + img_feature.append(out_img) + txt_feature.append(out_txt) + distill_index.extend(index) + + self.global_img_feature = torch.concat(img_feature, dim=0) + self.global_txt_feature = torch.concat(txt_feature, dim=0) + print(self.global_txt_feature.shape, self.global_img_feature.shape) + self.distill_index = distill_index + del img_feature, txt_feature + gc.collect() + else: + print("No agg_method or contrast, skipping global representations") + + self.global_node.send_msg({"global_img_feature":self.global_img_feature, "global_txt_feature":self.global_txt_feature, "distill_index": self.distill_index}) + + img_num = None + img_vec = None + txt_num = None + txt_vec = None + + for cur_peer in self.peers: + cur_node = cur_peer[0] + cur_items = self.global_node.get_from(from_node=cur_node, time_interal=20) + cur_img_num = cur_items["img_num"] + cur_img_vec = cur_items["img_vec"] + cur_txt_num = cur_items["txt_num"] + cur_txt_vec = cur_items["txt_vec"] + + if img_num is None: + img_num = cur_img_num + else: + img_num += cur_img_num + + if img_vec is None: + img_vec = cur_img_vec + else: + img_vec += cur_img_vec + + if txt_num is None: + txt_num = cur_txt_num + else: + txt_num += cur_txt_num + + if txt_vec is None: + txt_vec = cur_txt_vec + else: + txt_vec += cur_txt_vec + + # global distillation + if not self.args.disable_distill: + print("**********start distill*************") + self.distill(round_n, img_vec, txt_vec, img_num, txt_num, self.distill_index) + + def get_lr(optimizer): + for param_group in optimizer.param_groups: + return param_group['lr'] + + # assert round_n + 1 == self.cur_epoch, "inconstant round_n vs cur_epoch, added to check that code clean up does not change logic." + + # record after each epoch training + metadata = self.engine.metadata.copy() + metadata['cur_epoch'] = round_n + 1 + metadata['lr'] = get_lr(self.engine.optimizer) + + score = 0 + + test_scores = self.engine.evaluate({'test': self._dataloaders['test']}) + self.engine.report_scores(step=round_n + 1, + scores=test_scores, + metadata=metadata, + prefix=self.engine.eval_prefix) + rsum = test_scores['test']['n_fold']['i2t']['recall_1'] + test_scores['test']['n_fold']['t2i']['recall_1'] + \ + test_scores['test']['i2t']['recall_1'] + test_scores['test']['t2i']['recall_1'] + self.wandb.log({"Server rsum_r1": rsum}, step=self.cur_epoch) + self.wandb.log({"Server rsum": test_scores['test']['rsum']}, step=self.cur_epoch) + self.wandb.log({"Server n_fold_i2t_r1": test_scores['test']['n_fold']['i2t']['recall_1']}, step=self.cur_epoch) + self.wandb.log({"Server n_fold_t2i_r1": test_scores['test']['n_fold']['t2i']['recall_1']}, step=self.cur_epoch) + self.wandb.log({"Server i2t_r1": test_scores['test']['i2t']['recall_1']}, step=self.cur_epoch) + self.wandb.log({"Server t2i_r1": test_scores['test']['t2i']['recall_1']}, step=self.cur_epoch) + score = rsum + + if self.vqa_engine is not None: + test_loader = None + if round_n == 0: + test_loader = self.vqa_test_loader # only test during training in the first round + self.vqa_engine.train_vqa(self.cur_epoch, self.vqa_dataloader, vqa2_test_dataloader=test_loader) + test_scores = vqa_validation(10000, self.vqa_engine.fusion_model, self.vqa_meta, self.vqa_test_loader) + #test_scores = vqa_validation(100000, self.vqa_engine.fusion_model, self.vqa_meta, self.vqa_test_loader) + self.wandb.log(test_scores, step=self.cur_epoch) + score = test_scores['accuracy'] + + # print(f"**********current vqa accuracy {score}*************") + print(f"**********current score {score}*************") + + + def save_model(type_name, score=score): + prefix = f'{self.args.name}_{type_name}_{self.args.feature_dim}' + if self.vqa_engine is not None: + torch.save({'vqa': self.vqa_engine.fusion_model.state_dict(), + 'score':score}, f'{prefix}_vqa.pt') + else: + torch.save({'net': self.engine.model.state_dict(), + 'score':score}, f'{prefix}_net.pt') + + + if self.best_score < score: + best_score = score + metadata['best_score'] = best_score + metadata['best_epoch'] = round_n + 1 + self.best_metadata, self.best_scores = metadata, test_scores + save_model("best") + + if round_n == self.args.comm_rounds - 1: + save_model("last") + + self.engine.lr_scheduler.step() + if self.vqa_engine is not None: + self.vqa_engine.vqa_lr_scheduler.step() + + del img_vec, txt_vec + gc.collect() + + def distill(self, round_n, img_vec, txt_vec, img_num, txt_num, distill_index): + + if len(img_vec) == 0 and len(txt_vec) == 0: + print("No img_vec and txt_vec to distill (no clients)") + return + + self.engine.model.train() + + if self.config.model.use_img_client or self.config.model.use_txt_client or self.config.model.use_mm_client: + client_loss_cri = nn.MSELoss() + + def aggregation(i_vec=img_vec, t_vec=txt_vec): + if self.args.agg_method == "con_w": + if not i_vec: + self.logger.log("distill.aggregation i_vec is empty") + else: + contrastive_w = [] + for vec in i_vec: # vec: [50000, n_feature], global_txt_feature: [50000, n_feature] + logits = torch.matmul(vec, self.global_txt_feature.T) # [50000, 50000] + exp_logits = torch.exp(logits) + log_prob = logits - torch.log(torch.sum(exp_logits, dim=1, keepdim=True)) + contrastive_w.append(torch.diagonal(log_prob).reshape(1, -1)) + if not contrastive_w: + self.logger.log("distill.aggregation No tensors were added to contrastive_w for images") + else: + contrastive_w = torch.softmax(torch.cat(contrastive_w, dim=0), dim=0) + for i in range(len(i_vec)): + i_vec[i] = (i_vec[i] * contrastive_w[i].reshape(-1, 1)).unsqueeze(0) + i_vec = torch.sum(torch.cat(i_vec, dim=0), dim=0) # aggregated image vectors + + if not t_vec: + self.logger.log("distill.aggregation t_vec is empty") + else: + contrastive_w = [] + for vec in t_vec: # vec: [50000, n_feature], global_txt_feature: [50000, n_feature] + logits = torch.matmul(vec, self.global_img_feature.T) # [50000, 50000] + exp_logits = torch.exp(logits) + log_prob = logits - torch.log(torch.sum(exp_logits, dim=1, keepdim=True)) + contrastive_w.append(torch.diagonal(log_prob).reshape(1, -1)) + if not contrastive_w: + self.logger.log("distill.aggregation No tensors were added to contrastive_w for texts") + else: + contrastive_w = torch.softmax(torch.cat(contrastive_w, dim=0), dim=0) + for i in range(len(t_vec)): + t_vec[i] = (t_vec[i] * contrastive_w[i].reshape(-1, 1)).unsqueeze(0) + t_vec = torch.sum(torch.cat(t_vec, dim=0), dim=0) # aggregated text vectors + else: + raise NotImplementedError + + return i_vec, t_vec + + # aggregation + img_vec, txt_vec = aggregation() + + self.img_vec = img_vec + self.txt_vec = txt_vec + + distill_dict = {b: a for a, b in enumerate(distill_index)} # index in coco to index to list 'distill_index' + # distill + self.logger.log("start distilling") + for idx, (images, captions, captions_word, caption_lens, _, _, index) in tqdm( + enumerate(self.dataloaders_global['train_subset' + f'_{self.args.pub_data_num}'])): + images = images.to(self.engine.device) # [bs, 3, 224, 224] + captions = captions.to(self.engine.device) # [bs, seq_len] + caption_lens = caption_lens.to(self.engine.device) + + output = self.engine.model(images, captions, captions_word, caption_lens) + loss = 0 + + def code_sim(output, target, config): + output = output.sum(axis=1) if len(output.shape) == 3 else output + target = target.type_as(output) + + return client_loss_cri(output, target.type_as(output)) + + if len(img_num)> 0: + # if self.args.num_img_clients > 0 and len(img_num)> 0: + out_img = output['image_features'] + d_idx = operator.itemgetter(*index)(distill_dict) # idx of the current batch + target_img = self.img_vec[d_idx, :].type_as(out_img) + loss += self.args.kd_weight * code_sim(out_img, target_img, self.config) + # if self.args.num_txt_clients > 0 and len(txt_num) > 0: + if len(txt_num) > 0: + out_txt = output['caption_features'] + d_idx = operator.itemgetter(*index)(distill_dict) # idx of the current batch + target_txt = self.txt_vec[d_idx, :].type_as(out_txt) + loss += self.args.kd_weight * code_sim(out_txt, target_txt, self.config) + if len(img_num) > 0 and len(txt_num) > 0: + # if self.args.num_mm_clients > 0 and len(img_num) > 0 and len(txt_num) > 0: + out_img = output['image_features'] + d_idx = operator.itemgetter(*index)(distill_dict) # idx of the current batch + target_img = self.img_vec[d_idx, :].type_as(out_img) + out_txt = output['caption_features'] + target_txt = self.txt_vec[d_idx, :].type_as(out_txt) + loss += self.args.kd_weight * code_sim(out_img, target_img, self.config) + loss += self.args.kd_weight * code_sim(out_txt, target_txt, self.config) + + self.engine.optimizer.zero_grad() + + if self.config.train.get('use_fp16'): + with amp.scale_loss(loss, self.engine.optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + if self.config.train.grad_clip > 0: + nn.utils.clip_grad.clip_grad_norm_(self.engine.model.parameters(), + self.config.train.grad_clip) + self.engine.optimizer.step() diff --git a/src/algorithms/__init__.py b/src/algorithms/__init__.py index ca18301..7df9fb4 100644 --- a/src/algorithms/__init__.py +++ b/src/algorithms/__init__.py @@ -1 +1 @@ -from . import MMFL +# from . import MMFL diff --git a/src/algorithms/__pycache__/ClientTrainer.cpython-38.pyc b/src/algorithms/__pycache__/ClientTrainer.cpython-38.pyc deleted file mode 100644 index 2afb57d..0000000 Binary files a/src/algorithms/__pycache__/ClientTrainer.cpython-38.pyc and /dev/null differ diff --git a/src/algorithms/__pycache__/MMClientTrainer.cpython-38.pyc b/src/algorithms/__pycache__/MMClientTrainer.cpython-38.pyc deleted file mode 100644 index c02bac5..0000000 Binary files a/src/algorithms/__pycache__/MMClientTrainer.cpython-38.pyc and /dev/null differ diff --git a/src/algorithms/__pycache__/MMFL.cpython-38.pyc b/src/algorithms/__pycache__/MMFL.cpython-38.pyc deleted file mode 100644 index c589f5a..0000000 Binary files a/src/algorithms/__pycache__/MMFL.cpython-38.pyc and /dev/null differ diff --git a/src/algorithms/__pycache__/__init__.cpython-38.pyc b/src/algorithms/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 03d9440..0000000 Binary files a/src/algorithms/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/src/algorithms/__pycache__/base.cpython-38.pyc b/src/algorithms/__pycache__/base.cpython-38.pyc deleted file mode 100644 index c54ff52..0000000 Binary files a/src/algorithms/__pycache__/base.cpython-38.pyc and /dev/null differ diff --git a/src/algorithms/__pycache__/eval_coco.cpython-38.pyc b/src/algorithms/__pycache__/eval_coco.cpython-38.pyc deleted file mode 100644 index 071246b..0000000 Binary files a/src/algorithms/__pycache__/eval_coco.cpython-38.pyc and /dev/null differ diff --git a/src/algorithms/__pycache__/optimizers.cpython-38.pyc b/src/algorithms/__pycache__/optimizers.cpython-38.pyc deleted file mode 100644 index 46ffd3d..0000000 Binary files a/src/algorithms/__pycache__/optimizers.cpython-38.pyc and /dev/null differ diff --git a/src/algorithms/__pycache__/retrieval_trainer.cpython-38.pyc b/src/algorithms/__pycache__/retrieval_trainer.cpython-38.pyc deleted file mode 100644 index 59ceee5..0000000 Binary files a/src/algorithms/__pycache__/retrieval_trainer.cpython-38.pyc and /dev/null differ diff --git a/src/algorithms/base.py b/src/algorithms/base.py index edec5c1..e165170 100755 --- a/src/algorithms/base.py +++ b/src/algorithms/base.py @@ -15,16 +15,16 @@ from src.algorithms.optimizers import get_optimizer, get_lr_scheduler from src.algorithms.eval_coco import COCOEvaluator - from src.datasets._dataloader import prepare_f30k_dataloaders, load_vocab + from src.custom_datasets._dataloader import prepare_f30k_dataloaders, load_vocab except ImportError: from ..criterions import get_criterion from ..networks.models import get_model from ..algorithms.optimizers import get_optimizer, get_lr_scheduler from eval_coco import COCOEvaluator - from ..datasets._dataloader import prepare_f30k_dataloaders, load_vocab + from ..custom_datasets._dataloader import prepare_f30k_dataloaders, load_vocab -from utils.serialize_utils import torch_safe_load +from ..utils.serialize_utils import torch_safe_load try: from apex import amp @@ -59,9 +59,9 @@ def parse_config(config_path, cache_dir=None, pretrained_resnet_model_path=None, return config -class EngineBase(object): +class EngineBase2(object): def __init__(self, args, config, logger, client=-1, dset_name="flicker30k", device='cuda', - vocab_path='./datasets/vocabs/coco_vocab.pkl', mlp_local=False): + vocab_path='./custom_datasets/vocabs/coco_vocab.pkl', mlp_local=False): self.dset_name = dset_name self.args = args @@ -83,7 +83,7 @@ def __init__(self, args, config, logger, client=-1, dset_name="flicker30k", devi self.client = client - word2idx = self.set_dset(self.dset_name, client, vocab_path) + word2idx = self.set_dset(self.dset_name, client, args.num_mm_clients, vocab_path) self.config = config self.word2idx = word2idx @@ -114,13 +114,13 @@ def __init__(self, args, config, logger, client=-1, dset_name="flicker30k", devi self.local_epochs = args.local_epochs self.local_epoch = 0 - def set_dset(self, dset_name, client=-1, vocab_path='./datasets/vocabs/coco_vocab.pkl'): + def set_dset(self, dset_name, client=-1, num_users=-1, vocab_path='./custom_datasets/vocabs/coco_vocab.pkl'): if dset_name == "flicker30k": - dataloaders, vocab = prepare_f30k_dataloaders(self.config.dataloader, '', vocab_path, client=client) + dataloaders, vocab = prepare_f30k_dataloaders(self.config.dataloader, '', self.args.max_size, vocab_path, client=client, num_users=num_users) self.train_loader = dataloaders['train'] self.val_loader = dataloaders['te'] elif dset_name == "coco": - dataloaders, vocab = prepare_coco_dataloaders(self.config.dataloader, os.environ['HOME'] + '/data/mmdata/MSCOCO/2014', vocab_path, client=client) + dataloaders, vocab = prepare_coco_dataloaders(self.config.dataloader, os.environ['HOME'] + '/data/mmdata/MSCOCO/2014', self.args.pub_data_num, self.args.max_size, vocab_path, client=client, num_users=num_users) self.train_loader = dataloaders['train_client'] self.val_loader = dataloaders['test'] else: diff --git a/src/algorithms/eval_coco.py b/src/algorithms/eval_coco.py index 6091e86..8fbf2d0 100644 --- a/src/algorithms/eval_coco.py +++ b/src/algorithms/eval_coco.py @@ -227,7 +227,7 @@ def retrieve(self, q_features, g_features, q_ids, g_ids, q_classes=None, g_classes=None, topk=10, - batch_size=1024): + batch_size=32): if len(q_features) != len(q_ids): raise RuntimeError('length mismatch {}, {}'.format(q_features.shape, q_ids.shape)) @@ -273,7 +273,7 @@ def retrieve(self, q_features, g_features, @torch.no_grad() def evaluate_recall(self, q_features, g_features, q_labels, g_labels, q_ids=None, g_ids=None, - batch_size=1024): + batch_size=32): """Evaluate recall Args: @@ -333,6 +333,52 @@ def evaluate_recall(self, q_features, g_features, q_labels, g_labels, return scores + @torch.no_grad() + def evaluate_single(self, q_features, g_features, q_labels, g_labels): + """Evaluate recall + + Args: + q_features (tensor): N_q x d query features + g_features (tensor): N_g x d gallery features + q_labels (tensor): N query labels + g_labels (tensor): N gallery labels + """ + if len(q_features) != len(q_labels): + raise RuntimeError('length mismatch {}, {}'.format(q_features.shape, + q_labels.shape)) + if len(g_features) != len(g_labels): + raise RuntimeError('length mismatch {}, {}'.format(g_features.shape, + g_labels.shape)) + n_queries = len(q_labels) + n_galleries = len(g_labels) + best_pred_ranks = np.zeros(n_queries) + + if self.eval_method == 'matmul': + pmm = ParallelMatMulModule() + g_features = g_features.view(n_galleries * self.n_embeddings, -1).t() + elif self.eval_method == 'matching_prob': + pmm = MatchingProbModule(self.criterion.match_prob) + pmm.set_g_features(g_features) + + q_features = q_features.to(self.eval_device) + + for q_indices in self.pbar(batch(range(n_queries), batch_size=batch_size)): + q_indices = np.array(q_indices) + + if self.eval_method != 'matching_prob': + _q_feature = q_features[q_indices, :] + _q_feature = _q_feature.view(len(q_indices) * self.n_embeddings, -1) + else: + _q_feature = q_features[q_indices, :, :] + _, pred_ranks = pmm(_q_feature, n_embeddings=self.n_embeddings) + + for idx, q_idx in enumerate(q_indices): + pos_indices = np.where(g_labels == q_labels[q_idx])[0] + _pred_ranks = [torch.where(pred_ranks[idx] == pos_idx)[0][0].item() for pos_idx in pos_indices] + best_pred_ranks[q_idx] = min(_pred_ranks) + + return best_pred_ranks + def evaluate_n_fold(self, extracted_features, n_crossfolds, n_images_per_crossfold, n_captions_per_crossfold, eval_batch_size): image_features = extracted_features['image_features'] @@ -393,7 +439,7 @@ def evaluate_n_fold(self, extracted_features, n_crossfolds, n_images_per_crossfo def evaluate(self, dataloader, n_crossfolds=None, n_images_per_crossfold=1000, n_captions_per_crossfold=5000, - eval_batch_size=1024, + eval_batch_size=32, key=None): """evaluate image-to-caption and caption-to-image retrieval tasks. """ diff --git a/src/algorithms/retrieval_trainer.py b/src/algorithms/retrieval_trainer.py index 6396644..275bfec 100644 --- a/src/algorithms/retrieval_trainer.py +++ b/src/algorithms/retrieval_trainer.py @@ -18,11 +18,11 @@ from src.algorithms.optimizers import get_optimizer from src.algorithms.optimizers import get_lr_scheduler from src.criterions import get_criterion +from src.utils.load_datasets import load_vocab from src.networks.models import get_model from src.utils.config import parse_config from src.utils.serialize_utils import flatten_dict, torch_safe_load - try: from apex import amp except ImportError: @@ -81,8 +81,6 @@ def create(self, config, word2idx, evaluator, mlp_local): if self.logger is not None: self.logger.log('start train') - self.img_code, self.txt_code, self.mm_txt_code, self.mm_img_code = None, None, None, None - def model_to_device(self): self.model.to(self.device) if self.criterion: @@ -173,6 +171,36 @@ def load_models(self, state_dict_path, load_keys=None): model_hash, load_keys)) + + def load_models2(self, state_dict_path, evaluator, load_keys=None): + with open(state_dict_path, 'rb') as fin: + model_hash = hashlib.sha1(fin.read()).hexdigest() + self.metadata['pretrain_hash'] = model_hash + + state_dict = torch.load(state_dict_path, map_location='cpu') + + vocab_path = './src/custom_datasets/vocabs/coco_vocab.pkl' + vocab = load_vocab(vocab_path) + self.create(munch.munchify(state_dict['config']), vocab.word2idx, evaluator, False) + if 'model' not in state_dict: + torch_safe_load(self.model, state_dict) + return + + if not load_keys: + load_keys = ['model', 'criterion', 'optimizer', 'lr_scheduler'] + for key in load_keys: + try: + torch_safe_load(getattr(self, key), state_dict[key]) + except RuntimeError as e: + print(e) + if self.logger is not None: + self.logger.log('Unable to import state_dict, missing keys are found. {}'.format(e)) + torch_safe_load(getattr(self, key), state_dict[key]) + if self.logger is not None: + self.logger.log('state dict is loaded from {} (hash: {}), load_key ({})'.format(state_dict_path, + model_hash, + load_keys)) + def load_state_dict(self, state_dict_path, load_keys=None): state_dict = torch.load(state_dict_path) config = parse_config(state_dict['config']) @@ -212,6 +240,7 @@ def train(self, tr_loader, pub_data_ratio=1.): nn.utils.clip_grad.clip_grad_norm_(self.model.parameters(), self.config.train.grad_clip) self.optimizer.step() + def report_scores(self, step, scores, metadata, prefix=''): report_dict = {data_key: flatten_dict(_scores, sep='_') @@ -224,8 +253,13 @@ def report_scores(self, step, scores, metadata, prefix=''): if 'lr' in metadata: report_dict['{}lr'.format(prefix)] = metadata['lr'] - report_dict[ - 'summary'] = f"{report_dict['__test__n_fold_i2t_recall_1']}, {report_dict['__test__n_fold_i2t_recall_5']}, {report_dict['__test__n_fold_i2t_recall_10']}, {report_dict['__test__n_fold_t2i_recall_1']}, {report_dict['__test__n_fold_t2i_recall_5']}, {report_dict['__test__n_fold_t2i_recall_10']}, {report_dict['__test__i2t_recall_1']}, {report_dict['__test__i2t_recall_5']}, {report_dict['__test__i2t_recall_10']}, {report_dict['__test__t2i_recall_1']}, {report_dict['__test__t2i_recall_5']}, {report_dict['__test__t2i_recall_10']}" + # print all keys of report_dict + # print(report_dict.keys()) + # compatibility with different version dependencies + if 'test__n_fold_i2t_recall_1' in report_dict: + report_dict['summary'] = f"{report_dict['test__n_fold_i2t_recall_1']}, {report_dict['test__n_fold_i2t_recall_5']}, {report_dict['test__n_fold_i2t_recall_10']}, {report_dict['test__n_fold_t2i_recall_1']}, {report_dict['test__n_fold_t2i_recall_5']}, {report_dict['test__n_fold_t2i_recall_10']}, {report_dict['test__i2t_recall_1']}, {report_dict['test__i2t_recall_5']}, {report_dict['test__i2t_recall_10']}, {report_dict['test__t2i_recall_1']}, {report_dict['test__t2i_recall_5']}, {report_dict['test__t2i_recall_10']}" + else: + report_dict['summary'] = f"{report_dict['__test__n_fold_i2t_recall_1']}, {report_dict['__test__n_fold_i2t_recall_5']}, {report_dict['__test__n_fold_i2t_recall_10']}, {report_dict['__test__n_fold_t2i_recall_1']}, {report_dict['__test__n_fold_t2i_recall_5']}, {report_dict['__test__n_fold_t2i_recall_10']}, {report_dict['__test__i2t_recall_1']}, {report_dict['__test__i2t_recall_5']}, {report_dict['__test__i2t_recall_10']}, {report_dict['__test__t2i_recall_1']}, {report_dict['__test__t2i_recall_5']}, {report_dict['__test__t2i_recall_10']}" if self.logger is not None: self.logger.report(report_dict, prefix='[Eval] Report @step: ', @@ -236,7 +270,7 @@ def report_scores(self, step, scores, metadata, prefix=''): if self.logger is not None: self.logger.update_tracker(tracker_data) - +# not used class rawTrainerEngine(EngineBase): def _train_epoch(self, dataloader, cur_epoch, prefix='', pub_data_ratio=1.): diff --git a/src/algorithms/vqa_meta.py b/src/algorithms/vqa_meta.py new file mode 100644 index 0000000..f19a609 --- /dev/null +++ b/src/algorithms/vqa_meta.py @@ -0,0 +1,112 @@ +import os +import pickle + +import torch +import datasets +from tqdm import tqdm + +unknown_category = "" +unknown_category_id = 0 + +def build_or_load_categories(fn, dataset): + """ + Load categories from a file if it exists, otherwise build them from a dataset. + + This function checks if a file with the name `fn` exists. If it does, it loads and returns the + categories stored in that file. If the file does not exist, it builds the categories from the + provided dataset. The categories are sorted by their counts (excluding the first category, which + is reserved for unknowns) from most to lest. The sorted categories and their counts are + then saved to the file `fn` for future use. + + The intention of sorting the categories is to make picking the top N categories easier. + + Parameters: + - fn (str): The filename where the categories are stored or will be stored. + - dataset: The dataset from which to build the categories if the file does not exist. + + Returns: + - dict: A dictionary containing the sorted category list under the key 'category_list' and the + corresponding counts under the key 'category_counts'. + """ + if os.path.exists(fn): + with open(fn, "rb") as f: + return pickle.load(f) + builder = VQAMetaData() + builder.set_category_from_dataset(dataset) + sorted_pairs = sorted(zip(builder.category_list[1:], builder.category_counts[1:]), key=lambda x: x[1], reverse=True) + data = {'category_list': builder.category_list[:1]+[cat for cat, _ in sorted_pairs], + 'category_counts': builder.category_counts[:1]+[count for _, count in sorted_pairs]} + with open(fn, "wb") as f: + pickle.dump(data, f) + return data + +class VQAMetaData(): + def __init__(self): + self.category_list = [] # list of categories names + self.category_dict = {} # category name to index + self.category_counts = [] # category counts for each category from the training set + + + def build_or_load_categories_top(self, top = 3000): + if len(self.category_list) != 0: + raise Exception("categories already loaded") + fn = f"vqa2_categories_train.pkl" + data = build_or_load_categories(fn, datasets.load_dataset("HuggingFaceM4/VQAv2", split="train")) + + for i, cat in enumerate(data['category_list']): + count = data['category_counts'][i] + if i <= top: + self.category_list.append(cat) + self.category_counts.append(count) + self.category_dict[cat] = i + else: + self.category_counts[unknown_category_id] += count + + def get_category_size(self): + return len(self.category_list) + + def get_category_id(self, cat, add_new=False): + add_count = add_new # add count only when we are building the list of categories + if len(self.category_list) == 0: + self.category_dict[unknown_category] = unknown_category_id + self.category_list.append(unknown_category) + self.category_counts.append(0) + if cat in self.category_dict: + cat_id = self.category_dict[cat] + if add_count: + self.category_counts[cat_id] += 1 + return cat_id + if not add_new: + return unknown_category_id + new_id = len(self.category_list) + self.category_dict[cat] = new_id + self.category_list.append(cat) + self.category_counts.append(1) + return new_id + + def get_category_by_id(self, cat_id): + return self.category_list[cat_id] + + def set_category_from_dataset(self,dataset): + #for item in tqdm(dataset.map(lambda example: {'multiple_choice_answer': example['multiple_choice_answer']})): + # get_category_id(item['multiple_choice_answer']) + dataset = dataset.map(lambda example: {'multiple_choice_answer': example['multiple_choice_answer']}) + dataloader = torch.utils.data.DataLoader(dataset, batch_size=2048, num_workers=32, + collate_fn=lambda examples: {'multiple_choice_answer': [example['multiple_choice_answer'] for example in examples]}) + self.set_category_from_dataloader(dataloader) + + def set_category_from_dataloader(self, dataloader): + for batch in tqdm(dataloader): + for answer in batch['multiple_choice_answer']: + self.get_category_id(answer, add_new=True) + + def get_weights(self, args): + if args.vqa_cat_weight == '1': + return None # use default uniform weights for CrossEntropyLoss + epsilon = 1e-8 # Small value to prevent division by zero + if args.vqa_cat_weight == 'count+1000': + epsilon = 1000 + total_count = sum(self.category_counts) + total_count = total_count + epsilon * self.get_category_size() + return [total_count / (class_count + epsilon) for class_count in self.category_counts] + \ No newline at end of file diff --git a/src/algorithms/vqa_trainer.py b/src/algorithms/vqa_trainer.py new file mode 100644 index 0000000..b9973b7 --- /dev/null +++ b/src/algorithms/vqa_trainer.py @@ -0,0 +1,197 @@ +from collections import Counter +import random +import torch +from tqdm import tqdm +from algorithms.optimizers import get_optimizer, get_lr_scheduler +from algorithms.retrieval_trainer import TrainerEngine +from algorithms.vqa_meta import VQAMetaData, unknown_category_id, unknown_category +from networks.fusion_model import VQAFusionModel + +try: + from apex import amp + #print("enable f16 and using apex.amp for mixed precision training") + #use_f16 = True +except ImportError as e: + print('failed to import apex:', e) + +class VQAEngine(): + def __init__(self, args, base_trainer_engine:TrainerEngine, wandb, device='cuda'): + self.args = args + self.device = device + self.trainer_engine = base_trainer_engine + self.wandb = wandb + self.fusion_model = None + self.vqa_optimizer = None + self.vqa_criterion = None + self.vqa_lr_scheduler = None + self.vqa_meta = None + + def weights_tensor(self, meta:VQAMetaData): + weights = meta.get_weights(self.args) + if weights is None: + return None + return torch.tensor(weights).to(self.device) + + def create(self, config, word2idx, evaluator, mlp_local, meta:VQAMetaData): + self.config = config + self.vqa_meta = meta + self.trainer_engine.create(config, word2idx, evaluator, mlp_local) + self.fusion_model = VQAFusionModel(self.device,self.trainer_engine.model,1,1, meta.get_category_size(), config.vqa_hidden_sizes, dropout_rate=config.vqa_dropout).to(self.trainer_engine.device) + self.vqa_criterion = torch.nn.CrossEntropyLoss(weight=self.weights_tensor(meta)).to(self.device) + self.vqa_optimizer = get_optimizer(config.optimizer.name, + self.fusion_model.parameters(), + config.optimizer) + self.vqa_lr_scheduler = get_lr_scheduler(config.lr_scheduler.name, + self.vqa_optimizer, + config.lr_scheduler) + def to_half(self): + # Mixed precision + # https://nvidia.github.io/apex/amp.html + self.fusion_model, self.vqa_optimizer = amp.initialize(self.fusion_model, [self.vqa_optimizer, self.trainer_engine.optimizer], + opt_level='O2') + + def train(self, tr_loader, pub_data_ratio=1.): + self.trainer_engine.train(tr_loader, pub_data_ratio) + + def train_vqa(self, epoch, vqa_loader, vqa2_test_dataloader = None): + self.fusion_model.train() + full_training_epoch = self.args.vqa_full_training_epoch + if epoch < full_training_epoch: + print("Freezing base model") + self.fusion_model.freeze_base_model() + else: + print("Not freezing base model") + # print_model_tree(self.fusion_model) + + max_batches = len(vqa_loader) + if self.args.vqa_data_size_per_epoch == 0: + max_batches = self.args.pub_data_num / vqa_loader.batch_size + elif self.args.vqa_data_size_per_epoch > 0: + max_batches = self.args.vqa_data_size_per_epoch / vqa_loader.batch_size + + n = 0 + loss_avg = 0 + with tqdm(enumerate(vqa_loader), total=max_batches) as progress_bar: + for i, batch in progress_bar: + if i >= max_batches: + break + + self.vqa_optimizer.zero_grad() + outputs, last_features = self.fusion_model.forward(batch) + #answers = batch['multiple_choice_answer'] + answers = batch['answers'] # for multiple answers, learn a random answer based on popularity + if isinstance(answers[0], list): + picked_answers = [] + for answer_list in answers: + known = [] + for answer in answer_list: + name = answer['answer'] + id = self.vqa_meta.get_category_id(name) + if id != unknown_category_id: + known.append(id) + if len(known) == 0: + known = [unknown_category_id] + picked_answers.append(random.choice(known)) + targets = torch.tensor(picked_answers).to(self.device) + else: + targets = torch.tensor([self.vqa_meta.get_category_id(answer) for answer in answers]).to(self.device) + loss = self.vqa_criterion(outputs, targets) + + if self.config.train.get('use_fp16'): + with amp.scale_loss(loss, self.vqa_optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + if epoch >= full_training_epoch and self.config.train.grad_clip > 0: + torch.nn.utils.clip_grad.clip_grad_norm_(self.fusion_model.parameters(), + self.config.train.grad_clip) + + loss_avg_rate = min(i, 99) + loss_avg = (loss_avg * loss_avg_rate + loss.item()) / (loss_avg_rate + 1) + self.vqa_optimizer.step() + progress_bar.set_description(f"Epoch {epoch}, Iter {i}, l100: {loss_avg:.4f}") + + if vqa2_test_dataloader is not None and (i+1) % (128*2**n) == 0: + #vqa_validation(1000, self.fusion_model, self.vqa_meta, vqa2_test_dataloader, 2) + #vqa_validation(1000, self.fusion_model, self.vqa_meta, vqa2_test_dataloader, 500) + #vqa_validation(1000, self.fusion_model, self.vqa_meta, vqa2_test_dataloader) + vqa_validation(10000, self.fusion_model, self.vqa_meta, vqa2_test_dataloader) + n += 1 + self.fusion_model.train() + if epoch < full_training_epoch: + self.fusion_model.freeze_base_model() + self.wandb.log({"train_vqa_loss_100": loss_avg}, step=epoch) + #self.wandb.log({"train_vqa_lr": self.vqa_optimizer.param_groups[0].get_lr()}) + self.fusion_model.unfreeze_base_model() + #print_model_tree(self.fusion_model) + + +@torch.no_grad() +def vqa_validation(n, fusion_model, meta, validation_dataloader, max_cats = 3000): + fusion_model.eval() + right = 0 + unknown_right = 0 + unknown_outputs = 0 + unknown_answers = 0 + unknown_unknown = 0 + total = 0 + for j, testBatch in tqdm(enumerate(validation_dataloader)): + #answers = testBatch['multiple_choice_answer'] # single answer + answers = testBatch['answers'] # multiple answers + outputs, _ = fusion_model.forward(testBatch) + for k, answer in enumerate(answers): + output = outputs[k] + _, top_matches = torch.topk(output, min(5,max_cats+1), largest=True, sorted=True) + top_match_names = [meta.get_category_by_id(cat_id.item()) for cat_id in top_matches] + if isinstance(answer, list): # use testBatch['answers'] + expected_answer = [f"'{item}'x{count}" for item, count in Counter([v['answer'] for v in answer]).most_common()] + for name in top_match_names: + if name in [v['answer'] for v in answer] : + answer = name # answer is the logical expected answer so our guess can be matched + break + if isinstance(answer, list): + answer = answer[0]['answer'] # no matches, pick the first answer + else: + expected_answer = answer + answer_id = meta.get_category_id(answer) + match_type = "wrong answer" + if answer_id > max_cats: + answer_id = unknown_category_id + if len(output) > max_cats: + output = output[:max_cats+1] + if top_match_names[0] == answer: + right += 1 + match_type = "first" + if top_matches[0] == unknown_category_id: + unknown_outputs += 1 + match_type = "unknown output" + if top_match_names[1] == answer: + unknown_right += 1 + match_type = "second" + if answer_id == unknown_category_id: + unknown_answers += 1 + answer = unknown_category + answer # mark answers not in the training set + match_type = "unknown answer" + if top_matches[0] == unknown_category_id: + unknown_unknown += 1 + match_type = "unknown unknown" + if total + k < 16: + tqdm.write(f"j {j}, k {k}, expected {expected_answer}, got {top_match_names}, match type {match_type}") + total += len(answers) + if total >= n: + break + accuracy = (right + unknown_right) / total + tqdm.write(f"test {max_cats} accuracy {right+unknown_right}/{total}={accuracy}, unknown_answers:{unknown_answers}, unknown_outputs:{unknown_outputs}, right after unknown:{unknown_right}, unknown_unknown:{unknown_unknown}") + + return { + "max_cats": max_cats, + "right0": right, + "right": right+unknown_right, + "total": total, + "accuracy": accuracy, + "unknown_answers": unknown_answers, + "unknown_outputs": unknown_outputs, + "right1": unknown_right, + "unknown_unknown": unknown_unknown + } diff --git a/src/coco.yaml b/src/coco.yaml index 6ee5559..31e3e4a 100644 --- a/src/coco.yaml +++ b/src/coco.yaml @@ -1,6 +1,6 @@ dataloader: - batch_size: 128 - eval_batch_size: 8 + batch_size: 64 + eval_batch_size: 64 num_workers: 16 crop_size: 224 word_dim: 300 @@ -9,8 +9,8 @@ dataloader: model: name: pcme - embed_dim: 512 # 2048 origin - cnn_type: resnet50 # res152 origin + embed_dim: 2048 # (ignored value, set in code) 2048 origin + cnn_type: resnet101 # (ignored value, set in code) res152 origin wemb_type: glove word_dim: 300 cache_dir: /data/mmdata/log/server @@ -50,10 +50,13 @@ criterion: train: model_save_path: model_last_no_prob.pth best_model_save_path: model_best_no_prob.pth - finetune_epochs: 30 + finetune_epochs: 30 # not used finetune_lr_decay: 0.1 log_step: 100 grad_clip: 2 - val_epochs: 10 - use_fp16: True + val_epochs: 10 # not used + use_fp16: False output_file: model_noprob.log + +vqa_hidden_sizes: [1024,1024] +vqa_dropout: 0.0 # (ignored value, set in code) \ No newline at end of file diff --git a/src/common.py b/src/common.py new file mode 100644 index 0000000..6cedf69 --- /dev/null +++ b/src/common.py @@ -0,0 +1,182 @@ +import argparse +import os +import sys +sys.path.append("./") +sys.path.append("../") +sys.path.append("../../") +sys.path.append("../../../") + +from src.utils.helper import Helper as helper +from src.utils.config import parse_config + + + +def add_args(parser: argparse.ArgumentParser, is_vqa=False): + parser.add_argument('--inference', type=bool, default=False, help='inferencing or not.') + parser.add_argument('--port', type=int, default=2323, help='port') + parser.add_argument('--name', type=str, default='Test', help='The name for different experimental runs.') + parser.add_argument('--disable_wandb', action='store_true', default=False) + parser.add_argument('--exp_dir', type=str, default='./experiments/', + help='Locations to save different experimental runs.') + parser.add_argument('--local_epochs', type=int, default=5) # original default = 5 + parser.add_argument('--client_init_local_epochs', type=int, default=0, help='Number of additional local epochs when clients first receive the global model') + parser.add_argument('--comm_rounds', type=int, default=30) # original default = 30 + + #parser.add_argument('--model', type=str, default='resnet18', help='Target model name') + #parser.add_argument('--img_model_local', type=str, default='resnet10') + #parser.add_argument('--pretrained', type=int, default=0) + parser.add_argument('--pretrained_model', type=str, default="") + parser.add_argument('--no-cuda', action='store_true', default=False, + help='disables CUDA training') + parser.add_argument('--seed', type=int, default=1, metavar='S', + help='random seed (default: 1)') + parser.add_argument('--device', type=int, default=0) + parser.add_argument('--num_img_clients', type=int, default=10) # original default = 10 + parser.add_argument('--num_txt_clients', type=int, default=10) # original default = 10 + parser.add_argument('--num_mm_clients', type=int, default=15) # original default = 15 + + parser.add_argument('--client_num_per_round', type=int, default=10) # original default = 10 + + # === dataloader === + # parser.add_argument('--dataset', type=str, default='cifar100', choices=['svhn', 'cifar10', 'cifar100'], + # help='dataset name (default: cifar100)') # not implemented + parser.add_argument('--data_root', type=str, default=os.environ['HOME'] + "/data/") + parser.add_argument('--batch_size', type=int, default=32, metavar='N', + help='input batch size for training (default: 32)') + parser.add_argument('--alpha', type=float, default=0.1, + help='how evenly distributed the data is for img and txt clients (default: 0.1)') + parser.add_argument('--max_size', type=int, default=0, + help='maximum number of data samples to use per client (default: 0 (use all data))') + + # === communication cost === + parser.add_argument('--pub_data_num', type=int, default=50000, help='coco global training data size') + parser.add_argument('--feature_dim', type=int, default=256) + + # === optimization === + parser.add_argument('--server_lr', type=float, default=0.0002) + parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='Student learning rate (default: 0.1)') + parser.add_argument('--loss', type=str, default='l1', choices=['l1', 'kl', 'l1softmax'], ) + parser.add_argument('--scheduler', type=str, default='multistep', + choices=['multistep', 'cosine', 'exponential', "none"], ) + parser.add_argument('--steps', nargs='+', default=[0.05, 0.15, 0.3, 0.5, 0.75], type=float, + help="Percentage epochs at which to take next step") + parser.add_argument('--scale', type=float, default=0.1, help="Fractional decrease in lr") + parser.add_argument('--weight_decay', type=float, default=5e-4) + parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') + # === logs === + parser.add_argument('--log_interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + parser.add_argument('--save_interval', type=int, default=10, metavar='N', + help='how many batches to wait before logging training status') + + parser.add_argument('--disable_distill', action="store_true", default=False) + + parser.add_argument('--agg_method', type=str, default='con_w', help='representation aggregation method') + parser.add_argument('--contrast_local_intra', action="store_true", default=False) + parser.add_argument('--contrast_local_inter', action="store_true", default=False) + + parser.add_argument('--mlp_local', action="store_true", default=False) + + parser.add_argument('--kd_weight', type=float, default=0.3, help='coefficient of kd') + parser.add_argument('--interintra_weight', type=float, default=0.5, help='coefficient of inter+intra') + + parser.add_argument('--loss_scale', action='store_true', default=False) + parser.add_argument('--save_client', action='store_true', default=False) + + parser.add_argument('--data_local', action='store_true', default=False, + help='change data directory to ~/data_local') + + parser.add_argument('--not_bert', action='store_true', default=False, help="server bert, client not bert") + + # === federated learning networking === + parser.add_argument('--fed_config', default='fed_config.yaml', help="federation network configuration file") + parser.add_argument('--client_name', help="client name, only used by clients") + + # === vqa related option === + parser.add_argument('--no_retrieval_training', action='store_true', default=False,) + + + if is_vqa: # vqa only options + parser.add_argument('--vqa_fusion_network', default='linear') + parser.add_argument('--vqa_pretrained_base_model', default='./sl2-best_model.pt') + parser.add_argument('--vqa_pretrained_eval', action='store_true', default=False, + help='check how good the base model is on the retrieval task to make sure we are loading a good model.') + parser.add_argument('--vqa_hidden_sizes', nargs='*', type=int, default=[], + help='List of hidden layer sizes for the fusion network.') + parser.add_argument('--vqa_epochs', type=int, default=10, help='Number of epochs to train the model.') + #parser.add_argument('--vqa_unfreeze_base_epoch', type=int, default=5, help='Epoch to unfreeze the base model.') + parser.add_argument('--vqa_lr', type=float, default=0.0002, help='Learning rate for the fusion network.') + parser.add_argument('--vqa_weight_decay', type=float, default=0.0, help='Weight decay for the fusion network.') + parser.add_argument('--vqa_dropout', type=float, default=0.0, help='Dropout rate for the fusion network.') + parser.add_argument('--vqa_input_type', type=str, default='AxB', choices=['A_B', 'AxB'],) + parser.add_argument('--vqa_cat_weight', type=str, default='1', choices=['1', 'count', 'count+1000'],) + parser.add_argument('--vqa_full_training_epoch' , type=int, default=0, help='Number of epochs start training the model end to end.') + parser.add_argument('--vqa_data_size_per_epoch' , type=int, default=0, help='Number of data samples to use per epoch for global model training (default: 0 (same as pub_data_num), -1 (use all data)') + parser.add_argument('--vqa_filter_unknown', action='store_true', default=False, help='Filter unknown answers from the training and testing data.') + + +def init_wandb(args, script=None): + """ + wandb will automatically save the log + + wandb.log({"epoch": epoch, "loss": loss}, step=example_ct) + print(f"Loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}") + + wandb.log({"test_accuracy": correct / total}) + + # Save the model in the exchangeable ONNX format + torch.onnx.export(model, images, "model.onnx") + wandb.save("model.onnx") + + """ + + import wandb + if args.disable_wandb: + return wandb.init(mode="disabled") + + name = str(args.name) + + if script is not None: + name = f"{name}-{script}" + + wandb.init( + project="CreamFL", + name=name, + resume=None, + # dir=os.path.join(args.exp_dir, args.name), + config=args + ) + + return wandb + +def get_config(args, img='cifa100', txt='AG_NEWS'): + config = parse_config("./src/coco.yaml", strict_cast=False) + config.train.model_save_path = 'model_last_no_prob' + config.train.best_model_save_path = 'model_best_no_prob' + config.train.output_file = 'model_noprob' + config.model.img_client = img + config.model.txt_client = txt + config.train.model_save_path = config.train.model_save_path + '.pth' + config.train.best_model_save_path = config.train.best_model_save_path + '.pth' + config.train.output_file = config.train.output_file + '.log' + + config.model.embed_dim = args.feature_dim # set global model dim + + if args.not_bert: + config.model.not_bert = True + config.model.cnn_type = 'resnet50' + else: + config.model.not_bert = False + config.model.cnn_type = 'resnet101' + + return config + +def prepare_args(description: str, script=None, is_vqa=False): + parser = argparse.ArgumentParser(description=description) + add_args(parser, is_vqa=is_vqa) + args = parser.parse_args() + wandb = init_wandb(args, script=script) + args.save_dirs = helper.get_save_dirs(args.exp_dir, args.name) + args.log_dir = args.save_dirs['logs'] + helper.set_seed(args.seed) + return args, wandb diff --git a/src/criterions/__init__.py b/src/criterions/__init__.py index fb7bac8..80a89f9 100644 --- a/src/criterions/__init__.py +++ b/src/criterions/__init__.py @@ -1,8 +1,10 @@ -from criterions.probemb import MCSoftContrastiveLoss - +from . probemb import MCSoftContrastiveLoss +from torch.nn import BCEWithLogitsLoss def get_criterion(criterion_name, config): if criterion_name == 'pcme': return MCSoftContrastiveLoss(config) + if criterion_name == "vqa": + return BCEWithLogitsLoss(config) else: raise ValueError(f'Invalid criterion name: {criterion_name}') diff --git a/src/criterions/__pycache__/__init__.cpython-38.pyc b/src/criterions/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index e019526..0000000 Binary files a/src/criterions/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/src/criterions/__pycache__/probemb.cpython-38.pyc b/src/criterions/__pycache__/probemb.cpython-38.pyc deleted file mode 100644 index b673ae4..0000000 Binary files a/src/criterions/__pycache__/probemb.cpython-38.pyc and /dev/null differ diff --git a/src/custom_datasets/__init__.py b/src/custom_datasets/__init__.py new file mode 100644 index 0000000..7771a6e --- /dev/null +++ b/src/custom_datasets/__init__.py @@ -0,0 +1,9 @@ +from . _dataloader import prepare_coco_dataloaders, prepare_cub_dataloaders +from . vocab import Vocabulary + + +__all__ = [ + 'Vocabulary', + 'prepare_coco_dataloaders', + 'prepare_cub_dataloaders', +] diff --git a/src/datasets/_dataloader.py b/src/custom_datasets/_dataloader.py similarity index 82% rename from src/datasets/_dataloader.py rename to src/custom_datasets/_dataloader.py index 012e292..8e5ee14 100644 --- a/src/datasets/_dataloader.py +++ b/src/custom_datasets/_dataloader.py @@ -1,4 +1,4 @@ -"""libaray for multi-modal dataset loaders. +"""library for multi-modal dataset loaders. Acknowledgements: `image_to_caption_collate_fn` is based on @@ -12,11 +12,11 @@ from torch.utils.data import DataLoader try: - from datasets.coco import CocoCaptionsCap - from datasets.flickr30k import F30kCaptionsCap - from datasets.cub import CUBCaption, CUBSampler - from datasets.vocab import Vocabulary - from datasets._transforms import imagenet_transform, caption_transform + from custom_datasets.coco import CocoCaptionsCap + from custom_datasets.flickr30k import F30kCaptionsCap + from custom_datasets.cub import CUBCaption, CUBSampler + from custom_datasets.vocab import Vocabulary + from custom_datasets._transforms import imagenet_transform, caption_transform except: try: from coco import CocoCaptionsCap @@ -25,11 +25,16 @@ from vocab import Vocabulary from _transforms import imagenet_transform, caption_transform except: - from src.datasets.coco import CocoCaptionsCap - from src.datasets.flickr30k import F30kCaptionsCap - from src.datasets.cub import CUBCaption, CUBSampler - from src.datasets.vocab import Vocabulary - from src.datasets._transforms import imagenet_transform, caption_transform + import sys + sys.path.append("./") + sys.path.append("../") + sys.path.append("../../") + sys.path.append("../../../") + from src.custom_datasets.coco import CocoCaptionsCap + from src.custom_datasets.flickr30k import F30kCaptionsCap + from src.custom_datasets.cub import CUBCaption, CUBSampler + from src.custom_datasets.vocab import Vocabulary + from src.custom_datasets._transforms import imagenet_transform, caption_transform def image_to_caption_collate_fn(data): @@ -86,26 +91,26 @@ def _get_cub_file_paths(dataset_name, dataset_root, caption_root): Each split contains 100 train classes / 50 validation classes. - cub: The final split used for the final benchmark. - This split conntains 150 train classes / 50 unseen test classes (not in trainval) + This split contains 150 train classes / 50 unseen test classes (not in trainval) """ if dataset_name == 'cub_trainval1': - train_classes = './datasets/annotations/cub/trainclasses1.txt' - val_classes = './datasets/annotations/cub/valclasses1.txt' - omit_ids = './datasets/annotations/cub/seen_test_images.txt' + train_classes = './custom_datasets/annotations/cub/trainclasses1.txt' + val_classes = './custom_datasets/annotations/cub/valclasses1.txt' + omit_ids = './custom_datasets/annotations/cub/seen_test_images.txt' elif dataset_name == 'cub_trainval2': - train_classes = './datasets/annotations/cub/trainclasses2.txt' - val_classes = './datasets/annotations/cub/valclasses2.txt' - omit_ids = './datasets/annotations/cub/seen_test_images.txt' + train_classes = './custom_datasets/annotations/cub/trainclasses2.txt' + val_classes = './custom_datasets/annotations/cub/valclasses2.txt' + omit_ids = './custom_datasets/annotations/cub/seen_test_images.txt' elif dataset_name == 'cub_trainval3': - train_classes = './datasets/annotations/cub/trainclasses3.txt' - val_classes = './datasets/annotations/cub/valclasses3.txt' - omit_ids = './datasets/annotations/cub/seen_test_images.txt' + train_classes = './custom_datasets/annotations/cub/trainclasses3.txt' + val_classes = './custom_datasets/annotations/cub/valclasses3.txt' + omit_ids = './custom_datasets/annotations/cub/seen_test_images.txt' elif dataset_name == 'cub': - train_classes = './datasets/annotations/cub/trainvalclasses.txt' - val_classes = './datasets/annotations/cub/testclasses.txt' - omit_ids = './datasets/annotations/cub/seen_test_images.txt' + train_classes = './custom_datasets/annotations/cub/trainvalclasses.txt' + val_classes = './custom_datasets/annotations/cub/testclasses.txt' + omit_ids = './custom_datasets/annotations/cub/seen_test_images.txt' else: - raise ValueError(f'Invalide dataset_name: {dataset_name}') + raise ValueError(f'Invalid dataset_name: {dataset_name}') image_root = os.path.join(dataset_root, 'images/') @@ -149,7 +154,7 @@ def prepare_cub_dataloaders(dataloader_config, dataset_root, caption_root, vocab_path='./vocabs/cub_vocab.pkl', - num_workers=6): + num_workers=12): """Prepare CUB Caption train / val / test dataloaders CUB Caption loader has a fixed batch size - train loader: # classes (trainval = 100, full = 150) @@ -162,11 +167,11 @@ def prepare_cub_dataloaders(dataloader_config, Each split contains 100 train classes / 50 validation classes. - cub: The final split used for the final benchmark. - This split conntains 150 train classes / 50 unseen test classes (not in trainval) + This split contains 150 train classes / 50 unseen test classes (not in trainval) dataset_root (str): root of your CUB images (see README.md for detailed dataset hierarchy) caption_root (str): root of your CUB captions (see README.md for detailed dataset hierarchy) vocab_path (str, optional): path for vocab pickle file (default: ./vocabs/cub_vocab.pkl). - num_workers (int, optional): num_workers for the dataloaders (default: 6) + num_workers (int, optional): num_workers for the dataloaders (default: 12) Returns: dataloaders (dict): keys = ["train", "val", "val_in"], values are the corresponding dataloaders. vocab (Vocabulary object): vocab object @@ -244,10 +249,10 @@ def _get_coco_loader(image_root, def _get_coco_file_paths(dataset_root): """Select proper train / val classes and omit id files. """ - train_ids = np.load('./datasets/annotations/coco_train_ids.npy') - train_extra_ids = np.load('./datasets/annotations/coco_restval_ids.npy') - val_ids = np.load('./datasets/annotations/coco_dev_ids.npy')[:5000] - te_ids = np.load('./datasets/annotations/coco_test_ids.npy') + train_ids = np.load('./custom_datasets/annotations/coco_train_ids.npy') + train_extra_ids = np.load('./custom_datasets/annotations/coco_restval_ids.npy') + val_ids = np.load('./custom_datasets/annotations/coco_dev_ids.npy')[:5000] + te_ids = np.load('./custom_datasets/annotations/coco_test_ids.npy') image_root = os.path.join(dataset_root, 'allimages') train_ann = os.path.join(dataset_root, 'annotations/captions_train2014.json') @@ -259,13 +264,13 @@ def _get_coco_file_paths(dataset_root): def prepare_coco_dataloaders(dataloader_config, dataset_root, vocab_path='./vocabs/coco_vocab.pkl', - num_workers=6): + num_workers=12): """Prepare MS-COCO Caption train / val / test dataloaders Args: dataloader_config (dict): configuration file which should contain "batch_size" dataset_root (str): root of your MS-COCO dataset (see README.md for detailed dataset hierarchy) vocab_path (str, optional): path for vocab pickle file (default: ./vocabs/coco_vocab.pkl). - num_workers (int, optional): num_workers for the dataloaders (default: 6) + num_workers (int, optional): num_workers for the dataloaders (default: 12) Returns: dataloaders (dict): keys = ["train", "val", "te"], values are the corresponding dataloaders. vocab (Vocabulary object): vocab object @@ -341,18 +346,20 @@ def see_coco_len(dataset_root=os.environ['HOME'] + '/data/mmdata/MSCOCO/2014'): transform=None, target_transform=None) - print(f'train {len(train)}') - print(f'test {len(test)}') + print(f'see_coco_len train {len(train)}') + print(f'see_coco_len test {len(test)}') def _get_F30k_loader(vocab, num_workers, + max_size, batch_size=64, train=False, split='train', cutout_prob=0.0, caption_drop_prob=0.0, - client=-1): + client=-1, + num_users=-1): _image_transform = imagenet_transform( random_resize_crop=train, random_erasing_prob=cutout_prob, @@ -362,28 +369,30 @@ def _get_F30k_loader(vocab, coco_dataset = F30kCaptionsCap(train=True if split == 'train' else False, transform=_image_transform, - target_transform=_caption_transform, client=client) + target_transform=_caption_transform, client=client, num_users=num_users, max_size=max_size) dataloader = DataLoader(coco_dataset, batch_size=batch_size, shuffle=train, num_workers=num_workers, collate_fn=image_to_caption_collate_fn, - pin_memory=True) - print(f'Loading F30k Caption: n_images {coco_dataset.n_images} n_captions {len(coco_dataset)}...') + pin_memory=False) + print(f'Loading F30k Caption: split {split} n_images {coco_dataset.n_images} n_captions {len(coco_dataset)}...') return dataloader def prepare_f30k_dataloaders(dataloader_config, dataset_root, + max_size, vocab_path='./vocabs/coco_vocab.pkl', client=-1, - num_workers=6): + num_users=-1, + num_workers=12): """Prepare MS-COCO Caption train / val / test dataloaders Args: dataloader_config (dict): configuration file which should contain "batch_size" dataset_root (str): root of your MS-COCO dataset (see README.md for detailed dataset hierarchy) vocab_path (str, optional): path for vocab pickle file (default: ./vocabs/coco_vocab.pkl). - num_workers (int, optional): num_workers for the dataloaders (default: 6) + num_workers (int, optional): num_workers for the dataloaders (default: 12) Returns: dataloaders (dict): keys = ["train", "val", "te"], values are the corresponding dataloaders. vocab (Vocabulary object): vocab object @@ -398,13 +407,15 @@ def prepare_f30k_dataloaders(dataloader_config, dataloaders = {} dataloaders['train'] = _get_F30k_loader( vocab, - num_workers=num_workers, + num_workers, + max_size, batch_size=batch_size, train=True, split='train', cutout_prob=tr_cutout_prob, caption_drop_prob=tr_caption_drop_prob, - client=client + client=client, + num_users=num_users, ) # dataloaders['val'] = _get_F30k_loader( @@ -419,11 +430,13 @@ def prepare_f30k_dataloaders(dataloader_config, dataloaders['te'] = _get_F30k_loader( vocab, - num_workers=num_workers, + num_workers, + max_size, batch_size=eval_batch_size, train=False, split='test', - client=client + client=client, + num_users=num_users ) return dataloaders, vocab @@ -434,8 +447,8 @@ def see_f30k_len(): test = F30kCaptionsCap(split='test') - print(f'f30k train {len(train)}') - print(f'f30k test {len(test)}') + print(f'see_f30k_len train {len(train)}') + print(f'see_f30k_len test {len(test)}') if __name__ == '__main__': diff --git a/src/datasets/_transforms.py b/src/custom_datasets/_transforms.py similarity index 100% rename from src/datasets/_transforms.py rename to src/custom_datasets/_transforms.py diff --git a/src/datasets/annotations/coco_dev_ids.npy b/src/custom_datasets/annotations/coco_dev_ids.npy similarity index 100% rename from src/datasets/annotations/coco_dev_ids.npy rename to src/custom_datasets/annotations/coco_dev_ids.npy diff --git a/src/datasets/annotations/coco_restval_ids.npy b/src/custom_datasets/annotations/coco_restval_ids.npy similarity index 100% rename from src/datasets/annotations/coco_restval_ids.npy rename to src/custom_datasets/annotations/coco_restval_ids.npy diff --git a/src/datasets/annotations/coco_test_ids.npy b/src/custom_datasets/annotations/coco_test_ids.npy similarity index 100% rename from src/datasets/annotations/coco_test_ids.npy rename to src/custom_datasets/annotations/coco_test_ids.npy diff --git a/src/datasets/annotations/coco_train_ids.npy b/src/custom_datasets/annotations/coco_train_ids.npy similarity index 100% rename from src/datasets/annotations/coco_train_ids.npy rename to src/custom_datasets/annotations/coco_train_ids.npy diff --git a/src/datasets/annotations/cub/seen_test_images.txt b/src/custom_datasets/annotations/cub/seen_test_images.txt similarity index 100% rename from src/datasets/annotations/cub/seen_test_images.txt rename to src/custom_datasets/annotations/cub/seen_test_images.txt diff --git a/src/datasets/annotations/cub/testclasses.txt b/src/custom_datasets/annotations/cub/testclasses.txt similarity index 100% rename from src/datasets/annotations/cub/testclasses.txt rename to src/custom_datasets/annotations/cub/testclasses.txt diff --git a/src/datasets/annotations/cub/trainclasses1.txt b/src/custom_datasets/annotations/cub/trainclasses1.txt similarity index 100% rename from src/datasets/annotations/cub/trainclasses1.txt rename to src/custom_datasets/annotations/cub/trainclasses1.txt diff --git a/src/datasets/annotations/cub/trainclasses2.txt b/src/custom_datasets/annotations/cub/trainclasses2.txt similarity index 100% rename from src/datasets/annotations/cub/trainclasses2.txt rename to src/custom_datasets/annotations/cub/trainclasses2.txt diff --git a/src/datasets/annotations/cub/trainclasses3.txt b/src/custom_datasets/annotations/cub/trainclasses3.txt similarity index 100% rename from src/datasets/annotations/cub/trainclasses3.txt rename to src/custom_datasets/annotations/cub/trainclasses3.txt diff --git a/src/datasets/annotations/cub/trainvalclasses.txt b/src/custom_datasets/annotations/cub/trainvalclasses.txt similarity index 100% rename from src/datasets/annotations/cub/trainvalclasses.txt rename to src/custom_datasets/annotations/cub/trainvalclasses.txt diff --git a/src/datasets/annotations/cub/valclasses1.txt b/src/custom_datasets/annotations/cub/valclasses1.txt similarity index 100% rename from src/datasets/annotations/cub/valclasses1.txt rename to src/custom_datasets/annotations/cub/valclasses1.txt diff --git a/src/datasets/annotations/cub/valclasses2.txt b/src/custom_datasets/annotations/cub/valclasses2.txt similarity index 100% rename from src/datasets/annotations/cub/valclasses2.txt rename to src/custom_datasets/annotations/cub/valclasses2.txt diff --git a/src/datasets/annotations/cub/valclasses3.txt b/src/custom_datasets/annotations/cub/valclasses3.txt similarity index 100% rename from src/datasets/annotations/cub/valclasses3.txt rename to src/custom_datasets/annotations/cub/valclasses3.txt diff --git a/src/datasets/cifar.py b/src/custom_datasets/cifar.py similarity index 100% rename from src/datasets/cifar.py rename to src/custom_datasets/cifar.py diff --git a/src/datasets/coco.py b/src/custom_datasets/coco.py similarity index 97% rename from src/datasets/coco.py rename to src/custom_datasets/coco.py index 33dde5b..292c527 100644 --- a/src/datasets/coco.py +++ b/src/custom_datasets/coco.py @@ -10,8 +10,8 @@ import torch from torchvision import datasets -from src.datasets.dataset_L import caption_transform -from src.datasets.vocab import Vocabulary +from src.custom_datasets.dataset_L import caption_transform +from src.custom_datasets.vocab import Vocabulary try: import ujson as json @@ -210,7 +210,7 @@ def __init__(self, root=os.environ['HOME'] + "/data/mmdata/MSCOCO/2014/train2014 self.cat2cat[cat] = len(self.cat2cat) if not transform: - vocab_path = './src/datasets/vocabs/coco_vocab.pkl' + vocab_path = './src/custom_datasets/vocabs/coco_vocab.pkl' if isinstance(vocab_path, str): vocab = Vocabulary() vocab.load_from_pickle(vocab_path) @@ -244,5 +244,5 @@ def __getitem__(self, index): # image_id = annotation['image_id'] def __len__(self): - return len(self.ids) - # return 100 + return 1024 # test with less data + return len(self.ids) \ No newline at end of file diff --git a/src/custom_datasets/coco/datasets/__init__.py b/src/custom_datasets/coco/datasets/__init__.py new file mode 100644 index 0000000..682ae7c --- /dev/null +++ b/src/custom_datasets/coco/datasets/__init__.py @@ -0,0 +1,9 @@ +from custom_datasets._dataloader import prepare_coco_dataloaders, prepare_cub_dataloaders, prepare_f30k_dataloaders +from custom_datasets.vocab import Vocabulary + +__all__ = [ + 'Vocabulary', + 'prepare_coco_dataloaders', + 'prepare_cub_dataloaders', + 'prepare_f30k_dataloaders' +] diff --git a/src/datasets/coco/datasets/_dataloader.py b/src/custom_datasets/coco/datasets/_dataloader.py similarity index 89% rename from src/datasets/coco/datasets/_dataloader.py rename to src/custom_datasets/coco/datasets/_dataloader.py index a6a4ee8..e8cdb47 100644 --- a/src/datasets/coco/datasets/_dataloader.py +++ b/src/custom_datasets/coco/datasets/_dataloader.py @@ -12,11 +12,11 @@ from torch.utils.data import DataLoader # try: -from datasets.coco import CocoCaptionsCap -from datasets.f30k import F30kCaptionsCap -from datasets.cub import CUBCaption, CUBSampler -from datasets.vocab import Vocabulary -from datasets._transforms import imagenet_transform, caption_transform +from custom_datasets.coco import CocoCaptionsCap +from custom_datasets.f30k import F30kCaptionsCap +from custom_datasets.cub import CUBCaption, CUBSampler +from custom_datasets.vocab import Vocabulary +from custom_datasets._transforms import imagenet_transform, caption_transform # except: # from coco import CocoCaptionsCap # from f30k import F30kCaptionsCap @@ -82,21 +82,21 @@ def _get_cub_file_paths(dataset_name, dataset_root, caption_root): This split conntains 150 train classes / 50 unseen test classes (not in trainval) """ if dataset_name == 'cub_trainval1': - train_classes = './datasets/annotations/cub/trainclasses1.txt' - val_classes = './datasets/annotations/cub/valclasses1.txt' - omit_ids = './datasets/annotations/cub/seen_test_images.txt' + train_classes = './custom_datasets/annotations/cub/trainclasses1.txt' + val_classes = './custom_datasets/annotations/cub/valclasses1.txt' + omit_ids = './custom_datasets/annotations/cub/seen_test_images.txt' elif dataset_name == 'cub_trainval2': - train_classes = './datasets/annotations/cub/trainclasses2.txt' - val_classes = './datasets/annotations/cub/valclasses2.txt' - omit_ids = './datasets/annotations/cub/seen_test_images.txt' + train_classes = './custom_datasets/annotations/cub/trainclasses2.txt' + val_classes = './custom_datasets/annotations/cub/valclasses2.txt' + omit_ids = './custom_datasets/annotations/cub/seen_test_images.txt' elif dataset_name == 'cub_trainval3': - train_classes = './datasets/annotations/cub/trainclasses3.txt' - val_classes = './datasets/annotations/cub/valclasses3.txt' - omit_ids = './datasets/annotations/cub/seen_test_images.txt' + train_classes = './custom_datasets/annotations/cub/trainclasses3.txt' + val_classes = './custom_datasets/annotations/cub/valclasses3.txt' + omit_ids = './custom_datasets/annotations/cub/seen_test_images.txt' elif dataset_name == 'cub': - train_classes = './datasets/annotations/cub/trainvalclasses.txt' - val_classes = './datasets/annotations/cub/testclasses.txt' - omit_ids = './datasets/annotations/cub/seen_test_images.txt' + train_classes = './custom_datasets/annotations/cub/trainvalclasses.txt' + val_classes = './custom_datasets/annotations/cub/testclasses.txt' + omit_ids = './custom_datasets/annotations/cub/seen_test_images.txt' else: raise ValueError(f'Invalide dataset_name: {dataset_name}') @@ -237,10 +237,10 @@ def _get_coco_loader(image_root, def _get_coco_file_paths(dataset_root): """Select proper train / val classes and omit id files. """ - train_ids = np.load('./datasets/annotations/coco_train_ids.npy') - train_extra_ids = np.load('./datasets/annotations/coco_restval_ids.npy') - val_ids = np.load('./datasets/annotations/coco_dev_ids.npy')[:5000] - te_ids = np.load('./datasets/annotations/coco_test_ids.npy') + train_ids = np.load('./custom_datasets/annotations/coco_train_ids.npy') + train_extra_ids = np.load('./custom_datasets/annotations/coco_restval_ids.npy') + val_ids = np.load('./custom_datasets/annotations/coco_dev_ids.npy')[:5000] + te_ids = np.load('./custom_datasets/annotations/coco_test_ids.npy') image_root = os.path.join(dataset_root, 'allimages') train_ann = os.path.join(dataset_root, 'annotations/captions_train2014.json') @@ -330,6 +330,7 @@ def see_coco_len(dataset_root='/data/mmdata/MSCOCO/2014'): def _get_F30k_loader(vocab, num_workers, + max_size, batch_size=64, train=False, split='train', @@ -344,7 +345,8 @@ def _get_F30k_loader(vocab, coco_dataset = F30kCaptionsCap(split=split, transform=_image_transform, - target_transform=_caption_transform) + target_transform=_caption_transform, + max_size=max_size) dataloader = DataLoader(coco_dataset, batch_size=batch_size, @@ -358,6 +360,7 @@ def _get_F30k_loader(vocab, def prepare_f30k_dataloaders(dataloader_config, dataset_root, + max_size, vocab_path='./vocabs/coco_vocab.pkl', num_workers=6): """Prepare MS-COCO Caption train / val / test dataloaders @@ -380,7 +383,8 @@ def prepare_f30k_dataloaders(dataloader_config, dataloaders = {} dataloaders['train'] = _get_F30k_loader( vocab, - num_workers=num_workers, + num_workers, + max_size, batch_size=batch_size, train=True, split='train', @@ -390,7 +394,8 @@ def prepare_f30k_dataloaders(dataloader_config, dataloaders['val'] = _get_F30k_loader( vocab, - num_workers=num_workers, + num_workers, + max_size, batch_size=eval_batch_size, train=False, split='val', @@ -399,7 +404,8 @@ def prepare_f30k_dataloaders(dataloader_config, dataloaders['te'] = _get_F30k_loader( vocab, - num_workers=num_workers, + num_workers, + max_size, batch_size=eval_batch_size, train=False, split='test', diff --git a/src/datasets/coco/datasets/coco.py b/src/custom_datasets/coco/datasets/coco.py similarity index 100% rename from src/datasets/coco/datasets/coco.py rename to src/custom_datasets/coco/datasets/coco.py diff --git a/src/datasets/coco/datasets/cub.py b/src/custom_datasets/coco/datasets/cub.py similarity index 100% rename from src/datasets/coco/datasets/cub.py rename to src/custom_datasets/coco/datasets/cub.py diff --git a/src/datasets/coco/datasets/vocab.py b/src/custom_datasets/coco/datasets/vocab.py similarity index 100% rename from src/datasets/coco/datasets/vocab.py rename to src/custom_datasets/coco/datasets/vocab.py diff --git a/src/datasets/coco_transforms.py b/src/custom_datasets/coco_transforms.py similarity index 100% rename from src/datasets/coco_transforms.py rename to src/custom_datasets/coco_transforms.py diff --git a/src/datasets/cub.py b/src/custom_datasets/cub.py similarity index 100% rename from src/datasets/cub.py rename to src/custom_datasets/cub.py diff --git a/src/datasets/dataset_L.py b/src/custom_datasets/dataset_L.py similarity index 98% rename from src/datasets/dataset_L.py rename to src/custom_datasets/dataset_L.py index 6d3ad2e..b2feb6f 100644 --- a/src/datasets/dataset_L.py +++ b/src/custom_datasets/dataset_L.py @@ -17,7 +17,7 @@ from torchvision import transforms from tqdm import tqdm -from src.datasets.vocab import Vocabulary +from src.custom_datasets.vocab import Vocabulary def tokenize(sentence, vocab, caption_drop_prob): @@ -185,7 +185,7 @@ def __init__(self, name='AG_NEWS', train=True, transform=None, is_iid=False, self.targets = np.array(self.targets) if not transform: - vocab_path = './src/datasets/vocabs/coco_vocab.pkl' + vocab_path = './src/custom_datasets/vocabs/coco_vocab.pkl' if isinstance(vocab_path, str): vocab = Vocabulary() vocab.load_from_pickle(vocab_path) diff --git a/src/datasets/f30k.py b/src/custom_datasets/f30k.py similarity index 94% rename from src/datasets/f30k.py rename to src/custom_datasets/f30k.py index 4803206..378722c 100644 --- a/src/datasets/f30k.py +++ b/src/custom_datasets/f30k.py @@ -56,7 +56,7 @@ class F30kCaptionsCap(Dataset): """ def __init__(self, annFile=os.environ['HOME'] + '/data/mmdata/Flick30k/dataset_k_split.pkl', split='train', - transform=None, target_transform=None + transform=None, target_transform=None, max_size=0 ): self.transform = transform self.target_transform = target_transform @@ -64,6 +64,7 @@ def __init__(self, annFile=os.environ['HOME'] + '/data/mmdata/Flick30k/dataset_k if split not in self.data.keys(): assert False, f'split wrong {split}' self.data = self.data[split] + self.max_size = max_size def __getitem__(self, index): """ @@ -87,6 +88,8 @@ def __getitem__(self, index): return img, target, index, index, index def __len__(self): + if self.max_size != 0: + return min(self.max_size, len(self.data)) return len(self.data) diff --git a/src/datasets/flickr30k.py b/src/custom_datasets/flickr30k.py similarity index 82% rename from src/datasets/flickr30k.py rename to src/custom_datasets/flickr30k.py index 7bfe67a..caa9008 100644 --- a/src/datasets/flickr30k.py +++ b/src/custom_datasets/flickr30k.py @@ -1,7 +1,7 @@ import os -from src.datasets.coco_transforms import caption_transform -from src.datasets.vocab import Vocabulary +from src.custom_datasets.coco_transforms import caption_transform +from src.custom_datasets.vocab import Vocabulary try: import ujson as json @@ -23,17 +23,19 @@ class F30kCaptionsCap(Dataset): def __init__(self, annFile='./dataset_k_split.pkl', train=True, - transform=None, target_transform=None, is_iid=False, client=-1): + transform=None, target_transform=None, is_iid=False, client=-1, num_users=-1, max_size=0): + assert client == -1 or client < num_users, f'num_users ({num_users}) must be set when client ({client}) is set' split = 'train' if train else 'test' self.transform = transform self.data = pickle.load(open(annFile, 'rb')) if split not in self.data.keys(): assert False, f'split wrong {split}' self.data = self.data[split] # 145,000 img-txt pairs, in list + self.max_size = max_size if client > -1 and train: # print(self.data) - indices = self.iid()[client] if is_iid else self.non_iid()[client] + indices = self.iid(num_users)[client] if is_iid else self.non_iid(num_users)[client] indices = np.array(list(indices)).astype(int) self.data = [self.data[i] for i in indices] @@ -46,7 +48,7 @@ def __init__(self, annFile='./dataset_k_split.pkl', train=True, self.iid_to_cls = {} if not target_transform: - vocab_path = './src/datasets/vocabs/coco_vocab.pkl' + vocab_path = './src/custom_datasets/vocabs/coco_vocab.pkl' if isinstance(vocab_path, str): vocab = Vocabulary() vocab.load_from_pickle(vocab_path) @@ -56,14 +58,14 @@ def __init__(self, annFile='./dataset_k_split.pkl', train=True, else: self.target_transform = target_transform - def iid(self, root=os.environ['HOME']+'/data/mmdata/Flick30k/', num_users=20): + def iid(self, num_users, root=os.environ['HOME']+'/data/mmdata/Flick30k/'): """ Sample I.I.D. client data from MNIST dataset :param dataset: :param num_users: :return: dict of image index """ - pkl_path = root + 'client_iid.pkl' + pkl_path = root + f'client_iid_{num_users}.pkl' if os.path.exists(pkl_path): dict_users = pickle.load(open(pkl_path, 'rb')) else: @@ -76,8 +78,8 @@ def iid(self, root=os.environ['HOME']+'/data/mmdata/Flick30k/', num_users=20): pickle.dump(dict_users, open(pkl_path, 'wb')) return dict_users - def non_iid(self, root='./data_partition/', num_users=15): - pkl_path = root + 'client_noniid_flicker30k.pkl' + def non_iid(self, num_users, root='./data_partition/'): + pkl_path = root + f'client_noniid_{num_users}_flicker30k.pkl' if os.path.exists(pkl_path): dict_users = pickle.load(open(pkl_path, 'rb')) else: @@ -125,10 +127,12 @@ def __getitem__(self, index): return img, target, caption, index, int(index / 5), index def __len__(self): + if self.max_size != 0: + return min(self.max_size, len(self.data)) return len(self.data) if __name__ == '__main__': train = F30kCaptionsCap(train=False) - print(len(train)) - print(train.n_images) + print('flickr30k:len(train):', len(train)) + print('flickr30k:train.n_images', train.n_images) diff --git a/src/datasets/load_FL_datasets.py b/src/custom_datasets/load_FL_datasets.py similarity index 72% rename from src/datasets/load_FL_datasets.py rename to src/custom_datasets/load_FL_datasets.py index 9ca3962..5886077 100644 --- a/src/datasets/load_FL_datasets.py +++ b/src/custom_datasets/load_FL_datasets.py @@ -7,11 +7,17 @@ import torch import pickle -from src.datasets.dataset_L import Language, caption_collate_fn +from src.custom_datasets.dataset_L import Language, caption_collate_fn from src.utils.color_lib import RGBmean, RGBstdv - -def get_FL_trainloader(dataset, data_root, num_clients, partition, alpha, batch_size): +# dataset: str, name of the dataset +# data_root: str, path to the dataset +# num_clients: int, number of clients to partition the dataset +# partition: str, the partition method (homo:homogeneous/iid or hetero=heterogeneous/noniid) +# alpha: float, dirichlet parameter, if <1, results in a more skewed distribution. +# batch_size: int, DataLoader batch_size, for test set, it is doubled. +# max_size: int, maximum number of data samples to use per client, 0 means use all data. +def get_FL_trainloader(dataset, data_root, num_clients, partition, alpha, batch_size, max_size): if dataset == 'cifar100': data_transforms = transforms.Compose([transforms.Resize(int(256 * 1.1)), transforms.RandomRotation(10), @@ -55,8 +61,9 @@ def get_FL_trainloader(dataset, data_root, num_clients, partition, alpha, batch_ num_samples = train_set.data.shape[0] net_dataidx_map = data_partitioner(dataset, num_samples, num_clients, partition=partition, check_dir="./data_partition/", alpha=alpha, - y_train=np.array(train_set.targets)) - print(f"Samples Num: {[len(i) for i in net_dataidx_map.values()]}") + y_train=np.array(train_set.targets), max_size=max_size) + print(f"get_FL_trainloader Samples Num: {[len(i) for i in net_dataidx_map.values()]}") + print(f"get_FL_trainloader Samples Keys: {net_dataidx_map.keys()}") net_dataset_map = {i: torch.utils.data.Subset(train_set, net_dataidx_map[i]) for i in net_dataidx_map.keys()} if dataset == "cifar100" or dataset == "cifar10": loader_map = { @@ -76,50 +83,56 @@ def get_FL_trainloader(dataset, data_root, num_clients, partition, alpha, batch_ return loader_map, test_loader -def data_partitioner(dataset, num_samples, num_nets, partition='homo', check_dir=None, alpha=0.5, y_train=None): - check_dir = check_dir + f'client_{dataset}' +def data_partitioner_from_file(file_path, max_size): + net_dataidx_map = pickle.load(open(file_path, 'rb')) + return data_partitioner_apply_max_size(net_dataidx_map, max_size) - if partition == "homo": - check_dir = check_dir + "_iid.pkl" - if os.path.isfile(check_dir): - net_dataidx_map = pickle.load(open(check_dir, 'rb')) - else: - idxs = np.random.permutation(num_samples) - batch_idxs = np.array_split(idxs, num_nets) - net_dataidx_map = {i: batch_idxs[i] for i in range(num_nets)} - pickle.dump(net_dataidx_map, open(check_dir, 'wb')) +def data_partitioner_apply_max_size(net_dataidx_map, max_size): + if max_size > 0: + for i in net_dataidx_map.keys(): + if len(net_dataidx_map[i]) > max_size: + net_dataidx_map[i] = net_dataidx_map[i][:max_size] + return net_dataidx_map - elif partition == "hetero": - check_dir = check_dir + "_noniid.pkl" - if os.path.isfile(check_dir): - net_dataidx_map = pickle.load(open(check_dir, 'rb')) - else: - min_size = 0 - K = max(y_train) + 1 # todo - net_dataidx_map = {} - print('Hetero partition') - while min_size < (10 if dataset == "cifar100" else (3000 if dataset == "AG_NEWS" else 500)): - idx_batch = [[] for _ in range(num_nets)] - # for each class in the dataset - for k in range(K): - idx_k = np.where(y_train == k)[0] - np.random.shuffle(idx_k) - proportions = np.random.dirichlet(np.repeat(alpha, num_nets)) - ## Balance - proportions = np.array( - [p * (len(idx_j) < num_samples / num_nets) for p, idx_j in zip(proportions, idx_batch)]) - proportions = proportions / proportions.sum() - proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] - idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))] - min_size = min([len(idx_j) for idx_j in idx_batch]) - - for j in range(num_nets): - np.random.shuffle(idx_batch[j]) - net_dataidx_map[j] = idx_batch[j] - - pickle.dump(net_dataidx_map, open(check_dir, 'wb')) +def data_partitioner(dataset, num_samples, num_nets, partition='homo', check_dir=None, alpha=0.5, y_train=None, max_size=0): + check_dir = check_dir + f'client_{dataset}_{num_nets}_nets_{num_samples}_samples_{partition}_{alpha}.pkl' - return net_dataidx_map + if os.path.isfile(check_dir): + net_dataidx_map = pickle.load(open(check_dir, 'rb')) + + elif partition == "homo": + idxs = np.random.permutation(num_samples) + batch_idxs = np.array_split(idxs, num_nets) + net_dataidx_map = {i: batch_idxs[i] for i in range(num_nets)} + pickle.dump(net_dataidx_map, open(check_dir, 'wb')) + + elif partition == "hetero": + min_size = 0 + K = max(y_train) + 1 # Calculate the number of classes in the dataset + net_dataidx_map = {} + print('Hetero partition') + while min_size < (10 if dataset == "cifar100" else (3000 if dataset == "AG_NEWS" else 500)): + idx_batch = [[] for _ in range(num_nets)] + # for each class in the dataset + for k in range(K): + idx_k = np.where(y_train == k)[0] + np.random.shuffle(idx_k) + proportions = np.random.dirichlet(np.repeat(alpha, num_nets)) + ## Balance + proportions = np.array( + [p * (len(idx_j) < num_samples / num_nets) for p, idx_j in zip(proportions, idx_batch)]) + proportions = proportions / proportions.sum() + proportions = (np.cumsum(proportions) * len(idx_k)).astype(int)[:-1] + idx_batch = [idx_j + idx.tolist() for idx_j, idx in zip(idx_batch, np.split(idx_k, proportions))] + min_size = min([len(idx_j) for idx_j in idx_batch]) + + for j in range(num_nets): + np.random.shuffle(idx_batch[j]) + net_dataidx_map[j] = idx_batch[j] + + pickle.dump(net_dataidx_map, open(check_dir, 'wb')) + + return data_partitioner_apply_max_size(net_dataidx_map, max_size) def get_dataloader(args): diff --git a/src/datasets/utils.py b/src/custom_datasets/utils.py similarity index 92% rename from src/datasets/utils.py rename to src/custom_datasets/utils.py index d5fb007..8ee8d9e 100644 --- a/src/datasets/utils.py +++ b/src/custom_datasets/utils.py @@ -5,7 +5,7 @@ import torch from torch.utils.data import DataLoader -from src.datasets.vocab import Vocabulary +from src.custom_datasets.vocab import Vocabulary # from datasets.coco_transforms import imagenet_transform, caption_transform def image_to_caption_collate_fn(data): @@ -40,7 +40,7 @@ def image_to_caption_collate_fn(data): return images, targets, cap_lengths, ann_ids, image_ids, index -def load_vocab(vocab_path='./src/datasets/vocabs/coco_vocab.pkl'): +def load_vocab(vocab_path='./src/custom_datasets/vocabs/coco_vocab.pkl'): if isinstance(vocab_path, str): vocab = Vocabulary() vocab.load_from_pickle(vocab_path) diff --git a/src/datasets/vocab.py b/src/custom_datasets/vocab.py similarity index 100% rename from src/datasets/vocab.py rename to src/custom_datasets/vocab.py diff --git a/src/datasets/vocabs/coco_vocab.pkl b/src/custom_datasets/vocabs/coco_vocab.pkl similarity index 100% rename from src/datasets/vocabs/coco_vocab.pkl rename to src/custom_datasets/vocabs/coco_vocab.pkl diff --git a/src/datasets/vocabs/cub_vocab.pkl b/src/custom_datasets/vocabs/cub_vocab.pkl similarity index 100% rename from src/datasets/vocabs/cub_vocab.pkl rename to src/custom_datasets/vocabs/cub_vocab.pkl diff --git a/src/datasets/vocabs/make_vocab.py b/src/custom_datasets/vocabs/make_vocab.py similarity index 100% rename from src/datasets/vocabs/make_vocab.py rename to src/custom_datasets/vocabs/make_vocab.py diff --git a/src/datasets/__init__.py b/src/datasets/__init__.py deleted file mode 100644 index 9cb5f5b..0000000 --- a/src/datasets/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from datasets._dataloader import prepare_coco_dataloaders, prepare_cub_dataloaders -from datasets.vocab import Vocabulary - - -__all__ = [ - 'Vocabulary', - 'prepare_coco_dataloaders', - 'prepare_cub_dataloaders', -] diff --git a/src/datasets/__pycache__/__init__.cpython-38.pyc b/src/datasets/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index b893669..0000000 Binary files a/src/datasets/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/src/datasets/__pycache__/_dataloader.cpython-38.pyc b/src/datasets/__pycache__/_dataloader.cpython-38.pyc deleted file mode 100644 index ec53ae0..0000000 Binary files a/src/datasets/__pycache__/_dataloader.cpython-38.pyc and /dev/null differ diff --git a/src/datasets/__pycache__/_transforms.cpython-38.pyc b/src/datasets/__pycache__/_transforms.cpython-38.pyc deleted file mode 100644 index 17f858c..0000000 Binary files a/src/datasets/__pycache__/_transforms.cpython-38.pyc and /dev/null differ diff --git a/src/datasets/__pycache__/cifar.cpython-38.pyc b/src/datasets/__pycache__/cifar.cpython-38.pyc deleted file mode 100644 index 086348b..0000000 Binary files a/src/datasets/__pycache__/cifar.cpython-38.pyc and /dev/null differ diff --git a/src/datasets/__pycache__/coco.cpython-38.pyc b/src/datasets/__pycache__/coco.cpython-38.pyc deleted file mode 100644 index 8ad5c10..0000000 Binary files a/src/datasets/__pycache__/coco.cpython-38.pyc and /dev/null differ diff --git a/src/datasets/__pycache__/coco_transforms.cpython-38.pyc b/src/datasets/__pycache__/coco_transforms.cpython-38.pyc deleted file mode 100644 index f86e713..0000000 Binary files a/src/datasets/__pycache__/coco_transforms.cpython-38.pyc and /dev/null differ diff --git a/src/datasets/__pycache__/cub.cpython-38.pyc b/src/datasets/__pycache__/cub.cpython-38.pyc deleted file mode 100644 index d2eef9f..0000000 Binary files a/src/datasets/__pycache__/cub.cpython-38.pyc and /dev/null differ diff --git a/src/datasets/__pycache__/dataset_L.cpython-38.pyc b/src/datasets/__pycache__/dataset_L.cpython-38.pyc deleted file mode 100644 index c202e6a..0000000 Binary files a/src/datasets/__pycache__/dataset_L.cpython-38.pyc and /dev/null differ diff --git a/src/datasets/__pycache__/flickr30k.cpython-38.pyc b/src/datasets/__pycache__/flickr30k.cpython-38.pyc deleted file mode 100644 index c3b6819..0000000 Binary files a/src/datasets/__pycache__/flickr30k.cpython-38.pyc and /dev/null differ diff --git a/src/datasets/__pycache__/load_FL_datasets.cpython-38.pyc b/src/datasets/__pycache__/load_FL_datasets.cpython-38.pyc deleted file mode 100644 index a5a1486..0000000 Binary files a/src/datasets/__pycache__/load_FL_datasets.cpython-38.pyc and /dev/null differ diff --git a/src/datasets/__pycache__/vocab.cpython-38.pyc b/src/datasets/__pycache__/vocab.cpython-38.pyc deleted file mode 100644 index ebc893d..0000000 Binary files a/src/datasets/__pycache__/vocab.cpython-38.pyc and /dev/null differ diff --git a/src/datasets/coco/datasets/__init__.py b/src/datasets/coco/datasets/__init__.py deleted file mode 100644 index c84198b..0000000 --- a/src/datasets/coco/datasets/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from datasets._dataloader import prepare_coco_dataloaders, prepare_cub_dataloaders, prepare_f30k_dataloaders -from datasets.vocab import Vocabulary - -__all__ = [ - 'Vocabulary', - 'prepare_coco_dataloaders', - 'prepare_cub_dataloaders', - 'prepare_f30k_dataloaders' -] diff --git a/src/f30k.yaml b/src/f30k.yaml index c7e88eb..6527aa3 100644 --- a/src/f30k.yaml +++ b/src/f30k.yaml @@ -49,11 +49,11 @@ train: model_save_path: model_ best_model_save_path: model_best.pth pretrain_epochs: 0 - finetune_epochs: 30 + finetune_epochs: 30 # not used finetune_lr_decay: 0.1 log_step: 1000 grad_clip: 2 - val_epochs: 10 - pretrain_val_epochs: 10 + val_epochs: 10 # not used + pretrain_val_epochs: 10 # not used use_fp16: True output_file: model.log diff --git a/src/federation/api.py b/src/federation/api.py new file mode 100644 index 0000000..ebb96b2 --- /dev/null +++ b/src/federation/api.py @@ -0,0 +1,273 @@ +from typing import Dict, Optional +from enum import Enum +from datetime import datetime +import os +import time +import copy +import pickle +import torch +import hashlib +import requests + +import sys +sys.path.append("./") +sys.path.append("../") +sys.path.append("../../") +sys.path.append("../../../") +from src.utils.load_datasets import prepare_coco_dataloaders + +from context import Context + +url_prefix = '/cream_api' + +class RoundState(Enum): + BUSY = 1 # The server is busy processing the current round and calculating the next global model. + COLLECT = 2 # The server is collecting the client updates + +class ClientState: + def __init__(self, + name: str, + img_model: bool = False, + txt_model: bool = False, + local_rounds: int = 0, + img_model_hash: str = "", + txt_model_hash: str = "", + ): + self.name = name + self.img_model = img_model + self.txt_model = txt_model + self.local_rounds = local_rounds + self.img_model_hash = img_model_hash + self.txt_model_hash = txt_model_hash + + def to_dict(self): + data = { + 'name': self.name, + 'local_rounds': self.local_rounds, + } + if self.img_model: + data['img_model'] = self.img_model + data['img_model_hash']= self.img_model_hash + if self.txt_model: + data['txt_model'] = self.txt_model + data['txt_model_hash']= self.txt_model_hash + return data + + @classmethod + def from_dict(cls, data): + required_keys = ['name', 'local_rounds'] + for key in required_keys: + if key not in data: + raise ValueError(f"Key '{key}' is missing from ClientState") + return cls( + name=data['name'], + img_model=data.get('img_model', False), + txt_model=data.get('txt_model', False), + local_rounds=data['local_rounds'], + img_model_hash=data.get('img_model_hash', ''), + txt_model_hash=data.get('txt_model_hash', ''), + ) + +class ServerState: + def __init__(self, + round_number: int = 0, # defaults are for the starting state of a new server + round_state: RoundState = RoundState.BUSY, + round_started_at: datetime = datetime.now(), # the timestamp of the start of the round + clients_reported: Optional[Dict[str, ClientState]] = None, # map of names to ClientState that have reported their updates + feature_hash: str = ""): + self.round_number = round_number + self.round_state = round_state + self.round_started_at = round_started_at + self.clients_reported = clients_reported + if self.clients_reported is None: + self.clients_reported = {} + self.feature_hash = feature_hash # the hash of the global model's features, the actual data could be distributed by an external service. + + def add_client(self, client): + """ + Used by server to add a client to the list of clients that have reported their updates. + """ + self.clients_reported[client.name] = client + + def advance_round(self): + """ + Used by the server to advance the round to the next round. When the condition for finishing the round is met. + """ + self.round_number += 1 + self.round_state = RoundState.BUSY + self.round_started_at = datetime.now() + + def update_feature_hash(self, feature_hash): + """ + Used by the server to update the feature hash of the global model and set the round state to COLLECT. + """ + self.round_state = RoundState.COLLECT + self.clients_reported = {} + self.feature_hash = feature_hash + + def to_dict(self): + return { + 'round_number': self.round_number, + 'round_state': self.round_state.value, + 'round_started_at': self.round_started_at.isoformat(), + 'clients_reported': {k: v.to_dict() for k, v in self.clients_reported.items()}, + 'feature_hash': self.feature_hash, + } + + @classmethod + def from_dict(cls, data): + return cls( + round_number=data['round_number'], + round_state=RoundState(data['round_state']), + round_started_at= datetime.fromisoformat(data['round_started_at']), + clients_reported={k: ClientState.from_dict(v) for k, v in data['clients_reported'].items()}, + feature_hash=data['feature_hash'], + ) + +def feature_hash(data): + return hashlib.sha3_256(data).hexdigest() + +def save(obj, path): + os.makedirs(path, exist_ok=True) + data = pickle.dumps(obj) + hash = feature_hash(data) + fn = f"{path}/{hash}.pkl" + with open(fn, 'wb') as f: + f.write(data) + return fn, hash + +def load(cls, path, hash): + fn = f"{path}/{hash}.pkl" + with open(fn, 'rb') as f: + obj = pickle.load(f) + if not isinstance(obj, cls): + raise ValueError(f"Object loaded from {fn} is not an instance of {cls}") + return obj + + +class GlobalFeature: + def __init__(self, img: torch.Tensor, txt: torch.Tensor, distill_index: list): + self.img: torch.Tensor = img + self.txt: torch.Tensor = txt + self.distill_index = distill_index + self.hash = None # calculated when saved to or loaded from disk + + def save(self, path): + fn, hash = save(self, path) + self.hash = hash + return fn, hash + + @classmethod + def load(cls, path, hash): + obj = load(cls, path, hash) + obj.hash = hash + return obj + +def get_api_url(context:Context): + return context.fed_config.server["api_url"] + +def get_global_dataloader(context:Context): + dataset_root = os.environ['HOME'] + '/data/mmdata/MSCOCO/2014' + vocab_path = './src/custom_datasets/vocabs/coco_vocab.pkl' + return prepare_coco_dataloaders(context.config.dataloader, dataset_root, context.args.pub_data_num, context.args.max_size, vocab_path) + +def status_sleep(context, msg): + context.logger.log(f"{msg}, sleeping for 10 seconds.") + time.sleep(10) + +def error_sleep(context, error): + context.logger.log(f"Error: {error}, sleeping for 60 seconds.") + time.sleep(60) + +# start shared api section +def get_server_state(context:Context, expected_state: Optional[RoundState] = None): + """ + Get the current state of the server. + If expected_state is not None, this function will wait until the server is in the expected state. + """ + while True: + try: + url = get_api_url(context) + context.logger.log(f"Getting server state from {url}") + resp = requests.get(url) + state = ServerState.from_dict(resp.json()['current_state']) + if expected_state is not None and state.round_state != expected_state: + status_sleep(context, f"Server state is not {expected_state}") + time.sleep(10) + continue + return state + except Exception as e: + error_sleep(context, e) + +# start client api section + + +def get_global_feature(context:Context, state:ServerState): + """ + Get the global feature from a distributed storage service. + + Only mounted files are supported currently. + """ + return GlobalFeature.load(context.fed_config.feature_store, state.feature_hash) + +def add_local_repr(context:Context, expected_server_state:ServerState, img, txt, local_rounds:int): + """ + Submit the local representations to the server. + + local_rounds is the number of local rounds that the client has trained the model for, this is only used for reporting. + + Returns True if the submission was successful and the client should restart a new round. + Returns False if the submission was not successful but the client should restart a new round. + """ + context.logger.log(f"Saving local representations.") + client_state = ClientState(name=context.args.client_name, img_model=context.has_img_model, txt_model=context.has_txt_model, local_rounds=local_rounds) + if context.has_img_model: + _, client_state.img_model_hash = save(img, context.fed_config.feature_store) + if context.has_txt_model: + _, client_state.txt_model_hash = save(txt, context.fed_config.feature_store) + + url = get_api_url(context)+ f'/add_client?round_number={expected_server_state.round_number}&feature_hash={expected_server_state.feature_hash}' + context.logger.log(f"Submitting local representations to server.") + while True: + try: + resp = requests.put(url, json=client_state.to_dict()) + if resp.status_code == 200: + context.logger.log(f"Local representations submitted to server.") + return True + context.logger.log(f"Can not add local representations to server. Status code: {resp.status_code} body: {resp.text}") + return False + except Exception as e: + error_sleep(context, e) +# end client api section + +# start server api section +def submit_global_feature(context:Context, state:ServerState, global_feature:GlobalFeature): + context.logger.log(f"Saving global feature to file.") + _, hash = global_feature.save(context.fed_config.feature_store) + url = get_api_url(context) + f'/set_global_feature?round_number={state.round_number}&old_feature_hash={state.feature_hash}&new_feature_hash={hash}' + context.logger.log(f"Submitting global feature to server.") + while True: + try: + resp = requests.put(url) + if resp.status_code == 200: + context.logger.log(f"Global feature submitted to server.") + return True + context.logger.log(f"Can not submit global feature to server. Status code: {resp.status_code} body: {resp.text}") + error_sleep(context, resp.status_code) + return False + except Exception as e: + error_sleep(context, e) + +def get_clients_repr(context:Context, clients_reported: Dict[str, ClientState]): + img_vec = [] + txt_vec = [] + for name, client in clients_reported.items(): + if client.img_model: + img = load(torch.Tensor,context.fed_config.feature_store, client.img_model_hash) + img_vec.append(img) + if client.txt_model: + txt = load(torch.Tensor,context.fed_config.feature_store, client.txt_model_hash) + txt_vec.append(txt) + return img_vec, txt_vec + +# end server api section \ No newline at end of file diff --git a/src/federation/client.py b/src/federation/client.py new file mode 100644 index 0000000..501aa91 --- /dev/null +++ b/src/federation/client.py @@ -0,0 +1,161 @@ +import os +from copy import deepcopy +import sys + +import munch + +sys.path.append("./") +sys.path.append("../") +sys.path.append("../../") +sys.path.append("../../../") + +from src.custom_datasets.load_FL_datasets import get_FL_trainloader +from src.algorithms.ClientTrainer import ClientTrainer +from src.algorithms.MMClientTrainer import MMClientTrainer +from src.utils.color_lib import RGBmean, RGBstdv + +from src.utils.config import parse_config + +import api + +try: + from apex import amp +except ImportError: + print('failed to import apex') + +class Client: + def __init__(self, context): + # all configs + self.context = context + self.logger = context.logger + self.args = context.args + self.wandb = context.wandb + + # validation for clients + if context.args.client_name is None: + raise ValueError("The client_name argument is required for clients") + + my_client = None + for client in context.fed_config.clients: + if client["name"] == context.args.client_name: + my_client = munch.Munch(client) + if my_client is None: + raise ValueError(f"Client name {context.args.client_name} not found in configuration file {context.args.fed_config}") + + # setup client + self.client_config = my_client + self.name = context.args.client_name + self.device = context.device + + if self.client_config.data_type == 'txt': + self.context.has_img_model = False + self.context.has_txt_model = True + elif self.client_config.data_type == 'img': + self.context.has_img_model = True + self.context.has_txt_model = False + elif self.client_config.data_type == 'mm': + self.context.has_img_model = True + self.context.has_txt_model = True + else: + raise ValueError(f'client_config.data_type={self.client_config.data_type} is not implemented by federation.client.Client') + + # will setup in setup_data_loader + self.trainer = None + + def setup_data_loader(self): + + self.logger.log('setup global dataloader') + + global_dataloader, _ = api.get_global_dataloader(self.context) + self.global_eval_dataloader = global_dataloader['train_subset_eval' + f'_{self.args.pub_data_num}'] + + self.logger.log('start creating model and partition datasets') + + os.makedirs(os.environ['HOME'] + f'/data/yClient', exist_ok=True) + + args = self.context.args + client_config = self.client_config + + data_type = client_config.data_type + # data_partition_file_name = client_config.data_partition # todo + data_partition_index = client_config.data_partition_index + + alpha = args.alpha + batch_size = args.batch_size + max_size = args.max_size + + if data_type == 'txt': + dataset = 'AG_NEWS' + # this loads all train loaders for all clients, from old code + train_loaders, test_set = get_FL_trainloader(dataset, os.environ['HOME'] + "/data", + args.num_txt_clients, "hetero", alpha, batch_size, max_size) + dst = os.environ['HOME'] + f'/data/yClient/{dataset}-{self.name}' # unused? + self.trainer = ClientTrainer(args, dataset, dst, None, None, None, self.logger, + global_test_set=test_set, inter_distance=4, client_id=data_partition_index, wandb=self.wandb) + self.trainer.train_loader = train_loaders[data_partition_index] + elif data_type == 'img': + dataset = 'cifar100' + # this loads all train loaders for all clients, from old code + train_loaders, test_set = get_FL_trainloader(dataset, os.environ['HOME'] + "/data/cifar100", + args.num_img_clients, "hetero", alpha, batch_size, max_size) + dataset = 'Cifar100' + dst = os.environ['HOME'] + f'/data/yClient/{dataset}-{self.name}' # unused? + self.trainer = ClientTrainer(args, dataset, dst, RGBmean['Cifar100'], RGBstdv['Cifar100'], None, self.logger, + global_test_set=test_set, inter_distance=4, client_id=data_partition_index, wandb=self.wandb) + self.trainer.train_loader = train_loaders[data_partition_index] + elif data_type == 'mm': + config = parse_config("./src/f30k.yaml", strict_cast=False) + config.model.cache_dir = config.model.cache_dir + '-' + config.train.server_dataset + config.train.output_file = os.path.join(config.model.cache_dir, config.train.output_file) + config.train.best_model_save_path = os.path.join(config.model.cache_dir, config.train.best_model_save_path) + config.train.model_save_path = os.path.join(config.model.cache_dir, config.train.model_save_path) + config.model.embed_dim = self.args.feature_dim + config.model.not_bert = True + + self.trainer = MMClientTrainer(args, config, self.logger, client=data_partition_index, num_users=args.num_mm_clients, dset_name="flicker30k", + device='cuda', + vocab_path='./src/custom_datasets/vocabs/coco_vocab.pkl', + mlp_local=self.args.mlp_local) + else: + raise ValueError(f'client_config.data_type={data_type} is not implemented by federation.client.get_data_loader()') + + def train(self, gf: api.GlobalFeature): + trainer = self.trainer + self.logger.log(f"Training Client {self.name}") + trainer.run(gf.img, gf.txt, gf.distill_index, self.global_eval_dataloader) + self.logger.log(f"client {self.name} Generate Local Representations") + vec, di = trainer.generate_logits(self.global_eval_dataloader) + if di != gf.distill_index: + raise ValueError(f"distill_index mismatch: {di} != {gf.distill_index}") + img = vec.get('img') + txt = vec.get('txt') + return img, txt + + def submit(self, local_repr: api.ClientState): + self.logger.log(f"Submitting local representations to server.") + api.submit_local_repr(self.context, local_repr) + + def run(self): + current_path = os.path.dirname(os.path.dirname(__file__)) + with open(os.path.join(current_path, 'accuracy.txt'), 'w') as f: + f.write('') + while True: + self.logger.log(f"Client {self.name} is starting a new round.") + server_state = api.get_server_state(self.context, expected_state=api.RoundState.COLLECT) + self.logger.log(f"Server state: {server_state.round_state}, round number: {server_state.round_number}, global feature hash: {server_state.feature_hash}") + global_feature = api.get_global_feature(self.context, server_state) + self.logger.log(f"Global feature retrieved") + img, txt = self.train(global_feature) + del global_feature + api.add_local_repr(self.context, server_state, img, txt, -1) # todo: set local rounds + del img, txt + +if __name__ == "__main__": + from src.federation.context import new_client_context + context = new_client_context() + client = Client(context) + client.setup_data_loader() + client.run() + + + \ No newline at end of file diff --git a/src/federation/config.py b/src/federation/config.py new file mode 100644 index 0000000..8667f87 --- /dev/null +++ b/src/federation/config.py @@ -0,0 +1,24 @@ +import os +import munch +import yaml +from yaml.error import YAMLError +import copy + +def read_config_file(fn: str): + with open(fn, 'r') as fin: + return yaml.safe_load(fin) + +def apply_defaults(dic, defaults): + out = copy.deepcopy(defaults) + for k, v in dic.items(): + if isinstance(v, dict) and k in defaults: + out[k] = apply_defaults(v, defaults[k]) + else: + out.update({k: v}) + return out + +def load_config(config_file, defaults = None): + config = read_config_file(config_file) + if defaults is not None: + config = apply_defaults(config, defaults) + return munch.Munch(config) \ No newline at end of file diff --git a/src/federation/context.py b/src/federation/context.py new file mode 100644 index 0000000..86fe712 --- /dev/null +++ b/src/federation/context.py @@ -0,0 +1,55 @@ +import torch +import sys +sys.path.append("./") +sys.path.append("../") +sys.path.append("../../") +sys.path.append("../../../") +from src.common import prepare_args, get_config +from src.utils.logger import PythonLogger +import config + + +class Context(): + """ + Represents the context for executing a script. For sharing configuration logic between federation scripts. + + Args: + description (str): A description of the script. + script (str): The name/type of script to be executed. + """ + def __init__(self, description: str, script: str): + # from main + args, wandb = prepare_args(description, script) + self.args = args + self.wandb = wandb + + # from MMFL + self.config = get_config(args) + self.logger = PythonLogger(output_file=self.config.train.output_file) + + # automatically set device + if torch.cuda.is_available(): + self.device = torch.device("cuda:%d" % args.device) + else: + self.device = torch.device("cpu") + + + # federation specific config file + self.fed_config = config.load_config(args.fed_config) + + # to be configured by client + self.has_img_model = False + self.has_txt_model = False + + +def new_server_context(): + return Context(description="CreamFL Federated Learning (http server)", script="server") + +def new_client_context(): + return Context(description="CreamFL Federated Learning (client)", script="client") + +def new_global_context(): + c = Context(description="CreamFL Federated Learning (global compute)", script="global") + c.has_img_model = True + c.has_txt_model = True + return c diff --git a/src/federation/global.py b/src/federation/global.py new file mode 100644 index 0000000..38187a4 --- /dev/null +++ b/src/federation/global.py @@ -0,0 +1,319 @@ +# moved from MMFL.py with client logic removed + +import gc +import random + +import operator +import os +from copy import deepcopy +import sys + +import torch +import torch.nn as nn +from tqdm import tqdm + +sys.path.append("./") +sys.path.append("../") +sys.path.append("../../") +sys.path.append("../../../") + +from src.utils.color_lib import RGBmean, RGBstdv + +from src.algorithms.eval_coco import COCOEvaluator +from src.algorithms.retrieval_trainer import TrainerEngine +from src.utils.logger import PythonLogger + +import api + +try: + from apex import amp +except ImportError: + print('failed to import apex') + +current_path = os.path.dirname(os.path.dirname(__file__)) + +class Global: + def __init__(self, context): + # all configs + self.context = context + self.args = context.args + self.config = context.config + self.device = context.device + self.logger = context.logger + self.wandb = context.wandb + self.engine = None # set in load_dataset + self.best_score = 0 + + # coco global dataloaders + self.dataloaders_global = None + # universal test dataloader + self.test_loader = None + + self.img_vec, self.txt_vec = None, None + self.global_img_feature = None + self.global_txt_feature = None + self.distill_index = None + + def load_dataset(self): + args = self.context.args + self.dataloaders_global, self.vocab = api.get_global_dataloader(self.context) + + self.engine = TrainerEngine() + self.engine.set_logger(self.logger) + + self.config.optimizer.learning_rate = args.server_lr + + self._dataloaders = self.dataloaders_global.copy() + self.evaluator = COCOEvaluator(eval_method='matmul', + verbose=True, + eval_device='cuda', + n_crossfolds=5) + self.engine.create(self.config, self.vocab.word2idx, self.evaluator, args.mlp_local) + + self.train_eval_dataloader = self._dataloaders.pop( + 'train_subset_eval' + f'_{args.pub_data_num}') if self._dataloaders is not None else None + + self.engine.model_to_device() + torch.backends.cudnn.enabled = True + if self.config.train.get('use_fp16'): + self.engine.logger.log('Train with half precision') + self.engine.to_half() + + def train(self, server_state: api.ServerState): + + self.logger.log(f"Global training round {server_state.round_number}!") + self.engine.train( + tr_loader=self._dataloaders['train_subset' + f'_{self.args.pub_data_num}']) # global train + + # calculate global representation + if self.args.agg_method == "con_w" or self.args.contrast_local_intra or self.args.contrast_local_inter: + img_feature, txt_feature = [], [] + distill_index = [] + for idx, (images, captions, captions_word, caption_lens, _, _, index) in tqdm( + enumerate(self.dataloaders_global['train_subset_eval' + f'_{self.args.pub_data_num}']), + total=len(self.dataloaders_global['train_subset_eval' + f'_{self.args.pub_data_num}'])): + with torch.no_grad(): + images = images.to(self.engine.device) # [bs, 3, 224, 224] + captions = captions.to(self.engine.device) # [bs, seq_len] + caption_lens = caption_lens.to(self.engine.device) + + output = self.engine.model(images, captions, captions_word, caption_lens) + out_img = output['image_features'] + out_txt = output['caption_features'] + + out_img = out_img.cpu().detach() + out_txt = out_txt.cpu().detach() + + img_feature.append(out_img) + txt_feature.append(out_txt) + distill_index.extend(index) + + self.global_img_feature = torch.concat(img_feature, dim=0) + self.global_txt_feature = torch.concat(txt_feature, dim=0) + print(self.global_txt_feature.shape, self.global_img_feature.shape) + self.distill_index = distill_index + del img_feature, txt_feature + gc.collect() + + # submit global representation + global_feature = api.GlobalFeature(self.global_img_feature, self.global_txt_feature, self.distill_index) + if not api.submit_global_feature(self.context, server_state, global_feature): + self.logger.log("global train failed: failed to submit global feature") + return server_state, False + + # waiting and retrieving local representations + self.logger.log("Waiting for client representations") + new_server_state = api.get_server_state(self.context, expected_state=api.RoundState.BUSY) + if new_server_state.round_number != server_state.round_number + 1: + self.logger.log(f"global train failed: round number mismatch: {new_server_state.round_number} != {server_state.round_number}") + return new_server_state, False + if new_server_state.feature_hash != global_feature.hash: + self.logger.log(f"global train failed: feature hash mismatch: {new_server_state.feature_hash} != {global_feature.hash}") + return new_server_state, False + img_vec, txt_vec = api.get_clients_repr(self.context, new_server_state.clients_reported) + self.logger.log(f"loaded client representations: img_vec x{len(img_vec)}, txt_vec x{len(txt_vec)}") + + # global distillation + if not self.args.disable_distill: + self.distill(new_server_state.round_number, img_vec, txt_vec, self.distill_index) + + def get_lr(optimizer): + for param_group in optimizer.param_groups: + return param_group['lr'] + + round_n = server_state.round_number + # record after each epoch training + metadata = self.engine.metadata.copy() + metadata['cur_epoch'] = round_n + 1 + metadata['lr'] = get_lr(self.engine.optimizer) + + test_scores = self.engine.evaluate({'test': self._dataloaders['test']}) + self.engine.report_scores(step=round_n + 1, + scores=test_scores, + metadata=metadata, + prefix=self.engine.eval_prefix) + rsum = test_scores['test']['n_fold']['i2t']['recall_1'] + test_scores['test']['n_fold']['t2i']['recall_1'] + \ + test_scores['test']['i2t']['recall_1'] + test_scores['test']['t2i']['recall_1'] + rsum5 = test_scores['test']['n_fold']['i2t']['recall_5'] + test_scores['test']['n_fold']['t2i']['recall_5'] + \ + test_scores['test']['i2t']['recall_5'] + test_scores['test']['t2i']['recall_5'] + rsum10 = test_scores['test']['n_fold']['i2t']['recall_10'] + test_scores['test']['n_fold']['t2i']['recall_10'] + \ + test_scores['test']['i2t']['recall_10'] + test_scores['test']['t2i']['recall_10'] + + self.wandb.log({"Server rsum_r1": rsum}, step=round_n) + self.wandb.log({"Server n_fold_i2t_r1": test_scores['test']['n_fold']['i2t']['recall_1']}, step=round_n) + self.wandb.log({"Server n_fold_t2i_r1": test_scores['test']['n_fold']['t2i']['recall_1']}, step=round_n) + self.wandb.log({"Server i2t_r1": test_scores['test']['i2t']['recall_1']}, step=round_n) + self.wandb.log({"Server t2i_r1": test_scores['test']['t2i']['recall_1']}, step=round_n) + + with open(os.path.join(current_path, 'recall.txt'), 'a') as f: + f.write(f'{round_n}:{rsum},{rsum5},{rsum10}\n') + + if self.best_score < rsum: + best_score = rsum + metadata['best_score'] = best_score + metadata['best_epoch'] = round_n + 1 + self.best_metadata, self.best_scores = metadata, test_scores + + self.engine.save_models(self.args.name + '-best_model.pt') + # torch.save({'net': self.engine.model.state_dict()}, self.args.name + '-best_model.pt') + + if round_n == self.args.comm_rounds - 1: + self.engine.save_models(self.args.name + '-last_model.pt') + # torch.save({'net': self.engine.model.state_dict()}, self.args.name + '-last_model.pt') + + self.engine.lr_scheduler.step() + + del img_vec, txt_vec + gc.collect() + return new_server_state, True + + def distill(self, round_n, img_vec, txt_vec, distill_index): + + has_img_vec = len(img_vec) > 0 + has_txt_vec = len(txt_vec) > 0 + + self.engine.model.train() + + if self.config.model.use_img_client or self.config.model.use_txt_client or self.config.model.use_mm_client: + client_loss_cri = nn.MSELoss() + + def aggregation(i_vec=img_vec, t_vec=txt_vec): + if self.args.agg_method == "con_w": + if not i_vec: + self.logger.log("distill.aggregation i_vec is empty") + else: + contrastive_w = [] + for vec in i_vec: # vec: [50000, n_feature], global_txt_feature: [50000, n_feature] + logits = torch.matmul(vec, self.global_txt_feature.T) # [50000, 50000] + exp_logits = torch.exp(logits) + log_prob = logits - torch.log(torch.sum(exp_logits, dim=1, keepdim=True)) + contrastive_w.append(torch.diagonal(log_prob).reshape(1, -1)) + if not contrastive_w: + self.logger.log("distill.aggregation No tensors were added to contrastive_w for images") + else: + contrastive_w = torch.softmax(torch.cat(contrastive_w, dim=0), dim=0) + for i in range(len(i_vec)): + i_vec[i] = (i_vec[i] * contrastive_w[i].reshape(-1, 1)).unsqueeze(0) + i_vec = torch.sum(torch.cat(i_vec, dim=0), dim=0) # aggregated image vectors + + if not t_vec: + self.logger.log("distill.aggregation t_vec is empty") + else: + contrastive_w = [] + for vec in t_vec: # vec: [50000, n_feature], global_txt_feature: [50000, n_feature] + logits = torch.matmul(vec, self.global_img_feature.T) # [50000, 50000] + exp_logits = torch.exp(logits) + log_prob = logits - torch.log(torch.sum(exp_logits, dim=1, keepdim=True)) + contrastive_w.append(torch.diagonal(log_prob).reshape(1, -1)) + if not contrastive_w: + self.logger.log("distill.aggregation No tensors were added to contrastive_w for texts") + else: + contrastive_w = torch.softmax(torch.cat(contrastive_w, dim=0), dim=0) + for i in range(len(t_vec)): + t_vec[i] = (t_vec[i] * contrastive_w[i].reshape(-1, 1)).unsqueeze(0) + t_vec = torch.sum(torch.cat(t_vec, dim=0), dim=0) # aggregated text vectors + else: + raise NotImplementedError + + return i_vec, t_vec + + # aggregation + img_vec, txt_vec = aggregation() + + self.img_vec = img_vec + self.txt_vec = txt_vec + + distill_dict = {b: a for a, b in enumerate(distill_index)} # index in coco to index to list 'distill_index' + # distill + self.logger.log("start distilling") + for idx, (images, captions, captions_word, caption_lens, _, _, index) in tqdm( + enumerate(self.dataloaders_global['train_subset' + f'_{self.args.pub_data_num}'])): + images = images.to(self.engine.device) # [bs, 3, 224, 224] + captions = captions.to(self.engine.device) # [bs, seq_len] + caption_lens = caption_lens.to(self.engine.device) + + output = self.engine.model(images, captions, captions_word, caption_lens) + loss = 0 + + def code_sim(output, target, config): + output = output.sum(axis=1) if len(output.shape) == 3 else output + target = target.type_as(output) + + return client_loss_cri(output, target.type_as(output)) + + if has_img_vec: + out_img = output['image_features'] + d_idx = operator.itemgetter(*index)(distill_dict) # idx of the current batch + target_img = self.img_vec[d_idx, :].type_as(out_img) + loss += self.args.kd_weight * code_sim(out_img, target_img, self.config) + if has_txt_vec: + out_txt = output['caption_features'] + d_idx = operator.itemgetter(*index)(distill_dict) # idx of the current batch + target_txt = self.txt_vec[d_idx, :].type_as(out_txt) + loss += self.args.kd_weight * code_sim(out_txt, target_txt, self.config) + if has_img_vec and has_txt_vec: + out_img = output['image_features'] + d_idx = operator.itemgetter(*index)(distill_dict) # idx of the current batch + target_img = self.img_vec[d_idx, :].type_as(out_img) + out_txt = output['caption_features'] + target_txt = self.txt_vec[d_idx, :].type_as(out_txt) + loss += self.args.kd_weight * code_sim(out_img, target_img, self.config) + loss += self.args.kd_weight * code_sim(out_txt, target_txt, self.config) + + self.engine.optimizer.zero_grad() + + with open(os.path.join(current_path, 'loss.txt'), 'a') as f: + f.write(f'{round_n - 1}:{loss:.3f}\n') + + if self.config.train.get('use_fp16'): + with amp.scale_loss(loss, self.engine.optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + + if self.config.train.grad_clip > 0: + nn.utils.clip_grad.clip_grad_norm_(self.engine.model.parameters(), + self.config.train.grad_clip) + self.engine.optimizer.step() + +if __name__ == "__main__": + from src.federation.context import new_global_context + context = new_global_context() + global_compute = Global(context) + global_compute.load_dataset() + server_state = api.get_server_state(context) + with open(os.path.join(current_path, 'loss.txt'), 'w') as f: + f.write('') + with open(os.path.join(current_path, 'recall.txt'), 'w') as f: + f.write('') + while server_state.round_number < context.args.comm_rounds: + server_state, ok = global_compute.train(server_state) + if not ok: + context.logger.log(f"global compute failed:{server_state.to_dict()}") + context.logger.log(f"global compute finished:{server_state.to_dict()}") + global_compute.logger.log("Best:") + global_compute.engine.report_scores(step=context.args.comm_rounds, + scores=global_compute.best_scores, + metadata=global_compute.best_metadata, + prefix=global_compute.engine.eval_prefix) + diff --git a/src/federation/server.py b/src/federation/server.py new file mode 100644 index 0000000..2d75a1c --- /dev/null +++ b/src/federation/server.py @@ -0,0 +1,201 @@ +# build in +import json +from http import HTTPStatus +# external dependencies +from flask import Flask, request, send_file +from PIL import Image +from uuid import uuid4 +import numpy as np +import api, context, os +import torch +# internal dependencies +import sys +sys.path.append("./") +sys.path.append("../") +sys.path.append("../../") +sys.path.append("../../../") +from src.algorithms.retrieval_trainer import TrainerEngine +from src.utils.load_datasets import imagenet_transform +from src.algorithms.eval_coco import COCOEvaluator +from src.utils.tensor_utils import to_numpy +server = Flask(__name__) + +dir = os.path.dirname(os.path.dirname(os.path.dirname(__file__))) +print(dir) +server.config['UPLOAD_FOLDER'] = f'{dir}/uploads' + +server_context = None # set by main + +current_state = api.ServerState() + +url_prefix = api.url_prefix + +engine = TrainerEngine() + +@server.route(f'{url_prefix}', methods=['GET']) +def get(): + return json.dumps({"current_state":current_state.to_dict()}) + +@server.route(f'{url_prefix}/uploads/') +def send_img(path): + return send_file(f"{server.config['UPLOAD_FOLDER']}/{path}") + +@server.route(f'{url_prefix}/last_train_result', methods=['GET']) +def get_result(): + current_path = os.path.dirname(os.path.dirname(__file__)) + accuracy = '' + loss = '' + recall = '' + with open(os.path.join(current_path, 'accuracy.txt'), 'r') as f: + accuracy = f.readlines() + with open(os.path.join(current_path, 'loss.txt'), 'r') as f: + loss = f.readlines() + with open(os.path.join(current_path, 'recall.txt'), 'r') as f: + recall = f.readlines() + return json.dumps({"accuracy": accuracy, "loss": loss, "recall": recall}) + +@server.route(f'{url_prefix}/upload', methods=['POST']) +def upload(): + filename = '' + if request.method == 'POST': + f = request.files['file'] + filename = f"{uuid4()}.{f.filename.split('.')[-1]}" + f.save(os.path.join(server.config['UPLOAD_FOLDER'], filename)) + + return json.dumps({"status": "ok", "url": filename}) + + +@server.route(f'{url_prefix}/inference', methods=['POST']) +def inference(): + result = None + if request.method == 'POST': + batch = request.form['batch'] + + global engine + engine.model.eval() + result = [] + if batch == 'True': + config_path = request.form['config_path'] + config = json.loads(config_path) + for i in config: + captions = [item['text'] for item in i['captions']] + result.append(evl(i['img_path'], captions)) + # with open(config_path, 'r') as f: + # config = json.load(f) + # for i in config: + # captions = [item['text'] for item in i['captions']] + # result.append(evl(i['img_path'], captions)) + else: + captions = request.form['captions'] + f = request.form['file'] + path = os.path.join(server.config['UPLOAD_FOLDER'], f) + result.append(evl(path, captions)) + return json.dumps({"status": "ok", "result": result}) + + +def evl(f, captions): + images = (convert_img(f)) + images = images.unsqueeze(0) + images = images.to(engine.device) # used + sentences = [] + if isinstance(captions, str): + captions = captions.split('\n') # used + + output = engine.model(images, sentences, captions, len(sentences)) + f_ids = [i for i in range(len(captions))] + result = evaluate_single(output, f_ids) + return result + + +def evaluate_single(output, f_ids): + _image_features = output['image_features'] + _caption_features = output['caption_features'] + + n_embeddings = 7 + feat_size = 256 + image_features = np.zeros((1, n_embeddings, feat_size)) + caption_features = np.zeros((len(_caption_features), n_embeddings, feat_size)) + image_features[0] = to_numpy(_image_features[0]) + for i in range(len(_caption_features)): + caption_features[i] = to_numpy(_caption_features[i]) + # caption_features[0] = to_numpy(_caption_features[0]) + # caption_features[1] = to_numpy(_caption_features[1]) + # caption_features[2] = to_numpy(_caption_features[2]) + # caption_features[3] = to_numpy(_caption_features[3]) + # caption_features[4] = to_numpy(_caption_features[4]) + image_features = torch.from_numpy(image_features) + caption_features = torch.from_numpy(caption_features) + + id = 1 + q_id = [id] + + retrieved_items, retrieved_scores, _ = engine.evaluator.retrieve(image_features, caption_features, q_id, torch.tensor(f_ids), topk=5, batch_size=1) + values = [item.item() for item in retrieved_items[id]] + return {"pred": retrieved_items[id][0].item(), "score": values} + +def convert_img(path, cutout_prob=0.0): + _image_transform = imagenet_transform( + random_resize_crop=False, + random_erasing_prob=cutout_prob, + ) + img = Image.open(path).convert('RGB') + img = _image_transform(img) + return img + + +@server.route(f'{url_prefix}/set_global_feature', methods=['PUT']) +def set_global(): + # ensure that the server is in a state update the global model. + if current_state.round_state != api.RoundState.BUSY: + return json.dumps({"status":"error", "message":"The server should not update the global model at this time."}), HTTPStatus.CONFLICT + # ensure that we are submitting to the correct round. + round_number = request.args.get('round_number', default=-1) + if round_number != str(current_state.round_number): + return json.dumps({"status":"error", "message":f"The round has passed. expected={current_state.round_number}, got={round_number}"}), HTTPStatus.CONFLICT + # ensure that the global model is identical to the one the client used. This should always be true considering the round number is already checked. + old_feature_hash = request.args.get('old_feature_hash', default="missing") + if old_feature_hash != current_state.feature_hash: + return json.dumps({"status":"error", "message":f"The global model has changed {old_feature_hash}!={current_state.feature_hash}"}), HTTPStatus.CONFLICT + new_feature_hash = request.args.get('new_feature_hash', default="") + if new_feature_hash == "": + return json.dumps({"status":"error", "message":"The new feature hash is required"}), HTTPStatus.BAD_REQUEST + current_state.update_feature_hash(new_feature_hash) + return json.dumps({"status":"ok"}) + +@server.route(f'{url_prefix}/add_client', methods=['PUT']) +def add_client(): + # ensure that the server is in a state to accept clients models. + if current_state.round_state == api.RoundState.BUSY: + return json.dumps({"status":"error", "message":"The server is not accepting clients at this time."}), HTTPStatus.CONFLICT + # ensure that we are submitting to the correct round. + round_number = request.args.get('round_number', default=-1) + if round_number != str(current_state.round_number): + return json.dumps({"status":"error", "message":f"The round has passed. expected={current_state.round_number}, got={round_number}"}), HTTPStatus.CONFLICT + # ensure that the global model is identical to the one the client used. This should always be true considering the round number is already checked. + feature_hash = request.args.get('feature_hash', default="") + if feature_hash != current_state.feature_hash: + return json.dumps({"status":"error", "message":"The global model has changed"}), HTTPStatus.CONFLICT + + data = request.get_json() + client = api.ClientState.from_dict(data) + # should also verify client auth on an untrusted network. + current_state.clients_reported[client.name] = client + if len(current_state.clients_reported) == server_context.fed_config.server['max_clients']: + current_state.advance_round() + server_context.logger.log(f"Global round {current_state.round_number} has collected the max number of clients.") + + return json.dumps({"status":"ok"}) + +if __name__ == '__main__': + server_context = context.new_server_context() + if server_context.args.inference: + evaluator = COCOEvaluator(eval_method='matmul', + verbose=True, + eval_device='cuda', + n_crossfolds=5) + engine.load_models2("./sl2-best_model.pt", evaluator) + engine.model_to_device() + + server.run(port=server_context.args.port, host="0.0.0.0", debug=True) + + diff --git a/src/loss.txt b/src/loss.txt new file mode 100644 index 0000000..96037bd --- /dev/null +++ b/src/loss.txt @@ -0,0 +1,20 @@ +0:0.002 +1:0.002 +2:0.001 +3:0.001 +4:0.001 +5:0.001 +6:0.001 +7:0.001 +8:0.001 +9:0.001 +10:0.001 +11:0.001 +12:0.001 +13:0.001 +14:0.001 +15:0.001 +16:0.001 +17:0.001 +18:0.001 +19:0.001 diff --git a/src/losses/__pycache__/BatchAll.cpython-38.pyc b/src/losses/__pycache__/BatchAll.cpython-38.pyc deleted file mode 100644 index f34bf39..0000000 Binary files a/src/losses/__pycache__/BatchAll.cpython-38.pyc and /dev/null differ diff --git a/src/losses/__pycache__/CenterTriplet.cpython-38.pyc b/src/losses/__pycache__/CenterTriplet.cpython-38.pyc deleted file mode 100644 index a06a0d7..0000000 Binary files a/src/losses/__pycache__/CenterTriplet.cpython-38.pyc and /dev/null differ diff --git a/src/losses/__pycache__/DistanceMatchLoss.cpython-38.pyc b/src/losses/__pycache__/DistanceMatchLoss.cpython-38.pyc deleted file mode 100644 index f8d5a1c..0000000 Binary files a/src/losses/__pycache__/DistanceMatchLoss.cpython-38.pyc and /dev/null differ diff --git a/src/losses/__pycache__/GaussianMetric.cpython-38.pyc b/src/losses/__pycache__/GaussianMetric.cpython-38.pyc deleted file mode 100644 index 33075c5..0000000 Binary files a/src/losses/__pycache__/GaussianMetric.cpython-38.pyc and /dev/null differ diff --git a/src/losses/__pycache__/HistogramLoss.cpython-38.pyc b/src/losses/__pycache__/HistogramLoss.cpython-38.pyc deleted file mode 100644 index 1dbc787..0000000 Binary files a/src/losses/__pycache__/HistogramLoss.cpython-38.pyc and /dev/null differ diff --git a/src/losses/__pycache__/NeighbourHardLoss.cpython-38.pyc b/src/losses/__pycache__/NeighbourHardLoss.cpython-38.pyc deleted file mode 100644 index 6513350..0000000 Binary files a/src/losses/__pycache__/NeighbourHardLoss.cpython-38.pyc and /dev/null differ diff --git a/src/losses/__pycache__/NeighbourLoss.cpython-38.pyc b/src/losses/__pycache__/NeighbourLoss.cpython-38.pyc deleted file mode 100644 index ac97ed9..0000000 Binary files a/src/losses/__pycache__/NeighbourLoss.cpython-38.pyc and /dev/null differ diff --git a/src/losses/__pycache__/__init__.cpython-38.pyc b/src/losses/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index d48296c..0000000 Binary files a/src/losses/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/src/losses/__pycache__/triplet.cpython-38.pyc b/src/losses/__pycache__/triplet.cpython-38.pyc deleted file mode 100644 index 361759c..0000000 Binary files a/src/losses/__pycache__/triplet.cpython-38.pyc and /dev/null differ diff --git a/src/main.py b/src/main.py index 539c9fc..67afb11 100644 --- a/src/main.py +++ b/src/main.py @@ -1,128 +1,15 @@ -import os -import argparse -from utils.helper import Helper as helper -import algorithms -import random - - -def init_wandb(args): - """ - wandb will automatically save the log - - wandb.log({"epoch": epoch, "loss": loss}, step=example_ct) - print(f"Loss after " + str(example_ct).zfill(5) + f" examples: {loss:.3f}") - - wandb.log({"test_accuracy": correct / total}) - - # Save the model in the exchangeable ONNX format - torch.onnx.export(model, images, "model.onnx") - wandb.save("model.onnx") - - """ - - import wandb - - name = str(args.name) - - wandb.init( - project="CreamFL", - name=name, - resume=None, - # dir=os.path.join(args.exp_dir, args.name), - config=args - ) - - return wandb - - -def args(): - parser.add_argument('--name', type=str, default='Test', help='The name for different experimental runs.') - parser.add_argument('--exp_dir', type=str, default='./experiments/', - help='Locations to save different experimental runs.') - parser.add_argument('--local_epochs', type=int, default=5) - parser.add_argument('--comm_rounds', type=int, default=30) - - parser.add_argument('--model', type=str, default='resnet34', help='Target model name (default: resnet34_8x)') - parser.add_argument('--img_model_local', type=str, default='resnet10') - parser.add_argument('--pretrained', type=int, default=0) - parser.add_argument('--no-cuda', action='store_true', default=False, - help='disables CUDA training') - parser.add_argument('--seed', type=int, default=random.randint(0, 100000), metavar='S', - help='random seed (default: 1)') - parser.add_argument('--device', type=int, default=0) - - parser.add_argument('--num_img_clients', type=int, default=10) - parser.add_argument('--num_txt_clients', type=int, default=10) - parser.add_argument('--num_mm_clients', type=int, default=15) - - parser.add_argument('--client_num_per_round', type=int, default=10) - - # === dataloader === - parser.add_argument('--dataset', type=str, default='cifar100', choices=['svhn', 'cifar10', 'cifar100'], - help='dataset name (default: cifar10)') - parser.add_argument('--data_root', type=str, default=os.environ['HOME'] + "/data/") - parser.add_argument('--batch_size', type=int, default=64, metavar='N', - help='input batch size for training (default: 256)') - parser.add_argument('--alpha', type=float, default=0.5) - - # === optimization === - parser.add_argument('--server_lr', type=float, default=0.0002) - parser.add_argument('--lr', type=float, default=0.1, metavar='LR', help='Student learning rate (default: 0.1)') - parser.add_argument('--loss', type=str, default='l1', choices=['l1', 'kl', 'l1softmax'], ) - parser.add_argument('--scheduler', type=str, default='multistep', - choices=['multistep', 'cosine', 'exponential', "none"], ) - parser.add_argument('--steps', nargs='+', default=[0.05, 0.15, 0.3, 0.5, 0.75], type=float, - help="Percentage epochs at which to take next step") - parser.add_argument('--scale', type=float, default=0.1, help="Fractional decrease in lr") - parser.add_argument('--weight_decay', type=float, default=5e-4) - parser.add_argument('--momentum', type=float, default=0.9, metavar='M', help='SGD momentum (default: 0.9)') - # === logs === - parser.add_argument('--log_interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') - parser.add_argument('--save_interval', type=int, default=10, metavar='N', - help='how many batches to wait before logging training status') - - parser.add_argument('--disable_distill', action="store_true", default=False) - - parser.add_argument('--agg_method', type=str, default='con_w', help='representation aggregation method') - parser.add_argument('--contrast_local_intra', action="store_true", default=False) - parser.add_argument('--contrast_local_inter', action="store_true", default=False) - - parser.add_argument('--mlp_local', action="store_true", default=False) - - parser.add_argument('--kd_weight', type=float, default=0.3, help='coefficient of kd') - parser.add_argument('--interintra_weight', type=float, default=0.5, help='coefficient of inter+intra') - - parser.add_argument('--loss_scale', action='store_true', default=False) - parser.add_argument('--save_client', action='store_true', default=False) - - parser.add_argument('--data_local', action='store_true', default=False, - help='change data directory to ~/data_local') - - parser.add_argument('--pub_data_num', type=int, default=50000, help='communication') - parser.add_argument('--feature_dim', type=int, default=256) - - parser.add_argument('--not_bert', action='store_true', default=False, help="server bert, client not bert") - - -parser = argparse.ArgumentParser(description='Federated Learning') -args() -args = parser.parse_args() +import common if __name__ == "__main__": from algorithms.MMFL import MMFL - wandb = init_wandb(args) + args, wandb = common.prepare_args(description="CreamFL Federated Learning (local simulation)") Algo = MMFL(args, wandb) - args.save_dirs = helper.get_save_dirs(args.exp_dir, args.name) - args.log_dir = args.save_dirs['logs'] - helper.set_seed(args.seed) - - Algo.create_model(args) - Algo.load_dataset(args) + Algo.create_model(args) # create client models and datasets + Algo.load_dataset(args) # global model and dataset for round_n in range(args.comm_rounds): Algo.train(round_n) diff --git a/src/networks/__pycache__/__init__.cpython-38.pyc b/src/networks/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index e73b0f4..0000000 Binary files a/src/networks/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/src/networks/__pycache__/language_model.cpython-38.pyc b/src/networks/__pycache__/language_model.cpython-38.pyc deleted file mode 100644 index 41c3a3a..0000000 Binary files a/src/networks/__pycache__/language_model.cpython-38.pyc and /dev/null differ diff --git a/src/networks/__pycache__/resnet.cpython-38.pyc b/src/networks/__pycache__/resnet.cpython-38.pyc deleted file mode 100644 index 1c42e7b..0000000 Binary files a/src/networks/__pycache__/resnet.cpython-38.pyc and /dev/null differ diff --git a/src/networks/__pycache__/resnet_client.cpython-38.pyc b/src/networks/__pycache__/resnet_client.cpython-38.pyc deleted file mode 100644 index c345d29..0000000 Binary files a/src/networks/__pycache__/resnet_client.cpython-38.pyc and /dev/null differ diff --git a/src/networks/__pycache__/resnet_fedml.cpython-38.pyc b/src/networks/__pycache__/resnet_fedml.cpython-38.pyc deleted file mode 100644 index 6fc595c..0000000 Binary files a/src/networks/__pycache__/resnet_fedml.cpython-38.pyc and /dev/null differ diff --git a/src/networks/fusion_model.py b/src/networks/fusion_model.py new file mode 100644 index 0000000..da1835f --- /dev/null +++ b/src/networks/fusion_model.py @@ -0,0 +1,192 @@ +import torch +import torch.nn as nn +from enum import Enum + +from src.networks.models.pcme import PCME + +class InputType(Enum): + A_B = 'A_B' + AxB = 'AxB' + +def freeze_model(m): + for param in m.parameters(): + param.requires_grad = False + +def unfreeze_model(m): + for param in m.parameters(): + param.requires_grad = True + +# class LinearFusionModel(nn.Module): +# def __init__(self, image_model, text_model, num_classes): +# super(LinearFusionModel, self).__init__() +# self.image_model = image_model +# self.text_model = text_model +# self.fc = nn.Linear(image_model.output_size + text_model.output_size, num_classes) + +# def forward(self, image_input, text_input): +# image_features = self.image_model(image_input) +# text_features = self.text_model(text_input) +# fused_features = torch.cat((image_features, text_features), dim=1) +# output = self.fc(fused_features) +# return output + +class LinearFusionModelEmbedded(nn.Module): + def __init__(self, base_model:PCME): + super(LinearFusionModelEmbedded, self).__init__() + self.base_model = base_model + device = next(self.base_model.parameters()).device + self.fc = nn.Linear(base_model.embed_dim *2 , base_model.embed_dim) + self.to(device) + + def forward(self, images, sentences, captions_word, lengths): + outputs = self.base_model.forward(images, sentences, captions_word, lengths) + image_features = outputs['image_features'] + caption_features = outputs['caption_features'] + fused_features = torch.cat((image_features, caption_features), dim=1) + output = self.fc(fused_features) + return output + +class VQAFusionModel(nn.Module): + def __init__(self, device, base_model: PCME, img_features:int, txt_features:int, num_classes: int, hidden_sizes: list, + dropout_rate=0.0): + super(VQAFusionModel, self).__init__() + self.base_model = base_model + self.device = device + + print(f'VQA Fusion Model device: {self.device}') + + cross_size = hidden_sizes[0] + + self.image_in = nn.Linear(img_features * base_model.embed_dim, cross_size) + self.text_in = nn.Linear(txt_features * base_model.embed_dim, cross_size) + + layers = [] + for hidden_size in hidden_sizes[1:]: + layers.append(nn.Dropout(dropout_rate)) + layers.append(nn.Tanh()), + layers.append(nn.Linear(cross_size, hidden_size)) + cross_size = hidden_size + self.features_extractor = nn.Sequential(*layers) + + self.classifier_head = nn.Sequential( + nn.Dropout(dropout_rate), + nn.Tanh(), + nn.Linear(cross_size, num_classes), # Final classification layer + ) + + self.to(self.device) + + def forward(self, batch): + questions = batch['question'] + outputs = None + images = batch['image'].to(self.device) + #sub_images = batch['sub_images'].to(self.device) + #print(f'types images: {type(images)}, sub_images: {type(sub_images)}') + #print(f'shapes images: {images.shape}, sub_images: {sub_images.shape}') + outputs = self.base_model.forward(images, [], questions, 0) + image_features = outputs['image_features'] + caption_features = outputs['caption_features'] + #sub_images_features = self.base_model.image_forward(sub_images.view(-1, 3, 224, 224))['embedding'] + #sub_images_features = sub_images_features.view(-1, 4, self.base_model.embed_dim).transpose(0, 1) + #question_type_features = self.base_model.text_forward([], batch['question_type'], 0)['embedding'] + #question_rest_features = self.base_model.text_forward([], batch['question_rest'], 0)['embedding'] + #print(f'image_features: {image_features.shape} sub_images_features[0]: {sub_images_features[0].shape}') + #return self.forward_fusion([image_features]+[f for f in sub_images_features], [caption_features, question_type_features, question_rest_features]) + return self.forward_fusion([image_features], [caption_features]) + + def forward_fusion(self, image_features, text_features): + image_features = self.image_in(torch.cat(image_features, dim=1)) + text_features = self.text_in(torch.cat(text_features, dim=1)) + fused_features = image_features * text_features + last_features = self.features_extractor(fused_features) + return self.classifier_head(last_features), last_features + + def unfreeze_base_model(self): + self.frozen_base_model = False + unfreeze_model(self.base_model) + + def freeze_base_model(self): + self.frozen_base_model = True + freeze_model(self.base_model) + + def unfreeze_base_image_model(self): + self.frozen_base_model = False + unfreeze_model(self.base_model.img_enc.cnn) + + def freeze_base_image_model(self): + self.frozen_base_model = True + freeze_model(self.base_model.img_enc.cnn) + +class LinearFusionModelCategorical(nn.Module): + def __init__(self, base_model: PCME, num_features:int, num_classes: int, hidden_sizes: list, input_type: InputType, + dropout_rate=0.0): + super(LinearFusionModelCategorical, self).__init__() + self.base_model = base_model + input_type = InputType(input_type) + self.input_type = input_type + self.frozen_base_model = True + freeze_model(base_model) + self.device = next(self.base_model.parameters()).device + + layers = [] + input_size = base_model.embed_dim * num_features # Input size to the first hidden layer + if input_type == InputType.AxB: + input_size = base_model.embed_dim + #print(input_size) + for hidden_size in hidden_sizes: + layers.append(nn.Dropout(dropout_rate)) + layers.append(nn.Linear(input_size, hidden_size)) + layers.append(nn.ReLU()) + input_size = hidden_size # Update input size for the next layer + self.features_extractor = nn.Sequential(*layers) + + self.classifier_head = nn.Sequential( + nn.Dropout(dropout_rate), + nn.Linear(input_size, num_classes) # Final classification layer + ) + + self.to(self.device) + + def forward(self, batch): + questions = batch['question'] + outputs = None + if 'image_features' in batch: # use precalculated features if available + outputs = self.forward_fusion( + [batch['image_features'], + batch['caption_features']]+batch['sub_images']) + else: + images = batch['image'].to(self.device) + sub_images = batch['sub_images'].to(self.device) + #print(f'types images: {type(images)}, sub_images: {type(sub_images)}') + #print(f'shapes images: {images.shape}, sub_images: {sub_images.shape}') + outputs = self.base_model.forward(images, [], questions, 0) + image_features = outputs['image_features'] + caption_features = outputs['caption_features'] + sub_images_features = self.base_model.image_forward(sub_images.view(-1, 3, 224, 224))['embedding'] + sub_images_features = sub_images_features.view(-1, 4, self.base_model.embed_dim).transpose(0, 1) + question_type_features = self.base_model.text_forward([], batch['question_type'], 0)['embedding'] + question_rest_features = self.base_model.text_forward([], batch['question_rest'], 0)['embedding'] + #print(f'image_features: {image_features.shape} sub_images_features[0]: {sub_images_features[0].shape}') + return self.forward_fusion([image_features, caption_features, question_type_features, question_rest_features]+[f for f in sub_images_features]) + + def forward_fusion(self, features_list): + #print(image_features.shape, caption_features.shape) + #for i, tensor in enumerate(features_list): + # print(f"Tensor {i} shape: {tensor.shape}") + fused_features = None + if self.input_type == InputType.A_B: # Concatenation + fused_features = torch.cat(features_list, dim=1) + if self.input_type == InputType.AxB: # Element-wise multiplication + fused_features = features_list[0] + for i in range(1, len(features_list)): + fused_features = fused_features * features_list[i] + if fused_features is None: + raise ValueError(f"input_type {self.input_type} is not supported in forward_fusion") + #print(fused_features.shape) + last_features = self.features_extractor(fused_features) + return self.classifier_head(last_features), last_features + + def unfreeze_base_model(self): + self.frozen_base_model = False + unfreeze_model(self.base_model) + diff --git a/src/networks/language_model.py b/src/networks/language_model.py index 5e068d9..41fc600 100644 --- a/src/networks/language_model.py +++ b/src/networks/language_model.py @@ -28,7 +28,7 @@ def get_pad_mask(max_length, lengths, set_pad_to_one=True): class EncoderText(nn.Module): def __init__(self, wemb_type='glove', word_dim=300, embed_dim=2048, num_class=4, scale=128, mlp_local=False): super(EncoderText, self).__init__() - with open('src/datasets/vocabs/coco_vocab.pkl', + with open('src/custom_datasets/vocabs/coco_vocab.pkl', 'rb') as fin: vocab = pickle.load(fin) word2idx = vocab['word2idx'] @@ -77,6 +77,7 @@ def init_weights(self, wemb_type, word2idx, word_dim, cache_dir=os.environ['HOME assert wemb.vectors.shape[1] == word_dim, f'wemb.vectors.shape {wemb.vectors.shape}' # quick-and-dirty trick to improve word-hit rate + print(f'improving word-hit rate') missing_words = [] for word, idx in word2idx.items(): if word not in wemb.stoi: diff --git a/src/networks/models/__pycache__/__init__.cpython-38.pyc b/src/networks/models/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index 6fb6da8..0000000 Binary files a/src/networks/models/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/src/networks/models/__pycache__/caption_encoder.cpython-38.pyc b/src/networks/models/__pycache__/caption_encoder.cpython-38.pyc deleted file mode 100644 index 86c43ec..0000000 Binary files a/src/networks/models/__pycache__/caption_encoder.cpython-38.pyc and /dev/null differ diff --git a/src/networks/models/__pycache__/image_encoder.cpython-38.pyc b/src/networks/models/__pycache__/image_encoder.cpython-38.pyc deleted file mode 100644 index 4cc27e8..0000000 Binary files a/src/networks/models/__pycache__/image_encoder.cpython-38.pyc and /dev/null differ diff --git a/src/networks/models/__pycache__/pcme.cpython-38.pyc b/src/networks/models/__pycache__/pcme.cpython-38.pyc deleted file mode 100644 index d3b2efc..0000000 Binary files a/src/networks/models/__pycache__/pcme.cpython-38.pyc and /dev/null differ diff --git a/src/networks/models/__pycache__/pie_model.cpython-38.pyc b/src/networks/models/__pycache__/pie_model.cpython-38.pyc deleted file mode 100644 index ad25cfd..0000000 Binary files a/src/networks/models/__pycache__/pie_model.cpython-38.pyc and /dev/null differ diff --git a/src/networks/models/__pycache__/uncertainty_module.cpython-38.pyc b/src/networks/models/__pycache__/uncertainty_module.cpython-38.pyc deleted file mode 100644 index d8a4162..0000000 Binary files a/src/networks/models/__pycache__/uncertainty_module.cpython-38.pyc and /dev/null differ diff --git a/src/networks/models/caption_encoder.py b/src/networks/models/caption_encoder.py index 39dc5a2..f582e19 100644 --- a/src/networks/models/caption_encoder.py +++ b/src/networks/models/caption_encoder.py @@ -71,6 +71,7 @@ def init_weights(self, wemb_type, word2idx, word_dim, cache_dir=os.environ['HOME assert wemb.vectors.shape[1] == word_dim # quick-and-dirty trick to improve word-hit rate + print(f'improving word-hit rate') missing_words = [] for word, idx in word2idx.items(): if word not in wemb.stoi: diff --git a/src/networks/models/pcme.py b/src/networks/models/pcme.py index 9bf8598..ff14c61 100644 --- a/src/networks/models/pcme.py +++ b/src/networks/models/pcme.py @@ -1,4 +1,4 @@ -import sys +import sys, os import torch.nn as nn from transformers import BertModel, BertTokenizer @@ -28,21 +28,17 @@ def __init__(self, word2idx, config, mlp_local): if config.not_bert: self.txt_enc = EncoderText(word2idx, config, mlp_local) else: - self.txt_enc = BertModel.from_pretrained("bert-base-uncased") - self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") + if os.path.exists("/home/shannon/dev/tools/nlp/models/bert-base-uncased-CoLA"): # hard coded local path + self.txt_enc = BertModel.from_pretrained("/home/shannon/dev/tools/nlp/models/bert-base-uncased-CoLA") + self.tokenizer = BertTokenizer.from_pretrained("/home/shannon/dev/tools/nlp/models/bert-base-uncased-CoLA") + else: + self.txt_enc = BertModel.from_pretrained("bert-base-uncased", resume_download=True) + self.tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", resume_download=True) self.linear = nn.Linear(768, self.embed_dim) def forward(self, images, sentences, captions_word, lengths): - image_output = self.img_enc(images) - if self.config.not_bert: - caption_output = self.txt_enc(sentences, lengths) # sentences: [128, seq_len], lengths: 128 - else: - inputs = self.tokenizer(captions_word, padding=True, return_tensors='pt') - for a in inputs: - inputs[a] = inputs[a].cuda() - caption_output = self.txt_enc(**inputs) - caption_output = {'embedding': l2_normalize(self.linear(caption_output['last_hidden_state'][:, 0, :]))} # [bsz, 768] - + image_output = self.image_forward(images) + caption_output = self.text_forward(sentences, captions_word, lengths) return { 'image_features': image_output['embedding'], 'image_attentions': image_output.get('attention'), @@ -59,5 +55,16 @@ def forward(self, images, sentences, captions_word, lengths): def image_forward(self, images): return self.img_enc(images) - def text_forward(self, sentences, lengths): + def text_forward(self, sentences, captions_word, lengths): + if self.config.not_bert: + return self.txt_enc(sentences, lengths) # sentences: [128, seq_len], lengths: 128 + else: + inputs = self.tokenizer(captions_word, padding=True, return_tensors='pt') + for a in inputs: + inputs[a] = inputs[a].cuda() + caption_output = self.txt_enc(**inputs) + return {'embedding': l2_normalize(self.linear(caption_output['last_hidden_state'][:, 0, :]))} # [bsz, 768] + + + def text_forward_old(self, sentences, lengths): return self.txt_enc(sentences, lengths) diff --git a/src/recall.txt b/src/recall.txt new file mode 100644 index 0000000..a02bde3 --- /dev/null +++ b/src/recall.txt @@ -0,0 +1,20 @@ +0:0.196,1.096,2.2640000000000002 +1:0.156,0.9040000000000001,1.7240000000000002 +2:0.136,0.768,1.564 +3:0.12800000000000003,0.76,2.0000000000000004 +4:0.12,0.844,1.6199999999999999 +5:0.244,1.124,2.1879999999999997 +6:0.22,1.1320000000000001,2.1840000000000006 +7:0.252,1.1280000000000001,2.2720000000000002 +8:0.16800000000000004,1.036,2.112 +9:0.24,1.108,2.136 +10:0.28400000000000003,1.3320000000000003,2.428 +11:0.23199999999999998,1.1880000000000002,2.3120000000000003 +12:0.28,1.1640000000000001,2.5120000000000005 +13:0.29600000000000004,1.3480000000000003,2.668 +14:0.28400000000000003,1.3120000000000003,2.624 +15:0.268,1.2680000000000002,2.6080000000000005 +16:0.18000000000000002,1.124,2.64 +17:0.34800000000000003,1.4120000000000004,3.0080000000000005 +18:0.37599999999999995,1.484,3.132 +19:0.19999999999999998,1.512,3.02 diff --git a/src/retri_client_img.py b/src/retri_client_img.py new file mode 100644 index 0000000..bc25e7d --- /dev/null +++ b/src/retri_client_img.py @@ -0,0 +1,21 @@ +import common + +if __name__ == "__main__": + + from algorithms.MMFL_dist import MMFL_Client + + args, wandb = common.prepare_args( + description="CreamFL Federated Learning for retri task (global rep)", + script="retri_client", + is_vqa=False, + ) + + Algo = MMFL_Client( + args, wandb, node_id="localhost", router_port=5002, peers=[["localhost", 5001]] + ) + + Algo.create_model(args) # create client models and datasets + Algo.load_dataset(args, is_vqa=False) # global model and dataset + + for round_n in range(args.comm_rounds): + Algo.train(round_n) diff --git a/src/retri_client_mm.py b/src/retri_client_mm.py new file mode 100644 index 0000000..9bf7277 --- /dev/null +++ b/src/retri_client_mm.py @@ -0,0 +1,21 @@ +import common + +if __name__ == "__main__": + + from algorithms.MMFL_dist import MMFL_Client + + args, wandb = common.prepare_args( + description="CreamFL Federated Learning for retri task (global rep)", + script="retri_client", + is_vqa=False, + ) + + Algo = MMFL_Client( + args, wandb, node_id="localhost", router_port=5004, peers=[["localhost", 5001]] + ) + + Algo.create_model(args) # create client models and datasets + Algo.load_dataset(args, is_vqa=False) # global model and dataset + + for round_n in range(args.comm_rounds): + Algo.train(round_n) diff --git a/src/retri_client_txt.py b/src/retri_client_txt.py new file mode 100644 index 0000000..7d08f18 --- /dev/null +++ b/src/retri_client_txt.py @@ -0,0 +1,21 @@ +import common + +if __name__ == "__main__": + + from algorithms.MMFL_dist import MMFL_Client + + args, wandb = common.prepare_args( + description="CreamFL Federated Learning for retri task (global rep)", + script="retri_client", + is_vqa=False, + ) + + Algo = MMFL_Client( + args, wandb, node_id="localhost", router_port=5003, peers=[["localhost", 5001]] + ) + + Algo.create_model(args) # create client models and datasets + Algo.load_dataset(args, is_vqa=False) # global model and dataset + + for round_n in range(args.comm_rounds): + Algo.train(round_n) diff --git a/src/retri_global.py b/src/retri_global.py new file mode 100644 index 0000000..11e8a18 --- /dev/null +++ b/src/retri_global.py @@ -0,0 +1,25 @@ +import common + +if __name__ == "__main__": + + from algorithms.MMFL_dist import MMFL_Global + + args, wandb = common.prepare_args( + description="CreamFL Federated Learning for retri task (global rep)", + script="retri_global", + is_vqa=False, + ) + + Algo = MMFL_Global( + args, + wandb, + node_id="localhost", + router_port=5001, + peers=[["localhost", 5002], ["localhost", 5003], ["localhost", 5004]], + ) + + Algo.create_model(args) # create client models and datasets + Algo.load_dataset(args, is_vqa=False) # global model and dataset + + for round_n in range(args.comm_rounds): + Algo.train(round_n) diff --git a/src/test.http b/src/test.http new file mode 100644 index 0000000..ba16fba --- /dev/null +++ b/src/test.http @@ -0,0 +1,36 @@ +### +POST http://localhost:2323/cream_api/inference +Content-Type: multipart/form-data; boundary=WebAppBoundary + +--WebAppBoundary +content-Disposition: form-data; name="file" + +/home/shannon/dev/projects/distributed_cream/uploads/52bbf136-11f5-46dc-be4f-33c9a75255f7.jpg +--WebAppBoundary +content-Disposition: form-data; name="captions" + +A bicycle replica with a clock as the front wheel. +A dog is sniffing at a plate of food on a table. +A mother and baby zebra walking in the grass. +The bike has a clock as a tire. +A motorcycle with its brake extended standing outside +--WebAppBoundary +content-Disposition: form-data; name="batch" + +False +--WebAppBoundary-- + + +### +POST http://localhost:2323/cream_api/upoad +Content-Type: application/x-www-form-urlencoded + + +### +POST http://localhost:2323/cream_api/inference +Content-Type: application/x-www-form-urlencoded + +config_path=/home/shannon/dev/projects/web/src/components/batch.json&batch=True + +### 获取详情 +GET http://localhost:5000/job?id=25 \ No newline at end of file diff --git a/src/utils/__pycache__/Reader.cpython-38.pyc b/src/utils/__pycache__/Reader.cpython-38.pyc deleted file mode 100644 index 65bc385..0000000 Binary files a/src/utils/__pycache__/Reader.cpython-38.pyc and /dev/null differ diff --git a/src/utils/__pycache__/Utils.cpython-38.pyc b/src/utils/__pycache__/Utils.cpython-38.pyc deleted file mode 100644 index 8c17d14..0000000 Binary files a/src/utils/__pycache__/Utils.cpython-38.pyc and /dev/null differ diff --git a/src/utils/__pycache__/__init__.cpython-38.pyc b/src/utils/__pycache__/__init__.cpython-38.pyc deleted file mode 100644 index fd1f356..0000000 Binary files a/src/utils/__pycache__/__init__.cpython-38.pyc and /dev/null differ diff --git a/src/utils/__pycache__/color_lib.cpython-38.pyc b/src/utils/__pycache__/color_lib.cpython-38.pyc deleted file mode 100644 index efc246d..0000000 Binary files a/src/utils/__pycache__/color_lib.cpython-38.pyc and /dev/null differ diff --git a/src/utils/__pycache__/config.cpython-38.pyc b/src/utils/__pycache__/config.cpython-38.pyc deleted file mode 100644 index 5bf1177..0000000 Binary files a/src/utils/__pycache__/config.cpython-38.pyc and /dev/null differ diff --git a/src/utils/__pycache__/helper.cpython-38.pyc b/src/utils/__pycache__/helper.cpython-38.pyc deleted file mode 100644 index 96c3022..0000000 Binary files a/src/utils/__pycache__/helper.cpython-38.pyc and /dev/null differ diff --git a/src/utils/__pycache__/load_datasets.cpython-38.pyc b/src/utils/__pycache__/load_datasets.cpython-38.pyc deleted file mode 100644 index 9baeebf..0000000 Binary files a/src/utils/__pycache__/load_datasets.cpython-38.pyc and /dev/null differ diff --git a/src/utils/__pycache__/logger.cpython-38.pyc b/src/utils/__pycache__/logger.cpython-38.pyc deleted file mode 100644 index cf0d02b..0000000 Binary files a/src/utils/__pycache__/logger.cpython-38.pyc and /dev/null differ diff --git a/src/utils/__pycache__/serialize_utils.cpython-38.pyc b/src/utils/__pycache__/serialize_utils.cpython-38.pyc deleted file mode 100644 index 8b2c12d..0000000 Binary files a/src/utils/__pycache__/serialize_utils.cpython-38.pyc and /dev/null differ diff --git a/src/utils/__pycache__/tensor_utils.cpython-38.pyc b/src/utils/__pycache__/tensor_utils.cpython-38.pyc deleted file mode 100644 index c96fb53..0000000 Binary files a/src/utils/__pycache__/tensor_utils.cpython-38.pyc and /dev/null differ diff --git a/src/utils/helper.py b/src/utils/helper.py index 48ee9d4..87ecffa 100644 --- a/src/utils/helper.py +++ b/src/utils/helper.py @@ -9,40 +9,40 @@ class AverageMeter(object): - """Computes and stores the average and current value""" - def __init__(self): - self.reset() + """Computes and stores the average and current value""" + def __init__(self): + self.reset() - def reset(self): - self.val = 0 - self.avg = 0 - self.sum = 0 - self.count = 0 + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 - def update(self, val, n=1): - self.val = val - self.sum += val * n - self.count += n - self.avg = self.sum / self.count + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count class Helper: #All directories are end with / @staticmethod def get_accuracy(output, target, topk=(1,)): - """Computes the precision@k for the specified values of k""" - maxk = max(topk) - batch_size = target.size(0) + """Computes the precision@k for the specified values of k""" + maxk = max(topk) + batch_size = target.size(0) - _, pred = output.topk(maxk, 1, True, True) - pred = pred.t() - correct = pred.eq(target.view(1, -1).expand_as(pred)) + _, pred = output.topk(maxk, 1, True, True) + pred = pred.t() + correct = pred.eq(target.view(1, -1).expand_as(pred)) - res = [] - for k in topk: - correct_k = correct[:k].view(-1).float().sum(0) - res.append(correct_k.mul_(100.0 / batch_size)) - return res + res = [] + for k in topk: + correct_k = correct[:k].view(-1).float().sum(0) + res.append(correct_k.mul_(100.0 / batch_size)) + return res @staticmethod def pairwise_L2(x, y): @@ -122,18 +122,25 @@ def backup_codes(src_d, tgt_d, save_types=['.py', '.txt', '.sh', '.out']): @staticmethod def try_make_dir(d): - if not os.path.isdir(d): - # os.mkdir(d) - os.makedirs(d) # nested is allowed + if not os.path.isdir(d): + # os.mkdir(d) + os.makedirs(d) # nested is allowed @staticmethod def get_hms(seconds): - m, s = divmod(seconds, 60) - h, m = divmod(m, 60) - return h, m, s + m, s = divmod(seconds, 60) + h, m = divmod(m, 60) + return h, m, s @staticmethod def set_seed(seed): + if seed == 0: + print("Random seed is used, cudnn.deterministic is set to False.") + torch.backends.cudnn.deterministic = False + torch.backends.cudnn.benchmark = True + return + + print(f"Seed {seed} is used, cudnn.deterministic is set to True.") torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) @@ -165,7 +172,7 @@ def write_dict2csv(log_dir, write_dict, mode="a"): raise ValueError("write_dict has wrong type") - ###======================== Visualization ================= ### + ###======================== Visualization ================= ### @staticmethod def save_images(samples, sample_dir, sample_name, offset=0, nrows=0): if nrows == 0: diff --git a/src/utils/load_datasets.py b/src/utils/load_datasets.py index 4492fd3..a988363 100644 --- a/src/utils/load_datasets.py +++ b/src/utils/load_datasets.py @@ -12,21 +12,24 @@ sys.path.append("../..") sys.path.append("../../..") -from src.datasets._dataloader import image_to_caption_collate_fn -from src.datasets.coco import CocoCaptionsCap +from src.custom_datasets._dataloader import image_to_caption_collate_fn +from src.custom_datasets.coco import CocoCaptionsCap +from src.algorithms.vqa_meta import VQAMetaData, unknown_category_id # COCO def prepare_coco_dataloaders(dataloader_config, dataset_root, + subset_num, # was hard coded to 50000 + client_subset_num, # was hard coded to 10000 vocab_path='./vocabs/coco_vocab.pkl', - num_workers=6, tsne=False, client=-1): + num_workers=12, tsne=False, client=-1): """Prepare MS-COCO Caption train / val / test dataloaders Args: dataloader_config (dict): configuration file which should contain "batch_size" dataset_root (str): root of your MS-COCO dataset (see README.md for detailed dataset hierarchy) vocab_path (str, optional): path for vocab pickle file (default: ./vocabs/coco_vocab.pkl). - num_workers (int, optional): num_workers for the dataloaders (default: 6) + num_workers (int, optional): num_workers for the dataloaders (default: 12) Returns: dataloaders (dict): keys = ["train", "val", "te"], values are the corresponding dataloaders. vocab (Vocabulary object): vocab object @@ -63,10 +66,11 @@ def prepare_coco_dataloaders(dataloader_config, cutout_prob=tr_cutout_prob, caption_drop_prob=tr_caption_drop_prob, subset=False, - client=client + client=client, + client_subset_num=client_subset_num ) else: - dataloaders['train_subset_50000'] = _get_coco_loader( + dataloaders[f'train_subset_{subset_num}'] = _get_coco_loader( image_root, train_ann, train_ids, vocab, num_workers=num_workers, batch_size=batch_size, train=True, @@ -74,10 +78,11 @@ def prepare_coco_dataloaders(dataloader_config, extra_ids=train_extra_ids, cutout_prob=tr_cutout_prob, caption_drop_prob=tr_caption_drop_prob, - subset=True + subset=True, + subset_num=subset_num ) - dataloaders['train_subset_eval_50000'] = _get_coco_loader( + dataloaders[f'train_subset_eval_{subset_num}'] = _get_coco_loader( image_root, train_ann, train_ids, vocab, num_workers=num_workers, batch_size=batch_size * 2, train=False, @@ -85,19 +90,24 @@ def prepare_coco_dataloaders(dataloader_config, extra_ids=train_extra_ids, cutout_prob=tr_cutout_prob, caption_drop_prob=tr_caption_drop_prob, - subset=True + subset=True, + subset_num=subset_num ) dataloaders['val'] = _get_coco_loader( image_root, val_ann, val_ids, vocab, num_workers=num_workers, batch_size=eval_batch_size, train=False, + #subset=True, #AttributeError: 'Subset' object has no attribute 'n_images' + #subset_num=subset_num ) dataloaders['test'] = _get_coco_loader( image_root, val_ann, te_ids, vocab, num_workers=num_workers, batch_size=eval_batch_size if not tsne else 200, train=False, + #subset=True, #AttributeError: 'Subset' object has no attribute 'n_images' + #subset_num=subset_num ) return dataloaders, vocab @@ -106,10 +116,10 @@ def prepare_coco_dataloaders(dataloader_config, def _get_coco_file_paths(dataset_root): """Select proper train / val classes and omit id files. """ - train_ids = np.load('./src/datasets/annotations/coco_train_ids.npy') - train_extra_ids = np.load('./src/datasets/annotations/coco_restval_ids.npy') - val_ids = np.load('./src/datasets/annotations/coco_dev_ids.npy')[:5000] - te_ids = np.load('./src/datasets/annotations/coco_test_ids.npy') + train_ids = np.load('./src/custom_datasets/annotations/coco_train_ids.npy') + train_extra_ids = np.load('./src/custom_datasets/annotations/coco_restval_ids.npy') + val_ids = np.load('./src/custom_datasets/annotations/coco_dev_ids.npy')[:5000] + te_ids = np.load('./src/custom_datasets/annotations/coco_test_ids.npy') image_root = os.path.join(dataset_root, 'allimages') train_ann = os.path.join(dataset_root, 'annotations/captions_train2014.json') @@ -130,7 +140,8 @@ def _get_coco_loader(image_root, caption_drop_prob=0.0, subset=False, subset_num=50000, - client=-1): + client=-1, + client_subset_num=10000): _image_transform = imagenet_transform( random_resize_crop=train, random_erasing_prob=cutout_prob, @@ -146,23 +157,29 @@ def _get_coco_loader(image_root, target_transform=_caption_transform, client=client) if subset: - if not os.path.exists('coco_subset_idx_file'): - full_idx = [i for i in range(566435)] + full_size = 566435 + subset_num = min(subset_num, full_size) + + subset_fn = f'coco_subset_idx_{subset_num}' + if not os.path.exists(subset_fn): + full_idx = [i for i in range(full_size)] random.shuffle(full_idx) - idx = full_idx[0: 50000] + idx = full_idx[0: subset_num] idx.sort() - if not os.path.exists('coco_subset_idx_file'): - with open('coco_subset_idx_file', 'wb') as f: + if not os.path.exists(subset_fn): + with open(subset_fn, 'wb') as f: pickle.dump(idx, f) - - if subset_num == 50000: - with open('coco_subset_idx_file', 'rb') as f: - idx = pickle.load(f) + + with open(subset_fn, 'rb') as f: + idx = pickle.load(f) coco_dataset = torch.utils.data.Subset(coco_dataset, idx) elif client > -1: - idx = [i for i in range(100000+client*10000, 110000+client*10000)] + size_per_client = 10000 # 10000 is the old hard coded value + size_per_client = min(subset_num,client_subset_num) + range_start = 100000+client*size_per_client + idx = [i for i in range(range_start, range_start + size_per_client)] coco_dataset = torch.utils.data.Subset(coco_dataset, idx) dataloader = DataLoader(coco_dataset, @@ -177,6 +194,40 @@ def _get_coco_loader(image_root, print(f'Loading COCO Caption: n_images {coco_dataset.n_images} n_captions {len(coco_dataset)}...') return dataloader +def vqa2_dataloader(dataset, + num_workers=12, + batch_size=64, + cutout_prob=0.0, + train=False, + filter_unknown=False, + meta:VQAMetaData = None): + transform = imagenet_transform( + random_resize_crop=train, + random_erasing_prob=cutout_prob, + handle_gray=True, + ) + if filter_unknown: + def filter_fn(example): + return meta.get_category_id(example['multiple_choice_answer']) != unknown_category_id + dataset = dataset.filter(filter_fn) + def collate_fn(): + def func(examples): + batch = {} + batch['image'] = torch.stack([transform(example['image']) for example in examples]) + batch['question'] = [example['question'] for example in examples] + batch['question_type'] = [example['question_type'] for example in examples] + batch['question_rest'] = [example['question'][len(example['question_type'])+1:] for example in examples] + batch['multiple_choice_answer'] = [example['multiple_choice_answer'] for example in examples] + batch['answers'] = [example['answers'] for example in examples] + return batch + return func + return DataLoader(dataset, + batch_size=batch_size, + shuffle=train, + num_workers=num_workers, + collate_fn=collate_fn(), + ) + def load_vocab(vocab_path): if isinstance(vocab_path, str): @@ -237,6 +288,7 @@ def imagenet_transform(resize_size=256, crop_size=224, random_resize_crop=False, random_erasing_prob=0.0, + handle_gray=False, custom_transforms=None): """Standard ImageNet transform with resize/crop/normalize. @@ -258,6 +310,10 @@ def imagenet_transform(resize_size=256, else: transform.append(transforms.Resize(resize_size)) transform.append(transforms.CenterCrop(crop_size)) + if handle_gray: + transform.append(transforms.Lambda( + lambda img: img.convert("RGB")), # Convert grayscale to RGB + ) transform.append(transforms.ToTensor()) transform.append(imagenet_normalize()) diff --git a/src/utils/serialize_utils.py b/src/utils/serialize_utils.py index 85e8d97..480e451 100644 --- a/src/utils/serialize_utils.py +++ b/src/utils/serialize_utils.py @@ -31,3 +31,4 @@ def torch_safe_load(module, state_dict, strict=True): module.load_state_dict({ k.replace('module.', ''): v for k, v in state_dict.items() }, strict=strict) + diff --git a/src/utils/util.py b/src/utils/util.py index 6f33cd1..46581c4 100644 --- a/src/utils/util.py +++ b/src/utils/util.py @@ -1,5 +1,19 @@ import copy import torch +import torch.nn as nn + +def print_model_tree(model:nn.Module, indent=0): + for name, module in model.named_children(): + # Print the module type and increase indentation + print(' ' * indent + f'{name}: {module.__class__.__name__}') + + # If the module has children, recursively print them + if list(module.children()): + print_model_tree(module, indent + 2) + else: + # If it's a leaf module, print the parameter shapes + for param_name, param in module.named_parameters(): + print(' ' * (indent + 2) + f'{param_name} shape:{param.shape} grad:{param.requires_grad}') def average_weights(w): @@ -21,4 +35,7 @@ def sum_weights(w): for key in w_sum.keys(): for i in range(1, len(w)): w_sum[key] += w[i][key] - return w_sum \ No newline at end of file + return w_sum + + + diff --git a/src/vqa.py b/src/vqa.py new file mode 100644 index 0000000..75344f8 --- /dev/null +++ b/src/vqa.py @@ -0,0 +1,24 @@ +import common + +if __name__ == "__main__": + + from algorithms.MMFL import MMFL + + args, wandb = common.prepare_args( + description="CreamFL Federated Learning for VQA task (local simulation)", + script="vqa", + is_vqa=True) + + Algo = MMFL(args, wandb) + + Algo.create_model(args) # create client models and datasets + Algo.load_dataset(args, is_vqa=True) # global model and dataset + + for round_n in range(args.comm_rounds): + Algo.train(round_n) + + #Algo.logger.log("Best:") + #Algo.engine.report_scores(step=args.comm_rounds, + # scores=Algo.best_scores, + # metadata=Algo.best_metadata, + # prefix=Algo.engine.eval_prefix) diff --git a/src/vqa_exp.py b/src/vqa_exp.py new file mode 100644 index 0000000..effb4f3 --- /dev/null +++ b/src/vqa_exp.py @@ -0,0 +1,413 @@ +# Experiment with VQA algorithms + +import os +import random +import heapq +import pickle +from tqdm import tqdm + +from datasets import load_dataset, load_from_disk +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision import transforms + +import sys +sys.path.append("./") +sys.path.append("../") +sys.path.append("../../") +sys.path.append("../../../") + +from src.networks.fusion_model import LinearFusionModelCategorical, VQAFusionModel + +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +print(f"device {device}") + +use_f16 = False +if device == torch.device("cuda"): + try: + from apex import amp + #print("enable f16 and using apex.amp for mixed precision training") + #use_f16 = True + except ImportError as e: + print('failed to import apex:', e) + +text_retrieval_cache = {} +@torch.no_grad() +def get_text_features(engine, text): + global text_retrieval_cache + if text in text_retrieval_cache: + return text_retrieval_cache[text] + text_retrieval_cache[text] = engine.text_forward([],text,0)['embedding'] + return text_retrieval_cache[text] + +@torch.no_grad() +def get_matching_text(features, top_k=5): + global text_retrieval_cache + min_heap = [] # Using a min heap to keep track of top matches + for text, text_features in text_retrieval_cache.items(): + match_score = F.cosine_similarity(features, text_features).item() + heapq.heappush(min_heap, (match_score, text)) + if len(min_heap) > top_k: + heapq.heappop(min_heap) # remove the worst score + top_matches = [(score, text) for score, text in sorted(min_heap, reverse=True)] + return top_matches + +unknown_category = "" +unknown_category_id = 0 + +category_list = [] +category_dict = {} +category_counts = [] +@torch.no_grad() +def get_category_id(cat, add_new=False): + global category_list + global category_dict + global category_counts + add_count = add_new # add count only when we are building the list of categories + if len(category_list) == 0: + category_dict[unknown_category] = unknown_category_id + category_list.append(unknown_category) + category_counts.append(0) + if cat in category_dict: + cat_id = category_dict[cat] + if add_count: + category_counts[cat_id] += 1 + return cat_id + if not add_new: + return unknown_category_id + category_dict[cat] = len(category_list) + category_list.append(cat) + category_counts.append(1) + return category_dict[cat] + +def reset_category_list(): + global category_list + global category_dict + global category_counts + category_list = [] + category_dict = {} + category_counts = [] + +@torch.no_grad() +def get_category_by_id(cat_id): + global category_list + return category_list[cat_id] + +def set_category_from_dataset(dataset): + #for item in tqdm(dataset.map(lambda example: {'multiple_choice_answer': example['multiple_choice_answer']})): + # get_category_id(item['multiple_choice_answer']) + dataset = dataset.map(lambda example: {'multiple_choice_answer': example['multiple_choice_answer']}) + dataloader = DataLoader(dataset, batch_size=2048, num_workers=32, + collate_fn=lambda examples: {'multiple_choice_answer': [example['multiple_choice_answer'] for example in examples]}) + set_category_from_dataloader(dataloader) + +def set_category_from_dataloader(dataloader): + for batch in tqdm(dataloader): + for answer in batch['multiple_choice_answer']: + get_category_id(answer, add_new=True) + +def build_or_load_categories_top(top = 3000): + global category_list + global category_dict + global category_counts + if len(category_list) != 0: + raise Exception("categories already loaded") + fn = f"vqa2_categories_{top}.pkl" + if os.path.exists(fn): + with open(fn, "rb") as f: + data = pickle.load(f) + category_list = data['category_list'] + category_counts = data['category_counts'] + category_dict = {cat: i for i, cat in enumerate(category_list)} + return + # extract common categories from train and validation datasets + set_category_from_dataset(load_dataset("HuggingFaceM4/VQAv2", split="train")) + cutoff = heapq.nlargest(top, category_counts)[-1] + train_list = category_list + train_counts = category_counts + reset_category_list() + for i, cat in enumerate(train_list): + if train_counts[i] >= cutoff: + id = get_category_id(cat, add_new=True) + category_counts[id] = train_counts[i] + assert len(category_list) == top + 1 # top categories + 1 unknown + with open(fn, "wb") as f: + data = {'category_list': category_list, 'category_counts': category_counts} + pickle.dump(data, f) + + +def build_or_load_categories(): + global category_list + global category_dict + global category_counts + if len(category_list) != 0: + raise Exception("categories already loaded") + fn = "vqa2_categories_common_count.pkl" + if os.path.exists(fn): + with open(fn, "rb") as f: + data = pickle.load(f) + category_list = data['category_list'] + category_counts = data['category_counts'] + category_dict = {cat: i for i, cat in enumerate(category_list)} + return + # extract common categories from train and validation datasets + set_category_from_dataset(load_dataset("HuggingFaceM4/VQAv2", split="train")) + train_list = category_list + train_counts = category_counts + reset_category_list() + set_category_from_dataset(load_dataset("HuggingFaceM4/VQAv2", split="validation")) + validation_dict = category_dict + reset_category_list() + print(f"train categories {len(train_list)}, validation categories {len(validation_dict)}") + unknowns = 0 + for i, cat in enumerate(train_list): + if cat in validation_dict: + cat_id = get_category_id(cat, add_new=True) + category_counts[cat_id] = train_counts[i] + else: + unknowns += train_counts[i] + category_counts[0] = unknowns + print(f"common categories {len(category_list)}") + with open(fn, "wb") as f: + data = {'category_list': category_list, 'category_counts': category_counts} + pickle.dump(data, f) + +transform = transforms.Compose([ + transforms.Resize(224), + transforms.CenterCrop(224), # make all images the same size + transforms.Lambda(lambda img: img.convert("RGB")), # Convert grayscale to RGB + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # Normalizes the image + ]) + +random_transform = transforms.Compose([ + transform, + transforms.RandomErasing(p=0.5, value='random', scale=(0.05, 0.20), ratio=(0.3, 3.3)), + transforms.RandomRotation(10), + transforms.RandomErasing(p=1, scale=(0.05, 0.20), ratio=(0.3, 3.3)), + transforms.RandomResizedCrop(224), + ]) + +def split_image_into_4_parts(image): + width, height = image.size + mid_width, mid_height = width / 2, height / 2 + overlap_width = width * 0.1 / 2 + overlap_height = height * 0.1 / 2 + + top_left = (0, 0, mid_width + overlap_width, mid_height + overlap_height) + top_right = (mid_width - overlap_width, 0, width, mid_height + overlap_height) + bottom_left = (0, mid_height - overlap_height, mid_width + overlap_width, height) + bottom_right = (mid_width - overlap_width, mid_height - overlap_height, width, height) + + return [image.crop(box) for box in [top_left, top_right, bottom_left, bottom_right]] + +def prepare_question(is_train,question): + return question # disable question transformation + # if is_train: + # words = question.split() + # duplicated = words + words + # for i in random.sample(range(len(duplicated)), random.randint(0, 1)): + # duplicated[i] = "" + # return " ".join(" ".join(duplicated).split()) + # return question + " " + question + +def collate_fn(is_train: bool): + t = random_transform if is_train else transform + + def func(examples): + batch = {} + if 'image' in examples[0]: + batch['image'] = torch.stack([t(example['image']) for example in examples]) + batch['sub_images'] = torch.stack([ + torch.stack([t(sub_image) + for sub_image in split_image_into_4_parts(example['image'])]) + for example in examples + ]) + batch['question'] = [prepare_question(is_train,example['question']) for example in examples] + batch['question_type'] = [example['question_type'] for example in examples] + batch['question_rest'] = [example['question'][len(example['question_type'])+1:] for example in examples] + batch['multiple_choice_answer'] = [example['multiple_choice_answer'] for example in examples] + return batch + return func + +@torch.no_grad() +def process_retrieval_batch(batch): + # Transform and move the batch of images to the device + images = torch.stack([transform(image) for image in batch['image']]).to(device) + questions = batch['question'] + + # Forward pass with the batch of images and questions + batch_output, _ = retrieval_model.forward(images, [], questions, 0) + batch['image_features'] = batch_output['image_features'] + batch['caption_features'] = batch_output['caption_features'] + + # Remove the 'image' column from the batch + del batch['image'] + + return batch + +def validation(n, fusion_model, validation_dataloader): + right = 0 + unknown_outputs = 0 + unknown_answers = 0 + unknown_unknown = 0 + total = 0 + for j, testBatch in tqdm(enumerate(validation_dataloader)): + answers = testBatch['multiple_choice_answer'] + outputs, _ = fusion_model.forward(testBatch) + for k, answer in enumerate(answers): + answer_id = get_category_id(answer) + #top_matches = get_matching_text(outputs[k], top_k=5) + #if answer == top_matches[0][1]: + # right += 1 + _, top_matches = torch.topk(outputs[k], 5, largest=True, sorted=True) + top_match_names = [get_category_by_id(cat_id.item()) for cat_id in top_matches] + if top_match_names[0] == answer: + right += 1 + if answer_id == unknown_category_id: + unknown_answers += 1 + answer = unknown_category + answer # mark answers not in the training set + if top_matches[0] == unknown_category_id: + unknown_outputs += 1 + if answer_id == unknown_category_id and top_matches[0] == unknown_category_id: + unknown_unknown += 1 + if total + k < 8: + tqdm.write(f"j {j}, k {k}, expected {answer}, got {top_match_names}") + total += len(answers) + if total >= n: + break + accuracy = right / total + tqdm.write(f"test accuracy {right}/{total}={accuracy}, unknown_answers:{unknown_answers}, unknown_outputs:{unknown_outputs}, unknown_unknown:{unknown_unknown}") + +if __name__ == "__main__": + #with torch.autocast(device_type=device.type): + import common + args, wandb = common.prepare_args( + description="VQA for CreamFL Federated Learning (local simulation)", + is_vqa=True) + + base_path = args.vqa_pretrained_base_model + print(f"loading pretrained img txt model from {base_path}. Full path {os.path.abspath(base_path)}") + from src.algorithms.retrieval_trainer import TrainerEngine + retrieval_engine = TrainerEngine() + print(f" load COCOEvaluator") + from src.algorithms.eval_coco import COCOEvaluator + evaluator = COCOEvaluator(eval_method='matmul', + verbose=True, + eval_device='cuda', + n_crossfolds=5) + print(f" load models2") + retrieval_engine.load_models2("./sl2-best_model.pt", evaluator) + retrieval_engine.model_to_device() + #if use_f16: + # retrieval_engine.to_half() + retrieval_model = retrieval_engine.model + + if args.vqa_pretrained_eval: + print(f"loading coco test set") + dataset_root = os.environ['HOME'] + '/data/mmdata/MSCOCO/2014' + vocab_path = './src/custom_datasets/vocabs/coco_vocab.pkl' + from src.utils.config import parse_config + config = parse_config("./src/coco.yaml", strict_cast=False) + from src.utils.load_datasets import prepare_coco_dataloaders + dataloaders, vocab = prepare_coco_dataloaders( + config.dataloader, dataset_root, args.pub_data_num, args.max_size, vocab_path + ) + test_dataloader = dataloaders['test'] + print(f"evaluate coco test set") + test_scores = retrieval_engine.evaluate({'test': test_dataloader}) + print(f"test scores {test_scores}") + + num_workers = 16 + if use_f16: + num_workers = 16 # f16 requires more workers to keep the GPU busy + + print(f"loading vqa2 categories and category weights") + build_or_load_categories() + print(f" category_list size:{len(category_list)}") + print(f" category_list:{category_list[:10]}") + print(f" category_count:{category_counts[:10]}") + + print(f"loading vqa2 dataset") + vqa2_train = load_dataset("HuggingFaceM4/VQAv2", split="train") + # precalculate the forward pass on the base retrieval model + # vqa2_train = vqa2_train.map( + # process_retrieval_batch, + # batched=True, batch_size=32, + # ) + + vqa2_dataloader = DataLoader(vqa2_train, batch_size=args.batch_size, shuffle=True, collate_fn=collate_fn(random_transform), num_workers=num_workers) + + vqa2_test = load_dataset("HuggingFaceM4/VQAv2", split="validation[:10000]") + test_batch_size = 100 + if args.batch_size < test_batch_size: + test_batch_size = 10 + vqa2_test_dataloader = DataLoader(vqa2_test, batch_size=test_batch_size, collate_fn=collate_fn(transform), num_workers=num_workers) + + print(f'init vqa fusion model "{args.vqa_fusion_network}"') + fusion_model = None + if args.vqa_fusion_network == "linear": + fusion_model = LinearFusionModelCategorical(retrieval_model,2+4+2, len(category_list), args.vqa_hidden_sizes, args.vqa_input_type,dropout_rate=args.vqa_dropout).to(device) + elif args.vqa_fusion_network == "vqa1": + fusion_model = VQAFusionModel(device, retrieval_model,5,3, len(category_list), args.vqa_hidden_sizes, dropout_rate=args.vqa_dropout).to(device) + else: + print(f'vqa_fusion_network "{args.vqa_fusion_network}" is not supported') + exit(1) + + total_count = sum(category_counts) + epsilon = 1000 # 1e-8 # Small value to prevent division by zero + total_count = total_count + epsilon * len(category_counts) + category_weights = [total_count / (class_count + epsilon) for class_count in category_counts] + weights_tensor = torch.tensor(category_weights).to(device) + loss_function = torch.nn.CrossEntropyLoss(weight=weights_tensor) + + optimizer = torch.optim.Adam(fusion_model.parameters(), lr=args.vqa_lr, weight_decay=args.vqa_weight_decay) + + if use_f16: + fusion_model, optimizer = amp.initialize(fusion_model, optimizer, opt_level="O2") + + n = 0 + + loss_avg = 0 + + use_embed_loss = False + + for epoch in range(1,args.vqa_epochs+1): + print(f"epoch {epoch}") + if epoch >= args.vqa_unfreeze_base_epoch and fusion_model.frozen_base_model: + print(f"unfreeze base model") + fusion_model.unfreeze_base_model() + n = 0 + with tqdm(enumerate(vqa2_dataloader), total=len(vqa2_dataloader)) as progress_bar: + for i, batch in progress_bar: + optimizer.zero_grad() + outputs, last_features = fusion_model.forward(batch) + answers = batch['multiple_choice_answer'] + targets = torch.tensor([get_category_id(answer) for answer in answers]).to(device) + loss = loss_function(outputs, targets) + + if last_features.shape[1] == retrieval_model.embed_dim: + if not use_embed_loss: + progress_bar.write("using embedding loss") # only print once + use_embed_loss = True + target_last_features = torch.stack([get_text_features(retrieval_model, answer) for answer in answers], dim=0).to(device) + loss += 1 - F.cosine_similarity(last_features, target_last_features).mean() + + if use_f16: + with amp.scale_loss(loss, optimizer) as scaled_loss: + scaled_loss.backward() + else: + loss.backward() + loss_avg_rate = max(i, 99) + loss_avg = (loss_avg * loss_avg_rate + loss.item()) / (loss_avg_rate + 1) + optimizer.step() + progress_bar.set_description(f"Epoch {epoch}, Iter {i}, l100: {loss_avg:.4f}") + + if epoch == 1 and (i+1+(epoch-1)*len(vqa2_dataloader)) % (128*2**n) == 0: + validation(1000, fusion_model, vqa2_test_dataloader) + n += 1 + validation(10000, fusion_model, vqa2_test_dataloader) + + + \ No newline at end of file