Skip to content

Commit 1e43884

Browse files
hjh0119Jintao-Huang
authored andcommitted
remove prompt id for megatron grpo (#9094)
1 parent 0bb1be0 commit 1e43884

1 file changed

Lines changed: 3 additions & 15 deletions

File tree

swift/megatron/trainers/grpo_trainer.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1528,7 +1528,7 @@ def _process_image_data(image_data: Union[dict, str]) -> str:
15281528

15291529
def _preprocess_inputs(self, inputs: DataType) -> DataType:
15301530
"""Preprocess inputs before inference"""
1531-
processed_inputs = self._add_prompt_id_to_inputs(inputs)
1531+
processed_inputs = self._add_request_id_to_inputs(inputs)
15321532
for input_item in processed_inputs:
15331533
remove_response(input_item['messages'])
15341534
return processed_inputs
@@ -1594,24 +1594,12 @@ def resample_encode_failed_inputs(self, inputs: DataType, max_resample_rounds: i
15941594

15951595
return valid_samples[:required_count]
15961596

1597-
def _add_prompt_id_to_inputs(self, inputs: DataType) -> DataType:
1598-
"""Add unique prompt_id and request_id to each input"""
1597+
def _add_request_id_to_inputs(self, inputs: DataType) -> DataType:
1598+
"""Add unique request_id to each input"""
15991599
if not inputs:
16001600
return inputs
16011601

1602-
all_messages = gather_object([inp['messages'] for inp in inputs])
1603-
messages_to_prompt_id = {}
1604-
prompt_id_counter = 0
1605-
1606-
for messages in all_messages:
1607-
key = json.dumps(messages)
1608-
if key not in messages_to_prompt_id:
1609-
messages_to_prompt_id[key] = f'prompt_{prompt_id_counter}'
1610-
prompt_id_counter += 1
1611-
16121602
for input_item in inputs:
1613-
messages = input_item.get('messages')
1614-
input_item['prompt_id'] = messages_to_prompt_id[json.dumps(messages)]
16151603
input_item['request_id'] = f'chatcmpl-{str(uuid.uuid4().hex)}'
16161604

16171605
return inputs

0 commit comments

Comments
 (0)