From 0b048b37d1eb35e362d8f55d3b6e08d622889ed7 Mon Sep 17 00:00:00 2001 From: Adam Date: Thu, 5 Oct 2023 15:18:51 -0600 Subject: [PATCH] move pred to cpu, which is the same device as lm.classes() --- captum/concept/_utils/classifier.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/captum/concept/_utils/classifier.py b/captum/concept/_utils/classifier.py index 4ae9d1b9b7..1f855a8a65 100644 --- a/captum/concept/_utils/classifier.py +++ b/captum/concept/_utils/classifier.py @@ -178,7 +178,7 @@ def train_and_eval( predict = self.lm(x_test) - predict = self.lm.classes()[torch.argmax(predict, dim=1)] # type: ignore + predict = self.lm.classes()[torch.argmax(predict, dim=1).cpu()] # type: ignore score = predict.long() == y_test.long().cpu() accs = score.float().mean()