-
Notifications
You must be signed in to change notification settings - Fork 244
Expand file tree
/
Copy pathjob.py
More file actions
77 lines (67 loc) · 2.88 KB
/
job.py
File metadata and controls
77 lines (67 loc) · 2.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This code shows how to use NVIDIA FLARE Job Recipe to connect both Federated learning client and server algorithm
and run it under different environments
"""
import argparse
from nvflare.app_common.np.recipes.fedavg import NumpyFedAvgRecipe
from nvflare.client.config import TransferType
from nvflare.recipe import SimEnv, add_experiment_tracking
def define_parser():
parser = argparse.ArgumentParser()
parser.add_argument("--n_clients", type=int, default=2)
parser.add_argument("--num_rounds", type=int, default=3)
parser.add_argument("--update_type", type=str, default="full", choices=["full", "diff"])
parser.add_argument("--launch_process", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument("--export_config", action=argparse.BooleanOptionalAction, default=False)
parser.add_argument(
"--log_config",
type=str,
default=None,
help="Log config mode ('concise', 'full', 'verbose'), filepath to a log config json file, or level (info, debug, error, etc.)",
)
return parser.parse_args()
def main():
args = define_parser()
n_clients = args.n_clients
num_rounds = args.num_rounds
launch_process = args.launch_process
train_args = f"--update_type {args.update_type}"
recipe = NumpyFedAvgRecipe(
name="hello-numpy",
min_clients=n_clients,
num_rounds=num_rounds,
# Model can be array or None (if using initial_ckpt):
model=[[1, 2, 3], [4, 5, 6], [7, 8, 9]],
# For pre-trained weights: initial_ckpt="/server/path/to/model.npy",
train_script="client.py",
train_args=train_args,
launch_external_process=launch_process,
params_transfer_type=TransferType.FULL if args.update_type == "full" else TransferType.DIFF,
)
add_experiment_tracking(recipe, tracking_type="tensorboard")
if args.export_config:
job_dir = "/tmp/nvflare/jobs/job_config"
recipe.export(job_dir)
print(f"Job config exported to {job_dir}")
else:
env = SimEnv(num_clients=n_clients, log_config=args.log_config)
run = recipe.execute(env)
print()
print("Result can be found in :", run.get_result())
print("Job Status is:", run.get_status())
print()
if __name__ == "__main__":
main()