Skip to content

add model_config support in TransformersModel #1168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 55 additions & 27 deletions src/smolagents/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,7 +621,9 @@ def __call__(


class TransformersModel(Model):
"""A class that uses Hugging Face's Transformers library for language model interaction.
"""
A class that uses Hugging Face's Transformers library for language model interaction.
## Now with support to model params and qlora

This model allows you to load and use Hugging Face's models locally using the Transformers library. It supports features like stop sequences and grammar customization.

Expand All @@ -638,35 +640,54 @@ class TransformersModel(Model):
The torch_dtype to initialize your model with.
trust_remote_code (bool, default `False`):
Some models on the Hub require running remote code: for this model, you would have to set this flag to True.
model_config:
Params for model configuration that you want to use in AutoModelForImageTextToText.from_pretrained or AutoModelForCausalLM.from_pretrained
kwargs (dict, *optional*):
Any additional keyword arguments that you want to use in model.generate(), for instance `max_new_tokens` or `device`.
**kwargs:
Additional keyword arguments to pass to `model.generate()`, for instance `max_new_tokens` or `device`.

Raises:
ValueError:
If the model name is not provided.

Example:
```python
>>> engine = TransformersModel(
... model_id="Qwen/Qwen2.5-Coder-32B-Instruct",
... device="cuda",
... max_new_tokens=5000,
>>> from transformers import BitsAndBytesConfig
>>> from smolagents_v1.smolagents.src.smolagents import CodeAgent, TransformersModel

>>> model_id = "Qwen/Qwen2.5-Coder-32B-Instruct"

>>> bnb_config = BitsAndBytesConfig(
... load_in_4bit=True,
... bnb_4bit_compute_dtype="float16",
... bnb_4bit_use_double_quant=True,
... bnb_4bit_quant_type="nf4"
... )
>>> messages = [{"role": "user", "content": "Explain quantum mechanics in simple terms."}]
>>> response = engine(messages, stop_sequences=["END"])
>>> print(response)
"Quantum mechanics is the branch of physics that studies..."

>>> model = TransformersModel(
... model_id,
... device_map="auto",
... torch_dtype="auto",
... trust_remote_code=True,
... model_config={'quantization_config': bnb_config},
... max_new_tokens=2000
... )

>>> agent = CodeAgent(tools=[], model=model)

>>> result = agent.run("Explain quantum mechanics in simple terms.")
>>> print(result)
"Quantum mechanics is a branch of physics that studies the behavior of particles at the smallest scales, such as atoms and subatomic particles. Unlike classical physics, which..."
```
"""

def __init__(
self,
model_id: Optional[str] = None,
device_map: Optional[str] = None,
torch_dtype: Optional[str] = None,
trust_remote_code: bool = False,
**kwargs,
model_id: Optional[str] = None,
device_map: Optional[str] = None,
torch_dtype: Optional[str] = None,
trust_remote_code: bool = False,
model_config: dict = dict(), # Variável que armazena parâmetros do modelo como dicionário
**kwargs, # Armazena parâmetros do model.generate
):
try:
import torch
Expand All @@ -684,6 +705,7 @@ def __init__(
FutureWarning,
)
model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"

self.model_id = model_id

default_max_tokens = 5000
Expand All @@ -701,27 +723,31 @@ def __init__(
try:
self.model = AutoModelForImageTextToText.from_pretrained(
model_id,
device_map=device_map,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
device_map = device_map,
torch_dtype = torch_dtype,
trust_remote_code = trust_remote_code,
**model_config # Adiciona os kwargs, agora permite quantização
)
self.processor = AutoProcessor.from_pretrained(model_id, trust_remote_code=trust_remote_code)
self._is_vlm = True
except ValueError as e:
if "Unrecognized configuration class" in str(e):
self.model = AutoModelForCausalLM.from_pretrained(
model_id,
device_map=device_map,
torch_dtype=torch_dtype,
trust_remote_code=trust_remote_code,
device_map = device_map,
torch_dtype = torch_dtype,
trust_remote_code = trust_remote_code,
**model_config # Adiciona os kwargs, agora permite quantização
)
self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=trust_remote_code)
else:
raise e
except Exception as e:
raise ValueError(f"Failed to load tokenizer and model for {model_id=}: {e}") from e

super().__init__(flatten_messages_as_text=not self._is_vlm, **kwargs)


def make_stopping_criteria(self, stop_sequences: List[str], tokenizer) -> "StoppingCriteriaList":
from transformers import StoppingCriteria, StoppingCriteriaList

Expand All @@ -742,7 +768,7 @@ def __call__(self, input_ids, scores, **kwargs):
return False

return StoppingCriteriaList([StopOnStrings(stop_sequences, tokenizer)])

def __call__(
self,
messages: List[Dict[str, str]],
Expand All @@ -752,9 +778,9 @@ def __call__(
**kwargs,
) -> ChatMessage:
completion_kwargs = self._prepare_completion_kwargs(
messages=messages,
stop_sequences=stop_sequences,
grammar=grammar,
messages = messages,
stop_sequences = stop_sequences,
grammar = grammar,
**kwargs,
)

Expand Down Expand Up @@ -801,9 +827,10 @@ def __call__(

out = self.model.generate(
**prompt_tensor,
stopping_criteria=stopping_criteria,
stopping_criteria = stopping_criteria,
**completion_kwargs,
)

generated_tokens = out[0, count_prompt_tokens:]
if hasattr(self, "processor"):
output_text = self.processor.decode(generated_tokens, skip_special_tokens=True)
Expand All @@ -827,6 +854,7 @@ def __call__(
return chat_message



class ApiModel(Model):
"""
Base class for API-based language models.
Expand Down
93 changes: 93 additions & 0 deletions test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import sys
sys.path.insert(0, r"C:\Users\jonna\Desktop\projetos\2_HUGGING_FACE_AGENT_AI\2_1_THE_SMOLAGENS_FRAMEWORK\smolagents_v1\smolagents\src")

from transformers import BitsAndBytesConfig
from src.smolagents import CodeAgent, TransformersModel

model_id = "Qwen/Qwen2.5-Coder-32B-Instruct"

bnb_config = BitsAndBytesConfig(
load_in_4bit =True,
bnb_4bit_compute_dtype = "float16",
bnb_4bit_use_double_quant = True,
bnb_4bit_quant_type = "nf4"
)

model = TransformersModel(
model_id,
device_map = "auto",
torch_dtype = "auto", # pode ser omitido se quiser
trust_remote_code = True,
model_config = {'quantization_config': bnb_config},
max_new_tokens = 2000
)

'''
agent = CodeAgent(tools=[], model=model)
agent.run("Explain quantum mechanics in simple terms.")
'''

from langchain.docstore.document import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from smolagents import Tool
from langchain_community.retrievers import BM25Retriever
from smolagents import CodeAgent, LiteLLMModel, FinalAnswerTool

# tool
class PartyPlanningRetrieverTool(Tool):
name = "party_planning_retriever"
description = "Uses semantic search to retrieve relevant party planning ideas for Alfred's superhero-themed party at Wayne Manor."
inputs = {
"query": {
"type": "string",
"description": "The query to perform. This should be a query related to party planning or superhero themes."
}
}
output_type = "string"

def __init__(self, docs, **kwargs):
super().__init__(**kwargs)
self.retriever = BM25Retriever.from_documents(docs, k=5)

def forward(self, query: str) -> str:
assert isinstance(query, str), "Your search query must be a string"

docs = self.retriever.invoke(query)

return "\nRetrieved ideas:\n" + "".join([f"\n\n===== Idea {str(i)} =====\n" + doc.page_content for i, doc in enumerate(docs)])


# doc
party_ideas = [
{"text": "A superhero-themed masquerade ball with luxury decor, including gold accents and velvet curtains.", "source": "Party Ideas 1"},
{"text": "Hire a professional DJ who can play themed music for superheroes like Batman and Wonder Woman.", "source": "Entertainment Ideas"},
{"text": "For catering, serve dishes named after superheroes, like 'The Hulk's Green Smoothie' and 'Iron Man's Power Steak.'", "source": "Catering Ideas"},
{"text": "Decorate with iconic superhero logos and projections of Gotham and other superhero cities around the venue.", "source": "Decoration Ideas"},
{"text": "Interactive experiences with VR where guests can engage in superhero simulations or compete in themed games.", "source": "Entertainment Ideas"}
]

source_docs = [Document(page_content=doc["text"], metadata={"source": doc["source"]}) for doc in party_ideas]

# Split the documents into smaller chunks for more efficient search
text_splitter = RecursiveCharacterTextSplitter(
chunk_size = 500,
chunk_overlap = 50,
add_start_index = True,
strip_whitespace = True,
separators = ["\n\n", "\n", ".", " ", ""]
)
doc_processed = text_splitter.split_documents(source_docs)

# Create the retriever tool
party_planning_retriever = PartyPlanningRetrieverTool(doc_processed)

# Initialize the agned
agent = CodeAgent(tools=[party_planning_retriever, FinalAnswerTool()], model=model, verbosity_level = 3)

# Response
response = agent.run("Find ideas for a luxury superhero-themed party, including entertainment, catering, and decoration options.")

print(response)
'''
{'Entertainment Ideas': ['Interactive experiences with VR where guests can engage in superhero simulations or compete in themed games.', 'Hire a professional DJ who can play themed music for superheroes like Batman and Wonder Woman.'], 'Catering Ideas': ["Serve dishes named after superheroes, such as 'The Hulk's Green Smoothie' and 'Iron Man's Power Steak.'"], 'Decoration Ideas': ['A superhero-themed masquerade ball with luxury decor, including gold accents and velvet curtains.', 'Decorate with iconic superhero logos and projections of Gotham and other superhero cities around the venue.']}
'''