Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
254 changes: 115 additions & 139 deletions framework/py/flwr/server/serverapp/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


import argparse
import gc
from logging import DEBUG, ERROR, INFO
from pathlib import Path
from queue import Queue
Expand Down Expand Up @@ -65,7 +64,7 @@
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub
from flwr.server.grid.grpc_grid import GrpcGrid
from flwr.server.run_serverapp import run as run_
from flwr.supercore.app_utils import simple_get_token, start_parent_process_monitor
from flwr.supercore.app_utils import start_parent_process_monitor
from flwr.supercore.superexec.plugin import ServerAppExecPlugin
from flwr.supercore.superexec.run_superexec import run_with_deprecation_warning

Expand Down Expand Up @@ -109,7 +108,6 @@ def flwr_serverapp() -> None:
serverappio_api_address=args.serverappio_api_address,
log_queue=log_queue,
token=args.token,
run_once=(args.token is not None) or args.run_once,
flwr_dir=args.flwr_dir,
certificates=None,
parent_pid=args.parent_pid,
Expand All @@ -122,8 +120,7 @@ def flwr_serverapp() -> None:
def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
serverappio_api_address: str,
log_queue: Queue[Optional[str]],
run_once: bool,
token: Optional[str] = None,
token: str,
flwr_dir: Optional[str] = None,
certificates: Optional[bytes] = None,
parent_pid: Optional[int] = None,
Expand All @@ -142,154 +139,133 @@ def run_serverapp( # pylint: disable=R0913, R0914, R0915, R0917, W0212
heartbeat_sender = None
grid = None
context = None
while True:

try:
# Initialize the GrpcGrid
grid = GrpcGrid(
serverappio_service_address=serverappio_api_address,
root_certificates=certificates,
)
try:
# Initialize the GrpcGrid
grid = GrpcGrid(
serverappio_service_address=serverappio_api_address,
root_certificates=certificates,
)

# If token is not set, loop until token is received from SuperLink
if token is None:
log(DEBUG, "[flwr-serverapp] Request token")
token = simple_get_token(grid._stub)

# Pull ServerAppInputs from LinkState
req = PullAppInputsRequest(token=token)
log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
res: PullAppInputsResponse = grid._stub.PullAppInputs(req)
context = context_from_proto(res.context)
run = run_from_proto(res.run)
fab = fab_from_proto(res.fab)

hash_run_id = get_sha256_hash(run.run_id)

grid.set_run(run.run_id)

# Start log uploader for this run
log_uploader = start_log_uploader(
log_queue=log_queue,
node_id=0,
run_id=run.run_id,
stub=grid._stub,
)
# Pull ServerAppInputs from LinkState
req = PullAppInputsRequest(token=token)
log(DEBUG, "[flwr-serverapp] Pull ServerAppInputs")
res: PullAppInputsResponse = grid._stub.PullAppInputs(req)
context = context_from_proto(res.context)
run = run_from_proto(res.run)
fab = fab_from_proto(res.fab)

log(DEBUG, "[flwr-serverapp] Start FAB installation.")
install_from_fab(fab.content, flwr_dir=flwr_dir_, skip_prompt=True)
hash_run_id = get_sha256_hash(run.run_id)

fab_id, fab_version = get_fab_metadata(fab.content)
grid.set_run(run.run_id)

app_path = str(
get_project_dir(fab_id, fab_version, fab.hash_str, flwr_dir_)
)
config = get_project_config(app_path)
# Start log uploader for this run
log_uploader = start_log_uploader(
log_queue=log_queue,
node_id=0,
run_id=run.run_id,
stub=grid._stub,
)

# Obtain server app reference and the run config
server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
server_app_run_config = get_fused_config_from_dir(
Path(app_path), run.override_config
)
log(DEBUG, "[flwr-serverapp] Start FAB installation.")
install_from_fab(fab.content, flwr_dir=flwr_dir_, skip_prompt=True)

# Update run_config in context
context.run_config = server_app_run_config
fab_id, fab_version = get_fab_metadata(fab.content)

log(
DEBUG,
"[flwr-serverapp] Will load ServerApp `%s` in %s",
server_app_attr,
app_path,
)
app_path = str(get_project_dir(fab_id, fab_version, fab.hash_str, flwr_dir_))
config = get_project_config(app_path)

# Change status to Running
run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
grid._stub.UpdateRunStatus(
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
)
# Obtain server app reference and the run config
server_app_attr = config["tool"]["flwr"]["app"]["components"]["serverapp"]
server_app_run_config = get_fused_config_from_dir(
Path(app_path), run.override_config
)

event(
EventType.FLWR_SERVERAPP_RUN_ENTER,
event_details={"run-id-hash": hash_run_id},
)
# Update run_config in context
context.run_config = server_app_run_config

# Set up heartbeat sender
heartbeat_fn = get_grpc_app_heartbeat_fn(
grid._stub,
run.run_id,
failure_message="Heartbeat failed unexpectedly. The SuperLink could "
"not find the provided run ID, or the run status is invalid.",
)
heartbeat_sender = HeartbeatSender(heartbeat_fn)
heartbeat_sender.start()

# Load and run the ServerApp with the Grid
updated_context = run_(
grid=grid,
server_app_dir=app_path,
server_app_attr=server_app_attr,
context=context,
)
log(
DEBUG,
"[flwr-serverapp] Will load ServerApp `%s` in %s",
server_app_attr,
app_path,
)

# Send resulting context
context_proto = context_to_proto(updated_context)
log(DEBUG, "[flwr-serverapp] Will push ServerAppOutputs")
out_req = PushAppOutputsRequest(
token=token, run_id=run.run_id, context=context_proto
)
_ = grid._stub.PushAppOutputs(out_req)

run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
except RunNotRunningException:
log(INFO, "")
log(INFO, "Run ID %s stopped.", run.run_id)
log(INFO, "")
run_status = None
success = False

except Exception as ex: # pylint: disable=broad-exception-caught
exc_entity = "ServerApp"
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
run_status = RunStatus(Status.FINISHED, SubStatus.FAILED, str(ex))
success = False

finally:
# Stop heartbeat sender
if heartbeat_sender:
heartbeat_sender.stop()
heartbeat_sender = None

# Stop log uploader for this run and upload final logs
if log_uploader:
stop_log_uploader(log_queue, log_uploader)
log_uploader = None

# Update run status
if run_status and grid:
run_status_proto = run_status_to_proto(run_status)
grid._stub.UpdateRunStatus(
UpdateRunStatusRequest(
run_id=run.run_id, run_status=run_status_proto
)
)

# Close the Grpc connection
if grid:
grid.close()

# Clean up the Context and the token
context = None
token = None
gc.collect()

event(
EventType.FLWR_SERVERAPP_RUN_LEAVE,
event_details={"run-id-hash": hash_run_id, "success": success},
# Change status to Running
run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
grid._stub.UpdateRunStatus(
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
)

event(
EventType.FLWR_SERVERAPP_RUN_ENTER,
event_details={"run-id-hash": hash_run_id},
)

# Set up heartbeat sender
heartbeat_fn = get_grpc_app_heartbeat_fn(
grid._stub,
run.run_id,
failure_message="Heartbeat failed unexpectedly. The SuperLink could "
"not find the provided run ID, or the run status is invalid.",
)
heartbeat_sender = HeartbeatSender(heartbeat_fn)
heartbeat_sender.start()

# Load and run the ServerApp with the Grid
updated_context = run_(
grid=grid,
server_app_dir=app_path,
server_app_attr=server_app_attr,
context=context,
)

# Send resulting context
context_proto = context_to_proto(updated_context)
log(DEBUG, "[flwr-serverapp] Will push ServerAppOutputs")
out_req = PushAppOutputsRequest(
token=token, run_id=run.run_id, context=context_proto
)
_ = grid._stub.PushAppOutputs(out_req)

run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")
except RunNotRunningException:
log(INFO, "")
log(INFO, "Run ID %s stopped.", run.run_id)
log(INFO, "")
run_status = None
success = False

except Exception as ex: # pylint: disable=broad-exception-caught
exc_entity = "ServerApp"
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
run_status = RunStatus(Status.FINISHED, SubStatus.FAILED, str(ex))
success = False

finally:
# Stop heartbeat sender
if heartbeat_sender:
heartbeat_sender.stop()

# Stop log uploader for this run and upload final logs
if log_uploader:
stop_log_uploader(log_queue, log_uploader)

# Update run status
if run_status and grid:
run_status_proto = run_status_to_proto(run_status)
grid._stub.UpdateRunStatus(
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
)

# Stop the loop if `flwr-serverapp` is expected to process a single run
if run_once:
break
# Close the Grpc connection
if grid:
grid.close()

event(
EventType.FLWR_SERVERAPP_RUN_LEAVE,
event_details={"run-id-hash": hash_run_id, "success": success},
)


def _parse_args_run_flwr_serverapp() -> argparse.ArgumentParser:
Expand Down
Loading
Loading