Skip to content

Commit ccc4a10

Browse files
author
Xu Xiong
committed
llama30B
1 parent 5765675 commit ccc4a10

9 files changed

Lines changed: 49 additions & 42 deletions

File tree

benchmarks/benchmark_speculative_decoding.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def benchmark_inference(process_idx, args, result_pipe):
7373

7474
drafter = MultiSSMDrafter(
7575
ssm_model_name="JackFram/llama-68m",
76-
num_workers=1,
76+
num_workers=4,
7777
device="cuda"
7878
)
7979
model = AutoDistributedSpeculativeModel.from_pretrained(
@@ -82,12 +82,12 @@ def benchmark_inference(process_idx, args, result_pipe):
8282
tokenizer = AutoTokenizer.from_pretrained(args.model, use_fast=False)
8383

8484
batch_size = getattr(args, 'batch_size', 8)
85-
# dataset = load_dataset("tatsu-lab/alpaca")["train"]
86-
# indices = random.sample(range(len(dataset)), batch_size)
87-
# sampled = dataset.select(indices)
88-
# test_prompts = []
89-
# for item in sampled:
90-
# test_prompts.append(item["instruction"])
85+
dataset = load_dataset("tatsu-lab/alpaca")["train"]
86+
indices = random.sample(range(len(dataset)), batch_size)
87+
sampled = dataset.select(indices)
88+
test_prompts = []
89+
for item in sampled:
90+
test_prompts.append(item["instruction"])
9191

9292
# base_prompt = (
9393
# "Quantum mechanics explains the behavior of particles at very small scales. "
@@ -104,11 +104,11 @@ def benchmark_inference(process_idx, args, result_pipe):
104104
# f"{base_prompt} Example {i + 1} discusses large-scale AI systems and scientific discovery."
105105
# for i in range(batch_size)
106106
# ]
107-
prompt_indices = [args.prompt_start_index + i for i in range(batch_size)]
108-
if "{i}" not in args.prompt_template:
109-
raise ValueError("--prompt_template must include '{i}' placeholder")
110-
prompts = [args.prompt_template.format(i=i) for i in prompt_indices]
111-
test_prompts = prompts
107+
# prompt_indices = [args.prompt_start_index + i for i in range(batch_size)]
108+
# if "{i}" not in args.prompt_template:
109+
# raise ValueError("--prompt_template must include '{i}' placeholder")
110+
# prompts = [args.prompt_template.format(i=i) for i in prompt_indices]
111+
# test_prompts = prompts
112112

113113
tokenizer.pad_token = tokenizer.eos_token
114114
input_ids = tokenizer(test_prompts, return_tensors="pt", padding=True).to(device)["input_ids"]

src/bloombee/models/llama/spec_decoding_drafter.py

Lines changed: 27 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,22 +15,26 @@ def __init__(self, ssm_model_name: str, num_workers: int = 2, device: str = 'cud
1515
from transformers import AutoModelForCausalLM
1616

1717
self.num_workers = num_workers
18-
self.device = torch.device(device)
1918

2019
self.ssms = []
2120
self.streams = []
22-
for _ in range(num_workers):
21+
self.devices = []
22+
23+
for i in range(num_workers):
24+
device_i = torch.device(f'cuda:{i}')
25+
self.devices.append(device_i)
26+
2327
ssm = AutoModelForCausalLM.from_pretrained(
2428
ssm_model_name,
2529
torch_dtype=torch.float16)
26-
ssm = ssm.to(self.device)
30+
ssm = ssm.to(device_i)
2731
ssm.eval()
2832
self.ssms.append(ssm)
29-
self.streams.append(torch.cuda.Stream(device=self.device))
33+
self.streams.append(torch.cuda.Stream(device=device_i))
3034

3135
with torch.no_grad():
32-
dummy = torch.ones(1, 8, dtype=torch.long, device=self.device)
33-
for ssm in self.ssms:
36+
for i, ssm in enumerate(self.ssms):
37+
dummy = torch.ones(1, 8, dtype=torch.long, device=self.devices[i])
3438
ssm(dummy, attention_mask=torch.ones_like(dummy))
3539

3640
def build_trees_parallel(
@@ -49,10 +53,11 @@ def build_trees_parallel(
4953
def worker_fn(worker_idx: int, batch_indices: List[int]):
5054
ssm = self.ssms[worker_idx]
5155
stream = self.streams[worker_idx]
56+
device = self.devices[worker_idx]
5257

5358
with torch.cuda.stream(stream):
5459
results = self._build_trees_batched(
55-
batch_indices, input_ids, seq_lengths, ssm, beam_width, max_depth
60+
batch_indices, input_ids, seq_lengths, ssm, beam_width, max_depth, device
5661
)
5762
for batch_idx, tree in results:
5863
all_results[batch_idx] = tree
@@ -72,8 +77,9 @@ def worker_fn(worker_idx: int, batch_indices: List[int]):
7277
t.join()
7378

7479
# 同步所有 streams
75-
for stream in self.streams:
76-
stream.synchronize()
80+
for i, stream in enumerate(self.streams):
81+
with torch.cuda.device(self.devices[i]):
82+
stream.synchronize()
7783

7884
return all_results
7985

@@ -85,9 +91,10 @@ def _build_trees_batched(
8591
ssm,
8692
beam_width: int,
8793
max_depth: int,
94+
device: torch.device,
8895
) -> List:
8996

90-
pad_token_id = getattr(ssm.config, 'pad_token_id', 0)
97+
pad_token_id = getattr(ssm.config, 'pad_token_id', None) or 0
9198

9299
trees = {}
93100
valid_inputs = {}
@@ -96,7 +103,7 @@ def _build_trees_batched(
96103

97104
for batch_idx in batch_indices:
98105
actual_len = seq_lengths[batch_idx].item()
99-
valid_input_ids = input_ids[batch_idx, :actual_len]
106+
valid_input_ids = input_ids[batch_idx, :actual_len].to(device)
100107
valid_inputs[batch_idx] = valid_input_ids
101108
prefix_lengths[batch_idx] = max(actual_len - 1, 0)
102109

@@ -118,23 +125,23 @@ def _build_trees_batched(
118125
if pf_len > 0:
119126
prefix = valid_inputs[batch_idx][:-1]
120127
else:
121-
prefix = torch.tensor([], dtype=torch.long, device=self.device)
128+
prefix = torch.tensor([], dtype=torch.long, device=device)
122129

123130
pad_len = max_prefix_len - pf_len
124131

125132
if pf_len > 0:
126133
padded_prefixes.append(torch.cat([
127-
torch.full((pad_len,), pad_token_id, dtype=torch.long, device=self.device),
134+
torch.full((pad_len,), pad_token_id, dtype=torch.long, device=device),
128135
prefix
129136
]))
130137
else:
131138
padded_prefixes.append(
132-
torch.full((max_prefix_len,), pad_token_id, dtype=torch.long, device=self.device)
139+
torch.full((max_prefix_len,), pad_token_id, dtype=torch.long, device=device)
133140
)
134141

135142
prefix_masks.append(torch.cat([
136-
torch.zeros(pad_len, dtype=torch.long, device=self.device),
137-
torch.ones(pf_len, dtype=torch.long, device=self.device)
143+
torch.zeros(pad_len, dtype=torch.long, device=device),
144+
torch.ones(pf_len, dtype=torch.long, device=device)
138145
]))
139146

140147
batch_prefixes = torch.stack(padded_prefixes)
@@ -166,7 +173,7 @@ def _build_trees_batched(
166173

167174
for node in tree.get_nodes_at_depth(depth):
168175
path = node.get_path_from_root()
169-
path_tokens = torch.tensor([root_token] + path, dtype=torch.long, device=self.device)
176+
path_tokens = torch.tensor([root_token] + path, dtype=torch.long, device=device)
170177
all_paths.append(path_tokens)
171178
node_mapping.append((batch_idx, node))
172179
cache_indices.append(idx_map[batch_idx])
@@ -180,8 +187,8 @@ def _build_trees_batched(
180187
total_mask_len = max_pf_len + max_path_len
181188

182189
# 预分配
183-
batch_paths = torch.full((num_nodes, max_path_len), pad_token_id, dtype=torch.long, device=self.device)
184-
batch_path_masks = torch.zeros((num_nodes, total_mask_len), dtype=torch.long, device=self.device)
190+
batch_paths = torch.full((num_nodes, max_path_len), pad_token_id, dtype=torch.long, device=device)
191+
batch_path_masks = torch.zeros((num_nodes, total_mask_len), dtype=torch.long, device=device)
185192

186193
# 填充
187194
for i, path in enumerate(all_paths):
@@ -212,7 +219,7 @@ def _build_trees_batched(
212219
all_logits = outputs.logits[:, -1, :]
213220

214221
t_forward += time.perf_counter() - t0
215-
t0 = time.perf_counter()
222+
t0 = time.perf_counter()
216223
# 批量 topk
217224
_, all_top_k_indices = torch.topk(all_logits, k=beam_width, dim=-1)
218225
all_probs = torch.softmax(all_logits, dim=-1)

src/bloombee/models/llama/speculative_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def generate(
3535
logits_processor: Optional[LogitsProcessorList] = None,
3636
stopping_criteria: Optional[StoppingCriteriaList] = None,
3737
streamer: Optional["BaseStreamer"] = None,
38-
beam_width: int = 1,
39-
max_tree_depth: int = 4,
38+
beam_width: int = 2,
39+
max_tree_depth: int = 3,
4040
use_kv_cache: bool = True,
4141
kv_cache_window: int = 2048,
4242
max_new_tokens: int = 128,

src/bloombee/server/backend.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -369,8 +369,8 @@ def _flag_to_bool(value) -> bool:
369369
position_ids = self._position_ids_cache[cache_key] + (cache_len + offset)
370370
if self._is_spec_decoding:
371371
rotary_position_ids = self._create_tree_position_ids_with_invalid_cache(
372-
width=1,
373-
depth=4,
372+
width=2,
373+
depth=3,
374374
prefill_length=inference_info.prefill_length - 1,
375375
kv_cache_position_ids=kv_cache_position_ids,
376376
batch_offset=inference_info.batch_offset,

src/bloombee/server/handler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -795,7 +795,7 @@ async def _cross_stage_push_wrapper(mb_hidden, mb_keep, push_metadata):
795795
push_tensor_bytes = sum(len(t.buffer) for t in next_tensors)
796796

797797
# 模拟网络传输延时
798-
NETWORK_SPEED_BYTES_PER_SEC = 5 * 1024 * 1024 # 10 MB/s
798+
NETWORK_SPEED_BYTES_PER_SEC = 10 * 1024 * 1024 # 10 MB/s
799799
transfer_delay = push_tensor_bytes / NETWORK_SPEED_BYTES_PER_SEC
800800
await asyncio.sleep(transfer_delay)
801801
task = asyncio.create_task(self._push_outputs(request, output_tensors, step_metadata))

src/bloombee/server/server.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -324,13 +324,13 @@ def __init__(
324324
self.weight_home = array_1d(self.num_blocks, ValueHolder)
325325
self.path = os.path.join(tempfile.gettempdir(), 'data', 'llama_weights')
326326

327-
hidden_size = 4096
327+
hidden_size = 6656
328328
vocab_size = 32000
329329

330330
# Create configuration
331331
config = PruningConfig(
332332
method=PruningMethod.ADAPTIVE_NEURAL,
333-
neural_threshold=0.9,
333+
neural_threshold=0.6,
334334
simple_threshold=0.1
335335
)
336336

src/bloombee/server/speculative_pruner/adaptive_neural_pruner.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __init__(
5555
self.lm_head = MidLMHead(hidden_size=hidden_size, vocab_size=vocab_size).to("cuda")
5656
lm_head_weights_path = hf_hub_download(
5757
repo_id="xxiong59/lm-head-for-speculative-pruning",
58-
filename="lm_head_weights_15.pt",
58+
filename="lm_head_llama30B-15.pt",
5959
cache_dir="./cache"
6060
)
6161
lm_head_checkpoint = torch.load(lm_head_weights_path, map_location="cuda")

src/bloombee/server/speculative_pruner/pruner_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(
4242
self.iteration = 0
4343
self.middle_states = None
4444

45-
train_lm_head_mode = True
45+
train_lm_head_mode = False
4646
self.lm_head_trainer = LM_head_trainer(hidden_size, vocab_size, device, config) if train_lm_head_mode else None
4747

4848
def switch_method(self, method: Union[str, PruningMethod], keep_stats: bool = False):

src/bloombee/utils/lossless_wrapper_config.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
"""
66

77
# 0 = disable lossless wrapper, 1 = enable
8-
ENABLE_LOSSLESS_WRAPPER = 0
8+
ENABLE_LOSSLESS_WRAPPER = 1
99

1010
# "zstd" (recommended), "zlib", "none"
1111
LOSSLESS_ALGO = "zstd"

0 commit comments

Comments
 (0)