Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 21 additions & 6 deletions nvflare/fuel/utils/pipe/file_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,14 +205,29 @@ def _get_next(self, from_dir: str):
except Exception:
raise BrokenPipeError(f"error reading from {from_dir}")

if files:
files = [os.path.join(from_dir, f) for f in files]
files.sort(key=os.path.getmtime, reverse=False)
file_path = files[0]
return self._read_file(file_path)
else:
if not files:
return None

files = [os.path.join(from_dir, f) for f in files]

def _safe_mtime(f):
try:
return os.path.getmtime(f)
except FileNotFoundError:
return float("inf")

files.sort(key=_safe_mtime)
for file_path in files:
try:
return self._read_file(file_path)
except BrokenPipeError:
# File was removed between listdir and read (TOCTOU race).
# This happens when the sender's heartbeat send times out and
# deletes its own file just as the receiver is about to read it.
# Skip this file and try the next one.
continue
return None

def _get_from_dir(self, from_dir: str, timeout=None):
if not timeout or timeout <= 0:
return self._get_next(from_dir)
Expand Down
9 changes: 8 additions & 1 deletion nvflare/fuel/utils/pipe/pipe_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,7 +344,14 @@ def _try_read(self):
# the pipe handler is most likely stopped, but we leave it for the while loop to decide
continue

msg = p.receive()
try:
msg = p.receive()
except BrokenPipeError as e:
if not self.asked_to_stop:
self._add_message(
self._make_event_message(Topic.PEER_GONE, f"read error: {secure_format_exception(e)}")
)
break
now = time.time()

if msg:
Expand Down
75 changes: 75 additions & 0 deletions tests/unit_test/fuel/utils/pipe/file_pipe_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
# limitations under the License.

import os
from unittest.mock import patch

import pytest

from nvflare.fuel.utils.constants import Mode
from nvflare.fuel.utils.pipe.file_pipe import FilePipe
Expand Down Expand Up @@ -70,3 +73,75 @@ def test_open_does_not_set_remove_root_for_preexisting_dir(self, tmp_path):
pipe = FilePipe(mode=Mode.PASSIVE, root_path=root)
pipe.open("test_pipe")
assert pipe._remove_root is False


class TestFilePipeGetNextTOCTOU:
"""_get_next must handle TOCTOU races without raising BrokenPipeError."""

def _make_pipe(self, tmp_path):
root = str(tmp_path / "pipe_root")
pipe = FilePipe(mode=Mode.PASSIVE, root_path=root)
pipe.open("test_pipe")
return pipe

def test_file_disappears_during_mtime_sort_returns_none(self, tmp_path):
"""If a file disappears while os.getmtime is called during sort, _get_next returns None.

_safe_mtime catches FileNotFoundError and returns float('inf') so the sort
succeeds, but _read_file then also raises BrokenPipeError because the file
is gone. _get_next must skip it and return None.
"""
pipe = self._make_pipe(tmp_path)
from_dir = pipe.y_path

fake_file = os.path.join(from_dir, "fake_msg.fobs")
open(fake_file, "w").close()

with patch("nvflare.fuel.utils.pipe.file_pipe.os.path.getmtime", side_effect=FileNotFoundError):
with patch.object(pipe, "_read_file", side_effect=BrokenPipeError("pipe closed")):
result = pipe._get_next(from_dir)

assert result is None

def test_file_disappears_between_listdir_and_read_returns_none(self, tmp_path):
"""If a file disappears between listdir and _read_file, _get_next returns None."""
pipe = self._make_pipe(tmp_path)
from_dir = pipe.y_path

fake_file = os.path.join(from_dir, "fake_msg.fobs")
open(fake_file, "w").close()

# Simulate the file being removed just before _read_file is called.
original_read_file = pipe._read_file

def disappear_then_raise(path):
os.remove(path)
raise BrokenPipeError("pipe closed")

with patch.object(pipe, "_read_file", side_effect=disappear_then_raise):
result = pipe._get_next(from_dir)

assert result is None

def test_all_files_race_away_returns_none(self, tmp_path):
"""When every file in the listing races away, _get_next returns None without raising."""
pipe = self._make_pipe(tmp_path)
from_dir = pipe.y_path

for i in range(3):
f = os.path.join(from_dir, f"msg_{i}.fobs")
open(f, "w").close()

with patch.object(pipe, "_read_file", side_effect=BrokenPipeError("pipe closed")):
result = pipe._get_next(from_dir)

assert result is None

def test_get_next_does_not_suppress_broken_pipe_from_listdir(self, tmp_path):
"""_get_next must re-raise BrokenPipeError when os.listdir itself fails (real failure)."""
pipe = self._make_pipe(tmp_path)
from_dir = pipe.y_path

with patch("nvflare.fuel.utils.pipe.file_pipe.os.listdir", side_effect=OSError("dir gone")):
with pytest.raises(BrokenPipeError):
pipe._get_next(from_dir)
111 changes: 111 additions & 0 deletions tests/unit_test/fuel/utils/pipe/pipe_handler_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time

from nvflare.fuel.utils.constants import Mode
from nvflare.fuel.utils.pipe.pipe import Pipe, Topic
from nvflare.fuel.utils.pipe.pipe_handler import PipeHandler


class _BrokenPipe(Pipe):
"""Minimal Pipe stub whose receive() always raises BrokenPipeError."""

def __init__(self, error_msg):
super().__init__(mode=Mode.ACTIVE)
self._error_msg = error_msg

def open(self, name):
pass

def close(self):
pass

def send(self, msg, timeout=None):
return True

def receive(self, timeout=None):
raise BrokenPipeError(self._error_msg)

def get_last_peer_active_time(self):
return 0.0

def clear(self):
pass

def can_resend(self) -> bool:
return False


class TestPipeHandlerBrokenPipe:
"""PipeHandler._try_read must emit PEER_GONE and stop when receive() raises BrokenPipeError."""

def _make_handler(self, pipe):
return PipeHandler(
pipe=pipe,
read_interval=0.01,
heartbeat_interval=5.0,
heartbeat_timeout=30.0,
)

def _drain_messages(self, handler, timeout=1.0):
messages = []
deadline = time.time() + timeout
while time.time() < deadline:
msg = handler.get_next()
if msg:
messages.append(msg)
time.sleep(0.01)
return messages

def test_pipe_closed_emits_peer_gone_and_stops(self):
"""When receive() raises BrokenPipeError('pipe is not open'), PEER_GONE is emitted and the reader stops."""
handler = self._make_handler(_BrokenPipe("pipe is not open"))
handler.start()

messages = self._drain_messages(handler, timeout=1.0)
handler.stop()

assert any(m.topic == Topic.PEER_GONE for m in messages)
assert handler.reader is None

def test_pipe_dir_gone_emits_peer_gone_and_stops(self):
"""When receive() raises BrokenPipeError('error reading from ...'), PEER_GONE is emitted and the reader stops."""
handler = self._make_handler(_BrokenPipe("error reading from /some/dir"))
handler.start()

messages = self._drain_messages(handler, timeout=1.0)
handler.stop()

assert any(m.topic == Topic.PEER_GONE for m in messages)
assert handler.reader is None

def test_graceful_stop_does_not_emit_peer_gone(self):
"""BrokenPipeError raised after stop() is called must not emit PEER_GONE."""
pipe = _BrokenPipe("pipe is not open")
handler = self._make_handler(pipe)

received = []
handler.set_status_cb(lambda msg: received.append(msg))

handler.start()
# Stop immediately — asked_to_stop=True before _try_read can emit anything.
handler.stop()

# Give the reader thread time to finish.
deadline = time.time() + 1.0
while handler.reader is not None and time.time() < deadline:
time.sleep(0.01)

assert not any(m.topic == Topic.PEER_GONE for m in received)
Loading