-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
162 lines (139 loc) · 5.38 KB
/
main.py
File metadata and controls
162 lines (139 loc) · 5.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
#!/usr/bin/env python 3
import argparse
import time
import json
import subprocess
import threading
import os
import sys
from urllib import request, error
from statistics import mean
from typing import List, Any
DEFAULT_MODEL_NAME = "some model"
DEFAULT_PROMPT_FILE = "prompts"
DEFAULT_MAX_TOKENS = 256
DEFAULT_GPU_SAMPLING_INTERVAL_S = 0.5
DEFAULT_BASE_URL = os.getenv("RHEL_AI_BASE_URL", "http://127.0.0.1:8000/v1")
DEFAULT_REQUEST_TIMEOUT = 300
DEFAULT_BATCH_SIZE = 1
def sample_gpu_utilization(
samples: List[int],
stop_event: threading.Event,
interval: float
) -> None:
while not stop_event.is_set():
try:
output = subprocess.check_output(
[
"nvidia-smi",
"--query-gpu=utilization.gpu",
"--format=csv,noheader,nounits"
]
)
samples.append(int(output.decode("utf-8").strip()))
except Exception:
pass
time.sleep(interval)
def load_prompts(prompt_file: str) -> List[str]:
with open(prompt_file, "r") as f:
return [line.strip() for line in f.readlines() if line.strip()]
def run_inference(
prompts: List[str],
model_name: str,
max_tokens: int,
base_url: str,
request_timeout: int
) -> dict[str, Any]:
payload = {
"prompt": prompts if len(prompts) > 1 else prompts[0],
"max_tokens": max_tokens
}
req = request.Request(
f"{base_url}/completions",
data=json.dumps(payload).encode("utf-8"),
headers={
"Content-Type": "application/json",
},
method="POST"
)
with request.urlopen(req, timeout=request_timeout) as resp:
return json.loads(resp.read().decode("utf-8"))
# Run benchmark prompts
def run_benchmark(args: argparse.Namespace) -> dict[str, Any]:
prompts = load_prompts(args.prompt_file)
if not prompts:
raise ValueError(f"No prompts file found in {args.prompt_file}")
gpu_samples: List[int] = []
stop_event = threading.Event()
gpu_thread = threading.Thread(
target = sample_gpu_utilization,
args=(gpu_samples, stop_event, args.gpu_sampling_interval),
daemon=True
)
gpu_thread.start()
start_time = time.time()
batches = [
prompts[i:i + args.batch_size]
for i in range(0, len(prompts), args.batch_size)
]
responses: List[dict[str, Any]]= []
failures: List[dict[str, Any]]= []
for idx, batch in enumerate(batches, start=1):
try:
responses.append(run_inference(
prompts=batch,
model_name=args.model,
max_tokens=args.max_tokens,
base_url=args.base_url,
request_timeout=args.request_timeout
))
except error.HTTPError as e:
body = e.read().decode("utf-8", errors="replace")
failures.append({"index": idx, "type": "http", "status": e.code, "error": body})
except Exception as e:
failures.append({"index": idx, "type": "exception", "status": str(e)})
duration = time.time() - start_time
stop_event.set()
gpu_thread.join()
# Calculate results
total_generated_tokens = sum(r.get("usage", {}).get("completion_tokens", 0) for r in responses)
tokens_per_sec = total_generated_tokens / duration if duration else 0
tokens_per_hour = tokens_per_sec * 3600
avg_gpu_util = mean(gpu_samples) if gpu_samples else 0
return {
"model": args.model,
"num_prompts": len(prompts),
"batch_size": args.batch_size,
"num_batches": len(batches),
"total_generated_tokens": total_generated_tokens,
"duration": duration,
"tokens_per_sec": tokens_per_sec,
"tokens_per_hour": tokens_per_hour,
"avg_gpu_utilization_percent": avg_gpu_util,
"successful_requests": len(responses),
"failed_requests": len(failures),
"errors": failures
}
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Benchmark prompts against a local RHEL AI inference server."
)
parser.add_argument("--model", default=DEFAULT_MODEL_NAME, help="Name of the huggingface model served by vLLM")
parser.add_argument("--prompt-file", default=DEFAULT_PROMPT_FILE, help="Path to file with prompt workload. Each line is turned into a prompt.")
parser.add_argument("--max-tokens", type=int, default=DEFAULT_MAX_TOKENS, help="Max number of completion tokens per prompt.")
parser.add_argument("--gpu-sampling-interval", type=float, default=DEFAULT_GPU_SAMPLING_INTERVAL_S, help="Interval in which to sa,ple GPU utilization (seconds).")
parser.add_argument("--base-url", default=DEFAULT_BASE_URL, help="Base URL for RHEL AI inference server.")
parser.add_argument("--request-timeout", type=int, default=DEFAULT_REQUEST_TIMEOUT, help="Specify request timeout for prompting.")
parser.add_argument("--batch-size", type=int, default=DEFAULT_BATCH_SIZE, help="Number of prompts to send in a single request.")
return parser.parse_args()
def main() -> int:
args = parse_args()
try:
results = run_benchmark(args)
except Exception as e:
print(json.dumps({"error": str(e)}, indent=2), file=sys.stderr)
return 1
print(json.dumps(results, indent=2))
return 0
if __name__ == "__main__":
raise SystemExit(main())