From 7b9ccc816eceda14fed587949d45bed6559ab092 Mon Sep 17 00:00:00 2001 From: Fabian Bielicki Date: Mon, 10 Feb 2025 21:37:59 +0100 Subject: [PATCH] fix device select \n \n device for create_predictor was hardcoded and not passed from upper class --- lag_llama/gluon/estimator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lag_llama/gluon/estimator.py b/lag_llama/gluon/estimator.py index d0c5b52..d6aa1d9 100644 --- a/lag_llama/gluon/estimator.py +++ b/lag_llama/gluon/estimator.py @@ -477,7 +477,7 @@ def create_predictor( prediction_net=module, batch_size=self.batch_size, prediction_length=self.prediction_length, - device="cuda" if torch.cuda.is_available() else "cpu", + device=self.device, ) else: return PyTorchPredictor( @@ -486,5 +486,5 @@ def create_predictor( prediction_net=module, batch_size=self.batch_size, prediction_length=self.prediction_length, - device="cuda" if torch.cuda.is_available() else "cpu", + device=self.device, )