-
Notifications
You must be signed in to change notification settings - Fork 210
Add transformers_kwargs to Transformer's from_pretrained function
#1357
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Sarah Yurick <[email protected]>
Greptile SummaryThis PR adds a Key Changes:
Issues Found:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant EmbeddingCreatorStage
participant TokenizerStage
participant EmbeddingModelStage
participant AutoTokenizer
participant AutoModel
User->>EmbeddingCreatorStage: Create with transformers_kwargs
EmbeddingCreatorStage->>TokenizerStage: Pass transformers_kwargs
EmbeddingCreatorStage->>EmbeddingModelStage: Pass transformers_kwargs
Note over TokenizerStage: setup_on_node()
TokenizerStage->>AutoTokenizer: from_pretrained(**transformers_kwargs)
AutoTokenizer-->>TokenizerStage: tokenizer instance
Note over EmbeddingModelStage: setup()
EmbeddingModelStage->>AutoModel: from_pretrained(**transformers_kwargs)
AutoModel-->>EmbeddingModelStage: model instance
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
Overview
This PR adds a transformers_kwargs parameter across multiple classes to allow users to pass additional arguments (like trust_remote_code=True) to HuggingFace's from_pretrained() methods. This addresses issue #1296 where models requiring custom code couldn't be loaded.
Key Changes
- EmbeddingModelStage - Added
transformers_kwargsparameter with validation to preventlocal_files_onlyoverride - TokenizerStage - Added
transformers_kwargsparameter with comprehensive validation forcache_dir,padding_side, andlocal_files_only - MegatronTokenizerWriter - Added
transformers_kwargsparameter with validation forcache_dirandlocal_files_only - TokenCountFilter - Added
transformers_kwargsparameter with validation forlocal_files_onlyinload_tokenizer() - EmbeddingCreatorStage - Added
transformers_kwargsfield and propagates to child stages
Issues Found
1. Missing padding_side Validation in EmbeddingModelStage (Logic Error)
The EmbeddingModelStage class accepts a padding_side parameter in its __init__ signature but does NOT validate that users don't also pass padding_side via transformers_kwargs. This is inconsistent with TokenizerStage which properly validates this conflict. If a user passes transformers_kwargs={"padding_side": "left"} while also setting padding_side="right", the code will attempt to call AutoModel.from_pretrained(..., padding_side="left", padding_side="right"), resulting in a confusing TypeError about unexpected keyword arguments.
2. Inconsistent Validation Timing in TokenCountFilter (Style Issue)
The TokenCountFilter class validates local_files_only in the load_tokenizer() method instead of in __init__, unlike all other classes in this PR. This defers validation until runtime and is inconsistent with the pattern used elsewhere. Users won't discover configuration errors until load_tokenizer() is actually called. Additionally, the validation is only performed if load_tokenizer() is explicitly called - if the user provides a pre-loaded tokenizer, the validation never runs.
Positive Aspects
MegatronTokenizerWriterandTokenizerStagehave correct validation in placetransformers_kwargsis properly passed tofrom_pretrained()in all classes- Good protection against internal parameter conflicts
- Proper parameter propagation in
EmbeddingCreatorStage
Confidence Score: 2/5
- This PR contains 2 significant issues that should be fixed before merging: missing validation in EmbeddingModelStage and inconsistent validation placement in TokenCountFilter.
- The PR adds a useful feature (transformers_kwargs) but has two bugs that need to be addressed: (1) EmbeddingModelStage is missing validation for the padding_side parameter, which could lead to confusing runtime errors for users; (2) TokenCountFilter defers validation until load_tokenizer() is called instead of validating in init, creating inconsistent behavior and delayed error detection. Both issues are fixable but should be resolved before merging. The MegatronTokenizerWriter and TokenizerStage implementations are correct and well-validated.
- nemo_curator/stages/text/embedders/base.py (missing padding_side validation in EmbeddingModelStage), nemo_curator/stages/text/filters/heuristic_filter.py (inconsistent validation timing in TokenCountFilter)
Important Files Changed
File Analysis
| Filename | Score | Overview |
|---|---|---|
| nemo_curator/stages/text/embedders/base.py | 2/5 | Added transformers_kwargs parameter and validation. Missing validation for padding_side parameter which should be checked to prevent conflicts. |
| nemo_curator/stages/text/filters/heuristic_filter.py | 2/5 | Added transformers_kwargs parameter. Validation for local_files_only is done in load_tokenizer() instead of init, delaying validation until runtime. |
| nemo_curator/stages/text/io/writer/megatron_tokenizer.py | 4/5 | Added transformers_kwargs parameter with proper validation for cache_dir and local_files_only conflicts. Correctly passes transformers_kwargs to from_pretrained(). |
| nemo_curator/stages/text/models/tokenizer.py | 4/5 | Added transformers_kwargs parameter with comprehensive validation for cache_dir, padding_side, and local_files_only conflicts. Correctly passes transformers_kwargs to from_pretrained(). |
Sequence Diagram
sequenceDiagram
participant User
participant EmbeddingCreatorStage
participant TokenizerStage
participant EmbeddingModelStage
participant AutoTokenizer
participant AutoModel
User->>EmbeddingCreatorStage: Initialize with transformers_kwargs
EmbeddingCreatorStage->>TokenizerStage: Create with transformers_kwargs
EmbeddingCreatorStage->>EmbeddingModelStage: Create with transformers_kwargs
Note over EmbeddingModelStage: Missing padding_side validation!
User->>TokenizerStage: Call setup()
TokenizerStage->>AutoTokenizer: from_pretrained(..., **transformers_kwargs)
AutoTokenizer-->>TokenizerStage: Returns tokenizer
User->>EmbeddingModelStage: Call setup()
EmbeddingModelStage->>AutoModel: from_pretrained(..., **transformers_kwargs)
Note over AutoModel: If padding_side in transformers_kwargs,<br/>will raise TypeError
AutoModel-->>EmbeddingModelStage: Returns model (or error)
| def __init__( # noqa: PLR0913 | ||
| self, | ||
| tokenizer: AutoTokenizer | None = None, | ||
| hf_model_name: str | None = None, | ||
| hf_token: str | None = None, | ||
| min_tokens: int = 0, | ||
| max_tokens: int = float("inf"), | ||
| transformers_kwargs: dict[str, Any] | None = None, | ||
| ): | ||
| """ | ||
| Args: | ||
| tokenizer (AutoTokenizer | None): The pre-loaded tokenizer to use to count the tokens. | ||
| If None, the tokenizer will be initialized from the hf_model_name. | ||
| hf_model_name (str | None): The name of the Hugging Face model to use to count the tokens. | ||
| If None, the pre-loaded tokenizer must be provided via the tokenizer argument. | ||
| hf_token (str | None): The token to use to access the Hugging Face model, if needed. | ||
| min_tokens (int): The minimum number of tokens the document must contain. | ||
| Set to 0 to disable the minimum token count filter. | ||
| max_tokens (int): The maximum number of tokens the document can contain. | ||
| Set to infinity to disable the maximum token count filter. | ||
| transformers_kwargs: Additional keyword arguments to pass to the tokenizer's from_pretrained method. | ||
| Defaults to {}. | ||
| """ | ||
| super().__init__() | ||
|
|
||
| if tokenizer is None and hf_model_name is None: | ||
| msg = "Either tokenizer or hf_model_name must be provided" | ||
| raise ValueError(msg) | ||
| if tokenizer is not None and hf_model_name is not None: | ||
| msg = "Either tokenizer or hf_model_name must be provided, not both" | ||
| raise ValueError(msg) | ||
|
|
||
| self._token_count_filter_tokenizer = tokenizer | ||
| self._hf_model_name = hf_model_name | ||
| self._hf_token = hf_token | ||
| self._min_tokens = min_tokens | ||
| self._max_tokens = max_tokens | ||
| self._transformers_kwargs = transformers_kwargs or {} | ||
| self._name = "token_count" | ||
|
|
||
| def model_check_or_download(self) -> None: | ||
| if self._hf_model_name is not None: | ||
| # Use snapshot_download to download all files without loading the model into memory. | ||
| huggingface_hub.snapshot_download( | ||
| repo_id=self._hf_model_name, | ||
| token=self._hf_token, | ||
| local_files_only=False, # Download if not cached | ||
| resume_download=True, # Resume interrupted downloads | ||
| ) | ||
|
|
||
| def load_tokenizer(self) -> None: | ||
| if self._hf_model_name is not None: | ||
| if "local_files_only" in self._transformers_kwargs and self._transformers_kwargs["local_files_only"] is not None: | ||
| msg = "Passing the local_files_only parameter is not allowed" | ||
| raise ValueError(msg) | ||
|
|
||
| self._token_count_filter_tokenizer = AutoTokenizer.from_pretrained( | ||
| self._hf_model_name, local_files_only=True | ||
| self._hf_model_name, local_files_only=True, **self._transformers_kwargs | ||
| ) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The validation for local_files_only is performed in the load_tokenizer() method (line 698-700) instead of in __init__ (line 683-684). This creates inconsistent behavior with other classes in this PR (TokenizerStage, EmbeddingModelStage, MegatronTokenizerWriter), which all validate in __init__ or setup().
By deferring validation until load_tokenizer(), users won't discover invalid configurations until runtime when load_tokenizer() is called. If the tokenizer is never loaded (e.g., if a pre-loaded tokenizer is provided via the tokenizer parameter), the validation never runs.
For consistency and fail-fast behavior, move the validation to __init__ immediately after self._transformers_kwargs = transformers_kwargs or {} on line 683, and remove the validation check from the load_tokenizer() method.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Additional Comments (1)
If a user passes To maintain consistency with |
Closes #1296.