Skip to content

Multi-node #29

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion eval/chat_benchmarks/HumanEval/eval_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
self.logger.info("Generating responses for Human Eval...")
outputs = self.compute(model, all_instances)

if model.rank != 0:
if model.accelerator.process_index != 0:
continue

generated_examples = []
Expand Down
2 changes: 1 addition & 1 deletion eval/chat_benchmarks/IFEval/eval_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
self.logger.info("Generating responses...")
outputs = self.compute(model, all_instances)

if model.rank != 0:
if model.accelerator.process_index != 0:
return None

generated_examples = []
Expand Down
2 changes: 1 addition & 1 deletion eval/chat_benchmarks/MBPP/eval_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
outputs = self.compute(model, all_instances)

# Return None early for non-primary ranks
if model.rank != 0:
if model.accelerator.process_index != 0:
return None

generated_examples = []
Expand Down
2 changes: 1 addition & 1 deletion eval/chat_benchmarks/MTBench/eval_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def get_model_answers(self, model: LM, model_id: str, questions: List[Dict[str,
all_convs[q_idx].append({"role": "assistant", "content": output})
all_choices[q_idx]["turns"].append(output)

if model.rank != 0:
if model.accelerator.process_index != 0:
continue

# Save completed conversations
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def load_xft_model(model_path, xft_config: XftConfig):
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=False, padding_side="left", trust_remote_code=True)
xft_model = xfastertransformer.AutoModel.from_pretrained(model_path, dtype=data_type)
model = XftModel(xft_model=xft_model, xft_config=xft_config)
if model.model.rank > 0:
if model.model.accelerator.process_index > 0:
while True:
model.model.generate()
return model, tokenizer
8 changes: 4 additions & 4 deletions eval/chat_benchmarks/MixEval/eval_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,15 +135,15 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
for split in splits:
self.args.split = split
all_results = self._eval_split(model, split)
if model.rank == 0:
if model.accelerator.process_index == 0:
response_file = self._get_response_file()
with open(response_file, "w") as f:
for result in all_results:
f.write(json.dumps(result) + "\n")
out_dict[split] = all_results

# Only return results on rank 0
if model.world_size > 1 and model.rank != 0:
if model.world_size > 1 and model.accelerator.process_index != 0:
return None
return out_dict

Expand Down Expand Up @@ -192,7 +192,7 @@ def _eval_split(self, model: LM, split: str) -> List[Dict[str, Any]]:
for idx in list(range(len(eval_dataset.raw_inputs))):
eval_dataset.raw_inputs[idx]["response"] = all_responses[idx]

if model.rank == 0:
if model.accelerator.process_index == 0:
with open(response_file, "w") as f:
for item in eval_dataset.raw_inputs:
json_line = json.dumps(item)
Expand Down Expand Up @@ -243,7 +243,7 @@ def run_benchmark(self, model: LM) -> Dict[str, Any]:
generation_results = self.generate_responses(model)

# Only evaluate on rank 0
if model.world_size > 1 and model.rank != 0:
if model.world_size > 1 and model.accelerator.process_index != 0:
return None

evaluation_results = self.evaluate_responses(generation_results)
Expand Down
15 changes: 9 additions & 6 deletions eval/chat_benchmarks/RepoBench/eval_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
if self.legacy_mode:
return self._generate_responses_legacy(model)

if model.rank == 0:
if model.accelerator.process_index == 0:
temp_dir_obj = tempfile.TemporaryDirectory()
temp_dir = temp_dir_obj.name

Expand All @@ -76,8 +76,11 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:

all_instances = []
# Split dataset across ranks for parallel construction
# Get subset of dataset for this rank using built-in slice functionality
rank_dataset = list(islice(dataset, model.rank, len(dataset), model.world_size))
# Get subset of dataset for this rank using the same slicing strategy as the compute function
chunk_size = len(dataset) // model.world_size
start = model.accelerator.process_index * chunk_size
end = start + chunk_size if model.accelerator.process_index < model.world_size - 1 else len(dataset)
rank_dataset = dataset.select(range(start, end))

# Process examples for this rank's shard
for idx, example in enumerate(rank_dataset):
Expand All @@ -100,7 +103,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
outputs = self.compute(model, all_instances, do_slice=False)

# Only rank 0 should save the results
if model.rank != 0:
if model.accelerator.process_indexlerator.process_index != 0:
continue

generated_examples = []
Expand All @@ -118,7 +121,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
for ex in generated_examples:
fw.write(json.dumps(ex) + "\n")

if model.rank == 0:
if model.accelerator.process_index == 0:
return {"temp_dir_obj": temp_dir_obj}

def _generate_responses_legacy(self, model: LM) -> Dict[str, Any]:
Expand Down Expand Up @@ -156,7 +159,7 @@ def _generate_responses_legacy(self, model: LM) -> Dict[str, Any]:

outputs = self.compute(model, all_instances, do_slice=False)

if model.rank != 0:
if model.accelerator.process_index != 0:
continue

generated_examples = []
Expand Down
2 changes: 1 addition & 1 deletion eval/chat_benchmarks/WildBench/eval_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
outputs = self.compute(model, all_instances)

# Return None early for non-primary ranks
if model.rank != 0:
if model.accelerator.process_index != 0:
return None

outputs = [[output] for output in outputs]
Expand Down
2 changes: 1 addition & 1 deletion eval/chat_benchmarks/alpaca_eval/eval_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
self.logger.info("Generating responses for Alpaca Eval...")
outputs = self.compute(model, all_instances)

if model.rank != 0:
if model.accelerator.process_index != 0:
return None

model_outputs = []
Expand Down
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess that file should be removed.

Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
,win_rate,standard_error,n_wins,n_wins_base,n_draws,n_total,discrete_win_rate,mode,avg_length,length_controlled_winrate,lc_standard_error
Shopee-SlimMoA-v1,75.61428659805350,1.2706274059194700,621,184,0,805,77.14285714285720,community,1994,77.4515432873834,0.43017522149239600
blendaxai-gm-l6-vo31,69.11033492869565,1.3280735654354863,562,242,1,805,69.87577639751554,community,1809,76.91981221023656,0.5725365663132986
Shopee-SlimMoA-v1,75.6142865980535,1.27062740591947,621,184,0,805,77.1428571428572,community,1994,77.4515432873834,0.430175221492396
blendaxai-gm-l6-vo31,69.11033492869565,1.3280735654354865,562,242,1,805,69.87577639751554,community,1809,76.91981221023656,0.5725365663132986
gemma-2-9b-it-WPO-HB,77.82503168985093,1.2355857177790277,640,163,2,805,79.62732919254658,community,2285,76.72506842726064,0.4242603928637889
blendaxai-gm-l3-v35,73.41035740244067,1.254951147343878,607,196,2,805,75.527950310559,community,2186,73.37270365010379,0.6163911450738288
gemma-2-9b-it-SimPO,65.86422561532919,1.423459922555078,540,264,1,805,67.14285714285714,community,1833,72.3508446939842,0.5167873784867067
model_hf_model_args_pretrained=mlfoundations-dev__gemma-simpo-reproduction,67.35102937013792,1.4210070002869848,557,247,1,805,69.25465838509317,community,1950,71.18995900084634,0.5756949353655318
openpipe-moa-gpt-4-turbo-v1,63.15493451236265,1.422980098799326,515,283,7,805,64.40993788819875,community,1856,68.37866250336802,0.7309418614587613
gemma-2-9b-it-DPO,65.35922380122982,1.402802336467638,536,268,1,805,66.64596273291924,community,2016,67.6620382198043,0.6605613085864308
Together-MoA,59.8688062333292,1.434305604543079,490,314,1,805,60.93167701863354,community,1825,65.37996976852163,0.7392392836781445
Expand All @@ -22,7 +23,7 @@ gpt4_1106_preview_verbose,64.30360147101865,1.3348590089025316,525,268,12,805,65
gpt-4o-mini-2024-07-18,44.65413862507926,1.4572395578449813,350,451,4,805,43.72670807453416,minimal,1861,50.727144855901976,0.8284734951761676
Storm-7B,50.26886905528583,1.4728176780737183,397,408,0,805,49.31677018633541,community,2045,50.45110959343775,
gpt4_1106_preview,50.0,0.0,0,0,805,805,50.0,minimal,2049,50.0,
REBEL-Llama-3-8B-Instruct-Armo,48.43655307668638,1.480341435123528,394,410,1,805,49.006211180124225,community,1965,49.314293536857114,0.7061879308002301
REBEL-Llama-3-8B-Instruct-Armo,48.43655307668638,1.480341435123528,394,410,1,805,49.00621118012423,community,1965,49.31429353685712,0.7061879308002301
Infinity-Instruct-7M-Gen-Llama3_1-70B,37.46327383827497,1.4734130373862548,299,501,5,805,37.453416149068325,community,1654,46.10043331712677,0.822439983375277
Llama-3-Instruct-8B-SimPO-ExPO,40.63285400856655,1.4439449942168028,325,479,1,805,40.43478260869565,community,1765,45.78021783946177,
Llama-3-Instruct-8B-SimPO,40.52977498461182,1.422574464675002,319,485,1,805,39.68944099378882,community,1825,44.65131348921881,0.8800655791760451
Expand Down Expand Up @@ -209,3 +210,4 @@ guanaco-13b,3.469596859739131,0.5518606725700214,22,780,3,805,2.919254658385093,
guanaco-7b,2.880002266173913,0.5202924149314048,21,783,1,805,2.670807453416149,verified,1364,2.871116813131697,
Qwen1.5-1.8B-Chat,3.70555681579365,0.5811750995496215,27,774,3,804,3.544776119402985,verified,2673,2.588498849185137,
baichuan-13b-chat,1.9921455615279504,0.4176985079331233,14,790,1,805,1.8012422360248446,community,1727,2.062170253598568,
model_hf_model_args_pretrained=mlfoundations-dev__gemma-oh-preferences,0.005260368511326853,0.0018774672393365112,0,805,0,805,0.0,community,196,0.010252829751292214,0.0007495965900756891
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one too

Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,5 @@ Mistral-7B-Instruct-v0.3,-1.5007159011881868,0.9845683091847074,-1.7652759895328
Shopee-SlimMoA-v1,-0.6930943742294789,0.5778443790027642,1.4506276222723822
blendaxai-gm-l6-vo31,-1.4827230167114802,0.8256378421072179,1.5942312525409852
REBEL-Llama-3-8B-Instruct-Armo,-1.0427168605260002,0.6464073051877255,0.0395191056877229
model_hf_model_args_pretrained=mlfoundations-dev__gemma-simpo-reproduction,-1.1818376919023723,0.6835318362039150,1.1479555832649320
model_hf_model_args_pretrained=mlfoundations-dev__gemma-oh-preferences,-1.8345282763259563,0.7434213717748921,-9.8937244442602008
2 changes: 1 addition & 1 deletion eval/chat_benchmarks/zeroeval/eval_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:

outputs = self.compute(model, all_instances)

if model.rank != 0:
if model.accelerator.process_index != 0:
continue

outputs = [[output] for output in outputs]
Expand Down
4 changes: 2 additions & 2 deletions eval/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def evaluate(
cpu_count = os.cpu_count()

max_workers = min(len(valid_tasks), cpu_count * 2)
if lm.world_size <= 1 or lm.rank == 0:
if lm.world_size <= 1 or lm.accelerator.process_index == 0:
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
evaluate_results = list(
executor.map(
Expand Down Expand Up @@ -302,7 +302,7 @@ def cli_evaluate(args: Optional[argparse.Namespace] = None) -> None:
)

# Add metadata to results
if lm.rank == 0:
if lm.accelerator.process_index == 0:
add_results_metadata(results, args, lm)
handle_evaluation_output(results, args, evaluation_tracker, wandb_logger)

Expand Down
16 changes: 8 additions & 8 deletions eval/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,23 +20,23 @@ def __init__(self, logger: Optional[logging.Logger] = None):

def compute(self, model: LM, inputs: List[Instance], do_slice: bool = True) -> List[str]:
if model.world_size > 1 and do_slice:
prompts = list(islice(inputs, model.rank, len(inputs), model.world_size))
chunk_size = len(inputs) // model.world_size
start = model.accelerator.process_index * chunk_size
end = start + chunk_size if model.accelerator.process_index < model.world_size - 1 else len(inputs)
prompts = inputs[start:end]
else:
prompts = inputs

results = model.generate_until(prompts)
if model.world_size > 1:
all_results = [None for _ in range(model.world_size)]

dist.all_gather_object(all_results, results)

# Merge results from all ranks
length = sum(len(res) for res in all_results if res is not None)
merged = [None] * length
for rank, sub_results in enumerate(all_results):
# Simply concatenate results in rank order
merged = []
for sub_results in all_results:
if sub_results is not None:
for i, item in enumerate(sub_results):
merged[i * model.world_size + rank] = item
merged.extend(sub_results)
return merged
else:
return results
Expand Down