Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions examples/quickstart-tensorflow-timeseries/pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "tftsexample"
version = "0.0.0"
description = "Federated Learning on time series with Tensorflow/Keras and Flower (Quickstart Example)"
license = "Apache-2.0"
dependencies = [
"flwr[simulation]>=1.24.0",
"flwr-datasets[vision]>=0.5.0",
"tensorflow>=2.9.1, != 2.11.1 ; (platform_machine == \"x86_64\" or platform_machine == \"aarch64\")",
"tensorflow-macos>=2.9.1, != 2.11.1 ; sys_platform == \"darwin\" and platform_machine == \"arm64\"",
]
[tool.hatch.build.targets.wheel]
packages = ["."]

[tool.flwr.app]
publisher = "flwrlabs"

[tool.flwr.app.components]
serverapp = "tfexample.server_app:app"
clientapp = "tfexample.client_app:app"

[tool.flwr.app.config]
num-server-rounds = 3
local-epochs = 1
batch-size = 32
learning-rate = 0.005
fraction-train = 0.5
verbose = false

[tool.flwr.federations]
default = "local-simulation"

[tool.flwr.federations.local-simulation]
options.num-supernodes = 10

[tool.flwr.federations.local-deployment]
address = "127.0.0.1:9093"
root-certificates = "./.cache/certificates/ca.crt"
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""tfexample."""
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""timeseries: A Flower / TensorFlow app."""

from flwr.app import ArrayRecord, Context, Message, MetricRecord, RecordDict
from flwr.clientapp import ClientApp

from timeseries.task import load_data, load_model

# Flower ClientApp
app = ClientApp()


@app.train()
def train(msg: Message, context: Context):
"""Train the model on local data."""

# Load the model and initialize it with the received weights
model = load_model()
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
model.set_weights(ndarrays)

# Read from config
epochs = context.run_config["local-epochs"]
batch_size = context.run_config["batch-size"]
verbose = context.run_config.get("verbose")

# Load the data
partition_id = context.node_config["partition-id"]
num_partitions = context.node_config["num-partitions"]
tf_train, _ = load_data(partition_id, num_partitions, batch_size)

# Train the model on local data
history = model.fit(
tf_train,
epochs=epochs,
batch_size=batch_size,
verbose=verbose,
)

# Get final training loss and accuracy
train_loss = history.history["loss"][-1] if "loss" in history.history else None
train_acc = history.history.get("accuracy")
train_acc = train_acc[-1] if train_acc is not None else None

# Construct and return reply Message
model_record = ArrayRecord(model.get_weights())
metrics = {"num-examples": int(32000*0.8)}
if train_loss is not None:
metrics["train_loss"] = train_loss
if train_acc is not None:
metrics["train_acc"] = train_acc
metric_record = MetricRecord(metrics)
content = RecordDict({"arrays": model_record, "metrics": metric_record})
return Message(content=content, reply_to=msg)


@app.evaluate()
def evaluate(msg: Message, context: Context):
"""Evaluate the model on local data."""

# Load the model and initialize it with the received weights
model = load_model()
ndarrays = msg.content["arrays"].to_numpy_ndarrays()
model.set_weights(ndarrays)

# Load the data
partition_id = context.node_config["partition-id"]
batch_size = context.run_config["batch-size"]
num_partitions = context.node_config["num-partitions"]
_, tf_test = load_data(partition_id, num_partitions, batch_size)

# Evaluate the model on local data
loss, accuracy = model.evaluate(tf_test, verbose=0)

# Construct and return reply Message
metrics = {
"eval_loss": loss,
"eval_acc": accuracy,
"num-examples": int(32000*0.2),
}
metric_record = MetricRecord(metrics)
content = RecordDict({"metrics": metric_record})
return Message(content=content, reply_to=msg)
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
"""timeseries: A Flower / TensorFlow app."""

from flwr.app import ArrayRecord, Context
from flwr.serverapp import Grid, ServerApp
from flwr.serverapp.strategy import FedAvg

from timeseries.task import load_model

# Create ServerApp
app = ServerApp()


@app.main()
def main(grid: Grid, context: Context) -> None:
"""Main entry point for the ServerApp."""

# Read run config
num_rounds: int = context.run_config["num-server-rounds"]

# Load global model
model = load_model()
arrays = ArrayRecord(model.get_weights())

# Initialize FedAvg strategy
strategy = FedAvg(fraction_train=1.0, fraction_evaluate=1.0)

# Start strategy, run FedAvg for `num_rounds`
result = strategy.start(
grid=grid,
initial_arrays=arrays,
num_rounds=num_rounds,
)

# Save final model to disk
print("\nSaving final model to disk...")
ndarrays = result.arrays.to_numpy_ndarrays()
model.set_weights(ndarrays)
model.save("final_model.keras")
61 changes: 61 additions & 0 deletions examples/quickstart-tensorflow-timeseries/tftsexample/task.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""timeseries: A Flower / TensorFlow app."""

import os

import keras
from flwr_datasets import FederatedDataset
from keras import layers
from keras.preprocessing import timeseries_dataset_from_array

# Make TensorFlow log less verbose
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"


def load_model():
model = keras.Sequential(
[
keras.Input(shape=(12, 1)),
layers.LSTM(units=12, activation='tanh', return_sequences=True),
layers.Dropout(rate=0.005),
layers.LSTM(units=6, activation='tanh'),
layers.Dense(units=1),
]
)
model.compile("adam", keras.losses.mean_squared_error, metrics=[keras.losses.mean_squared_error])
return model


fds = None # Cache FederatedDataset


def load_data(partition_id, num_partitions, batch_size):
global fds
if fds is None:
fds = FederatedDataset(
dataset="sayanroy058/Jena-Climate",
partitioners={"train": num_partitions},
shuffle=False
)
partition = fds.load_partition(partition_id, "train")
partition.set_format("numpy")

partition_temperature = partition["T (degC)"][0:32000]

# Create temporal windows for the LSTM neural network
# from the 12 previous time steps the next step shall be predicted
input_data = partition_temperature[:-12]
targets = partition_temperature[12:]
tf_dataset = timeseries_dataset_from_array(
input_data,
targets,
sequence_length=12,
sequence_stride=1,
sampling_rate=1,
shuffle=False,
batch_size=batch_size)

# Divide data on each node: 80% train, 20% test of 1000 batches
tf_test = tf_dataset.take(200)
tf_train = tf_dataset.skip(200)

return tf_train, tf_test
2 changes: 1 addition & 1 deletion examples/quickstart-tensorflow/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ framework: [tensorflow]

# Federated Learning with Tensorflow/Keras and Flower (Quickstart Example)

This introductory example to Flower uses Tensorflow/Keras but deep knowledge of this frameworks is required to run the example. However, it will help you understand how to adapt Flower to your use case.
This introductory example to Flower uses Tensorflow/Keras but no deep knowledge of this frameworks is required to run the example. However, it will help you understand how to adapt Flower to your use case.
Running this example in itself is quite easy. This example uses [Flower Datasets](https://flower.ai/docs/datasets/) to download, partition and preprocess the CIFAR-10 dataset.

## Set up the project
Expand Down
Loading