From da66c8fd5d0a469a99bc79276c1988d426d9b4d1 Mon Sep 17 00:00:00 2001 From: nichoalscao <982912719@qq.com> Date: Mon, 12 Sep 2022 11:39:16 +0800 Subject: [PATCH] fix InputFeatures .to --- openprompt/data_utils/utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/openprompt/data_utils/utils.py b/openprompt/data_utils/utils.py index 0751105..bb13430 100644 --- a/openprompt/data_utils/utils.py +++ b/openprompt/data_utils/utils.py @@ -178,11 +178,12 @@ def to_tensor(self, device: str = 'cuda'): def to(self, device: str = "cuda:0"): r"""move the tensor keys to runtime device, such as gpu:0 """ - for key in self.tensorable_keys: - value = getattr(self, key) + target = copy.deepcopy(self) + for key in target.tensorable_keys: + value = getattr(target, key) if value is not None: - setattr(self, key, value.to(device)) - return self + setattr(target, key, value.to(device)) + return target def cuda(self, device: str = "cuda:0"): r"""mimic the tensor behavior