Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
281 changes: 127 additions & 154 deletions framework/py/flwr/simulation/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@


import argparse
import gc
from logging import DEBUG, ERROR, INFO
from queue import Queue
from typing import Optional
Expand Down Expand Up @@ -71,7 +70,7 @@
from flwr.server.superlink.fleet.vce.backend.backend import BackendConfig
from flwr.simulation.run_simulation import _run_simulation
from flwr.simulation.simulationio_connection import SimulationIoConnection
from flwr.supercore.app_utils import simple_get_token, start_parent_process_monitor
from flwr.supercore.app_utils import start_parent_process_monitor
from flwr.supercore.superexec.plugin import SimulationExecPlugin
from flwr.supercore.superexec.run_superexec import run_with_deprecation_warning

Expand Down Expand Up @@ -114,7 +113,6 @@ def flwr_simulation() -> None:
run_simulation_process(
simulationio_api_address=args.simulationio_api_address,
log_queue=log_queue,
run_once=(args.token is not None) or args.run_once,
token=args.token,
flwr_dir_=args.flwr_dir,
certificates=None,
Expand All @@ -128,8 +126,7 @@ def flwr_simulation() -> None:
def run_simulation_process( # pylint: disable=R0913, R0914, R0915, R0917, W0212
simulationio_api_address: str,
log_queue: Queue[Optional[str]],
run_once: bool,
token: Optional[str] = None,
token: str,
flwr_dir_: Optional[str] = None,
certificates: Optional[bytes] = None,
parent_pid: Optional[int] = None,
Expand All @@ -150,168 +147,144 @@ def run_simulation_process( # pylint: disable=R0913, R0914, R0915, R0917, W0212
heartbeat_sender = None
run_status = None

while True:
try:
# Pull SimulationInputs from LinkState
req = PullAppInputsRequest(token=token)
res: PullAppInputsResponse = conn._stub.PullAppInputs(req)
context = context_from_proto(res.context)
run = run_from_proto(res.run)
fab = fab_from_proto(res.fab)

# Start log uploader for this run
log_uploader = start_log_uploader(
log_queue=log_queue,
node_id=context.node_id,
run_id=run.run_id,
stub=conn._stub,
)

try:
# If token is not set, loop until token is received from SuperNode
if token is None:
log(DEBUG, "[flwr-simulation] Request token")
token = simple_get_token(conn._stub)

# Pull SimulationInputs from LinkState
req = PullAppInputsRequest(token=token)
res: PullAppInputsResponse = conn._stub.PullAppInputs(req)
context = context_from_proto(res.context)
run = run_from_proto(res.run)
fab = fab_from_proto(res.fab)

# Start log uploader for this run
log_uploader = start_log_uploader(
log_queue=log_queue,
node_id=context.node_id,
run_id=run.run_id,
stub=conn._stub,
)
log(DEBUG, "Simulation process starts FAB installation.")
install_from_fab(fab.content, flwr_dir=flwr_dir, skip_prompt=True)

log(DEBUG, "Simulation process starts FAB installation.")
install_from_fab(fab.content, flwr_dir=flwr_dir, skip_prompt=True)
fab_id, fab_version = get_fab_metadata(fab.content)

fab_id, fab_version = get_fab_metadata(fab.content)
app_path = get_project_dir(fab_id, fab_version, fab.hash_str, flwr_dir)
config = get_project_config(app_path)

app_path = get_project_dir(fab_id, fab_version, fab.hash_str, flwr_dir)
config = get_project_config(app_path)
# Get ClientApp and SeverApp components
app_components = config["tool"]["flwr"]["app"]["components"]
client_app_attr = app_components["clientapp"]
server_app_attr = app_components["serverapp"]
fused_config = get_fused_config_from_dir(app_path, run.override_config)

# Get ClientApp and SeverApp components
app_components = config["tool"]["flwr"]["app"]["components"]
client_app_attr = app_components["clientapp"]
server_app_attr = app_components["serverapp"]
fused_config = get_fused_config_from_dir(app_path, run.override_config)
# Update run_config in context
context.run_config = fused_config

# Update run_config in context
context.run_config = fused_config
log(
DEBUG,
"Flower will load ServerApp `%s` in %s",
server_app_attr,
app_path,
)
log(
DEBUG,
"Flower will load ClientApp `%s` in %s",
client_app_attr,
app_path,
)

log(
DEBUG,
"Flower will load ServerApp `%s` in %s",
server_app_attr,
app_path,
)
log(
DEBUG,
"Flower will load ClientApp `%s` in %s",
client_app_attr,
app_path,
)
# Change status to Running
run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
conn._stub.UpdateRunStatus(
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
)

# Change status to Running
run_status_proto = run_status_to_proto(RunStatus(Status.RUNNING, "", ""))
conn._stub.UpdateRunStatus(
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
)
# Pull Federation Options
fed_opt_res: GetFederationOptionsResponse = conn._stub.GetFederationOptions(
GetFederationOptionsRequest(run_id=run.run_id)
)
federation_options = config_record_from_proto(fed_opt_res.federation_options)

# Unflatten underlying dict
fed_opt = unflatten_dict({**federation_options})

# Extract configs values of interest
num_supernodes = fed_opt.get("num-supernodes")
if num_supernodes is None:
raise ValueError("Federation options expects `num-supernodes` to be set.")
backend_config: BackendConfig = fed_opt.get("backend", {})
verbose: bool = fed_opt.get("verbose", False)
enable_tf_gpu_growth: bool = fed_opt.get("enable_tf_gpu_growth", False)

event(
EventType.FLWR_SIMULATION_RUN_ENTER,
event_details={
"backend": "ray",
"num-supernodes": num_supernodes,
"run-id-hash": get_sha256_hash(run.run_id),
},
)

# Pull Federation Options
fed_opt_res: GetFederationOptionsResponse = conn._stub.GetFederationOptions(
GetFederationOptionsRequest(run_id=run.run_id)
)
federation_options = config_record_from_proto(
fed_opt_res.federation_options
)
# Set up heartbeat sender
heartbeat_fn = get_grpc_app_heartbeat_fn(
conn._stub,
run.run_id,
failure_message="Heartbeat failed unexpectedly. The SuperLink could "
"not find the provided run ID, or the run status is invalid.",
)
heartbeat_sender = HeartbeatSender(heartbeat_fn)
heartbeat_sender.start()

# Launch the simulation
updated_context = _run_simulation(
server_app_attr=server_app_attr,
client_app_attr=client_app_attr,
num_supernodes=num_supernodes,
backend_config=backend_config,
app_dir=str(app_path),
run=run,
enable_tf_gpu_growth=enable_tf_gpu_growth,
verbose_logging=verbose,
server_app_run_config=fused_config,
is_app=True,
exit_event=EventType.FLWR_SIMULATION_RUN_LEAVE,
)

# Unflatten underlying dict
fed_opt = unflatten_dict({**federation_options})

# Extract configs values of interest
num_supernodes = fed_opt.get("num-supernodes")
if num_supernodes is None:
raise ValueError(
"Federation options expects `num-supernodes` to be set."
)
backend_config: BackendConfig = fed_opt.get("backend", {})
verbose: bool = fed_opt.get("verbose", False)
enable_tf_gpu_growth: bool = fed_opt.get("enable_tf_gpu_growth", False)

event(
EventType.FLWR_SIMULATION_RUN_ENTER,
event_details={
"backend": "ray",
"num-supernodes": num_supernodes,
"run-id-hash": get_sha256_hash(run.run_id),
},
)
# Send resulting context
context_proto = context_to_proto(updated_context)
out_req = PushAppOutputsRequest(
token=token, run_id=run.run_id, context=context_proto
)
_ = conn._stub.PushAppOutputs(out_req)

# Set up heartbeat sender
heartbeat_fn = get_grpc_app_heartbeat_fn(
conn._stub,
run.run_id,
failure_message="Heartbeat failed unexpectedly. The SuperLink could "
"not find the provided run ID, or the run status is invalid.",
)
heartbeat_sender = HeartbeatSender(heartbeat_fn)
heartbeat_sender.start()

# Launch the simulation
updated_context = _run_simulation(
server_app_attr=server_app_attr,
client_app_attr=client_app_attr,
num_supernodes=num_supernodes,
backend_config=backend_config,
app_dir=str(app_path),
run=run,
enable_tf_gpu_growth=enable_tf_gpu_growth,
verbose_logging=verbose,
server_app_run_config=fused_config,
is_app=True,
exit_event=EventType.FLWR_SIMULATION_RUN_LEAVE,
)
run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")

# Send resulting context
context_proto = context_to_proto(updated_context)
out_req = PushAppOutputsRequest(
token=token, run_id=run.run_id, context=context_proto
except Exception as ex: # pylint: disable=broad-exception-caught
exc_entity = "Simulation"
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
run_status = RunStatus(Status.FINISHED, SubStatus.FAILED, str(ex))

finally:
# Stop heartbeat sender
if heartbeat_sender:
heartbeat_sender.stop()

# Stop log uploader for this run and upload final logs
if log_uploader:
stop_log_uploader(log_queue, log_uploader)

# Update run status
if run_status:
run_status_proto = run_status_to_proto(run_status)
conn._stub.UpdateRunStatus(
UpdateRunStatusRequest(run_id=run.run_id, run_status=run_status_proto)
)
_ = conn._stub.PushAppOutputs(out_req)

run_status = RunStatus(Status.FINISHED, SubStatus.COMPLETED, "")

except Exception as ex: # pylint: disable=broad-exception-caught
exc_entity = "Simulation"
log(ERROR, "%s raised an exception", exc_entity, exc_info=ex)
run_status = RunStatus(Status.FINISHED, SubStatus.FAILED, str(ex))

finally:
# Stop heartbeat sender
if heartbeat_sender:
heartbeat_sender.stop()
heartbeat_sender = None

# Stop log uploader for this run and upload final logs
if log_uploader:
stop_log_uploader(log_queue, log_uploader)
log_uploader = None

# Update run status
if run_status:
run_status_proto = run_status_to_proto(run_status)
conn._stub.UpdateRunStatus(
UpdateRunStatusRequest(
run_id=run.run_id, run_status=run_status_proto
)
)
run_status = None

# Clean up the Context if it exists
try:
del updated_context
except NameError:
pass

# Remove the token
token = None
gc.collect()

# Stop the loop if `flwr-simulation` is expected to process a single run
if run_once:
break

# Clean up the Context if it exists
try:
del updated_context
except NameError:
pass


def _parse_args_run_flwr_simulation() -> argparse.ArgumentParser:
Expand Down
31 changes: 0 additions & 31 deletions framework/py/flwr/supercore/app_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,6 @@
import signal
import threading
import time
from typing import Union

from flwr.proto.appio_pb2 import ( # pylint: disable=E0611
ListAppsToLaunchRequest,
ListAppsToLaunchResponse,
RequestTokenRequest,
RequestTokenResponse,
)
from flwr.proto.clientappio_pb2_grpc import ClientAppIoStub
from flwr.proto.serverappio_pb2_grpc import ServerAppIoStub
from flwr.proto.simulationio_pb2_grpc import SimulationIoStub

if os.name == "nt":
from ctypes import windll # type: ignore
Expand Down Expand Up @@ -67,23 +56,3 @@ def monitor() -> None:
os.kill(os.getpid(), signal.SIGKILL)

threading.Thread(target=monitor, daemon=True).start()


def simple_get_token(
stub: Union[ClientAppIoStub, ServerAppIoStub, SimulationIoStub]
) -> str:
"""Get a token from SuperLink/SuperNode.

This shall be removed once the SuperExec is fully implemented.
"""
while True:
res: ListAppsToLaunchResponse = stub.ListAppsToLaunch(ListAppsToLaunchRequest())

for run_id in res.run_ids:
tk_res: RequestTokenResponse = stub.RequestToken(
RequestTokenRequest(run_id=run_id)
)
if tk_res.token:
return tk_res.token

time.sleep(1) # Wait before retrying to get run IDs
Loading