diff --git a/kge/model/kge_model.py b/kge/model/kge_model.py index 96b27a725..45608393c 100644 --- a/kge/model/kge_model.py +++ b/kge/model/kge_model.py @@ -786,4 +786,8 @@ def score_sp_po( all_subjects = self.get_s_embedder().embed_all() sp_scores = self._scorer.score_emb(s, p, all_objects, combine="sp_") po_scores = self._scorer.score_emb(all_subjects, p, o, combine="_po") + if torch.isnan(sp_scores).any().item() or : + raise ValueError("Found NaN values when scoring sp_ predictions! ") + elif torch.isnan(po_scores).any().item(): + raise ValueError("Found NaN values when scoring _po predictions! ") return torch.cat((sp_scores, po_scores), dim=1)