Skip to content

Commit e27caa8

Browse files
sherm8nclaude
andcommitted
fix(distributed): leaf timeout between requests + streaming space loss
Both found by live-testing 2-node native mode on real hardware with TinyLlama. Both would have hit users on the multi-machine path. 1. Leaf timed out between requests recv_tensor had a 300s default applied to every receive, including the leaf's "wait for next prompt" call in broadcast_metadata. After 5 minutes of user think-time the socket fired TimeoutError and the leaf loop crashed. Fix: broadcast_metadata and broadcast_metadata_objects now pass timeout=None on the receive side — the wait for the next user message is unbounded and shouldn't time out. 2. Streaming output had no spaces _generate_stream decoded one token at a time and yielded the result. For SentencePiece/BPE tokenizers (Llama/Mistral/Qwen), the leading space metadata only appears when decoding multiple tokens together, so individual decodes produced "Thereare50states" instead of "There are 50 states". Fix: track the running token id list, decode the full list each step, and yield only the new substring. Standard incremental-decode pattern used by HF TextStreamer. Both verified live: leaf survived a multi-minute pause between requests, and TinyLlama-Chat output now renders with spaces in the chat UI. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
1 parent 6606be8 commit e27caa8

2 files changed

Lines changed: 23 additions & 8 deletions

File tree

ravnest/communication/communication_dynamic.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,8 @@ def send_tensor(sock, tensor):
4242

4343
@staticmethod
4444
def recv_tensor(sock, device="cpu", timeout=300):
45-
"""Receive a tensor from a socket."""
46-
sock.settimeout(timeout)
45+
"""Receive a tensor from a socket. Pass timeout=None to block indefinitely."""
46+
sock.settimeout(timeout) # None = blocking forever
4747
header = TensorSocket._recv_exactly(sock, 4)
4848
if header is None:
4949
raise ConnectionError("Connection closed while reading header")
@@ -420,8 +420,11 @@ def broadcast_metadata(self, data):
420420
if isinstance(key, str) and key.startswith("meta_"):
421421
TensorSocket.send_tensor(sock, data)
422422
else:
423-
# Receive from root
424-
received = TensorSocket.recv_tensor(self.peers["meta_root"], device=str(self.device))
423+
# Receive from root — block forever; user think-time between
424+
# requests is unbounded so a fixed timeout isn't appropriate here
425+
received = TensorSocket.recv_tensor(
426+
self.peers["meta_root"], device=str(self.device), timeout=None
427+
)
425428
data.copy_(received)
426429

427430
def broadcast_metadata_objects(self, data):
@@ -430,7 +433,8 @@ def broadcast_metadata_objects(self, data):
430433
if isinstance(key, str) and key.startswith("meta_"):
431434
TensorSocket.send_object(sock, data)
432435
else:
433-
received = TensorSocket.recv_object(self.peers["meta_root"])
436+
# Block forever waiting for next prompt — user think-time is unbounded
437+
received = TensorSocket.recv_object(self.peers["meta_root"], timeout=None)
434438
for i in range(len(data)):
435439
if i < len(received):
436440
data[i] = received[i]

ravnest/inference/inference_engine.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,11 @@ def _generate_stream(self, input_ids=None, max_seq_lengths=None, top_k=1, temper
219219
num_generated_tokens = 0
220220
is_generation_done = torch.tensor([False]*bs).to(self.node.device)
221221
pad_token_tensor = torch.tensor([self.tokenizer.pad_token_id]*bs).to(self.node.device)
222+
# Track generated token ids for incremental decoding (BPE/SentencePiece
223+
# tokenizers lose leading-space info when decoded one token at a time,
224+
# so we decode the running list and yield only the new substring)
225+
generated_ids_so_far = []
226+
decoded_so_far = ""
222227
while num_generated_tokens < max_seq_length_in_batch:
223228
self.comm_session.forward_input_shapes[0][1] = seq_length
224229

@@ -261,10 +266,16 @@ def _generate_stream(self, input_ids=None, max_seq_lengths=None, top_k=1, temper
261266
new_token_mask = kwargs['attention_mask'].new_ones((bs,1))
262267
kwargs['attention_mask'] = torch.cat((kwargs['attention_mask'], new_token_mask), axis=-1)
263268

264-
# Yield the decoded token for each sequence in the batch
269+
# Yield the decoded token for each sequence in the batch.
270+
# Decode the full running list and emit only the new substring,
271+
# so BPE/SentencePiece leading spaces are preserved.
265272
if self.node_type != NodeTypes.LEAF:
266-
token_text = self.tokenizer.decode(next_token_ids[0].item(), skip_special_tokens=True)
267-
yield token_text
273+
generated_ids_so_far.append(next_token_ids[0].item())
274+
full_decoded = self.tokenizer.decode(generated_ids_so_far, skip_special_tokens=True)
275+
if len(full_decoded) > len(decoded_so_far):
276+
delta = full_decoded[len(decoded_so_far):]
277+
decoded_so_far = full_decoded
278+
yield delta
268279

269280
is_generation_done = self.is_generation_complete(is_generation_done, next_token_ids, num_generated_tokens, max_seq_lengths)
270281

0 commit comments

Comments
 (0)