Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions nvflare/fuel/f3/streaming/byte_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
class RxTask:
"""Receiving task for ByteStream"""

rx_task_map = {}
rx_task_map: Dict[Tuple[str, int], "RxTask"] = {}
map_lock = threading.Lock()

def __init__(self, sid: int, origin: str, cell: CoreCell):
Expand Down Expand Up @@ -94,14 +94,14 @@ def find_or_create_task(cls, message: Message, cell: CoreCell) -> Optional["RxTa
error = message.get_header(StreamHeaderKey.ERROR_MSG, None)

with cls.map_lock:
task = cls.rx_task_map.get(sid, None)
task = cls.rx_task_map.get((origin, sid), None)
if not task:
if error:
log.warning(f"Received error for non-existing stream: SID {sid} from {origin}")
return None

task = RxTask(sid, origin, cell)
cls.rx_task_map[sid] = task
cls.rx_task_map[(origin, sid)] = task
else:
if error:
task.stop(StreamError(f"{task} Received error from {origin}: {error}"), notify=False)
Expand Down Expand Up @@ -195,7 +195,7 @@ def _handle_incoming_data(self, seq: int, message: Message):
def stop(self, error: StreamError = None, notify=True):

with RxTask.map_lock:
RxTask.rx_task_map.pop(self.sid, None)
RxTask.rx_task_map.pop((self.origin, self.sid), None)

if not error:
return
Expand Down
61 changes: 61 additions & 0 deletions tests/unit_test/fuel/f3/streaming/stream_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import multiprocessing as mp

from nvflare.fuel.f3.streaming.stream_utils import gen_stream_id


def generate_stream_ids(num_ids: int, result_queue: mp.Queue) -> None:
"""Worker function to generate stream IDs in a separate process.

Args:
num_ids: Number of stream IDs to generate
result_queue: Queue to put the generated IDs
"""
ids = []
for _ in range(num_ids):
ids.append(gen_stream_id())
result_queue.put(ids)
Comment on lines +20 to +30
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

unused helper function, leftover from removed multiprocess tests

Suggested change
def generate_stream_ids(num_ids: int, result_queue: mp.Queue) -> None:
"""Worker function to generate stream IDs in a separate process.
Args:
num_ids: Number of stream IDs to generate
result_queue: Queue to put the generated IDs
"""
ids = []
for _ in range(num_ids):
ids.append(gen_stream_id())
result_queue.put(ids)



class TestStreamUtils:
"""Test suite for stream_utils module"""

def test_gen_stream_id_uniqueness_single_process(self):
"""Test that gen_stream_id generates unique IDs within a single process"""
num_ids = 1000
ids = [gen_stream_id() for _ in range(num_ids)]

# Check for uniqueness
assert len(ids) == len(set(ids)), "Generated IDs contain duplicates in single process"

# Check that IDs are monotonically increasing
assert ids == sorted(ids), "Generated IDs are not monotonically increasing"

def test_gen_stream_id_returns_positive_int(self):
"""Test that gen_stream_id returns a positive integer"""
stream_id = gen_stream_id()
assert isinstance(stream_id, int), "Stream ID should be an integer"
assert stream_id > 0, "Stream ID should be positive"

def test_gen_stream_id_sequential_calls(self):
"""Test that sequential calls return increasing IDs"""
id1 = gen_stream_id()
id2 = gen_stream_id()
id3 = gen_stream_id()

assert id2 > id1, "Second ID should be greater than first"
assert id3 > id2, "Third ID should be greater than second"
assert id3 - id2 == id2 - id1 == 1, "IDs should increment by 1"
Loading