-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Description
Hi, I am using sentence-transformers through setfit and want to compile the model with torch.compile(). The (simplified) code looks like this:
model = SetFitModel.from_pretrained(...)
model = torch.compile(model)
However, the following error occurs:
<...>
File "/path/to/my/workspace/.pyenv/versions/workspace-3.9/lib/python3.9/site-packages/setfit/trainer.py", line 341, in __init__
self.st_trainer = BCSentenceTransformersTrainer(
File "/path/to/my/workspace/.pyenv/versions/workspace-3.9/lib/python3.9/site-packages/setfit/trainer.py", line 50, in __init__
super().__init__(
File "/path/to/my/workspace/.pyenv/versions/workspace-3.9/lib/python3.9/site-packages/sentence_transformers/trainer.py", line 278, in __init__
self.data_collator.include_prompt_lengths = self._include_prompt_length()
File "/path/to/my/workspace/.pyenv/versions/workspace-3.9/lib/python3.9/site-packages/sentence_transformers/trainer.py", line 1028, in _include_prompt_length
for module in self.model:
TypeError: 'OptimizedModule' object is not iterable
Brief analysis:
The problematic call is here:
sentence-transformers/sentence_transformers/trainer.py
Lines 280 to 281 in 6aaa53b
| if hasattr(self.data_collator, "include_prompt_lengths"): | |
| self.data_collator.include_prompt_lengths = self._include_prompt_length() |
I understand the problem, compiled modules cannot be iterated, but I don't know what the intention of this logic is.
To me it seems like we first check for the attribute on data_collator, only to ignore it and then try to get this information from them the modules again:
sentence-transformers/sentence_transformers/trainer.py
Lines 1031 to 1033 in 6aaa53b
| for module in self.model: | |
| if isinstance(module, Pooling): | |
| return not module.include_prompt |
Also, include_prompt_lengths should likely be checked for True as well, which would prevent the error already (in my case at least).