|
41 | 41 | MBPIPE_SCHEMA_PREFIX, |
42 | 42 | ) |
43 | 43 | from bloombee.utils.debug import dprint |
| 44 | +import traceback |
44 | 45 |
|
45 | 46 | # [MBPIPE] Cross-stage streaming push support |
46 | 47 | _cross_stage_push_callback = None # Will be set by handler for cross-stage streaming |
@@ -424,6 +425,27 @@ def restore_hidden_states( |
424 | 425 |
|
425 | 426 | return restored_hidden_states |
426 | 427 |
|
| 428 | +def ensure_tensors(flat_tensors): |
| 429 | + result = [] |
| 430 | + for i, t in enumerate(flat_tensors): |
| 431 | + if t is None: |
| 432 | + result.append(torch.tensor(0)) |
| 433 | + elif isinstance(t, torch.Tensor): |
| 434 | + result.append(t) |
| 435 | + elif isinstance(t, (list, tuple)): |
| 436 | + t_clean = [x for x in t if x is not None] |
| 437 | + if len(t_clean) == 0: |
| 438 | + result.append(torch.tensor(0)) |
| 439 | + elif isinstance(t_clean[0], torch.Tensor): |
| 440 | + result.append(torch.stack(t_clean)) |
| 441 | + else: |
| 442 | + result.append(torch.tensor(t_clean)) |
| 443 | + elif isinstance(t, (int, float, bool)): |
| 444 | + result.append(torch.tensor(t)) |
| 445 | + else: |
| 446 | + raise TypeError(f"flat_tensors[{i}] cant trans to tensor: type={type(t)}, value={t}") |
| 447 | + return tuple(result) |
| 448 | + |
427 | 449 | async def iterate_rpc_inference( |
428 | 450 | requested_uids: Sequence[ExpertUID], |
429 | 451 | requested_backends: Sequence[TransformerBackend], |
@@ -1197,7 +1219,7 @@ async def iterate_rpc_inference( |
1197 | 1219 | if args_structure is not None: |
1198 | 1220 | flat_tensors, kwargs = unpack_args_kwargs(flat_tensors, args_structure) |
1199 | 1221 |
|
1200 | | - hidden_states, keep_indices, need_pruning1, prompts, hypo_ids, tree_attention_mask, kv_cache_position_ids, draft_tokens, prefill_length, is_spec_dec1, *_ = flat_tensors |
| 1222 | + hidden_states, keep_indices, need_pruning1, tree_attention_mask, kv_cache_position_ids, draft_tokens, prefill_length, is_spec_dec1, prompts, hypo_ids, *_ = flat_tensors |
1201 | 1223 | draft_tokens = draft_tokens if draft_tokens is not None and not is_dummy(draft_tokens) else None |
1202 | 1224 |
|
1203 | 1225 | # Fix for bus error in cross-machine setups: ensure tensors are contiguous |
@@ -1229,11 +1251,7 @@ async def iterate_rpc_inference( |
1229 | 1251 | ) |
1230 | 1252 | if not need_pruning and _as_python_bool(step_metadata.get("need_pruning", 0)): |
1231 | 1253 | need_pruning = True |
1232 | | - |
1233 | | - # logger.info(f"hidden_states: {hidden_states.shape}") |
1234 | | - # logger.info(f"keep_indices: {keep_indices.shape}") |
1235 | | - # logger.info(f"draft_tokens: {draft_tokens.shape}") |
1236 | | - |
| 1254 | + |
1237 | 1255 | if is_spec_dec and draft_tokens is not None and draft_tokens.shape[0] != hidden_states.shape[0]: |
1238 | 1256 | hidden_states = restore_hidden_states(hidden_states, keep_indices, draft_tokens.shape[-1]) |
1239 | 1257 |
|
@@ -1845,10 +1863,14 @@ async def process_microbatch(mb_idx: int, mb_start: int, mb_end: int): |
1845 | 1863 |
|
1846 | 1864 | serialize_start = perf_counter() |
1847 | 1865 | need_pruning_next = torch.tensor(0) |
| 1866 | + |
| 1867 | + flat_tensors = (hidden_states, keep_indices, need_pruning_next, tree_attention_mask, kv_cache_position_ids, draft_tokens) |
| 1868 | + flat_tensors = ensure_tensors(flat_tensors) |
1848 | 1869 | output_tensors = [ |
1849 | 1870 | serialize_torch_tensor(result.to(proto.dtype), proto.compression, allow_inplace=True) |
1850 | | - for result, proto in zip((hidden_states, keep_indices, need_pruning_next), nested_flatten(requested_backends[-1].outputs_schema)) |
| 1871 | + for result, proto in zip(flat_tensors, nested_flatten(requested_backends[-1].outputs_schema)) |
1851 | 1872 | ] |
| 1873 | + |
1852 | 1874 | serialize_end = perf_counter() |
1853 | 1875 | serialize_time = (serialize_end - serialize_start) * 1000 # ms |
1854 | 1876 | # print('after serialize and send last layer outputs ', ) |
|
0 commit comments