Skip to content
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions nvflare/fuel/f3/streaming/byte_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
class RxTask:
"""Receiving task for ByteStream"""

rx_task_map = {}
rx_task_map: Dict[Tuple[str, int], "RxTask"] = {}
map_lock = threading.Lock()

def __init__(self, sid: int, origin: str, cell: CoreCell):
Expand Down Expand Up @@ -94,14 +94,14 @@ def find_or_create_task(cls, message: Message, cell: CoreCell) -> Optional["RxTa
error = message.get_header(StreamHeaderKey.ERROR_MSG, None)

with cls.map_lock:
task = cls.rx_task_map.get(sid, None)
task = cls.rx_task_map.get((origin, sid), None)
if not task:
if error:
log.warning(f"Received error for non-existing stream: SID {sid} from {origin}")
return None

task = RxTask(sid, origin, cell)
cls.rx_task_map[sid] = task
cls.rx_task_map[(origin, sid)] = task
else:
if error:
task.stop(StreamError(f"{task} Received error from {origin}: {error}"), notify=False)
Expand Down Expand Up @@ -195,7 +195,7 @@ def _handle_incoming_data(self, seq: int, message: Message):
def stop(self, error: StreamError = None, notify=True):

with RxTask.map_lock:
RxTask.rx_task_map.pop(self.sid, None)
RxTask.rx_task_map.pop((self.origin, self.sid), None)

if not error:
return
Expand Down
61 changes: 61 additions & 0 deletions tests/unit_test/fuel/f3/streaming/stream_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing as mp

from nvflare.fuel.f3.streaming.stream_utils import gen_stream_id


def generate_stream_ids(num_ids: int, result_queue: mp.Queue) -> None:
"""Worker function to generate stream IDs in a separate process.

Args:
num_ids: Number of stream IDs to generate
result_queue: Queue to put the generated IDs
"""
ids = []
for _ in range(num_ids):
ids.append(gen_stream_id())
result_queue.put(ids)
Comment on lines +20 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused helper function, leftover from removed multiprocess tests

Suggested change
def generate_stream_ids(num_ids: int, result_queue: mp.Queue) -> None:
"""Worker function to generate stream IDs in a separate process.
Args:
num_ids: Number of stream IDs to generate
result_queue: Queue to put the generated IDs
"""
ids = []
for _ in range(num_ids):
ids.append(gen_stream_id())
result_queue.put(ids)



class TestStreamUtils:
"""Test suite for stream_utils module"""

def test_gen_stream_id_uniqueness_single_process(self):
"""Test that gen_stream_id generates unique IDs within a single process"""
num_ids = 1000
ids = [gen_stream_id() for _ in range(num_ids)]

# Check for uniqueness
assert len(ids) == len(set(ids)), "Generated IDs contain duplicates in single process"

# Check that IDs are monotonically increasing
assert ids == sorted(ids), "Generated IDs are not monotonically increasing"

def test_gen_stream_id_returns_positive_int(self):
"""Test that gen_stream_id returns a positive integer"""
stream_id = gen_stream_id()
assert isinstance(stream_id, int), "Stream ID should be an integer"
assert stream_id > 0, "Stream ID should be positive"

def test_gen_stream_id_sequential_calls(self):
"""Test that sequential calls return increasing IDs"""
id1 = gen_stream_id()
id2 = gen_stream_id()
id3 = gen_stream_id()

assert id2 > id1, "Second ID should be greater than first"
assert id3 > id2, "Third ID should be greater than second"
assert id3 - id2 == id2 - id1 == 1, "IDs should increment by 1"
Loading