Skip to content

Commit 402ac2e

Browse files
authored
Truely use maxd; Optimize query start loc; Make evaluation use LLM (#86)
* Truely use maxd in Token Throttling * Optimize query_start_loc * Make evaluation use offline LLM
1 parent 6ed4378 commit 402ac2e

6 files changed

Lines changed: 58 additions & 55 deletions

File tree

README.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,8 +137,7 @@ python benchmarks/benchmark_prefix_serving.py \
137137

138138
### Evaluate Output Quality
139139
```
140-
# Launch server first
141-
python evaluations/evaluate_MMLU_pro.py --model $MODEL --port $PORT
140+
python evaluations/evaluate_MMLU_pro.py --model $MODEL
142141
```
143142

144143
## Supported Models
Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
# adopt from https://github.com/TIGER-AI-Lab/MMLU-Pro/blob/main/evaluate_from_api.py
2-
3-
import os
42
import re
53
import random
6-
from tqdm import tqdm
7-
from datasets import load_dataset
84
import argparse
9-
from benchmarks.backend_request_func import async_request_openai_chat_completions, RequestFuncInput
105
import asyncio
116

7+
from gllm import LLM
128

13-
API_KEY = "EMPTY"
14-
random.seed(12345)
9+
from tqdm import tqdm
10+
from datasets import load_dataset
1511

12+
random.seed(12345)
1613

1714
def load_mmlu_pro():
1815
dataset = load_dataset("TIGER-Lab/MMLU-Pro")
@@ -83,7 +80,7 @@ def extract_final(text):
8380
return None
8481

8582

86-
def single_request(api_url, single_question, cot_examples_dict, pbar):
83+
def single_request(single_question, cot_examples_dict):
8784
category = single_question["category"]
8885
cot_examples = cot_examples_dict[category]
8986
question = single_question["question"]
@@ -95,43 +92,44 @@ def single_request(api_url, single_question, cot_examples_dict, pbar):
9592
prompt += format_example(each["question"],
9693
each["options"], each["cot_content"])
9794
input_text = format_example(question, options)
98-
9995
prompt = prompt + input_text
100-
101-
request_func_input = RequestFuncInput(prompt=prompt,
102-
api_url=api_url,
103-
prompt_len=len(prompt),
104-
output_len=args.output_len,
105-
model=args.model,
106-
)
107-
return async_request_openai_chat_completions(request_func_input=request_func_input, pbar=pbar)
96+
97+
return prompt
10898

10999

110100

111101
async def evaluate(subjects):
112-
api_url = f"http://{args.host}:{args.port}/v1/chat/completions"
113102
test_df, dev_df = load_mmlu_pro()
114103
if not subjects:
115104
subjects = list(test_df.keys())
116105
print("assigned subjects", subjects)
117106
category_record = {'total':{'#correct':0,'#wrong':0}}
118107

119-
print(f"Sending requests ...")
120-
pbar = tqdm()
121-
tasks = []
108+
llm = LLM(model_path=args.model,
109+
gpu_memory_util=args.gpu_memory_util,
110+
kvthresh=args.kvthresh,
111+
pp_size=args.pp,
112+
tp_size=args.tp,
113+
enable_prefix_caching=True,
114+
use_thinking=False)
115+
116+
print(f"generating requests ...")
117+
prompts = []
122118
test_data_total = []
123119
for subject in subjects:
124120
test_data = test_df[subject][:args.num_per_sub]
125121
test_data_total.extend(test_data)
126122
for each in test_data:
127-
tasks.append(single_request(api_url, each, dev_df, pbar))
128-
pbar.total = len(tasks)
129-
completions = await asyncio.gather(*tasks)
130-
pbar.close()
123+
prompts.append(single_request(each, dev_df))
124+
125+
seqs = llm.generate(prompts, output_lens=[args.output_len for i in range(len(prompts))])
126+
127+
outputs = [seq.output for seq in seqs]
128+
131129
print(f"Processing completions ...")
132-
for idx, each in tqdm(enumerate(test_data_total),total=len(tasks)):
130+
for idx, each in tqdm(enumerate(test_data_total),total=len(prompts)):
133131
label = each["answer"]
134-
response = completions[idx].generated_text
132+
response = outputs[idx]
135133
response = response.replace('**', '')
136134
pred = extract_answer(response)
137135
category = each["category"]
@@ -162,8 +160,10 @@ async def evaluate(subjects):
162160
parser.add_argument("--assigned_subjects", "-a", type=str, default="all",
163161
help="business, law, psychology, biology, chemistry, history, other, health, "
164162
"economics, math, physics, computer science, philosophy, engineering")
165-
parser.add_argument("--host", type=str, default='0.0.0.0')
166-
parser.add_argument("--port", type=int, default=8000)
163+
parser.add_argument("--tp", type=int, default=1)
164+
parser.add_argument("--pp", type=int, default=1)
165+
parser.add_argument('--gpu-memory-util', type=float, default=0.9)
166+
parser.add_argument('--kvthresh', type=float, default=0.2)
167167
parser.add_argument("--output-len", type=int, default=1024)
168168
parser.add_argument("--num-per-sub", type=int, default=100)
169169
assigned_subjects = []

gllm/entrypoints/api_server.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -102,11 +102,11 @@ async def run_server(args):
102102
parser.add_argument('--disable-ep', help='Disable expert parallelism (EP is enable by default)', action='store_true')
103103
parser.add_argument('--assigned-layers', type=str, help='If the model have 64 layers, we can set it to 16,16,16,16 or 16,16,17,15', default=None)
104104
# Token Throttling
105-
parser.add_argument('--maxd', type=int, help='Maximum decode token count, used in LLM (offline infernce)', default=512)
106-
parser.add_argument('--maxp', type=int, help='Maximum token count in prefill', default=2048)
107-
parser.add_argument('--minp', type=int, help='Minimum token count in prefill, used in PipeAsyncLLM', default=32)
108-
parser.add_argument('--iterp', type=int, help='Number of iterations to process waiting prefill tokens, used in PipeAsyncLLM', default=8)
109-
parser.add_argument('--kvthresh', type=float, help='KV cache threshold for prefill operations', default=0.05)
105+
parser.add_argument('--maxd', type=int, help='Maximum decode token count per batch (Token Throttling)', default=2048)
106+
parser.add_argument('--maxp', type=int, help='Maximum prefill token count per batch (Token Throttling) or token budget in Sarathi-Serve', default=2048)
107+
parser.add_argument('--minp', type=int, help='Minimum prefill token count per batch (Token Throttling)', default=32)
108+
parser.add_argument('--iterp', type=int, help='Number of iterations to process waiting prefill tokens (Token Throttling)', default=8)
109+
parser.add_argument('--kvthresh', type=float, help='KV cache threshold for prefill operations (Token Throttling)', default=0.05)
110110
parser.add_argument('--use-naive-schedule', help='Use scheduling policy in Sarathi-Serve', action='store_true')
111111
# Multi-Node deployment
112112
parser.add_argument('--launch-mode', type=str, choices=['normal', 'master', 'slave'], default='normal')

gllm/input_data.py

Lines changed: 14 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ def __init__(self, seqs: List[Sequence], memory_manager: MemoryManager):
2929
self.slot_mapping_tensor = self.get_slot_mapping()
3030
self.tokens = self.get_tokens()
3131
self.positions = self.get_position()
32-
self.max_seq_len, self.seq_start_loc = self.get_seq_len_loc()
32+
self.max_seq_len, self.seq_start_loc = self.get_seq_lens()
3333
self.block_table = self.get_block_table()
34-
self.max_query_len, self.query_start_loc = self.get_query_len_loc()
34+
self.max_query_len, self.query_start_loc = self.get_query_start_loc()
3535

3636
assert self.tokens.shape == self.positions.shape
3737

@@ -50,21 +50,18 @@ def get_position(self):
5050
return async_tensor_h2d(
5151
positions_list, torch.long, 'cuda', True)
5252

53-
def get_seq_len_loc(self):
54-
seq_start_loc = [seq.seq_len for seq in self.seqs]
55-
max_seqlen = max(seq_start_loc)
56-
return max_seqlen, async_tensor_h2d(seq_start_loc, torch.int32, 'cuda', True)
53+
def get_seq_lens(self):
54+
seq_lens = [seq.seq_len for seq in self.seqs]
55+
max_seqlen = max(seq_lens)
56+
return max_seqlen, async_tensor_h2d(seq_lens, torch.int32, 'cuda', True)
5757

58-
def get_query_len_loc(self):
59-
max_query_len = 0
60-
cu_query_len = 0
61-
query_start_loc = [0]
62-
for seq in self.seqs:
63-
query_len = seq.to_compute_token_num
64-
cu_query_len += query_len
65-
query_start_loc.append(cu_query_len)
66-
max_query_len = max(query_len, max_query_len)
67-
return max_query_len, async_tensor_h2d(query_start_loc, torch.int32, 'cuda', True)
58+
def get_query_start_loc(self):
59+
query_lens = [0] + [seq.to_compute_token_num for seq in self.seqs]
60+
max_query_len = max(query_lens)
61+
query_start_loc = torch.from_numpy(np.cumsum(query_lens)).to(device='cuda',
62+
dtype=torch.int32,
63+
non_blocking=True)
64+
return max_query_len, query_start_loc
6865

6966
def get_block_table(self):
7067
block_tables_list = [seq.page_table for seq in self.seqs]
@@ -73,7 +70,7 @@ def get_block_table(self):
7370
(len(block_tables_list), max_num_block), 0, dtype=np.int32)
7471
for idx, block_table in enumerate(block_tables_list):
7572
block_tables[idx, :len(block_table)] = block_table
76-
return torch.from_numpy(block_tables).to(device='cuda',non_blocking=True)
73+
return torch.from_numpy(block_tables).to(device='cuda', non_blocking=True)
7774

7875
def get_slot_mapping(self):
7976
slot_mapping = []

gllm/worker.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def init(self):
6565
self.pp_size,
6666
self.model_runner.memory_manager,
6767
self.use_naive_schedule,
68+
self.model_runner.maxd,
6869
self.model_runner.maxp,
6970
self.model_runner.minp,
7071
self.model_runner.iterp,

gllm/worker_scheduler.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313

1414

1515
class WorkerScheduler():
16-
def __init__(self, pp_size, memory_manager:MemoryManager, use_naive_schedule, maxp, minp, iterp, page_size, kvthresh):
16+
def __init__(self, pp_size, memory_manager:MemoryManager, use_naive_schedule,
17+
maxd, maxp, minp, iterp, page_size, kvthresh):
1718
self.pp_size = pp_size
1819
self.memory_manager = memory_manager
1920
self.use_naive_schedule = use_naive_schedule
21+
self.maxd = maxd
2022
self.maxp = maxp
2123
self.minp = minp
2224
self.iterp = iterp
@@ -35,6 +37,7 @@ def __init__(self, pp_size, memory_manager:MemoryManager, use_naive_schedule, ma
3537
# preempt seqs
3638
self.num_preempt_seqs = 0
3739
self.log_num_preempt_seqs = 0
40+
self.delta_log_num_preempt_seqs = 10
3841
# num wait tokens
3942
self.num_wait_tokens = 0
4043
# abort ids
@@ -102,8 +105,9 @@ def check_preempt(self, num_decode_tokens):
102105
self.seqs_to_prefill.extendleft(preempt_seqs)
103106

104107
self.num_preempt_seqs += len(preempt_seqs)
105-
if self.num_preempt_seqs - self.log_num_preempt_seqs >= 10:
108+
if self.num_preempt_seqs - self.log_num_preempt_seqs >= self.delta_log_num_preempt_seqs:
106109
self.log_num_preempt_seqs = self.num_preempt_seqs
110+
self.delta_log_num_preempt_seqs *= 2
107111
logger.warning(f'#Preempted seqs: {self.num_preempt_seqs}, Try increase --kvthresh or the performance is poor!')
108112

109113
def check_abort_seqs_list(self, seqs:deque, ipc_package:IPCPackage):
@@ -240,6 +244,8 @@ def schedule(self):
240244
# because we want to solve the situation when #seqs=5 pp_size=4
241245
decode_token_budget = (
242246
num_total_decode_seqs + random.randint(0, self.pp_size-1)) // self.pp_size
247+
248+
decode_token_budget = min(self.maxd, decode_token_budget)
243249

244250
self.check_preempt(decode_token_budget)
245251

0 commit comments

Comments
 (0)