BinaryConfusionMatrix does not work with float target #212
Open
Description
🐛 Describe the bug
Binary Confusion Matrix does not work if target
is float (even if all values are 0 or 1).
Note: It could be argued that target should always be int. However given that #146 was solved, I assume you will also want to solve this one, or at least include an error for it.
In my case, target was float because my csv was automatically loaded by pandas.
Minimal example:
import torch
from torcheval.metrics import BinaryConfusionMatrix
input = torch.randint(0, 2, (10,)).to(torch.float32)
target = torch.randint(0, 2, (10,))
cm = BinaryConfusionMatrix()
cm.update(input, target) # no error here
print(cm.compute())
cm = BinaryConfusionMatrix()
cm.update(input, target.to(torch.float32)) # error here
print(cm.compute())
Error and Traceback
Traceback (most recent call last):
File "/mnt/c/Users/rafae/Documents/faculdade/thesis/thesis-framework/temp.py", line 12, in <module>
cm.update(input, target.to(torch.float32)) # error here
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/rafae/Documents/faculdade/thesis/thesis-framework/.venv/lib/python3.12/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context
return func(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/rafae/Documents/faculdade/thesis/thesis-framework/.venv/lib/python3.12/site-packages/torcheval/metrics/classification/confusion_matrix.py", line 311, in update
self.confusion_matrix += _binary_confusion_matrix_update(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/mnt/c/Users/rafae/Documents/faculdade/thesis/thesis-framework/.venv/lib/python3.12/site-packages/torcheval/metrics/functional/classification/confusion_matrix.py", line 175, in _binary_confusion_matrix_update
return _update(input, target, 2)
^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
File "/mnt/c/Users/rafae/Documents/faculdade/thesis/thesis-framework/.venv/lib/python3.12/site-packages/torcheval/metrics/functional/classification/confusion_matrix.py", line 232, in _update
# Each prediction creates an entry at the position (true, pred)
sparse_cm = torch.sparse_coo_tensor(coordinates, torch.ones_like(target), cm_shape)
~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE
return sparse_cm.to_dense()
RuntimeError: indices must be an int64 tensor
Interpretation
The coordinates
tensor (corresponding to the indices parameter of sparse_coo_tensor
) is a vstack
of input
(after applying the threshold, therefore an int tensor) with target. Since target is float
, vstack
must create the stacked tensor as float.
Versions
Versions
torcheval
0.0.7torch
2.5.0
Metadata
Assignees
Labels
No labels