1616 field_validator ,
1717 model_validator ,
1818)
19- from transformers import AutoConfig , AutoTokenizer , PretrainedConfig
19+ from transformers import AutoConfig , AutoTokenizer , PretrainedConfig , PreTrainedTokenizerBase
2020
2121from ..cli .artifact_structure import Workdir
2222from ..config .parameters import SafeSynthesizerParameters
2727)
2828from ..observability import get_logger
2929from ..utils import load_json , write_json
30+ from .utils import trust_remote_code_for_model
3031
3132logger = get_logger (__name__ )
3233
@@ -77,7 +78,7 @@ class LLMPromptConfig(BaseModel):
7778 """Integer id for the EOS token."""
7879
7980 @classmethod
80- def from_tokenizer (cls , name : str , tokenizer : AutoTokenizer | None = None , ** kwargs ) -> LLMPromptConfig :
81+ def from_tokenizer (cls , name : str , tokenizer : PreTrainedTokenizerBase | None = None , ** kwargs ) -> LLMPromptConfig :
8182 """Create a prompt config by reading from settings of a tokenizer.
8283
8384 If no ``tokenizer`` is supplied one is loaded from ``name``
@@ -94,7 +95,9 @@ def from_tokenizer(cls, name: str, tokenizer: AutoTokenizer | None = None, **kwa
9495 Returns:
9596 A new ``LLMPromptConfig`` populated from the tokenizer.
9697 """
97- tokenizer = tokenizer or AutoTokenizer .from_pretrained (name )
98+ tokenizer = tokenizer or AutoTokenizer .from_pretrained (
99+ name , trust_remote_code = trust_remote_code_for_model (name )
100+ )
98101 bos_token = kwargs .get ("bos_token" , getattr (tokenizer , "bos_token" , None ))
99102 bos_token_id = kwargs .get ("bos_token_id" , getattr (tokenizer , "bos_token_id" , None ))
100103 eos_token = kwargs .get ("eos_token" , getattr (tokenizer , "eos_token" , None ))
@@ -339,7 +342,11 @@ def populate_derived_fields(cls, data: dict) -> dict:
339342 The mutated ``data`` dict with derived fields populated.
340343 """
341344 if data .get ("autoconfig" ) is None :
342- data ["autoconfig" ] = AutoConfig .from_pretrained (data ["model_name_or_path" ])
345+ model_name_or_path = data ["model_name_or_path" ]
346+ data ["autoconfig" ] = AutoConfig .from_pretrained (
347+ model_name_or_path ,
348+ trust_remote_code = trust_remote_code_for_model (model_name_or_path ),
349+ )
343350
344351 if data .get ("base_max_seq_length" ) is None :
345352 data ["base_max_seq_length" ] = get_base_max_seq_length (data ["autoconfig" ])
@@ -447,6 +454,32 @@ def save_metadata(self) -> None:
447454 indent = 4 ,
448455 )
449456
457+ @staticmethod
458+ def _load_config_and_tokenizer (
459+ model_name_or_path : str ,
460+ tokenizer : PreTrainedTokenizerBase | None = None ,
461+ ) -> tuple [PretrainedConfig , PreTrainedTokenizerBase ]:
462+ """Load ``PretrainedConfig`` and (optionally) ``AutoTokenizer`` for a model.
463+
464+ Centralises the repeated boilerplate present in every subclass
465+ ``__init__``: loading the HuggingFace config and, when no
466+ pre-loaded tokenizer is supplied, fetching one via
467+ ``AutoTokenizer.from_pretrained``.
468+
469+ Args:
470+ model_name_or_path: HuggingFace model identifier or local path.
471+ tokenizer: Pre-loaded tokenizer to reuse. When ``None`` a new
472+ one is loaded from ``model_name_or_path``.
473+
474+ Returns:
475+ A ``(config, tokenizer)`` tuple ready to pass to ``super().__init__``.
476+ """
477+ trust = trust_remote_code_for_model (model_name_or_path )
478+ config : PretrainedConfig = AutoConfig .from_pretrained (model_name_or_path , trust_remote_code = trust )
479+ if tokenizer is None :
480+ tokenizer = AutoTokenizer .from_pretrained (model_name_or_path , trust_remote_code = trust )
481+ return config , tokenizer
482+
450483 @classmethod
451484 def _resolve_model_class (cls : type ["ModelMetadata" ], model_name_or_path : Path | str ) -> type ["ModelMetadata" ]:
452485 """Resolve model name or path to the matching metadata subclass.
@@ -588,8 +621,7 @@ def __init__(
588621 rope_scaling_factor : float | None = None ,
589622 ** kwargs ,
590623 ) -> None :
591- tokenizer = AutoTokenizer .from_pretrained (model_name_or_path ) if tokenizer is None else tokenizer
592- config : PretrainedConfig = AutoConfig .from_pretrained (model_name_or_path )
624+ config , tokenizer = ModelMetadata ._load_config_and_tokenizer (model_name_or_path , tokenizer )
593625
594626 super ().__init__ (
595627 autoconfig = config ,
@@ -628,8 +660,7 @@ def __init__(
628660 rope_scaling_factor : float | None = None ,
629661 ** kwargs ,
630662 ) -> None :
631- config : PretrainedConfig = AutoConfig .from_pretrained (model_name_or_path )
632- tokenizer = AutoTokenizer .from_pretrained (model_name_or_path ) if tokenizer is None else tokenizer
663+ config , tokenizer = ModelMetadata ._load_config_and_tokenizer (model_name_or_path , tokenizer )
633664
634665 super ().__init__ (
635666 autoconfig = config ,
@@ -668,12 +699,11 @@ class Mistral(ModelMetadata):
668699 def __init__ (
669700 self ,
670701 model_name_or_path : str ,
671- tokenizer : AutoTokenizer | None = None ,
702+ tokenizer : PreTrainedTokenizerBase | None = None ,
672703 rope_scaling_factor : float | None = None ,
673704 ** kwargs ,
674705 ) -> None :
675- tokenizer : AutoTokenizer = AutoTokenizer .from_pretrained (model_name_or_path ) if tokenizer is None else tokenizer
676- config : PretrainedConfig = AutoConfig .from_pretrained (model_name_or_path )
706+ config , tokenizer = ModelMetadata ._load_config_and_tokenizer (model_name_or_path , tokenizer )
677707 if rope_scaling_factor :
678708 logger .warning (
679709 f"Rope scaling factor { rope_scaling_factor } is not supported for Mistral due to longer default context lengths. Ignoring."
@@ -714,8 +744,7 @@ def __init__(
714744 rope_scaling_factor : float | None = None ,
715745 ** kwargs ,
716746 ) -> None :
717- tokenizer : AutoTokenizer = AutoTokenizer .from_pretrained (model_name_or_path ) if tokenizer is None else tokenizer
718- config : PretrainedConfig = AutoConfig .from_pretrained (model_name_or_path )
747+ config , tokenizer = ModelMetadata ._load_config_and_tokenizer (model_name_or_path , tokenizer )
719748
720749 super ().__init__ (
721750 autoconfig = config ,
@@ -751,8 +780,7 @@ def __init__(
751780 rope_scaling_factor : float | None = None ,
752781 ** kwargs ,
753782 ) -> None :
754- tokenizer = AutoTokenizer .from_pretrained (model_name_or_path ) if tokenizer is None else tokenizer
755- config = AutoConfig .from_pretrained (model_name_or_path )
783+ config , tokenizer = ModelMetadata ._load_config_and_tokenizer (model_name_or_path , tokenizer )
756784
757785 super ().__init__ (
758786 autoconfig = config ,
@@ -792,14 +820,13 @@ def __init__(
792820 rope_scaling_factor : float | None = None ,
793821 ** kwargs ,
794822 ) -> None :
795- tokenizer = AutoTokenizer .from_pretrained (model_name_or_path ) if tokenizer is None else tokenizer
796- config = AutoConfig .from_pretrained (model_name_or_path )
823+ config , tokenizer = ModelMetadata ._load_config_and_tokenizer (model_name_or_path , tokenizer )
797824 if rope_scaling_factor :
798825 logger .warning (
799826 f"Rope scaling factor { rope_scaling_factor } is not supported for SmolLM2 due to longer default context lengths. Ignoring."
800827 )
801828
802- im_start_id = tokenizer .convert_tokens_to_ids ("<|im_start|>" )
829+ im_start_id = tokenizer .convert_tokens_to_ids ("<|im_start|>" ) # ty: ignore[unresolved-attribute] -- third-party stub
803830 super ().__init__ (
804831 autoconfig = config ,
805832 instruction = DEFAULT_INSTRUCTION ,
@@ -840,8 +867,7 @@ def __init__(
840867 rope_scaling_factor : float | None = None ,
841868 ** kwargs ,
842869 ) -> None :
843- tokenizer = AutoTokenizer .from_pretrained (model_name_or_path ) if tokenizer is None else tokenizer
844- config = AutoConfig .from_pretrained (model_name_or_path )
870+ config , tokenizer = ModelMetadata ._load_config_and_tokenizer (model_name_or_path , tokenizer )
845871
846872 # we use the bos token here explicitly for support during group-by SFT.
847873 # the groupby assumes there is a bos token at the start of the prompt.
@@ -890,8 +916,7 @@ def __init__(
890916 rope_scaling_factor : float | None = None ,
891917 ** kwargs ,
892918 ) -> None :
893- tokenizer = tokenizer or AutoTokenizer .from_pretrained (model_name_or_path )
894- config = AutoConfig .from_pretrained (model_name_or_path )
919+ config , tokenizer = ModelMetadata ._load_config_and_tokenizer (model_name_or_path , tokenizer )
895920
896921 super ().__init__ (
897922 autoconfig = config ,
0 commit comments