Skip to content

Commit 32137fd

Browse files
eginhardRocketknight1
authored andcommitted
Fix BatchEncoding.to() for nested elements
1 parent d29482c commit 32137fd

File tree

1 file changed

+1
-2
lines changed

1 file changed

+1
-2
lines changed

src/transformers/tokenization_utils_base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -801,14 +801,13 @@ def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False)
801801
[`BatchEncoding`]: The same instance after modification.
802802
"""
803803
requires_backends(self, ["torch"])
804-
import torch
805804

806805
# This check catches things like APEX blindly calling "to" on all inputs to a module
807806
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
808807
# into a HalfTensor
809808
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
810809
self.data = {
811-
k: v.to(device=device, non_blocking=non_blocking) if isinstance(v, torch.Tensor) else v
810+
k: v.to(device=device, non_blocking=non_blocking) if hasattr(v, "to") and callable(v.to) else v
812811
for k, v in self.data.items()
813812
}
814813
else:

0 commit comments

Comments
 (0)