diff --git a/deployement/working/IndicLID.py b/deployement/working/IndicLID.py index 1157335..a7a3d89 100644 --- a/deployement/working/IndicLID.py +++ b/deployement/working/IndicLID.py @@ -297,7 +297,7 @@ def get_dataloaders(self, indices, input_texts, batch_size): def predict(self, input): input_list = [input,] - self.batch_predict(input_list, 1) + return self.batch_predict(input_list, 1) def batch_predict(self, input_list, batch_size):