Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
163 changes: 79 additions & 84 deletions web/src/components/code.astro
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def main():

summary_writer.add_scalar(
tag="loss_for_each_batch",
scalar=running_loss,
scalar=running_loss,
global_step=global_step
)
running_loss = 0.0
Expand Down Expand Up @@ -175,70 +175,65 @@ def main():
if __name__ == "__main__":
main()`;

const serverCode_pt = `
from nvflare.app_common.workflows.base_fedavg import BaseFedAvg

class FedAvg(BaseFedAvg):
def run(self) -> None:
self.info("Start FedAvg.")

model = self.load_model()
model.start_round = self.start_round
model.total_rounds = self.num_rounds

for self.current_round in range(self.start_round, self.start_round + self.num_rounds):
self.info(f"Round {self.current_round} started.")
model.current_round = self.current_round

clients = self.sample_clients(self.num_clients)

results = self.send_model_and_wait(targets=clients, data=model)

aggregate_results = self.aggregate(results)
const modelCode_pt = `
import torch
import torch.nn as nn
import torch.nn.functional as F

model = self.update_model(model, aggregate_results)

self.save_model(model)
class SimpleNetwork(nn.Module):
def __init__(self):
super(SimpleNetwork, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(6, 16, 5)
self.fc1 = nn.Linear(16 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)

self.info("Finished FedAvg.")
def forward(self, x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = torch.flatten(x, 1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
`;

const jobCode_pt = `
from cifar10_pt_fl import Net
from model import SimpleNetwork

from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.job_config.base_fed_job import BaseFedJob
from nvflare.job_config.script_runner import ScriptRunner
from nvflare.app_opt.pt.recipes.fedavg import FedAvgRecipe
from nvflare.recipe import SimEnv, add_experiment_tracking

if __name__ == "__main__":

def main():
n_clients = 2
num_rounds = 2
train_script = "cifar10_pt_fl.py"

# Create BaseFedJob with model
job = BaseFedJob(
name="cifar10_pt_fedavg",
model=Net(),
)

# Define the controller and send to server
controller = FedAvg(
num_clients=n_clients,
recipe = FedAvgRecipe(
name="hello-pt",
min_clients=n_clients,
num_rounds=num_rounds,
model=SimpleNetwork(),
train_script="client.py",
train_args="--batch_size 16",
)
job.to_server(controller)
add_experiment_tracking(recipe, tracking_type="tensorboard")

# Add clients
for i in range(n_clients):
runner = ScriptRunner(script=train_script)
job.to(runner, f"site-{i}")
env = SimEnv(num_clients=n_clients)
run = recipe.execute(env)
print("Job Status is:", run.get_status())
print("Result can be found in:", run.get_result())

# job.export_job("/tmp/nvflare/jobs/job_config")
job.simulator_run("/tmp/nvflare/jobs/workdir", gpu="0")

if __name__ == "__main__":
main()
`;

const runCode_pt = `
python3 fedavg_cifar10_pt_job.py
python job.py
`;

// Lightning Code Sections --------------------------------------------------
Expand Down Expand Up @@ -530,7 +525,7 @@ def main():
# (5) evaluate aggregated/received model
_, test_global_acc = model.evaluate(test_images, test_labels, verbose=2)
print(
f"Accuracy of the received model on round {input_model.current_round} on the 10000 test images:
f"Accuracy of the received model on round {input_model.current_round} on the 10000 test images:
{test_global_acc * 100} %"
)

Expand Down Expand Up @@ -628,8 +623,8 @@ python3 fedavg_cifar10_tf_job.py
const frameworks = [
{
id: "pytorch",
colab_link: `https://colab.research.google.com/github/NVIDIA/NVFlare/blob/${gh_branch}/examples/getting_started/pt/nvflare_pt_getting_started.ipynb`,
github_link: `https://github.com/NVIDIA/NVFlare/blob/${gh_branch}/examples/getting_started/pt/nvflare_pt_getting_started.ipynb`,
colab_link: `https://colab.research.google.com/github/NVIDIA/NVFlare/blob/${gh_branch}/examples/hello-world/hello-pt/hello-pt.ipynb`,
github_link: `https://github.com/NVIDIA/NVFlare/blob/${gh_branch}/examples/hello-world/hello-pt/hello-pt.ipynb`,
sections: [
{
id: "install-pytorch",
Expand All @@ -644,29 +639,28 @@ const frameworks = [
id: "client-pytorch",
type: "client",
framework: "pytorch",
title: "Client Code (cifar10_pt_fl.py)",
title: "Client Code (client.py)",
description:
"We use the Client API to convert the centralized training PyTorch code into federated learning code with only a few lines of changes highlighted below. Essentially the client will receive a model from NVIDIA FLARE, perform local training and validation, and then send the model back.",
"Use the Client API to convert your training script into federated learning code. The client receives the global model from FLARE, performs local training and validation, and sends the updated model back.",
code: clientCode_pt,
highlighted_lines: "29,58,61,63,139-143,145",
},
{
id: "server-pytorch",
type: "server",
id: "model-pytorch",
type: "model",
framework: "pytorch",
title: "Server Code (fedavg.py)",
title: "Model (model.py)",
description:
"The ModelController API is used to write a federated routine with mechanisms for distributing and receiving models from clients. Here we implement the basic FedAvg algorithm using some helper functions from BaseFedAvg.",
code: serverCode_pt,
highlighted_lines: "7,17,23",
"Model definition lives in model.py and is referenced by both the client and the job recipe.",
code: modelCode_pt,
},
{
id: "job-pytorch",
type: "job",
framework: "pytorch",
title: "Job Code (fedavg_cifar10_pt_job.py)",
title: "Job (job.py)",
description:
"Lastly we construct the job with our 'cifar10_pt_fl.py' client script and 'FedAvg' server controller. The BaseFedJob automatically configures components for model persistence, model selection, and TensorBoard streaming. We then run the job with the FL simulator.",
"The Recipe API defines the FL job in Python: FedAvgRecipe with model, client script, and options. No separate server file — run with the simulator via recipe.execute(SimEnv(...)).",
code: jobCode_pt,
},
{
Expand All @@ -675,15 +669,15 @@ const frameworks = [
framework: "pytorch",
title: "Run the Job",
description:
"To run the job with the simulator, copy the code and execute the job script, or run in Google Colab. Alternatively, export the job to a configuration and run the job in other modes.",
"From the example directory, run: python job.py. Or open the notebook in Google Colab.",
code: runCode_pt,
},
],
},
{
id: "lightning",
colab_link: `https://colab.research.google.com/github/NVIDIA/NVFlare/blob/${gh_branch}/examples/getting_started/pt/nvflare_lightning_getting_started.ipynb`,
github_link: `https://github.com/NVIDIA/NVFlare/blob/${gh_branch}/examples/getting_started/pt/nvflare_lightning_getting_started.ipynb`,
colab_link: `https://colab.research.google.com/github/NVIDIA/NVFlare/blob/${gh_branch}/examples/hello-world/hello-lightning/hello_lightning.ipynb`,
github_link: `https://github.com/NVIDIA/NVFlare/blob/${gh_branch}/examples/hello-world/hello-lightning/hello_lightning.ipynb`,
sections: [
{
id: "install-lightning",
Expand Down Expand Up @@ -736,8 +730,8 @@ const frameworks = [
},
{
id: "tensorflow",
colab_link: `https://colab.research.google.com/github/NVIDIA/NVFlare/blob/${gh_branch}/examples/getting_started/tf/nvflare_tf_getting_started.ipynb`,
github_link: `https://github.com/NVIDIA/NVFlare/blob/${gh_branch}/examples/getting_started/tf/nvflare_tf_getting_started.ipynb`,
colab_link: `https://colab.research.google.com/github/NVIDIA/NVFlare/blob/${gh_branch}/examples/hello-world/hello-pt/hello-pt.ipynb`,
github_link: `https://github.com/NVIDIA/NVFlare/tree/${gh_branch}/examples/hello-world/hello-tf`,
sections: [
{
id: "install-tensorflow",
Expand Down Expand Up @@ -842,7 +836,7 @@ const frameworks = [
</div>

<!-- Google Colab Button -->
<a id="colab-button" href=`https://colab.research.google.com/github/NVIDIA/NVFlare/blob/${gh_branch}/examples/getting_started/pt/nvflare_pt_getting_started.ipynb` target="_blank" rel="noopener noreferrer" class="text-xs font-semibold text-gray-900 m-0.5 hover:bg-gray-100 rounded-lg py-2 px-2.5 inline-flex items-center justify-center bg-white border-gray-200 border">
<a id="colab-button" href=`https://colab.research.google.com/github/NVIDIA/NVFlare/blob/${gh_branch}/examples/hello-world/hello-pt/hello-pt.ipynb` target="_blank" rel="noopener noreferrer" class="text-xs font-semibold text-gray-900 m-0.5 hover:bg-gray-100 rounded-lg py-2 px-2.5 inline-flex items-center justify-center bg-white border-gray-200 border">
<span id="default-message" class="inline-flex items-center">
<svg class="w-0 h-3" aria-hidden="true" xmlns="http://www.w3.org/2000/svg" fill="currentColor" viewBox="0 0 18 20">
<img class="w-6 me-1.5" src={GoogleColab.src} alt="NVIDIA logo">
Expand All @@ -852,7 +846,7 @@ const frameworks = [
</a>

<!-- View on Github Button -->
<a id="github-button" href=`https://github.com/NVIDIA/NVFlare/blob/${gh_branch}/examples/getting_started/pt/nvflare_pt_getting_started.ipynb` target="_blank" rel="noopener noreferrer" class="text-xs font-semibold text-gray-900 m-0.5 hover:bg-gray-100 rounded-lg py-2 px-2.5 inline-flex items-center justify-center bg-white border-gray-200 border">
<a id="github-button" href=`https://github.com/NVIDIA/NVFlare/blob/${gh_branch}/examples/hello-world/hello-pt/hello-pt.ipynb` target="_blank" rel="noopener noreferrer" class="text-xs font-semibold text-gray-900 m-0.5 hover:bg-gray-100 rounded-lg py-2 px-2.5 inline-flex items-center justify-center bg-white border-gray-200 border">
<span id="default-message" class="inline-flex items-center">
<svg class="w-4 h-4 text-black mr-2" aria-hidden="true" xmlns="http://www.w3.org/2000/svg" fill="currentColor" viewBox="0 0 24 24">
<path fill-rule="evenodd" d="M12.006 2a9.847 9.847 0 0 0-6.484 2.44 10.32 10.32 0 0 0-3.393 6.17 10.48 10.48 0 0 0 1.317 6.955 10.045 10.045 0 0 0 5.4 4.418c.504.095.683-.223.683-.494 0-.245-.01-1.052-.014-1.908-2.78.62-3.366-1.21-3.366-1.21a2.711 2.711 0 0 0-1.11-1.5c-.907-.637.07-.621.07-.621.317.044.62.163.885.346.266.183.487.426.647.71.135.253.318.476.538.655a2.079 2.079 0 0 0 2.37.196c.045-.52.27-1.006.635-1.37-2.219-.259-4.554-1.138-4.554-5.07a4.022 4.022 0 0 1 1.031-2.75 3.77 3.77 0 0 1 .096-2.713s.839-.275 2.749 1.05a9.26 9.26 0 0 1 5.004 0c1.906-1.325 2.74-1.05 2.74-1.05.37.858.406 1.828.101 2.713a4.017 4.017 0 0 1 1.029 2.75c0 3.939-2.339 4.805-4.564 5.058a2.471 2.471 0 0 1 .679 1.897c0 1.372-.012 2.477-.012 2.814 0 .272.18.592.687.492a10.05 10.05 0 0 0 5.388-4.421 10.473 10.473 0 0 0 1.313-6.948 10.32 10.32 0 0 0-3.39-6.165A9.847 9.847 0 0 0 12.007 2Z" clip-rule="evenodd"/>
Expand All @@ -866,7 +860,7 @@ const frameworks = [
<div id="install-wrapper" class="mx-auto max-w-5xl py-4 text-left"></div>

<div class="mx-auto max-w-[1500px] py-4 text-left flex flex-col md:flex-row">
<!-- Client, Server, Job Tabs -->
<!-- Client, Model, Job Tabs -->
<ul id="code-tabs" class="flex md:flex-col w-full flex-wrap md:w-auto h-fit md:border-r border-b md:border-b-0 space-y-0 md:space-y-2 space-x-2 md:space-x-0 text-xl font-medium md:mr-14 ms-1" id="default-styled-tab" data-tabs-toggle="#default-styled-tab-content" data-tabs-active-classes="stroke-nvidia text-nvidia hover:text-nvidia border-b-2 md:border-nvidia border-nvidia" data-tabs-inactive-classes="stroke-gray-500 hover:stroke-gray-600 text-gray-500 hover:text-gray-600 border-transparent hover:border-gray-300" role="tablist">
<li class="flex-1" role="presentation">
<button class="w-36 md:w-36 inline-block p-5 border-r-0 md:border-r-2 md:border-b-0 font-bold text-left" id="client-tab" data-tabs-target="#client-wrapper" type="button" role="tab" aria-controls="client" aria-selected="false">
Expand All @@ -879,12 +873,12 @@ const frameworks = [
</button>
</li>
<li class="flex-1" role="presentation">
<button class="w-36 md:w-36 inline-block p-5 border-r-0 md:border-r-2 md:border-b-0 font-bold text-left" id="server-tab" data-tabs-target="#server-wrapper" type="button" role="tab" aria-controls="server" aria-selected="false">
<button class="w-36 md:w-36 inline-block p-5 border-r-0 md:border-r-2 md:border-b-0 font-bold text-left" id="model-tab" data-tabs-target="#model-wrapper" type="button" role="tab" aria-controls="model" aria-selected="false">
<span id="default-message" class="inline-flex items-center py-2.5">
<svg class="w-8 h-8 text-gray-800 mr-2" aria-hidden="true" xmlns="http://www.w3.org/2000/svg" width="24" height="24" fill="none" viewBox="0 0 24 24">
<path stroke="" stroke-linecap="round" stroke-linejoin="round" stroke-width="2" d="M5 12a1 1 0 0 0-1 1v4a1 1 0 0 0 1 1h14a1 1 0 0 0 1-1v-4a1 1 0 0 0-1-1M5 12h14M5 12a1 1 0 0 1-1-1V7a1 1 0 0 1 1-1h14a1 1 0 0 1 1 1v4a1 1 0 0 1-1 1m-2 3h.01M14 15h.01M17 9h.01M14 9h.01"/>
</svg>
Server
Model
</span>
</button>
</li>
Expand All @@ -899,11 +893,11 @@ const frameworks = [
</button>
</li>
</ul>

<!-- Client wrapper -->
<div id="client-wrapper" class="py-4 text-left overflow-x-auto"></div>
<!-- Server wrapper -->
<div id="server-wrapper" class="py-4 text-left overflow-x-auto hidden"></div>
<!-- Model wrapper (Recipe: model.py; legacy server content for Lightning/TF) -->
<div id="model-wrapper" class="py-4 text-left overflow-x-auto hidden"></div>
<!-- Job wrapper -->
<div id="job-wrapper" class="py-4 text-left overflow-x-auto hidden"></div>

Expand All @@ -927,33 +921,35 @@ const frameworks = [
const googleColab = document.getElementById('colab-button');
const githubButton = document.getElementById('github-button');

const codeTabs = document.getElementById("code-tabs");
const codeTabs = document.getElementById("code-tabs");

var sectionMap = {
"install": {
wrapper: document.getElementById('install-wrapper'),
wrapper: document.getElementById('install-wrapper'),
elements: [],
},
"client": {
wrapper: document.getElementById('client-wrapper'),
wrapper: document.getElementById('client-wrapper'),
elements: [],
},
"model": {
wrapper: document.getElementById('model-wrapper'),
elements: [],
},
"server": {
wrapper: document.getElementById('server-wrapper'),
wrapper: document.getElementById('model-wrapper'),
elements: [],
},
"job": {
wrapper: document.getElementById('job-wrapper'),
wrapper: document.getElementById('job-wrapper'),
elements: [],
},
"run": {
wrapper: document.getElementById('run-wrapper'),
wrapper: document.getElementById('run-wrapper'),
elements: [],
}
};

console.log(sectionMap);

// Loop over the code sections and create the code elements
frameworks.forEach((framework) => {
framework.sections.forEach((code_section) => {
Expand All @@ -962,7 +958,7 @@ const frameworks = [
sectionMap[code_section.type].elements.push(
{
section: code_section,
code: codeElement,
code: codeElement,
framework: code_section.framework
}
);
Expand All @@ -986,7 +982,6 @@ const frameworks = [
code_height = "h-[700px]";
if (code_section.hasOwnProperty("highlighted_lines")) {
highlighted_lines = "data-line=\"" + code_section.highlighted_lines + "\" "
console.log(code_section.title, code_section.type, code_section.hasOwnProperty("highlighted_lines"), highlighted_lines)
}
}

Expand Down Expand Up @@ -1043,7 +1038,7 @@ const frameworks = [
googleColab?.setAttribute("href", framework.colab_link);
githubButton?.setAttribute("href", framework.github_link);
}
});
});
}

</script>
20 changes: 10 additions & 10 deletions web/src/components/gettingStarted.astro
Original file line number Diff line number Diff line change
Expand Up @@ -17,31 +17,31 @@ const walkthrough = [
{
id: "step2",
step: "Step 2",
title: "Server Code",
title: "Job (Recipe)",
description:
"Use the ModelController API to write a federated control flow for FedAvg.",
"Use the Recipe API to define the FL job in Python: job.py with FedAvgRecipe, model, and client script — no separate server file.",
button_text: "View Source",
link: `https://github.com/NVIDIA/NVFlare/blob/${gh_branch}/nvflare/app_common/workflows/fedavg.py`,
link: `https://github.com/NVIDIA/NVFlare/blob/${gh_branch}/examples/hello-world/hello-pt/job.py`,
video: "https://developer.download.nvidia.com/assets/Clara/flare/Flare%202.5.0%20Getting%20Started%20-%20Part%202%20-%20Server.mp4",
},
{
id: "step3",
step: "Step 3",
title: "Client Code",
description:
"Use the Client API to write local training code for a PyTorch CIFAR-10 trainer.",
"Use the Client API to convert your training script into federated learning code (client.py).",
button_text: "View Source",
link: `https://github.com/NVIDIA/NVFlare/blob/${gh_branch}/examples/getting_started/pt/src/cifar10_fl.py`,
link: `https://github.com/NVIDIA/NVFlare/blob/${gh_branch}/examples/hello-world/hello-pt/client.py`,
video: "https://developer.download.nvidia.com/assets/Clara/flare/Flare%202.5.0%20Getting%20Started%20-%20Part%203%20-%20Client.mp4",
},
{
id: "step4",
step: "Step 4",
title: "FedJob and Simulator",
title: "Model & Run",
description:
"Formulate the NVIDIA FLARE job and simulate a federated run with the multi-process simulator.",
button_text: "View Notebook",
link: `https://colab.research.google.com/github/NVIDIA/NVFlare/blob/${gh_branch}/examples/getting_started/pt/nvflare_pt_getting_started.ipynb`,
"Model lives in model.py. Run the job with the simulator: python job.py, or open the hello-pt notebook in Colab.",
button_text: "View Example",
link: `https://github.com/NVIDIA/NVFlare/tree/${gh_branch}/examples/hello-world/hello-pt`,
video: "https://developer.download.nvidia.com/assets/Clara/flare/Flare%202.5.0%20Getting%20Started%20-%20Part%204%20-%20Job.mp4",
},
{
Expand Down Expand Up @@ -197,7 +197,7 @@ const series = [
</div>
))
}

</div>
</div>

Expand Down
Loading
Loading