Skip to content

Commit ab82379

Browse files
author
Xu Xiong
committed
llama30B
1 parent b04b519 commit ab82379

4 files changed

Lines changed: 10 additions & 8 deletions

File tree

src/bloombee/server/backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ def _flag_to_bool(value) -> bool:
304304
self._is_spec_decoding = _flag_to_bool(inference_info.is_spec_dec)
305305

306306
training_mode = True
307-
if training_mode and self._is_spec_decoding and inference_info.uid == 'llama-13b-hf.20':
307+
if training_mode and self._is_spec_decoding and inference_info.uid == 'llama-30b-hf.15':
308308
self.pruner_manager.middle_states = hidden_states
309309

310310
# We chunk the inputs so that peak memory for long sequences fits into `autograd_memory`
@@ -464,7 +464,7 @@ def _flag_to_bool(value) -> bool:
464464
self.pruner_manager.train_model(middle_norm_hidden_states, final_logits, full_mask, inference_info.draft_tokens)
465465

466466
training_lm_head_mode = True
467-
if training_mode and training_lm_head_mode and self._is_spec_decoding and inference_info.uid == 'llama-13b-hf.39':
467+
if training_mode and training_lm_head_mode and self._is_spec_decoding and inference_info.uid == 'llama-30b-hf.59':
468468
logger.info(f"prepare training_lm_head_mode")
469469
norm_hidden_states = self.module.rms_norm(output_hidden_states)
470470
middle_norm_hidden_states = self.module.rms_norm(self.pruner_manager.middle_states)

src/bloombee/server/server.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def __init__(
324324
self.weight_home = array_1d(self.num_blocks, ValueHolder)
325325
self.path = os.path.join(tempfile.gettempdir(), 'data', 'llama_weights')
326326

327-
hidden_size = 5120
327+
hidden_size = 6656
328328
vocab_size = 32000
329329

330330
# Create configuration

src/bloombee/server/speculative_pruner/lm_head_trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,13 @@ def __init__(
2626

2727
# ── 用于推理 target 的 frozen 原始 LM head ─────────────
2828
self.original_lm_head = MidLMHead(hidden_size=hidden_size, vocab_size=vocab_size).to(device)
29-
self.original_lm_head.load_weight("/tmp/data/llama_weights/llama-13b-np")
29+
self.original_lm_head.load_weight("/tmp/data/llama_weights/llama-30b-np")
3030
self.original_lm_head.requires_grad_(False)
3131
self.original_lm_head.to(dtype=torch.bfloat16)
3232

3333
# ── 待训练的 LM head ────────────────────────────────────
3434
self.lm_head = MidLMHead(hidden_size=hidden_size, vocab_size=vocab_size).to(device)
35-
self.lm_head.load_weight("/tmp/data/llama_weights/llama-13b-np")
35+
self.lm_head.load_weight("/tmp/data/llama_weights/llama-30b-np")
3636
self.lm_head.to(dtype=torch.bfloat16)
3737

3838
self.optimizer_head = torch.optim.AdamW(self.lm_head.parameters(), lr=3e-5)

upload_file_hf.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from huggingface_hub import upload_file
1+
from huggingface_hub import login, upload_file
2+
3+
login(token="")
24

35
upload_file(
4-
path_or_fileobj="./checkpoints/lmhead/lm_head_llama13B-20.pt",
5-
path_in_repo="lm_head_llama13B-20.pt",
6+
path_or_fileobj="./checkpoints/lmhead/lm_head_llama30B-15.pt",
7+
path_in_repo="lm_head_llama30B-15.pt",
68
repo_id="xxiong59/lm-head-for-speculative-pruning",
79
repo_type="model"
810
)

0 commit comments

Comments
 (0)