Open
Description
Hi there,
Thanks for providing Lag-Llama
. It is a wonderful model.
Kindly fix the predictor device to be the same as the estimator.
def create_predictor(
self,
transformation: Transformation,
module,
) -> PyTorchPredictor:
prediction_splitter = self._create_instance_splitter(module, "test")
if self.time_feat:
return PyTorchPredictor(
input_transform=transformation + prediction_splitter,
input_names=PREDICTION_INPUT_NAMES
+ ["past_time_feat", "future_time_feat"],
prediction_net=module,
batch_size=self.batch_size,
prediction_length=self.prediction_length,
# device="cuda" if torch.cuda.is_available() else "cpu",
device=self.device.type,
)
else:
return PyTorchPredictor(
input_transform=transformation + prediction_splitter,
input_names=PREDICTION_INPUT_NAMES,
prediction_net=module,
batch_size=self.batch_size,
prediction_length=self.prediction_length,
# device="cuda" if torch.cuda.is_available() else "cpu",
device=self.device.type,
)
Metadata
Metadata
Assignees
Labels
No labels