Skip to content

add model_config support in TransformersModel #1168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

Jonnathanz
Copy link

@Jonnathanz Jonnathanz commented Apr 10, 2025

This pull request adds support for the model_config parameter in the TransformersModel class. With this change, it's now possible to pass a dictionary containing specific configuration options for model loading (via AutoModelForCausalLM.from_pretrained or AutoModelForImageTextToText.from_pretrained), separating these settings from the kwargs used in the generate() method.

Highlights:

Quantization support: Enables the use of configurations such as quantization_config (e.g., for 4-bit quantization using BitsAndBytes), as well as other parameters like torch_dtype and device_map.

Flexible model initialization: Users can now customize model loading with a wide range of parameters without interfering with generation-specific arguments.

Clear separation of concerns: Model configuration is handled through the model_config dictionary, while generation parameters remain in **kwargs during the generate() call.

This update improves customization options during model initialization, making the framework more versatile and suitable for models requiring specific loading configurations.

Open to feedback — happy to refine the implementation as needed.

Example:
```python
>>> from transformers import BitsAndBytesConfig
>>> from smolagents import CodeAgent, TransformersModel

>>> model_id = "Qwen/Qwen2.5-Coder-32B-Instruct"

>>> bnb_config = BitsAndBytesConfig(
...     load_in_4bit=True,
...     bnb_4bit_compute_dtype="float16",
...     bnb_4bit_use_double_quant=True,
...     bnb_4bit_quant_type="nf4"
... )

>>> model = TransformersModel(
...     model_id,
...     device_map="auto",
...     torch_dtype="auto",
...     trust_remote_code=True,
...     model_config={'quantization_config': bnb_config},
...     max_new_tokens=2000
... )

>>> agent = CodeAgent(tools=[], model=model)

>>> result = agent.run("Explain quantum mechanics in simple terms.")
>>> print(result)
"Quantum mechanics is a branch of physics that studies the behavior of particles at the smallest scales, such as atoms and subatomic particles. Unlike classical physics, which..."

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant