Skip to content

Commit eca8380

Browse files
Halve import time by removing torch dependency (#147)
* Halve import time by removing torch dependency
1 parent d8a4b83 commit eca8380

File tree

9 files changed

+64
-54
lines changed

9 files changed

+64
-54
lines changed

.github/workflows/quality.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ jobs:
1313
- name: Set up Python
1414
uses: actions/setup-python@v2
1515
with:
16-
python-version: "3.10"
16+
python-version: "3.12"
1717

1818
# Setup venv
1919
# TODO: revisit when https://github.com/astral-sh/uv/issues/1526 is addressed.

docs/source/en/examples/multiagents.md

+3-3
Original file line numberDiff line numberDiff line change
@@ -48,10 +48,10 @@ Run the line below to install the required dependencies:
4848

4949
Let's login in order to call the HF Inference API:
5050

51-
```py
52-
from huggingface_hub import notebook_login
51+
```
52+
from huggingface_hub import login
5353
54-
notebook_login()
54+
login()
5555
```
5656

5757
⚡️ Our agent will be powered by [Qwen/Qwen2.5-Coder-32B-Instruct](https://huggingface.co/Qwen/Qwen2.5-Coder-32B-Instruct) using `HfApiModel` class that uses HF's Inference API: the Inference API allows to quickly and easily run any OS model.

docs/source/en/tutorials/tools.md

+2-2
Original file line numberDiff line numberDiff line change
@@ -177,7 +177,7 @@ agent.run("How many more blocks (also denoted as layers) are in BERT base encode
177177

178178
### Manage your agent's toolbox
179179

180-
You can manage an agent's toolbox by adding or replacing a tool.
180+
You can manage an agent's toolbox by adding or replacing a tool in attribute `agent.tools`, since it is a standard dictionary.
181181

182182
Let's add the `model_download_tool` to an existing agent initialized with only the default toolbox.
183183

@@ -187,7 +187,7 @@ from smolagents import HfApiModel
187187
model = HfApiModel("Qwen/Qwen2.5-Coder-32B-Instruct")
188188

189189
agent = CodeAgent(tools=[], model=model, add_base_tools=True)
190-
agent.tools.append(model_download_tool)
190+
agent.tools[model_download_tool.name] = model_download_tool
191191
```
192192
Now we can leverage the new tool:
193193

pyproject.toml

+13-4
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,6 @@ authors = [
1212
readme = "README.md"
1313
requires-python = ">=3.10"
1414
dependencies = [
15-
"torch",
16-
"torchaudio",
17-
"torchvision",
1815
"transformers>=4.0.0",
1916
"requests>=2.32.3",
2017
"rich>=13.9.4",
@@ -30,10 +27,22 @@ dependencies = [
3027
]
3128

3229
[tool.ruff]
33-
ignore = ["F403"]
30+
lint.ignore = ["F403"]
3431

3532
[project.optional-dependencies]
33+
dev = [
34+
"torch",
35+
"torchaudio",
36+
"torchvision",
37+
"sqlalchemy",
38+
"accelerate",
39+
"soundfile",
40+
"litellm>=1.55.10",
41+
]
3642
test = [
43+
"torch",
44+
"torchaudio",
45+
"torchvision",
3746
"pytest>=8.1.0",
3847
"sqlalchemy",
3948
"ruff>=0.5.0",

src/smolagents/default_tools.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,9 @@
2020
from typing import Dict, Optional
2121

2222
from huggingface_hub import hf_hub_download, list_spaces
23-
from transformers.models.whisper import (
24-
WhisperForConditionalGeneration,
25-
WhisperProcessor,
26-
)
27-
from transformers.utils import is_offline_mode
23+
24+
25+
from transformers.utils import is_offline_mode, is_torch_available
2826

2927
from .local_python_executor import (
3028
BASE_BUILTIN_MODULES,
@@ -34,6 +32,15 @@
3432
from .tools import TOOL_CONFIG_FILE, PipelineTool, Tool
3533
from .types import AgentAudio
3634

35+
if is_torch_available():
36+
from transformers.models.whisper import (
37+
WhisperForConditionalGeneration,
38+
WhisperProcessor,
39+
)
40+
else:
41+
WhisperForConditionalGeneration = object
42+
WhisperProcessor = object
43+
3744

3845
@dataclass
3946
class PreTool:

src/smolagents/models.py

+10-19
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from enum import Enum
2323
from typing import Dict, List, Optional
2424

25-
import torch
2625
from huggingface_hub import (
2726
InferenceClient,
2827
ChatCompletionOutputMessage,
@@ -35,6 +34,7 @@
3534
AutoTokenizer,
3635
StoppingCriteria,
3736
StoppingCriteriaList,
37+
is_torch_available,
3838
)
3939
import openai
4040

@@ -147,29 +147,12 @@ def __init__(self):
147147
self.last_input_token_count = None
148148
self.last_output_token_count = None
149149

150-
def get_token_counts(self):
150+
def get_token_counts(self) -> Dict[str, int]:
151151
return {
152152
"input_token_count": self.last_input_token_count,
153153
"output_token_count": self.last_output_token_count,
154154
}
155155

156-
def generate(
157-
self,
158-
messages: List[Dict[str, str]],
159-
stop_sequences: Optional[List[str]] = None,
160-
grammar: Optional[str] = None,
161-
max_tokens: int = 1500,
162-
):
163-
raise NotImplementedError
164-
165-
def get_tool_call(
166-
self,
167-
messages: List[Dict[str, str]],
168-
available_tools: List[Tool],
169-
stop_sequences,
170-
):
171-
raise NotImplementedError
172-
173156
def __call__(
174157
self,
175158
messages: List[Dict[str, str]],
@@ -256,6 +239,10 @@ def __call__(
256239
max_tokens: int = 1500,
257240
tools_to_call_from: Optional[List[Tool]] = None,
258241
) -> str:
242+
"""
243+
Gets an LLM output message for the given list of input messages.
244+
If argument `tools_to_call_from` is passed, the model's tool calling options will be used to return a tool call.
245+
"""
259246
messages = get_clean_message_list(
260247
messages, role_conversions=tool_role_conversions
261248
)
@@ -293,6 +280,10 @@ class TransformersModel(Model):
293280

294281
def __init__(self, model_id: Optional[str] = None, device: Optional[str] = None):
295282
super().__init__()
283+
if not is_torch_available():
284+
raise ImportError("Please install torch in order to use TransformersModel.")
285+
import torch
286+
296287
default_model_id = "HuggingFaceTB/SmolLM2-1.7B-Instruct"
297288
if model_id is None:
298289
model_id = default_model_id

src/smolagents/tools.py

+10-12
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
from pathlib import Path
2828
from typing import Callable, Dict, Optional, Union, get_type_hints
2929

30-
import torch
3130
from huggingface_hub import (
3231
create_repo,
3332
get_collection,
@@ -37,7 +36,6 @@
3736
)
3837
from huggingface_hub.utils import RepositoryNotFoundError
3938
from packaging import version
40-
from transformers import AutoProcessor
4139
from transformers.dynamic_module_utils import get_imports
4240
from transformers.utils import (
4341
TypeHintParsingException,
@@ -54,13 +52,14 @@
5452

5553
logger = logging.getLogger(__name__)
5654

57-
58-
if is_torch_available():
59-
pass
60-
6155
if is_accelerate_available():
62-
pass
56+
from accelerate import PartialState
57+
from accelerate.utils import send_to_device
6358

59+
if is_torch_available():
60+
from transformers import AutoProcessor
61+
else:
62+
AutoProcessor = object
6463

6564
TOOL_CONFIG_FILE = "tool_config.json"
6665

@@ -1026,8 +1025,6 @@ def setup(self):
10261025
"""
10271026
Instantiates the `pre_processor`, `model` and `post_processor` if necessary.
10281027
"""
1029-
from accelerate import PartialState
1030-
10311028
if isinstance(self.pre_processor, str):
10321029
self.pre_processor = self.pre_processor_class.from_pretrained(
10331030
self.pre_processor, **self.hub_kwargs
@@ -1066,6 +1063,8 @@ def forward(self, inputs):
10661063
"""
10671064
Sends the inputs through the `model`.
10681065
"""
1066+
import torch
1067+
10691068
with torch.no_grad():
10701069
return self.model(**inputs)
10711070

@@ -1076,16 +1075,15 @@ def decode(self, outputs):
10761075
return self.post_processor(outputs)
10771076

10781077
def __call__(self, *args, **kwargs):
1078+
import torch
1079+
10791080
args, kwargs = handle_agent_input_types(*args, **kwargs)
10801081

10811082
if not self.is_initialized:
10821083
self.setup()
10831084

10841085
encoded_inputs = self.encode(*args, **kwargs)
10851086

1086-
import torch
1087-
from accelerate.utils import send_to_device
1088-
10891087
tensor_inputs = {
10901088
k: v for k, v in encoded_inputs.items() if isinstance(v, torch.Tensor)
10911089
}

src/smolagents/types.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -22,10 +22,10 @@
2222
import numpy as np
2323
import requests
2424
from transformers.utils import (
25-
is_soundfile_availble,
2625
is_torch_available,
2726
is_vision_available,
2827
)
28+
from transformers.utils.import_utils import _is_package_available
2929

3030
logger = logging.getLogger(__name__)
3131

@@ -41,7 +41,7 @@
4141
else:
4242
Tensor = object
4343

44-
if is_soundfile_availble():
44+
if _is_package_available("soundfile"):
4545
import soundfile as sf
4646

4747

@@ -189,7 +189,7 @@ class AgentAudio(AgentType, str):
189189
def __init__(self, value, samplerate=16_000):
190190
super().__init__(value)
191191

192-
if not is_soundfile_availble():
192+
if not _is_package_available("soundfile"):
193193
raise ImportError("soundfile must be installed in order to handle audio.")
194194

195195
self._path = None
@@ -253,7 +253,7 @@ def to_string(self):
253253
INSTANCE_TYPE_MAPPING = {
254254
str: AgentText,
255255
ImageType: AgentImage,
256-
torch.Tensor: AgentAudio,
256+
Tensor: AgentAudio,
257257
}
258258

259259
if is_torch_available():

tests/test_types.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -18,20 +18,19 @@
1818
import uuid
1919
from pathlib import Path
2020

21-
import torch
2221
from PIL import Image
2322
from transformers.testing_utils import (
2423
require_soundfile,
2524
require_torch,
2625
require_vision,
2726
)
28-
from transformers.utils import (
29-
is_soundfile_availble,
27+
from transformers.utils.import_utils import (
28+
_is_package_available,
3029
)
3130

3231
from smolagents.types import AgentAudio, AgentImage, AgentText
3332

34-
if is_soundfile_availble():
33+
if _is_package_available("soundfile"):
3534
import soundfile as sf
3635

3736

@@ -44,6 +43,8 @@ def get_new_path(suffix="") -> str:
4443
@require_torch
4544
class AgentAudioTests(unittest.TestCase):
4645
def test_from_tensor(self):
46+
import torch
47+
4748
tensor = torch.rand(12, dtype=torch.float64) - 0.5
4849
agent_type = AgentAudio(tensor)
4950
path = str(agent_type.to_string())
@@ -61,6 +62,8 @@ def test_from_tensor(self):
6162
self.assertTrue(torch.allclose(tensor, torch.tensor(new_tensor), atol=1e-4))
6263

6364
def test_from_string(self):
65+
import torch
66+
6467
tensor = torch.rand(12, dtype=torch.float64) - 0.5
6568
path = get_new_path(suffix=".wav")
6669
sf.write(path, tensor, 16000)
@@ -75,6 +78,8 @@ def test_from_string(self):
7578
@require_torch
7679
class AgentImageTests(unittest.TestCase):
7780
def test_from_tensor(self):
81+
import torch
82+
7883
tensor = torch.randint(0, 256, (64, 64, 3))
7984
agent_type = AgentImage(tensor)
8085
path = str(agent_type.to_string())

0 commit comments

Comments
 (0)