Skip to content

Commit f90f731

Browse files
authored
backward compatibility hot fix (#622)
1 parent f37dda5 commit f90f731

File tree

2 files changed

+8
-5
lines changed

2 files changed

+8
-5
lines changed

Diff for: open_instruct/dataset_transformation.py

+8-3
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,10 @@ def get_tokenizer_tulu_v2_2(tc: "TokenizerConfig"):
457457
"get_tokenizer_tulu_v2_2": get_tokenizer_tulu_v2_2,
458458
}
459459

460+
DEFAULT_SFT_MESSAGES_KEY = "messages"
461+
GROUND_TRUTHS_KEY = "ground_truth"
462+
DATASET_SOURCE_KEY = "dataset"
463+
460464

461465
@dataclass
462466
class TokenizerConfig:
@@ -474,6 +478,10 @@ class TokenizerConfig:
474478
# backward compatibility to make sure script runs
475479
use_slow_tokenizer: bool = False # completely ignored
476480
tokenizer_name: Optional[str] = None
481+
ground_truths_key: str = GROUND_TRUTHS_KEY
482+
"""columns name for the ground truth"""
483+
sft_messages_key: str = DEFAULT_SFT_MESSAGES_KEY
484+
"""columns name for the sft messages"""
477485

478486
@cached_property
479487
def tokenizer(self):
@@ -499,7 +507,6 @@ def tokenizer(self):
499507
# ----------------------------------------------------------------------------
500508
# Dataset Transformation
501509
# SFT dataset
502-
DEFAULT_SFT_MESSAGES_KEY = "messages"
503510
INPUT_IDS_KEY = "input_ids"
504511
ATTENTION_MASK_KEY = "attention_mask"
505512
LABELS_KEY = "labels"
@@ -526,8 +533,6 @@ def tokenizer(self):
526533

527534
INPUT_IDS_PROMPT_KEY = "input_ids_prompt"
528535
ATTENTION_MASK_PROMPT_KEY = "attention_mask_prompt"
529-
GROUND_TRUTHS_KEY = "ground_truth"
530-
DATASET_SOURCE_KEY = "dataset"
531536

532537
TOKENIZED_PREFERENCE_DATASET_KEYS = [
533538
CHOSEN_INPUT_IDS_KEY,

Diff for: open_instruct/grpo_fast.py

-2
Original file line numberDiff line numberDiff line change
@@ -134,8 +134,6 @@ class Args:
134134
"""The maximum token length to use for the dataset"""
135135
max_prompt_token_length: int = 256
136136
"""The maximum prompt token length to use for the dataset"""
137-
ground_truths_key: str = GROUND_TRUTHS_KEY
138-
"""columns name for the ground truth"""
139137

140138
# Experiment
141139
exp_name: str = os.path.basename(__file__)[: -len(".py")]

0 commit comments

Comments
 (0)