Skip to content

Commit 2d96e47

Browse files
jmercatneginraoof
authored andcommitted
change slicing in compute, change rank to accelerator.process_index
1 parent 639af29 commit 2d96e47

File tree

14 files changed

+38
-31
lines changed

14 files changed

+38
-31
lines changed

eval/chat_benchmarks/HumanEval/eval_instruct.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
112112
self.logger.info("Generating responses for Human Eval...")
113113
outputs = self.compute(model, all_instances)
114114

115-
if model.rank != 0:
115+
if model.accelerator.process_index != 0:
116116
continue
117117

118118
generated_examples = []

eval/chat_benchmarks/IFEval/eval_instruct.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -115,7 +115,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
115115
self.logger.info("Generating responses...")
116116
outputs = self.compute(model, all_instances)
117117

118-
if model.rank != 0:
118+
if model.accelerator.process_index != 0:
119119
return None
120120

121121
generated_examples = []

eval/chat_benchmarks/MBPP/eval_instruct.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
161161
outputs = self.compute(model, all_instances)
162162

163163
# Return None early for non-primary ranks
164-
if model.rank != 0:
164+
if model.accelerator.process_index != 0:
165165
return None
166166

167167
generated_examples = []

eval/chat_benchmarks/MTBench/eval_instruct.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ def get_model_answers(self, model: LM, model_id: str, questions: List[Dict[str,
151151
all_convs[q_idx].append({"role": "assistant", "content": output})
152152
all_choices[q_idx]["turns"].append(output)
153153

154-
if model.rank != 0:
154+
if model.accelerator.process_index != 0:
155155
continue
156156

157157
# Save completed conversations

eval/chat_benchmarks/MTBench/fastchat/modules/xfastertransformer.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def load_xft_model(model_path, xft_config: XftConfig):
3636
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side="left", trust_remote_code=True)
3737
xft_model = xfastertransformer.AutoModel.from_pretrained(model_path, dtype=data_type)
3838
model = XftModel(xft_model=xft_model, xft_config=xft_config)
39-
if model.model.rank > 0:
39+
if model.model.accelerator.process_index > 0:
4040
while True:
4141
model.model.generate()
4242
return model, tokenizer

eval/chat_benchmarks/MixEval/eval_instruct.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -135,15 +135,15 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
135135
for split in splits:
136136
self.args.split = split
137137
all_results = self._eval_split(model, split)
138-
if model.rank == 0:
138+
if model.accelerator.process_index == 0:
139139
response_file = self._get_response_file()
140140
with open(response_file, "w") as f:
141141
for result in all_results:
142142
f.write(json.dumps(result) + "\n")
143143
out_dict[split] = all_results
144144

145145
# Only return results on rank 0
146-
if model.world_size > 1 and model.rank != 0:
146+
if model.world_size > 1 and model.accelerator.process_index != 0:
147147
return None
148148
return out_dict
149149

@@ -192,7 +192,7 @@ def _eval_split(self, model: LM, split: str) -> List[Dict[str, Any]]:
192192
for idx in list(range(len(eval_dataset.raw_inputs))):
193193
eval_dataset.raw_inputs[idx]["response"] = all_responses[idx]
194194

195-
if model.rank == 0:
195+
if model.accelerator.process_index == 0:
196196
with open(response_file, "w") as f:
197197
for item in eval_dataset.raw_inputs:
198198
json_line = json.dumps(item)
@@ -243,7 +243,7 @@ def run_benchmark(self, model: LM) -> Dict[str, Any]:
243243
generation_results = self.generate_responses(model)
244244

245245
# Only evaluate on rank 0
246-
if model.world_size > 1 and model.rank != 0:
246+
if model.world_size > 1 and model.accelerator.process_index != 0:
247247
return None
248248

249249
evaluation_results = self.evaluate_responses(generation_results)

eval/chat_benchmarks/RepoBench/eval_instruct.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
5959
if self.legacy_mode:
6060
return self._generate_responses_legacy(model)
6161

62-
if model.rank == 0:
62+
if model.accelerator.process_index == 0:
6363
temp_dir_obj = tempfile.TemporaryDirectory()
6464
temp_dir = temp_dir_obj.name
6565

@@ -76,8 +76,11 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
7676

7777
all_instances = []
7878
# Split dataset across ranks for parallel construction
79-
# Get subset of dataset for this rank using built-in slice functionality
80-
rank_dataset = list(islice(dataset, model.rank, len(dataset), model.world_size))
79+
# Get subset of dataset for this rank using the same slicing strategy as the compute function
80+
chunk_size = len(dataset) // model.world_size
81+
start = model.accelerator.process_index * chunk_size
82+
end = start + chunk_size if model.accelerator.process_index < model.world_size - 1 else len(dataset)
83+
rank_dataset = dataset.select(range(start, end))
8184

8285
# Process examples for this rank's shard
8386
for idx, example in enumerate(rank_dataset):
@@ -100,7 +103,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
100103
outputs = self.compute(model, all_instances, do_slice=False)
101104

102105
# Only rank 0 should save the results
103-
if model.rank != 0:
106+
if model.accelerator.process_indexlerator.process_index != 0:
104107
continue
105108

106109
generated_examples = []
@@ -118,7 +121,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
118121
for ex in generated_examples:
119122
fw.write(json.dumps(ex) + "\n")
120123

121-
if model.rank == 0:
124+
if model.accelerator.process_index == 0:
122125
return {"temp_dir_obj": temp_dir_obj}
123126

124127
def _generate_responses_legacy(self, model: LM) -> Dict[str, Any]:
@@ -156,7 +159,7 @@ def _generate_responses_legacy(self, model: LM) -> Dict[str, Any]:
156159

157160
outputs = self.compute(model, all_instances, do_slice=False)
158161

159-
if model.rank != 0:
162+
if model.accelerator.process_index != 0:
160163
continue
161164

162165
generated_examples = []

eval/chat_benchmarks/WildBench/eval_instruct.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
196196
outputs = self.compute(model, all_instances)
197197

198198
# Return None early for non-primary ranks
199-
if model.rank != 0:
199+
if model.accelerator.process_index != 0:
200200
return None
201201

202202
outputs = [[output] for output in outputs]

eval/chat_benchmarks/alpaca_eval/eval_instruct.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
117117
self.logger.info("Generating responses for Alpaca Eval...")
118118
outputs = self.compute(model, all_instances)
119119

120-
if model.rank != 0:
120+
if model.accelerator.process_index != 0:
121121
return None
122122

123123
model_outputs = []

eval/chat_benchmarks/alpaca_eval/src/alpaca_eval/leaderboards/data_AlpacaEval_2/weighted_alpaca_eval_gpt4_turbo_leaderboard.csv

+5-3
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
,win_rate,standard_error,n_wins,n_wins_base,n_draws,n_total,discrete_win_rate,mode,avg_length,length_controlled_winrate,lc_standard_error
2-
Shopee-SlimMoA-v1,75.61428659805350,1.2706274059194700,621,184,0,805,77.14285714285720,community,1994,77.4515432873834,0.43017522149239600
3-
blendaxai-gm-l6-vo31,69.11033492869565,1.3280735654354863,562,242,1,805,69.87577639751554,community,1809,76.91981221023656,0.5725365663132986
2+
Shopee-SlimMoA-v1,75.6142865980535,1.27062740591947,621,184,0,805,77.1428571428572,community,1994,77.4515432873834,0.430175221492396
3+
blendaxai-gm-l6-vo31,69.11033492869565,1.3280735654354865,562,242,1,805,69.87577639751554,community,1809,76.91981221023656,0.5725365663132986
44
gemma-2-9b-it-WPO-HB,77.82503168985093,1.2355857177790277,640,163,2,805,79.62732919254658,community,2285,76.72506842726064,0.4242603928637889
55
blendaxai-gm-l3-v35,73.41035740244067,1.254951147343878,607,196,2,805,75.527950310559,community,2186,73.37270365010379,0.6163911450738288
66
gemma-2-9b-it-SimPO,65.86422561532919,1.423459922555078,540,264,1,805,67.14285714285714,community,1833,72.3508446939842,0.5167873784867067
7+
model_hf_model_args_pretrained=mlfoundations-dev__gemma-simpo-reproduction,67.35102937013792,1.4210070002869848,557,247,1,805,69.25465838509317,community,1950,71.18995900084634,0.5756949353655318
78
openpipe-moa-gpt-4-turbo-v1,63.15493451236265,1.422980098799326,515,283,7,805,64.40993788819875,community,1856,68.37866250336802,0.7309418614587613
89
gemma-2-9b-it-DPO,65.35922380122982,1.402802336467638,536,268,1,805,66.64596273291924,community,2016,67.6620382198043,0.6605613085864308
910
Together-MoA,59.8688062333292,1.434305604543079,490,314,1,805,60.93167701863354,community,1825,65.37996976852163,0.7392392836781445
@@ -22,7 +23,7 @@ gpt4_1106_preview_verbose,64.30360147101865,1.3348590089025316,525,268,12,805,65
2223
gpt-4o-mini-2024-07-18,44.65413862507926,1.4572395578449813,350,451,4,805,43.72670807453416,minimal,1861,50.727144855901976,0.8284734951761676
2324
Storm-7B,50.26886905528583,1.4728176780737183,397,408,0,805,49.31677018633541,community,2045,50.45110959343775,
2425
gpt4_1106_preview,50.0,0.0,0,0,805,805,50.0,minimal,2049,50.0,
25-
REBEL-Llama-3-8B-Instruct-Armo,48.43655307668638,1.480341435123528,394,410,1,805,49.006211180124225,community,1965,49.314293536857114,0.7061879308002301
26+
REBEL-Llama-3-8B-Instruct-Armo,48.43655307668638,1.480341435123528,394,410,1,805,49.00621118012423,community,1965,49.31429353685712,0.7061879308002301
2627
Infinity-Instruct-7M-Gen-Llama3_1-70B,37.46327383827497,1.4734130373862548,299,501,5,805,37.453416149068325,community,1654,46.10043331712677,0.822439983375277
2728
Llama-3-Instruct-8B-SimPO-ExPO,40.63285400856655,1.4439449942168028,325,479,1,805,40.43478260869565,community,1765,45.78021783946177,
2829
Llama-3-Instruct-8B-SimPO,40.52977498461182,1.422574464675002,319,485,1,805,39.68944099378882,community,1825,44.65131348921881,0.8800655791760451
@@ -209,3 +210,4 @@ guanaco-13b,3.469596859739131,0.5518606725700214,22,780,3,805,2.919254658385093,
209210
guanaco-7b,2.880002266173913,0.5202924149314048,21,783,1,805,2.670807453416149,verified,1364,2.871116813131697,
210211
Qwen1.5-1.8B-Chat,3.70555681579365,0.5811750995496215,27,774,3,804,3.544776119402985,verified,2673,2.588498849185137,
211212
baichuan-13b-chat,1.9921455615279504,0.4176985079331233,14,790,1,805,1.8012422360248446,community,1727,2.062170253598568,
213+
model_hf_model_args_pretrained=mlfoundations-dev__gemma-oh-preferences,0.005260368511326853,0.0018774672393365112,0,805,0,805,0.0,community,196,0.010252829751292214,0.0007495965900756891

eval/chat_benchmarks/alpaca_eval/src/alpaca_eval/metrics/weights/weighted_alpaca_eval_gpt4_turbo/length_controlled_v1/baseline_gpt4_1106_preview.csv

+2
Original file line numberDiff line numberDiff line change
@@ -186,3 +186,5 @@ Mistral-7B-Instruct-v0.3,-1.5007159011881868,0.9845683091847074,-1.7652759895328
186186
Shopee-SlimMoA-v1,-0.6930943742294789,0.5778443790027642,1.4506276222723822
187187
blendaxai-gm-l6-vo31,-1.4827230167114802,0.8256378421072179,1.5942312525409852
188188
REBEL-Llama-3-8B-Instruct-Armo,-1.0427168605260002,0.6464073051877255,0.0395191056877229
189+
model_hf_model_args_pretrained=mlfoundations-dev__gemma-simpo-reproduction,-1.1818376919023723,0.6835318362039150,1.1479555832649320
190+
model_hf_model_args_pretrained=mlfoundations-dev__gemma-oh-preferences,-1.8345282763259563,0.7434213717748921,-9.8937244442602008

eval/chat_benchmarks/zeroeval/eval_instruct.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
144144

145145
outputs = self.compute(model, all_instances)
146146

147-
if model.rank != 0:
147+
if model.accelerator.process_index != 0:
148148
continue
149149

150150
outputs = [[output] for output in outputs]

eval/eval.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def evaluate(
148148
cpu_count = os.cpu_count()
149149

150150
max_workers = min(len(valid_tasks), cpu_count * 2)
151-
if lm.world_size <= 1 or lm.rank == 0:
151+
if lm.world_size <= 1 or lm.accelerator.process_index == 0:
152152
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
153153
evaluate_results = list(
154154
executor.map(
@@ -302,7 +302,7 @@ def cli_evaluate(args: Optional[argparse.Namespace] = None) -> None:
302302
)
303303

304304
# Add metadata to results
305-
if lm.rank == 0:
305+
if lm.accelerator.process_index == 0:
306306
add_results_metadata(results, args, lm)
307307
handle_evaluation_output(results, args, evaluation_tracker, wandb_logger)
308308

eval/task.py

+8-8
Original file line numberDiff line numberDiff line change
@@ -20,23 +20,23 @@ def __init__(self, logger: Optional[logging.Logger] = None):
2020

2121
def compute(self, model: LM, inputs: List[Instance], do_slice: bool = True) -> List[str]:
2222
if model.world_size > 1 and do_slice:
23-
prompts = list(islice(inputs, model.rank, len(inputs), model.world_size))
23+
chunk_size = len(inputs) // model.world_size
24+
start = model.accelerator.process_index * chunk_size
25+
end = start + chunk_size if model.accelerator.process_index < model.world_size - 1 else len(inputs)
26+
prompts = inputs[start:end]
2427
else:
2528
prompts = inputs
2629

2730
results = model.generate_until(prompts)
2831
if model.world_size > 1:
2932
all_results = [None for _ in range(model.world_size)]
30-
3133
dist.all_gather_object(all_results, results)
3234

33-
# Merge results from all ranks
34-
length = sum(len(res) for res in all_results if res is not None)
35-
merged = [None] * length
36-
for rank, sub_results in enumerate(all_results):
35+
# Simply concatenate results in rank order
36+
merged = []
37+
for sub_results in all_results:
3738
if sub_results is not None:
38-
for i, item in enumerate(sub_results):
39-
merged[i * model.world_size + rank] = item
39+
merged.extend(sub_results)
4040
return merged
4141
else:
4242
return results

0 commit comments

Comments
 (0)