@@ -457,6 +457,10 @@ def get_tokenizer_tulu_v2_2(tc: "TokenizerConfig"):
457
457
"get_tokenizer_tulu_v2_2" : get_tokenizer_tulu_v2_2 ,
458
458
}
459
459
460
+ DEFAULT_SFT_MESSAGES_KEY = "messages"
461
+ GROUND_TRUTHS_KEY = "ground_truth"
462
+ DATASET_SOURCE_KEY = "dataset"
463
+
460
464
461
465
@dataclass
462
466
class TokenizerConfig :
@@ -474,6 +478,10 @@ class TokenizerConfig:
474
478
# backward compatibility to make sure script runs
475
479
use_slow_tokenizer : bool = False # completely ignored
476
480
tokenizer_name : Optional [str ] = None
481
+ ground_truths_key : str = GROUND_TRUTHS_KEY
482
+ """columns name for the ground truth"""
483
+ sft_messages_key : str = DEFAULT_SFT_MESSAGES_KEY
484
+ """columns name for the sft messages"""
477
485
478
486
@cached_property
479
487
def tokenizer (self ):
@@ -499,7 +507,6 @@ def tokenizer(self):
499
507
# ----------------------------------------------------------------------------
500
508
# Dataset Transformation
501
509
# SFT dataset
502
- DEFAULT_SFT_MESSAGES_KEY = "messages"
503
510
INPUT_IDS_KEY = "input_ids"
504
511
ATTENTION_MASK_KEY = "attention_mask"
505
512
LABELS_KEY = "labels"
@@ -526,8 +533,6 @@ def tokenizer(self):
526
533
527
534
INPUT_IDS_PROMPT_KEY = "input_ids_prompt"
528
535
ATTENTION_MASK_PROMPT_KEY = "attention_mask_prompt"
529
- GROUND_TRUTHS_KEY = "ground_truth"
530
- DATASET_SOURCE_KEY = "dataset"
531
536
532
537
TOKENIZED_PREFERENCE_DATASET_KEYS = [
533
538
CHOSEN_INPUT_IDS_KEY ,
0 commit comments