Skip to content

Commit 9e97c90

Browse files
fix server 2 server net data error (#51)
Co-authored-by: JesusmiCaH <1010851196jch@gmail.com>
1 parent d3a5adf commit 9e97c90

7 files changed

Lines changed: 70 additions & 32 deletions

File tree

benchmarks/benchmark_speculative_decoding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,8 @@ def benchmark_inference(process_idx, args, result_pipe):
9999

100100
result = ""
101101
start_time = perf_counter()
102-
result = model.generate(input_ids=input_ids, drafter=drafter)
102+
max_new_tokens = getattr(args, 'seq_len', 128)
103+
result = model.generate(input_ids=input_ids, drafter=drafter, max_new_tokens=max_new_tokens)
103104
time = perf_counter() - start_time
104105
generated_tokens_nums = []
105106
for i in range(batch_size):

src/bloombee/client/inference_session.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,14 @@ def _infer_batch_dim(value) -> int:
175175
input_tensors, args_structure = pack_args_kwargs(
176176
inputs,
177177
normalize_arg(keep_indices),
178-
normalize_arg(torch.tensor(1 if need_pruning else 0)),
179-
prompts, hypo_ids,
178+
normalize_arg(torch.tensor(1 if need_pruning else 0)),
180179
normalize_arg(tree_attention_mask),
181180
normalize_arg(kv_cache_position_ids),
182181
normalize_arg(draft_tokens),
183182
normalize_arg(prefill_length),
184183
normalize_arg(torch.tensor(1 if is_spec_dec else 0)),
184+
prompts,
185+
hypo_ids,
185186
)
186187
logger.debug(f"_ServerInferenceSession step id {step_id}")
187188
request_metadata = dict(session_id=self.session_id, step_id=step_id)
@@ -199,12 +200,10 @@ def _infer_batch_dim(value) -> int:
199200
request_metadata["start_from_position"] = self._position
200201
# Enable server-to-server communication to trigger CROSS_GPU_TRANSFER
201202
# Speculative decoding keeps strict full-batch semantics; avoid cross-stage push.
202-
if self.config.use_server_to_server and not is_spec_dec:
203+
if self.config.use_server_to_server:
203204
next_servers = self._collect_next_servers()
204205
if next_servers:
205206
request_metadata["next_servers"] = next_servers
206-
elif is_spec_dec:
207-
request_metadata["disable_cross_stage_push"] = 1
208207

209208
request_metadata["args_structure"] = args_structure
210209

@@ -495,7 +494,7 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与
495494
# 🔍 CLIENT DEBUG: Log server span processing start
496495
span_start_time = time.perf_counter()
497496

498-
inputs, keep_indices, need_pruning_next = server_session.step(
497+
inputs, keep_indices, *_ = server_session.step(
499498
inputs,
500499
prompts[server_session.span.start : server_session.span.end],
501500
hypo_ids,
@@ -516,7 +515,7 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与
516515
# 🔍 CLIENT DEBUG: Log server span processing end
517516
span_end_time = time.perf_counter()
518517
span_duration = (span_end_time - span_start_time) * 1000 # ms
519-
logger.debug(f"[CLIENT_SERVER_END] ServerIdx={server_idx} | Blocks={server_session.span.start}:{server_session.span.end} | Duration={span_duration:.2f}ms")
518+
logger.info(f"[CLIENT_SERVER_END] ServerIdx={server_idx} | Blocks={server_session.span.start}:{server_session.span.end} | Duration={span_duration:.2f}ms")
520519
# print('inputs ', inputs)
521520
# print('inputs.shape ', inputs.shape)
522521
server_idx += 1
@@ -551,7 +550,7 @@ def step( # 执行一次推理步骤,处理输入数据和相应的提示与
551550
# 🔍 CLIENT DEBUG: Log inference step end
552551
inference_step_end = time.perf_counter()
553552
inference_step_duration = (inference_step_end - inference_step_start) * 1000 # ms
554-
logger.debug(f"[CLIENT_INFERENCE_END] Position={self._position} | Duration={inference_step_duration:.2f}ms | Servers={server_idx}")
553+
logger.info(f"[CLIENT_INFERENCE_END] Position={self._position} | Duration={inference_step_duration:.2f}ms | Servers={server_idx}")
555554

556555
outputs = outputs.to(device=inputs_device, dtype=inputs_dtype)
557556
# print('client inference session outputs ', outputs.shape)

src/bloombee/models/llama/speculative_model.py

Lines changed: 18 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,11 +35,11 @@ def generate(
3535
logits_processor: Optional[LogitsProcessorList] = None,
3636
stopping_criteria: Optional[StoppingCriteriaList] = None,
3737
streamer: Optional["BaseStreamer"] = None,
38-
beam_width: int = 2,
38+
beam_width: int = 1,
3939
max_tree_depth: int = 4,
4040
use_kv_cache: bool = True,
4141
kv_cache_window: int = 2048,
42-
max_new_tokens: int = 64,
42+
max_new_tokens: int = 128,
4343
session_max_length: Optional[int] = None,
4444
**model_kwargs,
4545
) -> torch.LongTensor:
@@ -132,6 +132,8 @@ def _sample_with_session(
132132
initial_len = input_ids.shape[1]
133133
t0 = time.perf_counter() # 用于记录第一个达标的时间
134134
has_printed_first_reach = False # 确保只打印一次
135+
sample_finish_times = [None] * batch_size
136+
sample_finished = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device)
135137
while not finished and (seq_lengths.min().item() - initial_len) < max_new_tokens:
136138
# 1. Build speculative trees using SSM - 传入 seq_lengths
137139
t1 = time.perf_counter()
@@ -207,16 +209,24 @@ def _sample_with_session(
207209
finished = unfinished_sequences.max() == 0
208210
total_time = time.perf_counter() - t1
209211
logger.info(f"Step {step_idx}: FTotal Time Elapsed={total_time:.4f} seconds")
210-
step_idx += 1
211212
current_generations = seq_lengths - initial_len
212-
if not has_printed_first_reach and current_generations.max().item() >= max_new_tokens:
213-
first_reach_time = time.perf_counter() - t0
214-
logger.info(f"🚀 [First Reach] 第一个样本达到 max_new_tokens,耗时: {first_reach_time:.4f}s")
215-
has_printed_first_reach = True
216-
213+
for i in range(batch_size):
214+
if (current_generations[i] >= max_new_tokens and not sample_finished[i]):
215+
finish_time = time.perf_counter() - t0
216+
sample_finish_times[i] = finish_time
217+
sample_finished[i] = True
218+
logger.info(f"step {step_idx} Sample {i} finished generation ({max_new_tokens} tokens) at {finish_time:.4f}s")
219+
step_idx += 1
217220

218221
if streamer is not None:
219222
streamer.end()
223+
224+
logger.info("====== Batch Generation Summary ======")
225+
for i, t in enumerate(sample_finish_times):
226+
if t is not None:
227+
logger.info(f"Sample {i}: finished at {t:.4f}s")
228+
else:
229+
logger.info(f"Sample {i}: did not reach max_new_tokens")
220230

221231
return current_input_ids
222232

src/bloombee/server/backend.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -369,7 +369,7 @@ def _flag_to_bool(value) -> bool:
369369
position_ids = self._position_ids_cache[cache_key] + (cache_len + offset)
370370
if self._is_spec_decoding:
371371
rotary_position_ids = self._create_tree_position_ids_with_invalid_cache(
372-
width=2,
372+
width=1,
373373
depth=4,
374374
prefill_length=inference_info.prefill_length - 1,
375375
kv_cache_position_ids=kv_cache_position_ids,
@@ -468,14 +468,10 @@ def _flag_to_bool(value) -> bool:
468468
self.pruner_manager.train_lm_head(middle_norm_hidden_states, norm_hidden_states)
469469

470470
if not training_mode and self._is_spec_decoding and self._need_pruning and self._is_last_block:
471-
t6 = time.perf_counter()
472471
norm_hidden_states = self.module.rms_norm(output_hidden_states)
473472
keep_indices = self.prune_draft_tree(norm_hidden_states, inference_info.draft_tokens, full_mask)
474473
keep_indices = keep_indices
475474

476-
t7 = time.perf_counter()
477-
logger.info(f"prune_draft_tree spend: {t7 - t6}")
478-
479475
if not training_mode and self._is_spec_decoding and self._is_last_block:
480476
original_hidden_states = output_hidden_states
481477
batch_size, seq_len, hidden_size = original_hidden_states.shape

src/bloombee/server/block_functions.py

Lines changed: 29 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@
4141
MBPIPE_SCHEMA_PREFIX,
4242
)
4343
from bloombee.utils.debug import dprint
44+
import traceback
4445

4546
# [MBPIPE] Cross-stage streaming push support
4647
_cross_stage_push_callback = None # Will be set by handler for cross-stage streaming
@@ -424,6 +425,27 @@ def restore_hidden_states(
424425

425426
return restored_hidden_states
426427

428+
def ensure_tensors(flat_tensors):
429+
result = []
430+
for i, t in enumerate(flat_tensors):
431+
if t is None:
432+
result.append(torch.tensor(0))
433+
elif isinstance(t, torch.Tensor):
434+
result.append(t)
435+
elif isinstance(t, (list, tuple)):
436+
t_clean = [x for x in t if x is not None]
437+
if len(t_clean) == 0:
438+
result.append(torch.tensor(0))
439+
elif isinstance(t_clean[0], torch.Tensor):
440+
result.append(torch.stack(t_clean))
441+
else:
442+
result.append(torch.tensor(t_clean))
443+
elif isinstance(t, (int, float, bool)):
444+
result.append(torch.tensor(t))
445+
else:
446+
raise TypeError(f"flat_tensors[{i}] cant trans to tensor: type={type(t)}, value={t}")
447+
return tuple(result)
448+
427449
async def iterate_rpc_inference(
428450
requested_uids: Sequence[ExpertUID],
429451
requested_backends: Sequence[TransformerBackend],
@@ -1197,7 +1219,7 @@ async def iterate_rpc_inference(
11971219
if args_structure is not None:
11981220
flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure)
11991221

1200-
hidden_states, keep_indices, need_pruning1, prompts, hypo_ids, tree_attention_mask, kv_cache_position_ids, draft_tokens, prefill_length, is_spec_dec1, *_ = flat_tensors
1222+
hidden_states, keep_indices, need_pruning1, tree_attention_mask, kv_cache_position_ids, draft_tokens, prefill_length, is_spec_dec1, prompts, hypo_ids, *_ = flat_tensors
12011223
draft_tokens = draft_tokens if draft_tokens is not None and not is_dummy(draft_tokens) else None
12021224

12031225
# Fix for bus error in cross-machine setups: ensure tensors are contiguous
@@ -1229,11 +1251,7 @@ async def iterate_rpc_inference(
12291251
)
12301252
if not need_pruning and _as_python_bool(step_metadata.get("need_pruning", 0)):
12311253
need_pruning = True
1232-
1233-
# logger.info(f"hidden_states: {hidden_states.shape}")
1234-
# logger.info(f"keep_indices: {keep_indices.shape}")
1235-
# logger.info(f"draft_tokens: {draft_tokens.shape}")
1236-
1254+
12371255
if is_spec_dec and draft_tokens is not None and draft_tokens.shape[0] != hidden_states.shape[0]:
12381256
hidden_states = restore_hidden_states(hidden_states, keep_indices, draft_tokens.shape[-1])
12391257

@@ -1845,10 +1863,14 @@ async def process_microbatch(mb_idx: int, mb_start: int, mb_end: int):
18451863

18461864
serialize_start = perf_counter()
18471865
need_pruning_next = torch.tensor(0)
1866+
1867+
flat_tensors = (hidden_states, keep_indices, need_pruning_next, tree_attention_mask, kv_cache_position_ids, draft_tokens)
1868+
flat_tensors = ensure_tensors(flat_tensors)
18481869
output_tensors = [
18491870
serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True)
1850-
for result, proto in zip((hidden_states, keep_indices, need_pruning_next), nested_flatten(requested_backends[-1].outputs_schema))
1871+
for result, proto in zip(flat_tensors, nested_flatten(requested_backends[-1].outputs_schema))
18511872
]
1873+
18521874
serialize_end = perf_counter()
18531875
serialize_time = (serialize_end - serialize_start) * 1000 # ms
18541876
# print('after serialize and send last layer outputs ', )

src/bloombee/server/handler.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,15 +1480,16 @@ async def _push_outputs(
14801480

14811481

14821482
normalized_outputs = self._normalize_serialized_tensors(serialized_outputs)
1483-
next_tensors = normalized_outputs + list(request.tensors[3:])
1483+
next_tensors_data = normalized_outputs + list(request.tensors[6:])
1484+
next_tensors = serialized_outputs + request.tensors[6:]
14841485

14851486
next_metadata = metadata.copy()
14861487
next_metadata.update(session_id=next_session_id, next_servers=next_servers[1:], pushed=True)
14871488
sender_send_us = self._now_us()
14881489
next_metadata["clock_sync_sender_send_us"] = sender_send_us
14891490

14901491
stub = self.get_stub(self._p2p, next_peer_id)
1491-
push_tensor_bytes = sum(len(t.buffer) for t in next_tensors)
1492+
push_tensor_bytes = sum(len(t.buffer) for t in next_tensors_data)
14921493
serialized_next_metadata = MSGPackSerializer.dumps(next_metadata)
14931494
push_metadata_bytes = len(serialized_next_metadata)
14941495

src/bloombee/server/server.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ def __init__(
330330
# Create configuration
331331
config = PruningConfig(
332332
method=PruningMethod.ADAPTIVE_NEURAL,
333-
neural_threshold=0.5,
333+
neural_threshold=0.9,
334334
simple_threshold=0.1
335335
)
336336

@@ -716,6 +716,15 @@ def create(
716716
dtype=torch.int64,
717717
compression=compression
718718
),
719+
BatchTensorDescriptor(
720+
1, 64, 64, dtype=torch.bool
721+
), # tree_attention_mask
722+
BatchTensorDescriptor(
723+
1, 128, dtype=torch.int64
724+
), # kv_cache_position_ids
725+
BatchTensorDescriptor(
726+
1, 128, dtype=torch_dtype
727+
), # draft_tokens
719728
),
720729
min_batch_size=min_batch_size,
721730
max_batch_size=max_batch_size,

0 commit comments

Comments
 (0)