Skip to content

Commit 4ffe6c1

Browse files
committed
Handle missing task names in return_df when data_params lacks tasks
Use .get() to safely access data_params["tasks"]["name"] so return_df works on models that have not been fit (e.g. loaded from checkpoint without training data).
1 parent 4bb3ce4 commit 4ffe6c1

1 file changed

Lines changed: 10 additions & 7 deletions

File tree

src/grelu/lightning/__init__.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -860,14 +860,17 @@ def predict_on_dataset(
860860

861861
if return_df:
862862
if (preds.ndim == 3) and (preds.shape[-1] == 1):
863-
task_names = self.data_params["tasks"]["name"]
864863
n_tasks_pred = preds.shape[-2]
865-
if n_tasks_pred != len(task_names):
866-
warnings.warn(
867-
f"Prediction has {n_tasks_pred} task(s) but the model"
868-
f" has {len(task_names)} task name(s), likely due to a"
869-
" prediction transform. Using generic column names."
870-
)
864+
task_names = (
865+
self.data_params.get("tasks", {}).get("name", None)
866+
)
867+
if task_names is None or n_tasks_pred != len(task_names):
868+
if task_names is not None:
869+
warnings.warn(
870+
f"Prediction has {n_tasks_pred} task(s) but the model"
871+
f" has {len(task_names)} task name(s), likely due to a"
872+
" prediction transform. Using generic column names."
873+
)
871874
task_names = [
872875
f"task_{i}" for i in range(n_tasks_pred)
873876
]

0 commit comments

Comments
 (0)