From a6583c3759f48987b99d2e38dd49d30de31337c2 Mon Sep 17 00:00:00 2001 From: yipenghe Date: Sat, 7 Dec 2019 14:04:09 -0800 Subject: [PATCH] temp cuda --- semisupervised/codes/train.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/semisupervised/codes/train.py b/semisupervised/codes/train.py index 5bd39bc..79b05a0 100755 --- a/semisupervised/codes/train.py +++ b/semisupervised/codes/train.py @@ -186,7 +186,8 @@ def train(epoches): trainer.optimizer.zero_grad() k = 10 temp = torch.zeros([k, target_q.shape[0], target_q.shape[1]], dtype=target_q.dtype) - temp = temp.cuda() + if opt['cuda']: + temp = temp.cuda() for i in range(k): temp[i,:,:] = trainer.predict_noisy(inputs_q) target_predict = temp.mean(dim = 0)