Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions nvflare/fuel/f3/cellnet/defs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,11 @@ class MessageHeaderKey:
OPTIONAL = CELLNET_PREFIX + "optional"
MSG_ROOT_ID = CELLNET_PREFIX + "msg_root_id"
MSG_ROOT_TTL = CELLNET_PREFIX + "msg_root_ttl"
# When True on an incoming cell message, Adapter.call() builds a per-call
# FOBS decode context with FOBSContextKey.PASS_THROUGH=True so that tensors
# in that message arrive as LazyDownloadRef placeholders rather than being
# downloaded inline. Set by CellPipe.send() when pass_through_on_send=True.
PASS_THROUGH = CELLNET_PREFIX + "pass_through"


class ReturnReason:
Expand Down
10 changes: 7 additions & 3 deletions nvflare/fuel/f3/cellnet/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,12 @@ def encode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCOD


def decode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCODING, fobs_ctx: dict = None):
if isinstance(fobs_ctx, dict):
fobs_ctx[fobs.FOBSContextKey.MESSAGE] = message

# Always normalize to a dict and avoid mutating caller-owned context.
ctx = fobs_ctx.copy() if fobs_ctx is not None else {}
if MessageHeaderKey.PASS_THROUGH in message.headers:
ctx[fobs.FOBSContextKey.PASS_THROUGH] = bool(message.get_header(MessageHeaderKey.PASS_THROUGH))
ctx[fobs.FOBSContextKey.MESSAGE] = message

size = buffer_len(message.payload)
message.set_header(MessageHeaderKey.PAYLOAD_LEN, size)
Expand All @@ -185,7 +189,7 @@ def decode_payload(message: Message, encoding_key=MessageHeaderKey.PAYLOAD_ENCOD
return

if encoding == Encoding.FOBS:
message.payload = fobs.loads(message.payload, fobs_ctx=fobs_ctx)
message.payload = fobs.loads(message.payload, fobs_ctx=ctx)
elif encoding == Encoding.NONE:
message.payload = None
else:
Expand Down
24 changes: 24 additions & 0 deletions nvflare/fuel/f3/comm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,12 @@ class VarName:
STREAMING_ACK_INTERVAL = "streaming_ack_interval"
STREAMING_MAX_OUT_SEQ_CHUNKS = "streaming_max_out_seq_chunks"
STREAMING_READ_TIMEOUT = "streaming_read_timeout"
STREAMING_SEND_TIMEOUT = "streaming_send_timeout"
STREAMING_ACK_PROGRESS_TIMEOUT = "streaming_ack_progress_timeout"
STREAMING_ACK_PROGRESS_CHECK_INTERVAL = "streaming_ack_progress_check_interval"
SFM_SEND_STALL_TIMEOUT = "sfm_send_stall_timeout"
SFM_CLOSE_STALLED_CONNECTION = "sfm_close_stalled_connection"
SFM_SEND_STALL_CONSECUTIVE_CHECKS = "sfm_send_stall_consecutive_checks"


class CommConfigurator:
Expand Down Expand Up @@ -114,6 +120,24 @@ def get_streaming_max_out_seq_chunks(self, default):
def get_streaming_read_timeout(self, default):
return ConfigService.get_int_var(VarName.STREAMING_READ_TIMEOUT, self.config, default)

def get_streaming_send_timeout(self, default):
return ConfigService.get_float_var(VarName.STREAMING_SEND_TIMEOUT, self.config, default=default)

def get_streaming_ack_progress_timeout(self, default):
return ConfigService.get_float_var(VarName.STREAMING_ACK_PROGRESS_TIMEOUT, self.config, default=default)

def get_streaming_ack_progress_check_interval(self, default):
return ConfigService.get_float_var(VarName.STREAMING_ACK_PROGRESS_CHECK_INTERVAL, self.config, default=default)

def get_sfm_send_stall_timeout(self, default):
return ConfigService.get_float_var(VarName.SFM_SEND_STALL_TIMEOUT, self.config, default=default)

def get_sfm_close_stalled_connection(self, default=False):
return ConfigService.get_bool_var(VarName.SFM_CLOSE_STALLED_CONNECTION, self.config, default=default)

def get_sfm_send_stall_consecutive_checks(self, default=3):
return ConfigService.get_int_var(VarName.SFM_SEND_STALL_CONSECUTIVE_CHECKS, self.config, default=default)

def get_int_var(self, name: str, default=None):
return ConfigService.get_int_var(name, self.config, default=default)

Expand Down
68 changes: 65 additions & 3 deletions nvflare/fuel/f3/drivers/socket_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import errno
import logging
import select
import socket
import time
from socketserver import BaseRequestHandler
from typing import Any, Union

from nvflare.fuel.f3.comm_config import CommConfigurator
from nvflare.fuel.f3.comm_error import CommError
from nvflare.fuel.f3.connection import BytesAlike, Connection
from nvflare.fuel.f3.drivers.driver import ConnectorInfo
Expand All @@ -35,6 +39,7 @@ def __init__(self, sock: Any, connector: ConnectorInfo, secure: bool = False):
self.secure = secure
self.closing = False
self.conn_props = self._get_socket_properties()
self.send_timeout = CommConfigurator().get_streaming_send_timeout(30.0)

def get_conn_properties(self) -> dict:
return self.conn_props
Expand All @@ -52,24 +57,81 @@ def close(self):

def send_frame(self, frame: BytesAlike):
try:
self.sock.sendall(frame)
self._send_with_timeout(frame, self.send_timeout)
except CommError as error:
if not self.closing:
# A send timeout may occur after partial bytes are already written to the stream.
# Close the connection to avoid frame-boundary desync on subsequent sends.
if error.code == CommError.TIMEOUT:
self.close()
raise
except Exception as ex:
if not self.closing:
if self._is_timeout_exception(ex):
self.close()
raise CommError(
CommError.TIMEOUT,
f"send_frame timeout on conn {self}: {secure_format_exception(ex)}",
)
if self._is_closed_socket_exception(ex):
raise CommError(
CommError.CLOSED,
f"Connection {self.name} is closed while sending: {secure_format_exception(ex)}",
)
raise CommError(CommError.ERROR, f"Error sending frame on conn {self}: {secure_format_exception(ex)}")

@staticmethod
def _is_timeout_exception(ex: Exception) -> bool:
return isinstance(ex, (TimeoutError, socket.timeout))

@staticmethod
def _is_closed_socket_exception(ex: Exception) -> bool:
if isinstance(ex, (BrokenPipeError, ConnectionResetError, ConnectionAbortedError)):
return True

if isinstance(ex, OSError):
return ex.errno in {
errno.EPIPE,
errno.ECONNRESET,
errno.ENOTCONN,
errno.ECONNABORTED,
errno.EBADF,
errno.ESHUTDOWN,
}

return False

def _send_with_timeout(self, frame: BytesAlike, timeout_sec: float):
view = frame if isinstance(frame, memoryview) else memoryview(frame)
deadline = time.monotonic() + timeout_sec
while view:
remaining = deadline - time.monotonic()
if remaining <= 0:
raise CommError(CommError.TIMEOUT, f"send_frame timeout after {timeout_sec} seconds on {self.name}")

_, writable, _ = select.select([], [self.sock], [], remaining)
if not writable:
raise CommError(CommError.TIMEOUT, f"send_frame timeout after {timeout_sec} seconds on {self.name}")

sent = self.sock.send(view)
if sent <= 0:
raise CommError(CommError.CLOSED, f"Connection {self.name} is closed while sending")

view = view[sent:]

def read_loop(self):
try:
self.read_frame_loop()
except CommError as error:
if error.code == CommError.CLOSED:
log.debug(f"Connection {self.name} is closed by peer")
else:
log.debug(f"Connection {self.name} is closed due to error: {error}")
log.debug(f"Connection {self.name} is closed due to CommError: {error}")
except Exception as ex:
if self.closing:
log.debug(f"Connection {self.name} is closed")
else:
log.debug(f"Connection {self.name} is closed due to error: {secure_format_exception(ex)}")
log.debug(f"Connection {self.name} is closed due to exception: {secure_format_exception(ex)}")

def read_frame_loop(self):
# read_frame throws exception on stale/bad connection so this is not a dead loop
Expand Down
33 changes: 27 additions & 6 deletions nvflare/fuel/f3/sfm/conn_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,11 +152,11 @@ def stop(self):
connector.stopped.set()
connector.driver.shutdown()

self.stopped = True

self.conn_mgr_executor.shutdown(True)
self.frame_mgr_executor.shutdown(True)

self.stopped = True

def find_endpoint(self, name: str) -> Optional[Endpoint]:

sfm_endpoint = self.sfm_endpoints.get(name)
Expand Down Expand Up @@ -257,7 +257,10 @@ def start_connector(self, connector: ConnectorInfo):

log.info(f"Connector {connector} is starting")

self.conn_mgr_executor.submit(self.start_connector_task, connector)
try:
self.conn_mgr_executor.submit(self.start_connector_task, connector)
except RuntimeError:
log.debug("Connector start skipped — executor already shut down")

@staticmethod
def start_connector_task(connector: ConnectorInfo):
Expand Down Expand Up @@ -332,6 +335,9 @@ def state_change(self, connection: Connection):

def process_frame_task(self, sfm_conn: SfmConnection, frame: BytesAlike):

if self.stopped:
return

try:
prefix = Prefix.from_bytes(frame)
log.debug(f"Received frame: {prefix} on {sfm_conn.conn}")
Expand Down Expand Up @@ -367,16 +373,25 @@ def process_frame_task(self, sfm_conn: SfmConnection, frame: BytesAlike):

else:
log.error(f"Received unsupported frame type {prefix.type} on {sfm_conn.get_name()}")
except RuntimeError as ex:
if self.stopped:
log.debug(f"Frame processing interrupted by shutdown: {secure_format_exception(ex)}")
else:
log.error(f"Error processing frame (RuntimeError): {secure_format_exception(ex)}")
log.debug(secure_format_traceback())
except Exception as ex:
log.error(f"Error processing frame: {secure_format_exception(ex)}")
log.error(f"Error processing frame (Exception): {secure_format_exception(ex)}")
log.debug(secure_format_traceback())

def process_frame(self, sfm_conn: SfmConnection, frame: BytesAlike):
if self.stopped:
log.debug(f"Frame received after shutdown for connection {sfm_conn.get_name()}")
return

self.frame_mgr_executor.submit(self.process_frame_task, sfm_conn, frame)
try:
self.frame_mgr_executor.submit(self.process_frame_task, sfm_conn, frame)
except RuntimeError:
log.debug(f"Frame received after shutdown for connection {sfm_conn.get_name()}")

def update_endpoint(self, sfm_conn: SfmConnection, data: dict):

Expand Down Expand Up @@ -460,8 +475,14 @@ def close_connection(self, connection: Connection):
def send_loopback_message(self, endpoint: Endpoint, app_id: int, headers: Optional[dict], payload: BytesAlike):
"""Send message to itself"""

if self.stopped:
return

# Call receiver in a different thread to avoid deadlock
self.frame_mgr_executor.submit(self.loopback_message_task, endpoint, app_id, headers, payload)
try:
self.frame_mgr_executor.submit(self.loopback_message_task, endpoint, app_id, headers, payload)
except RuntimeError as e:
log.debug(f"Loopback submit skipped: {e}")

def loopback_message_task(self, endpoint: Endpoint, app_id: int, headers: Optional[dict], payload: BytesAlike):

Expand Down
31 changes: 30 additions & 1 deletion nvflare/fuel/f3/sfm/heartbeat_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

HEARTBEAT_TICK = 5
DEFAULT_HEARTBEAT_INTERVAL = 60
DEFAULT_SEND_STALL_CONSECUTIVE_CHECKS = 3


class HeartbeatMonitor(Thread):
Expand All @@ -33,7 +34,14 @@ def __init__(self, conns: Dict[str, SfmConnection]):
self.conns = conns
self.stopped = Event()
self.curr_time = 0
self.interval = CommConfigurator().get_heartbeat_interval(DEFAULT_HEARTBEAT_INTERVAL)
config = CommConfigurator()
self.interval = config.get_heartbeat_interval(DEFAULT_HEARTBEAT_INTERVAL)
self.send_stall_timeout = config.get_sfm_send_stall_timeout(45.0)
self.close_stalled_connection = config.get_sfm_close_stalled_connection(False)
self.stall_consecutive_checks = max(
1, config.get_sfm_send_stall_consecutive_checks(DEFAULT_SEND_STALL_CONSECUTIVE_CHECKS)
)
self.stall_counts = {}
if self.interval < HEARTBEAT_TICK:
log.warning(f"Heartbeat interval is too small ({self.interval} < {HEARTBEAT_TICK})")

Expand All @@ -55,7 +63,24 @@ def run(self):

def _check_heartbeat(self):

active_keys = set()
for sfm_conn in self.conns.values():
conn_key = sfm_conn.get_name() if hasattr(sfm_conn, "get_name") else str(id(sfm_conn))
active_keys.add(conn_key)

stall_sec = sfm_conn.get_send_stall_seconds()
if stall_sec > self.send_stall_timeout:
count = self.stall_counts.get(conn_key, 0) + 1
self.stall_counts[conn_key] = count
log.warning(
f"Detected stalled send on {sfm_conn.conn}: blocked {stall_sec:.1f}s "
f"({count}/{self.stall_consecutive_checks})"
)
if self.close_stalled_connection and count >= self.stall_consecutive_checks:
sfm_conn.conn.close()
continue

self.stall_counts[conn_key] = 0

driver = sfm_conn.conn.connector.driver
caps = driver.capabilities()
Expand All @@ -65,3 +90,7 @@ def _check_heartbeat(self):
if self.curr_time - sfm_conn.last_activity > self.interval:
sfm_conn.send_heartbeat(Types.PING)
log.debug(f"Heartbeat sent to connection: {sfm_conn.conn}")

stale_keys = [k for k in self.stall_counts.keys() if k not in active_keys]
for k in stale_keys:
self.stall_counts.pop(k, None)
16 changes: 15 additions & 1 deletion nvflare/fuel/f3/sfm/sfm_conn.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@ def __init__(self, conn: Connection, local_endpoint: Endpoint):
self.last_activity = 0
self.sequence = 0
self.lock = threading.Lock()
self.send_state_lock = threading.Lock()
self.send_started_at = 0.0

def get_name(self) -> str:
return self.conn.name
Expand Down Expand Up @@ -145,7 +147,19 @@ def send_frame(self, prefix: Prefix, headers: Optional[dict], payload: Optional[
log.debug(f"Sending frame: {prefix} on {self.conn}")
# Only one thread can send data on a connection. Otherwise, the frames may interleave.
with self.lock:
self.conn.send_frame(buffer)
with self.send_state_lock:
self.send_started_at = time.monotonic()
try:
self.conn.send_frame(buffer)
finally:
with self.send_state_lock:
self.send_started_at = 0.0

def get_send_stall_seconds(self) -> float:
with self.send_state_lock:
if self.send_started_at <= 0.0:
return 0.0
return time.monotonic() - self.send_started_at

@staticmethod
def headers_to_bytes(headers: Optional[dict]) -> Optional[bytes]:
Expand Down
14 changes: 11 additions & 3 deletions nvflare/fuel/f3/streaming/blob_streamer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from nvflare.fuel.f3.streaming.byte_streamer import STREAM_CHUNK_SIZE, STREAM_TYPE_BLOB, ByteStreamer
from nvflare.fuel.f3.streaming.stream_const import EOS
from nvflare.fuel.f3.streaming.stream_types import Stream, StreamError, StreamFuture
from nvflare.fuel.f3.streaming.stream_utils import FastBuffer, stream_thread_pool, wrap_view
from nvflare.fuel.f3.streaming.stream_utils import FastBuffer, callback_thread_pool, stream_thread_pool, wrap_view
from nvflare.fuel.utils.buffer_list import BufferList
from nvflare.security.logging import secure_format_traceback

Expand Down Expand Up @@ -96,11 +96,19 @@ def handle_blob_cb(self, future: StreamFuture, stream: Stream, resume: bool, *ar
blob_task = BlobTask(future, stream)

stream_thread_pool.submit(self._read_stream, blob_task)

self.blob_cb(future, *args, **kwargs)
callback_thread_pool.submit(self._run_blob_cb, future, stream, args, kwargs)

return 0

def _run_blob_cb(self, future: StreamFuture, stream: Stream, args: tuple, kwargs: dict):
"""Run blob_cb on the callback pool; preserve exception handling (log + task.stop) as in ByteReceiver."""
try:
self.blob_cb(future, *args, **kwargs)
except Exception as ex:
log.error(f"blob_cb threw: {ex}\n{secure_format_traceback()}")
if hasattr(stream, "task"):
stream.task.stop(StreamError(f"blob_cb threw: {ex}"))
Comment on lines +103 to +110
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stream may not be stopped if blob_cb raises and stream.task is absent

stream.task is not a guaranteed attribute on the Stream base type; hasattr(stream, "task") will be False for any stream that was not enriched with this attribute, causing the task.stop() call to be silently skipped.

When blob_cb raises and the task is not stopped, _read_stream continues to run on stream_thread_pool — pulling chunks from the stream and writing them to blob_task.buffer — even though no consumer is active. The future is eventually resolved via blob_task.future.set_result(result), but any caller await-ing that future will receive the data with no indication that the callback failed.

Consider propagating the error to the future directly so callers can observe it:

def _run_blob_cb(self, future: StreamFuture, stream: Stream, args: tuple, kwargs: dict):
    try:
        self.blob_cb(future, *args, **kwargs)
    except Exception as ex:
        log.error(f"blob_cb threw: {ex}\n{secure_format_traceback()}")
        future.set_exception(StreamError(f"blob_cb threw: {ex}"))
        if hasattr(stream, "task"):
            stream.task.stop(StreamError(f"blob_cb threw: {ex}"))


def _read_stream(self, blob_task: BlobTask):

try:
Expand Down
Loading