Skip to content

Commit 5b05321

Browse files
committed
BatchEncoding.to with device with tests (#9584)
1 parent 412d878 commit 5b05321

File tree

2 files changed

+11
-1
lines changed

2 files changed

+11
-1
lines changed

src/transformers/tokenization_utils_base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,12 @@ def _is_torch(x):
6565
return isinstance(x, torch.Tensor)
6666

6767

68+
def _is_torch_device(x):
69+
import torch
70+
71+
return isinstance(x, torch.device)
72+
73+
6874
def _is_tensorflow(x):
6975
import tensorflow as tf
7076

@@ -801,7 +807,7 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
801807
# This check catches things like APEX blindly calling "to" on all inputs to a module
802808
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
803809
# into a HalfTensor
804-
if isinstance(device, str) or isinstance(device, torch.device) or isinstance(device, int):
810+
if isinstance(device, str) or _is_torch_device(device) or isinstance(device, int):
805811
self.data = {k: v.to(device=device) for k, v in self.data.items()}
806812
else:
807813
logger.warning(

tests/test_tokenization_common.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1704,6 +1704,10 @@ def test_torch_encode_plus_sent_to_model(self):
17041704
first_ten_tokens = list(tokenizer.get_vocab().keys())[:10]
17051705
sequence = " ".join(first_ten_tokens)
17061706
encoded_sequence = tokenizer.encode_plus(sequence, return_tensors="pt")
1707+
1708+
# Ensure that the BatchEncoding.to() method works.
1709+
encoded_sequence.to(model.device)
1710+
17071711
batch_encoded_sequence = tokenizer.batch_encode_plus([sequence, sequence], return_tensors="pt")
17081712
# This should not fail
17091713

0 commit comments

Comments
 (0)