Skip to content

Commit 1eac866

Browse files
authored
LLMPerfV2 (#19)
* LLMPerfV2 The latest version of LLMPerf brings a suite of significant updates designed to provide more in-depth and customizable benchmarking capabilities for LLM inference. These updates include: - Expanded metrics with quantile distribution (P25-99): Comprehensive data representation for deeper insights. - Customizable benchmarking parameters: Tailor parameters to fit specific use case scenarios. - Introduction of load test and correctness test: Assessing performance and accuracy under stress. - Broad compatibility: Supports a range of products including [Anyscale Endpoints](https://www.anyscale.com/endpoints), [OpenAI](https://openai.com/blog/openai-api), [Anthropic](https://docs.anthropic.com/claude/reference/getting-started-with-the-api), [together.ai](http://together.ai/), [Fireworks.ai](https://app.fireworks.ai/), [Perplexity](https://www.perplexity.ai/), [Huggingface](https://huggingface.co/inference-endpoints), [Lepton AI](https://www.lepton.ai/docs/overview/model_apis), and various APIs supported by the [LiteLLM project](https://litellm.ai/)). - Easy addition of new LLMs via the LLMClient API. Signed-off-by: Avnish Narayan <[email protected]>
1 parent ae8a418 commit 1eac866

28 files changed

+2647
-1890
lines changed

.gitignore

+243-157
Large diffs are not rendered by default.

README.md

+383-54
Large diffs are not rendered by default.

analyze-raw.ipynb

-588
This file was deleted.

analyze-token-benchmark-results.ipynb

+327
Large diffs are not rendered by default.

configs.py

-29
This file was deleted.

env_sample.txt

-19
This file was deleted.

llm_correctness.py

+309
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,309 @@
1+
import argparse
2+
import json
3+
import os
4+
from pathlib import Path
5+
import random
6+
import re
7+
import time
8+
from typing import Any, Dict, List, Optional, Tuple
9+
10+
import num2words
11+
import ray
12+
from tqdm import tqdm
13+
14+
from llmperf import common_metrics
15+
from llmperf.common import SUPPORTED_APIS, construct_clients
16+
from llmperf.models import RequestConfig
17+
from llmperf.requests_launcher import RequestsLauncher
18+
from llmperf.utils import (
19+
LLMPerfResults,
20+
)
21+
22+
MAX_RANDOM_NUMBER = 10000
23+
24+
25+
def llm_correctness(
26+
model: str,
27+
additional_sampling_params: Optional[Dict[str, Any]] = None,
28+
num_concurrent_requests: int = 1,
29+
max_num_completed_requests: int = 500,
30+
test_timeout_s=90,
31+
llm_api="chat",
32+
) -> Tuple[Dict[str, Any], List[Dict[str, Any]]]:
33+
"""Get the token throughput and latencies for the given model.
34+
35+
Args:
36+
model: The name of the model to query.
37+
additional_sampling_params: Additional sampling parameters to send with the request.
38+
For more information see the LLM APIs documentation for the completions
39+
num_concurrent_requests: The number of concurrent requests to make. Increase
40+
this to increase the amount of load and vice versa.
41+
test_timeout_s: The amount of time to run the test for before reporting results.
42+
llm_api: The type of request to make. Either "chat" or "litellm".
43+
44+
Returns:
45+
A tuple containing summary metrics and raw results from the test.
46+
47+
"""
48+
49+
if not additional_sampling_params:
50+
additional_sampling_params = {}
51+
52+
clients = construct_clients(llm_api=llm_api, num_clients=num_concurrent_requests)
53+
req_launcher = RequestsLauncher(clients)
54+
start_time = time.monotonic()
55+
56+
num_errored_requests = 0
57+
num_mismatched_requests = 0
58+
num_completed_requests = 0
59+
60+
sampling_params = {"temperature": 0.0}
61+
sampling_params.update(additional_sampling_params)
62+
completed_requests = []
63+
iter = 0
64+
pbar = tqdm(total=max_num_completed_requests)
65+
while (
66+
time.monotonic() - start_time < test_timeout_s
67+
and num_completed_requests < max_num_completed_requests
68+
):
69+
iter += 1
70+
rnd_number = random.randint(0, MAX_RANDOM_NUMBER)
71+
rnd_num_words = num2words.num2words(rnd_number)
72+
73+
prompt = f"Convert the following sequence of words into a number: {rnd_num_words}.\nPrint the number first."
74+
75+
request_config = RequestConfig(
76+
model=model,
77+
prompt=(prompt, 0),
78+
sampling_params=sampling_params,
79+
metadata={"rnd_number": rnd_number},
80+
llm_api=llm_api,
81+
)
82+
req_launcher.launch_requests(request_config)
83+
84+
if not (iter % num_concurrent_requests):
85+
completed_requests.extend(req_launcher.get_next_ready())
86+
pbar.update(len(completed_requests) - num_completed_requests)
87+
num_completed_requests = len(completed_requests)
88+
89+
pbar.close()
90+
end_time = time.monotonic()
91+
if end_time - start_time >= test_timeout_s:
92+
print("Test timed out before all requests could be completed.")
93+
94+
raw_results = []
95+
96+
print("Mismatched and errored requests.")
97+
for out in completed_requests:
98+
metrics, generated_text, completed_request_config = out
99+
100+
raw_results.append(
101+
{
102+
"metrics": metrics,
103+
"generated_text": generated_text,
104+
"request_config": dict(completed_request_config),
105+
}
106+
)
107+
108+
# if there were no errors when making request.
109+
if not metrics[common_metrics.ERROR_CODE]:
110+
try:
111+
commas_between_numbers_re = r"(\d+),(?=\d)"
112+
gen_text_commas_removed = re.sub(
113+
commas_between_numbers_re, r"\1", generated_text
114+
)
115+
nums = re.findall(r"\d+", gen_text_commas_removed)
116+
generated_text = gen_text_commas_removed.replace("\n", " ")
117+
118+
assert str(completed_request_config.metadata["rnd_number"]) in nums
119+
except:
120+
num_mismatched_requests += 1
121+
print(
122+
f" mismatched request: {generated_text}, expected: {completed_request_config.metadata['rnd_number']}"
123+
)
124+
else:
125+
num_errored_requests += 1
126+
print(
127+
f" The request errored: {metrics[common_metrics.ERROR_CODE]}, "
128+
f"{metrics[common_metrics.ERROR_MSG]} "
129+
)
130+
print()
131+
132+
error_rate = num_errored_requests / num_completed_requests
133+
mismatch_rate = num_mismatched_requests / num_completed_requests
134+
num_non_errored_requests = num_completed_requests - num_errored_requests
135+
summary_metrics = {}
136+
summary_metrics[common_metrics.NUM_ERRORS] = num_errored_requests
137+
summary_metrics["num_mismatched_requests"] = num_mismatched_requests
138+
summary_metrics["error_rate"] = error_rate
139+
summary_metrics["mismatch_rate"] = mismatch_rate
140+
summary_metrics[common_metrics.NUM_COMPLETED_REQUESTS] = num_completed_requests
141+
summary_metrics["num_non_errored_requests"] = num_non_errored_requests
142+
143+
# Metadata
144+
summary_metrics["model"] = model
145+
summary_metrics["num_concurrent_requests"] = num_concurrent_requests
146+
summary_metrics["additional_sampling_params"] = additional_sampling_params
147+
summary_metrics["llm_api"] = llm_api
148+
149+
return summary_metrics, raw_results
150+
151+
152+
def run(
153+
llm_api: str,
154+
model: str,
155+
test_timeout_s: int,
156+
max_num_completed_requests: int,
157+
num_concurrent_requests: int,
158+
additional_sampling_params: str,
159+
results_dir: str,
160+
user_metadata: Dict[str, str],
161+
):
162+
"""
163+
Args:
164+
llm_api: The type of request to make. Either "chat" or "litellm".
165+
model: The name of the model to query.
166+
max_num_completed_requests: The number of requests to complete before finishing the test.
167+
test_timeout_s: The amount of time to run the test for before reporting results.
168+
num_concurrent_requests: The number of concurrent requests to make. Increase
169+
this to increase the amount of load and vice versa.
170+
mean_input_tokens: The mean number of tokens to send in the prompt for the request.
171+
stddev_input_tokens: The standard deviation of the number of tokens to send in the prompt for the request.
172+
mean_output_tokens: The mean number of tokens to generate per request.
173+
stddev_output_tokens: The standard deviation of the number of tokens to generate per request.
174+
additional_sampling_params: Additional sampling parameters to send with the request.
175+
For more information see the LLM APIs documentation for the completions.
176+
results_dir: The directory to save the results to.
177+
178+
"""
179+
180+
summary_metrics, raw_results = llm_correctness(
181+
model=model,
182+
llm_api=llm_api,
183+
test_timeout_s=test_timeout_s,
184+
max_num_completed_requests=max_num_completed_requests,
185+
num_concurrent_requests=num_concurrent_requests,
186+
additional_sampling_params=json.loads(additional_sampling_params),
187+
)
188+
189+
time.sleep(2)
190+
191+
print(
192+
f"Results for llm correctness test for {model} queried with the {llm_api} api."
193+
)
194+
print(
195+
f"Errors: {summary_metrics[common_metrics.NUM_ERRORS]}, "
196+
f"Error rate: {summary_metrics['error_rate']}"
197+
)
198+
199+
print(
200+
f"Mismatched: {summary_metrics['num_mismatched_requests']}, "
201+
f"Mismatch rate: {summary_metrics['mismatch_rate']}"
202+
)
203+
print(f"Completed: {summary_metrics[common_metrics.NUM_COMPLETED_REQUESTS]}")
204+
print(f"Completed without errors: {summary_metrics['num_non_errored_requests']}")
205+
206+
if results_dir:
207+
file_name = f"{model}_correctness"
208+
file_name = re.sub(r"[^\w\d-]+", "-", file_name)
209+
file_name = re.sub(r"-{2,}", "-", file_name)
210+
summary_file_name = f"{file_name}_summary"
211+
individual_responses_filename = f"{file_name}_individual_responses"
212+
summary_metrics.update(user_metadata)
213+
results = LLMPerfResults(name=summary_file_name, metadata=summary_metrics)
214+
results_dir = Path(results_dir)
215+
if not results_dir.exists():
216+
results_dir.mkdir(parents=True)
217+
elif not results_dir.is_dir():
218+
raise ValueError(f"{results_dir} is not a directory")
219+
with open(results_dir / f"{summary_file_name}.json", "w") as f:
220+
json.dump(results.to_dict(), f, indent=4)
221+
with open(results_dir / f"{individual_responses_filename}.json", "w") as f:
222+
json.dump(raw_results, f, indent=4)
223+
224+
225+
args = argparse.ArgumentParser(description="Run a correctness test for a given model.")
226+
227+
args.add_argument(
228+
"--model", type=str, required=True, help="The model to use for this load test."
229+
)
230+
args.add_argument(
231+
"--num-concurrent-requests",
232+
type=int,
233+
default=10,
234+
help=("The number of concurrent requests to send. (default: %(default)s)"),
235+
)
236+
args.add_argument(
237+
"--timeout",
238+
type=int,
239+
default=90,
240+
help="The amount of time to run the load test for. (default: %(default)s)",
241+
)
242+
args.add_argument(
243+
"--max-num-completed-requests",
244+
type=int,
245+
default=50,
246+
help=(
247+
"The number of requests to complete before finishing the test. Note "
248+
"that its possible for the test to timeout first. (default: %(default)s)"
249+
),
250+
)
251+
args.add_argument(
252+
"--additional-sampling-params",
253+
type=str,
254+
default="{}",
255+
help=(
256+
"Additional sampling params to send with the each request to the LLM API. "
257+
"(default: %(default)s) No additional sampling params are sent."
258+
),
259+
)
260+
args.add_argument(
261+
"--results-dir",
262+
type=str,
263+
default="",
264+
help=(
265+
"The directory to save the results to. "
266+
"(`default: %(default)s`) No results are saved)"
267+
),
268+
)
269+
args.add_argument(
270+
"--llm-api",
271+
type=str,
272+
default="openai",
273+
help=(
274+
f"The type of request to make. The supported llm apis are {SUPPORTED_APIS} "
275+
" (`default: %(default)s`)"
276+
),
277+
)
278+
args.add_argument(
279+
"--metadata",
280+
type=str,
281+
default="",
282+
help=(
283+
"A comma separated list of metadata to include in the results, e.g. "
284+
"name=foo,bar=1. These will be added to the metadata field of the results. "
285+
),
286+
)
287+
288+
if __name__ == "__main__":
289+
args = args.parse_args()
290+
291+
env_vars = dict(os.environ)
292+
ray.init(runtime_env={"env_vars": env_vars})
293+
# Parse user metadata.
294+
user_metadata = {}
295+
if args.metadata:
296+
for item in args.metadata.split(","):
297+
key, value = item.split("=")
298+
user_metadata[key] = value
299+
300+
run(
301+
llm_api=args.llm_api,
302+
model=args.model,
303+
test_timeout_s=args.timeout,
304+
max_num_completed_requests=args.max_num_completed_requests,
305+
num_concurrent_requests=args.num_concurrent_requests,
306+
additional_sampling_params=args.additional_sampling_params,
307+
results_dir=args.results_dir,
308+
user_metadata=user_metadata,
309+
)

0 commit comments

Comments
 (0)