Skip to content

Commit 508a6be

Browse files
committed
Add an example script to run lighteval based evaluation
1 parent a2a8f82 commit 508a6be

File tree

2 files changed

+146
-0
lines changed

2 files changed

+146
-0
lines changed

matrix/cluster/ray_cluster.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -392,3 +392,15 @@ def __enter__(self):
392392
def __exit__(self, _exc_type, _exc_value, _traceback):
393393
# Ensure the cluster is stopped when exiting the context
394394
self.stop()
395+
396+
def get_resources(self):
397+
import ray
398+
399+
cluster_info = self.cluster_info()
400+
assert cluster_info is not None, "Head is not ready"
401+
init_ray_if_necessary(cluster_info)
402+
return {
403+
"nodes": ray.nodes(),
404+
"total_resources": ray.cluster_resources(),
405+
"available_resources": ray.available_resources(),
406+
}
Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,134 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
# pip install lighteval[litellm]
8+
9+
import time
10+
11+
import fire
12+
import lighteval
13+
import yaml
14+
from lighteval.logging.evaluation_tracker import EvaluationTracker
15+
from lighteval.models.endpoints.litellm_model import LiteLLMModelConfig
16+
from lighteval.pipeline import ParallelismManager, Pipeline, PipelineParameters
17+
18+
import matrix
19+
20+
21+
def main(
22+
cluster_id="lighteval_test",
23+
model_name="/datasets/pretrained-llms/Llama-3.1-8B-Instruct",
24+
num_replicas=8,
25+
app_name="8B",
26+
slurm_account="data",
27+
slurm_qos="data_high",
28+
eval_task="lighteval|math_500|0",
29+
max_samples: int | None = None,
30+
cleanup=False,
31+
):
32+
33+
def setup():
34+
cli = matrix.Cli(
35+
cluster_id=cluster_id,
36+
)
37+
num_gpu = 0
38+
try:
39+
resources = cli.cluster.get_resources()
40+
num_gpu = resources["available_resources"].get("GPU", 0)
41+
except:
42+
pass
43+
44+
if num_gpu == 0:
45+
cli.start_cluster(
46+
add_workers=1,
47+
slurm={
48+
"account": slurm_account,
49+
"qos": slurm_qos,
50+
},
51+
)
52+
53+
try:
54+
cli.get_app_metadata(app_name)
55+
except Exception as e:
56+
if "uknown app_name" in str(e):
57+
cli.deploy_applications(
58+
applications=[
59+
{
60+
"model_name": model_name,
61+
"min_replica": num_replicas,
62+
"name": app_name,
63+
"model_size": "8B",
64+
}
65+
]
66+
)
67+
while (status := cli.app.app_status(app_name)) != "RUNNING":
68+
print(f"{app_name} not ready, current status {status}")
69+
time.sleep(10)
70+
base_url = cli.get_app_metadata(app_name)["endpoints"]["head"]
71+
return cli, base_url
72+
73+
# https://huggingface.co/docs/lighteval/en/using-the-python-api
74+
def run_eval(base_url):
75+
evaluation_tracker = EvaluationTracker(
76+
output_dir="./results",
77+
save_details=True,
78+
push_to_hub=False,
79+
)
80+
81+
pipeline_params = PipelineParameters(
82+
launcher_type=ParallelismManager.OPENAI,
83+
max_samples=max_samples,
84+
)
85+
86+
yaml_str = f"""
87+
model_parameters:
88+
model_name: "openai/{model_name}"
89+
provider: "openai"
90+
base_url: {base_url}
91+
api_key: "EMPTY"
92+
generation_parameters:
93+
temperature: 0.6
94+
max_new_tokens: 16384
95+
top_p: 0.95
96+
seed: 42
97+
repetition_penalty: 1.0
98+
frequency_penalty: 0.0
99+
"""
100+
data: dict = yaml.safe_load(yaml_str)
101+
102+
model_config = LiteLLMModelConfig(**data["model_parameters"])
103+
104+
pipeline = Pipeline(
105+
tasks=eval_task,
106+
pipeline_parameters=pipeline_params,
107+
evaluation_tracker=evaluation_tracker,
108+
model_config=model_config,
109+
)
110+
111+
pipeline.evaluate()
112+
pipeline.save_and_push_results()
113+
pipeline.show_results()
114+
115+
cli = None
116+
try:
117+
cli, base_url = setup()
118+
run_eval(base_url)
119+
finally:
120+
if cleanup and cli is not None:
121+
cli.deploy_applications(
122+
action=matrix.utils.ray.Action.REMOVE,
123+
applications=[
124+
{
125+
"name": app_name,
126+
}
127+
],
128+
)
129+
130+
cli.stop_cluster()
131+
132+
133+
if __name__ == "__main__":
134+
fire.Fire(main)

0 commit comments

Comments
 (0)