-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
114 lines (95 loc) · 4.11 KB
/
Copy pathmain.py
File metadata and controls
114 lines (95 loc) · 4.11 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
from __future__ import annotations
import argparse
from analysis.generate_reports import generate_reports
from benchmark import BenchmarkRunner, IterativeOptimizer
from dashboard.build_dashboard import build_dashboard
from kv_cache_engine.config import load_config
from models import InferenceRunner
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(description="KVCacheX benchmark and inference CLI")
parser.add_argument("--config", default="config.yaml", help="Path to the YAML config file.")
subparsers = parser.add_subparsers(dest="command", required=True)
benchmark_parser = subparsers.add_parser("benchmark", help="Run the full benchmark suite.")
benchmark_parser.add_argument("--model-name", default=None, help="Override the model name.")
infer_parser = subparsers.add_parser("infer", help="Run a single inference job.")
infer_parser.add_argument("--prompt", required=True, help="Prompt text to evaluate.")
infer_parser.add_argument(
"--mode",
default="kvcachex",
choices=["no_cache", "standard_cache", "kvcachex"],
help="Inference mode.",
)
infer_parser.add_argument("--model-name", default=None, help="Override the model name.")
infer_parser.add_argument("--max-new-tokens", type=int, default=None, help="Decode length.")
optimize_parser = subparsers.add_parser(
"optimize", help="Run the iterative benchmark-driven optimizer loop."
)
optimize_parser.add_argument("--iterations", type=int, default=3, help="Optimization iterations.")
optimize_parser.add_argument("--model-name", default=None, help="Override the model name.")
dashboard_parser = subparsers.add_parser("dashboard", help="Build the HTML dashboard.")
dashboard_parser.add_argument("--model-name", default=None, help="Unused, reserved for parity.")
subparsers.add_parser("analyze", help="Generate analysis markdown from saved benchmark results.")
subparsers.add_parser(
"train-importance", help="Fit the token importance model from calibration workloads."
)
return parser
def main() -> None:
parser = build_parser()
args = parser.parse_args()
config = load_config(args.config)
if args.command == "benchmark":
runner = BenchmarkRunner(config)
metrics_df, _ = runner.run(model_name=args.model_name)
generate_reports(config)
build_dashboard(config)
print(
metrics_df[
[
"workload_name",
"mode",
"mean_latency_ms",
"peak_cache_bytes",
"compression_ratio",
"token_agreement_vs_standard",
]
].to_string(index=False)
)
return
if args.command == "infer":
runner = InferenceRunner(config)
artifacts = runner.run(
prompt=args.prompt,
workload_name="ad_hoc",
mode=args.mode,
max_new_tokens=args.max_new_tokens,
model_name=args.model_name,
)
print(artifacts.summary.output_text)
print(
f"\nmean_latency_ms={artifacts.summary.mean_latency_ms:.2f} "
f"tokens_per_sec={artifacts.summary.tokens_per_sec:.2f} "
f"peak_cache_bytes={artifacts.summary.peak_cache_bytes}"
)
return
if args.command == "optimize":
optimizer = IterativeOptimizer(config)
_, history = optimizer.run(iterations=args.iterations, model_name=args.model_name)
print(history)
return
if args.command == "dashboard":
output_path = build_dashboard(config)
print(output_path)
return
if args.command == "analyze":
generate_reports(config)
print(config.outputs.bottleneck_report)
print(config.outputs.failure_report)
return
if args.command == "train-importance":
runner = BenchmarkRunner(config)
workloads = runner.build_workloads()
runner.train_importance_predictor(workloads)
print(config.eviction.semantic_model_path)
return
if __name__ == "__main__":
main()