Skip to content

Commit 5b186f0

Browse files
author
Xu Xiong
committed
llama13B
1 parent bca8ebd commit 5b186f0

5 files changed

Lines changed: 8 additions & 7 deletions

File tree

benchmarks/benchmark_speculative_decoding.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def benchmark_inference(process_idx, args, result_pipe):
7373

7474
drafter = MultiSSMDrafter(
7575
ssm_model_name="JackFram/llama-68m",
76-
num_workers=4,
76+
num_workers=2,
7777
device="cuda"
7878
)
7979
model = AutoDistributedSpeculativeModel.from_pretrained(
@@ -88,6 +88,7 @@ def benchmark_inference(process_idx, args, result_pipe):
8888
test_prompts = []
8989
for item in sampled:
9090
test_prompts.append(item["instruction"])
91+
# test_prompts.append("Generate a list of the best places to eat in London.")
9192

9293
# base_prompt = (
9394
# "Quantum mechanics explains the behavior of particles at very small scales. "

src/bloombee/server/backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -468,8 +468,8 @@ def _flag_to_bool(value) -> bool:
468468
self.pruner_manager.train_lm_head(middle_norm_hidden_states, norm_hidden_states)
469469

470470
if not training_mode and self._is_spec_decoding and self._need_pruning and self._is_last_block:
471-
# norm_hidden_states = self.module.rms_norm(output_hidden_states)
472-
# keep_indices = self.prune_draft_tree(norm_hidden_states, inference_info.draft_tokens, full_mask)
471+
norm_hidden_states = self.module.rms_norm(output_hidden_states)
472+
keep_indices = self.prune_draft_tree(norm_hidden_states, inference_info.draft_tokens, full_mask)
473473
keep_indices = keep_indices
474474

475475
if not training_mode and self._is_spec_decoding and self._is_last_block:

src/bloombee/server/handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ async def _cross_stage_push_wrapper(mb_hidden, mb_keep, push_metadata):
795795
push_tensor_bytes = sum(len(t.buffer) for t in next_tensors)
796796

797797
# 模拟网络传输延时
798-
NETWORK_SPEED_BYTES_PER_SEC = 10 * 1024 * 1024 # 10 MB/s
798+
NETWORK_SPEED_BYTES_PER_SEC = 50 * 1024 * 1024 # 10 MB/s
799799
transfer_delay = push_tensor_bytes / NETWORK_SPEED_BYTES_PER_SEC
800800
await asyncio.sleep(transfer_delay)
801801
task = asyncio.create_task(self._push_outputs(request, output_tensors, step_metadata))

src/bloombee/server/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,13 +324,13 @@ def __init__(
324324
self.weight_home = array_1d(self.num_blocks, ValueHolder)
325325
self.path = os.path.join(tempfile.gettempdir(), 'data', 'llama_weights')
326326

327-
hidden_size = 6656
327+
hidden_size = 5120
328328
vocab_size = 32000
329329

330330
# Create configuration
331331
config = PruningConfig(
332332
method=PruningMethod.ADAPTIVE_NEURAL,
333-
neural_threshold=0.6,
333+
neural_threshold=0.5,
334334
simple_threshold=0.1
335335
)
336336

src/bloombee/server/speculative_pruner/adaptive_neural_pruner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
self.lm_head = MidLMHead(hidden_size=hidden_size, vocab_size=vocab_size).to("cuda")
5656
lm_head_weights_path = hf_hub_download(
5757
repo_id="xxiong59/lm-head-for-speculative-pruning",
58-
filename="lm_head_llama30B-15.pt",
58+
filename="lm_head_llama13B-20.pt",
5959
cache_dir="./cache"
6060
)
6161
lm_head_checkpoint = torch.load(lm_head_weights_path, map_location="cuda")

0 commit comments

Comments
 (0)