Skip to content

Commit aa7b451

Browse files
committed
fix for single process
1 parent 2d96e47 commit aa7b451

File tree

12 files changed

+140
-26
lines changed

12 files changed

+140
-26
lines changed

eval/chat_benchmarks/HumanEval/eval_instruct.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
112112
self.logger.info("Generating responses for Human Eval...")
113113
outputs = self.compute(model, all_instances)
114114

115-
if model.accelerator.process_index != 0:
115+
is_main_process = lm.accelerator.process_index == 0 if hasattr(lm, 'accelerator') else lm.world_size <= 1
116+
if not is_main_process:
116117
continue
117118

118119
generated_examples = []

eval/chat_benchmarks/IFEval/eval_instruct.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,8 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
115115
self.logger.info("Generating responses...")
116116
outputs = self.compute(model, all_instances)
117117

118-
if model.accelerator.process_index != 0:
118+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
119+
if not is_main_process:
119120
return None
120121

121122
generated_examples = []

eval/chat_benchmarks/MBPP/eval_instruct.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
161161
outputs = self.compute(model, all_instances)
162162

163163
# Return None early for non-primary ranks
164-
if model.accelerator.process_index != 0:
164+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
165+
if not is_main_process:
165166
return None
166167

167168
generated_examples = []

eval/chat_benchmarks/MTBench/eval_instruct.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,8 @@ def get_model_answers(self, model: LM, model_id: str, questions: List[Dict[str,
151151
all_convs[q_idx].append({"role": "assistant", "content": output})
152152
all_choices[q_idx]["turns"].append(output)
153153

154-
if model.accelerator.process_index != 0:
154+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
155+
if not is_main_process:
155156
continue
156157

157158
# Save completed conversations

eval/chat_benchmarks/MixEval/eval_instruct.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -132,18 +132,20 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
132132
out_dict = {}
133133

134134
self.logger.info("Generating responses for MixEval...")
135+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
136+
135137
for split in splits:
136138
self.args.split = split
137139
all_results = self._eval_split(model, split)
138-
if model.accelerator.process_index == 0:
140+
if is_main_process:
139141
response_file = self._get_response_file()
140142
with open(response_file, "w") as f:
141143
for result in all_results:
142144
f.write(json.dumps(result) + "\n")
143145
out_dict[split] = all_results
144146

145147
# Only return results on rank 0
146-
if model.world_size > 1 and model.accelerator.process_index != 0:
148+
if not is_main_process:
147149
return None
148150
return out_dict
149151

@@ -192,7 +194,8 @@ def _eval_split(self, model: LM, split: str) -> List[Dict[str, Any]]:
192194
for idx in list(range(len(eval_dataset.raw_inputs))):
193195
eval_dataset.raw_inputs[idx]["response"] = all_responses[idx]
194196

195-
if model.accelerator.process_index == 0:
197+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
198+
if is_main_process:
196199
with open(response_file, "w") as f:
197200
for item in eval_dataset.raw_inputs:
198201
json_line = json.dumps(item)
@@ -243,7 +246,8 @@ def run_benchmark(self, model: LM) -> Dict[str, Any]:
243246
generation_results = self.generate_responses(model)
244247

245248
# Only evaluate on rank 0
246-
if model.world_size > 1 and model.accelerator.process_index != 0:
249+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
250+
if not is_main_process:
247251
return None
248252

249253
evaluation_results = self.evaluate_responses(generation_results)

eval/chat_benchmarks/RepoBench/eval_instruct.py

+12-7
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,8 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
5959
if self.legacy_mode:
6060
return self._generate_responses_legacy(model)
6161

62-
if model.accelerator.process_index == 0:
62+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
63+
if is_main_process:
6364
temp_dir_obj = tempfile.TemporaryDirectory()
6465
temp_dir = temp_dir_obj.name
6566

@@ -77,10 +78,13 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
7778
all_instances = []
7879
# Split dataset across ranks for parallel construction
7980
# Get subset of dataset for this rank using the same slicing strategy as the compute function
80-
chunk_size = len(dataset) // model.world_size
81-
start = model.accelerator.process_index * chunk_size
82-
end = start + chunk_size if model.accelerator.process_index < model.world_size - 1 else len(dataset)
83-
rank_dataset = dataset.select(range(start, end))
81+
if hasattr(model, 'accelerator'):
82+
chunk_size = len(dataset) // model.world_size
83+
start = model.accelerator.process_index * chunk_size
84+
end = start + chunk_size if model.accelerator.process_index < model.world_size - 1 else len(dataset)
85+
rank_dataset = dataset.select(range(start, end))
86+
else:
87+
rank_dataset = list(islice(dataset, model.rank, len(dataset), model.world_size))
8488

8589
# Process examples for this rank's shard
8690
for idx, example in enumerate(rank_dataset):
@@ -103,7 +107,8 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
103107
outputs = self.compute(model, all_instances, do_slice=False)
104108

105109
# Only rank 0 should save the results
106-
if model.accelerator.process_indexlerator.process_index != 0:
110+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
111+
if not is_main_process:
107112
continue
108113

109114
generated_examples = []
@@ -121,7 +126,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
121126
for ex in generated_examples:
122127
fw.write(json.dumps(ex) + "\n")
123128

124-
if model.accelerator.process_index == 0:
129+
if is_main_process:
125130
return {"temp_dir_obj": temp_dir_obj}
126131

127132
def _generate_responses_legacy(self, model: LM) -> Dict[str, Any]:

eval/chat_benchmarks/WildBench/eval_instruct.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,8 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
196196
outputs = self.compute(model, all_instances)
197197

198198
# Return None early for non-primary ranks
199-
if model.accelerator.process_index != 0:
199+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
200+
if not is_main_process:
200201
return None
201202

202203
outputs = [[output] for output in outputs]

eval/chat_benchmarks/alpaca_eval/eval_instruct.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,8 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
117117
self.logger.info("Generating responses for Alpaca Eval...")
118118
outputs = self.compute(model, all_instances)
119119

120-
if model.accelerator.process_index != 0:
120+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
121+
if not is_main_process:
121122
return None
122123

123124
model_outputs = []

eval/chat_benchmarks/alpaca_eval/src/alpaca_eval/leaderboards/data_AlpacaEval_2/weighted_alpaca_eval_gpt4_turbo_leaderboard.csv

+4-6
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
,win_rate,standard_error,n_wins,n_wins_base,n_draws,n_total,discrete_win_rate,mode,avg_length,length_controlled_winrate,lc_standard_error
2-
Shopee-SlimMoA-v1,75.6142865980535,1.27062740591947,621,184,0,805,77.1428571428572,community,1994,77.4515432873834,0.430175221492396
3-
blendaxai-gm-l6-vo31,69.11033492869565,1.3280735654354865,562,242,1,805,69.87577639751554,community,1809,76.91981221023656,0.5725365663132986
2+
Shopee-SlimMoA-v1,75.61428659805350,1.2706274059194700,621,184,0,805,77.14285714285720,community,1994,77.4515432873834,0.43017522149239600
3+
blendaxai-gm-l6-vo31,69.11033492869565,1.3280735654354863,562,242,1,805,69.87577639751554,community,1809,76.91981221023656,0.5725365663132986
44
gemma-2-9b-it-WPO-HB,77.82503168985093,1.2355857177790277,640,163,2,805,79.62732919254658,community,2285,76.72506842726064,0.4242603928637889
55
blendaxai-gm-l3-v35,73.41035740244067,1.254951147343878,607,196,2,805,75.527950310559,community,2186,73.37270365010379,0.6163911450738288
66
gemma-2-9b-it-SimPO,65.86422561532919,1.423459922555078,540,264,1,805,67.14285714285714,community,1833,72.3508446939842,0.5167873784867067
7-
model_hf_model_args_pretrained=mlfoundations-dev__gemma-simpo-reproduction,67.35102937013792,1.4210070002869848,557,247,1,805,69.25465838509317,community,1950,71.18995900084634,0.5756949353655318
87
openpipe-moa-gpt-4-turbo-v1,63.15493451236265,1.422980098799326,515,283,7,805,64.40993788819875,community,1856,68.37866250336802,0.7309418614587613
98
gemma-2-9b-it-DPO,65.35922380122982,1.402802336467638,536,268,1,805,66.64596273291924,community,2016,67.6620382198043,0.6605613085864308
109
Together-MoA,59.8688062333292,1.434305604543079,490,314,1,805,60.93167701863354,community,1825,65.37996976852163,0.7392392836781445
@@ -23,7 +22,7 @@ gpt4_1106_preview_verbose,64.30360147101865,1.3348590089025316,525,268,12,805,65
2322
gpt-4o-mini-2024-07-18,44.65413862507926,1.4572395578449813,350,451,4,805,43.72670807453416,minimal,1861,50.727144855901976,0.8284734951761676
2423
Storm-7B,50.26886905528583,1.4728176780737183,397,408,0,805,49.31677018633541,community,2045,50.45110959343775,
2524
gpt4_1106_preview,50.0,0.0,0,0,805,805,50.0,minimal,2049,50.0,
26-
REBEL-Llama-3-8B-Instruct-Armo,48.43655307668638,1.480341435123528,394,410,1,805,49.00621118012423,community,1965,49.31429353685712,0.7061879308002301
25+
REBEL-Llama-3-8B-Instruct-Armo,48.43655307668638,1.480341435123528,394,410,1,805,49.006211180124225,community,1965,49.314293536857114,0.7061879308002301
2726
Infinity-Instruct-7M-Gen-Llama3_1-70B,37.46327383827497,1.4734130373862548,299,501,5,805,37.453416149068325,community,1654,46.10043331712677,0.822439983375277
2827
Llama-3-Instruct-8B-SimPO-ExPO,40.63285400856655,1.4439449942168028,325,479,1,805,40.43478260869565,community,1765,45.78021783946177,
2928
Llama-3-Instruct-8B-SimPO,40.52977498461182,1.422574464675002,319,485,1,805,39.68944099378882,community,1825,44.65131348921881,0.8800655791760451
@@ -209,5 +208,4 @@ oasst-sft-pythia-12b,1.790114083180124,0.3985580883049341,13,790,2,805,1.7391304
209208
guanaco-13b,3.469596859739131,0.5518606725700214,22,780,3,805,2.919254658385093,verified,1774,3.003787329611614,
210209
guanaco-7b,2.880002266173913,0.5202924149314048,21,783,1,805,2.670807453416149,verified,1364,2.871116813131697,
211210
Qwen1.5-1.8B-Chat,3.70555681579365,0.5811750995496215,27,774,3,804,3.544776119402985,verified,2673,2.588498849185137,
212-
baichuan-13b-chat,1.9921455615279504,0.4176985079331233,14,790,1,805,1.8012422360248446,community,1727,2.062170253598568,
213-
model_hf_model_args_pretrained=mlfoundations-dev__gemma-oh-preferences,0.005260368511326853,0.0018774672393365112,0,805,0,805,0.0,community,196,0.010252829751292214,0.0007495965900756891
211+
baichuan-13b-chat,1.9921455615279504,0.4176985079331233,14,790,1,805,1.8012422360248446,community,1727,2.062170253598568,

eval/chat_benchmarks/alpaca_eval/src/alpaca_eval/metrics/weights/weighted_alpaca_eval_gpt4_turbo/length_controlled_v1/baseline_gpt4_1106_preview.csv

+1-1
Original file line numberDiff line numberDiff line change
@@ -187,4 +187,4 @@ Shopee-SlimMoA-v1,-0.6930943742294789,0.5778443790027642,1.4506276222723822
187187
blendaxai-gm-l6-vo31,-1.4827230167114802,0.8256378421072179,1.5942312525409852
188188
REBEL-Llama-3-8B-Instruct-Armo,-1.0427168605260002,0.6464073051877255,0.0395191056877229
189189
model_hf_model_args_pretrained=mlfoundations-dev__gemma-simpo-reproduction,-1.1818376919023723,0.6835318362039150,1.1479555832649320
190-
model_hf_model_args_pretrained=mlfoundations-dev__gemma-oh-preferences,-1.8345282763259563,0.7434213717748921,-9.8937244442602008
190+
model_hf_model_args_pretrained=mlfoundations-dev__gemma-oh-preferences,-1.8345282763259563,0.7434213717748921,-9.8937244442602008

eval/chat_benchmarks/zeroeval/eval_instruct.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,8 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
144144

145145
outputs = self.compute(model, all_instances)
146146

147-
if model.accelerator.process_index != 0:
147+
is_main_process = model.accelerator.process_index == 0 if hasattr(model, 'accelerator') else model.world_size <= 1
148+
if not is_main_process:
148149
continue
149150

150151
outputs = [[output] for output in outputs]

eval/eval.py

+101-1
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,11 @@
55
import sys
66
import time
77
from typing import Optional, List, Dict, Union
8+
from pathlib import Path
89

910
import concurrent.futures
1011
import torch.distributed as dist
12+
from huggingface_hub import snapshot_download
1113

1214
from lm_eval import utils
1315
from lm_eval import evaluator as pretrain_evaluator
@@ -26,6 +28,103 @@
2628
from eval.eval_tracker import DCEvaluationTracker
2729

2830

31+
class ModelInitializer:
32+
"""Handles model initialization for distributed evaluations."""
33+
34+
def __init__(self, cache_dir: Optional[str] = None):
35+
self.cache_dir = cache_dir or os.getenv('HF_HOME',
36+
os.path.join(os.path.expanduser("~"), ".cache", "huggingface"))
37+
self._ensure_directory(self.cache_dir)
38+
39+
def _ensure_directory(self, path: str) -> None:
40+
"""Safely create directory if it doesn't exist."""
41+
Path(path).mkdir(parents=True, exist_ok=True)
42+
43+
def download_model(self, model_id: str) -> None:
44+
"""Download model files with proper error handling."""
45+
try:
46+
snapshot_download(
47+
repo_id=model_id,
48+
cache_dir=self.cache_dir,
49+
local_files_only=False,
50+
resume_download=True
51+
)
52+
except Exception as e:
53+
raise RuntimeError(f"Failed to download model {model_id}: {str(e)}")
54+
55+
56+
def initialize_model_for_eval(
57+
model: Union[str, LM],
58+
model_args: Optional[str] = None,
59+
batch_size: int = None,
60+
max_batch_size: Optional[int] = None,
61+
device: Optional[str] = None,
62+
cache_dir: Optional[str] = None
63+
) -> LM:
64+
"""
65+
Initialize model for distributed evaluation where each node runs independent evaluations.
66+
67+
Args:
68+
model (Union[str, LM]):
69+
Either a string identifier for the model to load from registry,
70+
or an already instantiated LM object.
71+
model_args (Optional[str], optional):
72+
Additional arguments for model initialization as a string.
73+
Only used if model is provided as a string. Defaults to None.
74+
batch_size (Optional[int], optional):
75+
Batch size for model inference. Defaults to None.
76+
max_batch_size (Optional[int], optional):
77+
Maximum allowed batch size. Defaults to None.
78+
device (Optional[str], optional):
79+
Device to load the model on (e.g., 'cuda', 'cpu'). Defaults to None.
80+
81+
Returns:
82+
LM:
83+
Initialized language model instance with configured parameters
84+
and a sanitized model identifier.
85+
"""
86+
local_rank = int(os.getenv('LOCAL_RANK', '0'))
87+
88+
if isinstance(model, str):
89+
initializer = ModelInitializer(cache_dir)
90+
91+
try:
92+
initializer.download_model(model)
93+
except Exception as e:
94+
print(f"Rank {local_rank} failed to initialize model: {str(e)}")
95+
if dist.is_initialized():
96+
dist.barrier() # Ensure all ranks fail together
97+
raise e
98+
99+
if dist.is_initialized():
100+
dist.barrier()
101+
102+
if model_args is None:
103+
model_args = ""
104+
105+
config = {
106+
"batch_size": batch_size,
107+
"max_batch_size": max_batch_size,
108+
"device": device,
109+
}
110+
111+
try:
112+
lm = lm_eval.api.registry.get_model(model).create_from_arg_string(
113+
model_args,
114+
config,
115+
)
116+
except Exception as e:
117+
print(f"Rank {local_rank} failed to create model: {str(e)}")
118+
if dist.is_initialized():
119+
dist.barrier()
120+
raise e
121+
else:
122+
lm = model
123+
124+
lm.model_identifier = sanitize_model_name(f"model_{model}_model_args_{model_args}")
125+
return lm
126+
127+
29128
def setup_custom_parser():
30129
"""
31130
Create a custom argument parser that extends lm-eval-harness parser.
@@ -302,7 +401,8 @@ def cli_evaluate(args: Optional[argparse.Namespace] = None) -> None:
302401
)
303402

304403
# Add metadata to results
305-
if lm.accelerator.process_index == 0:
404+
is_main_process = lm.accelerator.process_index == 0 if hasattr(lm, 'accelerator') else lm.world_size <= 1
405+
if is_main_process:
306406
add_results_metadata(results, args, lm)
307407
handle_evaluation_output(results, args, evaluation_tracker, wandb_logger)
308408

0 commit comments

Comments
 (0)