Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions nvflare/fuel/f3/streaming/byte_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
class RxTask:
"""Receiving task for ByteStream"""

rx_task_map = {}
rx_task_map: Dict[Tuple[str, int], "RxTask"] = {}
map_lock = threading.Lock()

def __init__(self, sid: int, origin: str, cell: CoreCell):
Expand Down Expand Up @@ -94,14 +94,14 @@ def find_or_create_task(cls, message: Message, cell: CoreCell) -> Optional["RxTa
error = message.get_header(StreamHeaderKey.ERROR_MSG, None)

with cls.map_lock:
task = cls.rx_task_map.get(sid, None)
task = cls.rx_task_map.get((origin, sid), None)
if not task:
if error:
log.warning(f"Received error for non-existing stream: SID {sid} from {origin}")
return None

task = RxTask(sid, origin, cell)
cls.rx_task_map[sid] = task
cls.rx_task_map[(origin, sid)] = task
else:
if error:
task.stop(StreamError(f"{task} Received error from {origin}: {error}"), notify=False)
Expand Down Expand Up @@ -195,7 +195,7 @@ def _handle_incoming_data(self, seq: int, message: Message):
def stop(self, error: StreamError = None, notify=True):

with RxTask.map_lock:
RxTask.rx_task_map.pop(self.sid, None)
RxTask.rx_task_map.pop((self.origin, self.sid), None)

if not error:
return
Expand Down
Loading