Skip to content

Commit 9086c7c

Browse files
author
bens
committed
fix numbering bug
1 parent cd8b3fe commit 9086c7c

2 files changed

Lines changed: 5 additions & 2 deletions

File tree

fuse/eval/metrics/libs/classification.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def auc_pr(
186186
y_score = np.asarray(pred)[:, pos_class_index]
187187

188188
precision, recall, _ = metrics.precision_recall_curve(
189-
probas_pred=y_score,
189+
y_score=y_score,
190190
y_true=np.asarray(target) == pos_class_index,
191191
sample_weight=sample_weight,
192192
)

fuse/eval/metrics/metrics_common.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,10 @@ def collect(self, batch: Dict) -> None:
190190
if isinstance(value, np.ndarray) and value.ndim == 0:
191191
self._collected_data[name].append(value.item())
192192
else:
193-
self._collected_data[name].extend(value)
193+
if isinstance(value, np.ndarray):
194+
self._collected_data[name].extend(value.tolist())
195+
else:
196+
self._collected_data[name].extend(value)
194197

195198
# extract ids and store it in self._collected_ids
196199
if self._to_collect_ids:

0 commit comments

Comments
 (0)