diff --git a/captum/concept/_utils/classifier.py b/captum/concept/_utils/classifier.py index 4ae9d1b9b7..1f855a8a65 100644 --- a/captum/concept/_utils/classifier.py +++ b/captum/concept/_utils/classifier.py @@ -178,7 +178,7 @@ def train_and_eval( predict = self.lm(x_test) - predict = self.lm.classes()[torch.argmax(predict, dim=1)] # type: ignore + predict = self.lm.classes()[torch.argmax(predict, dim=1).cpu()] # type: ignore score = predict.long() == y_test.long().cpu() accs = score.float().mean()