Skip to content

Commit bffeefe

Browse files
winglianLysandreJik
authored andcommitted
handle torch version edge cases (#37399)
1 parent 5c076fb commit bffeefe

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/transformers/integrations/flex_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ def __init__(self, training):
6666
# cause errors. The suggested fix is to compile with "max-autotune-no-cudagraphs"
6767
# see https://github.com/pytorch/pytorch/issues/146260 for training
6868
self.training = training
69-
if _torch_version == "2.6.0" and training:
69+
if _torch_version.split("+")[0] == "2.6.0" and training:
7070
self._compiled_flex_attention = torch.compile(
7171
flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
7272
)

0 commit comments

Comments
 (0)