Skip to content

RetrievalRecall, RetrievalPrecision require different, 1D input than MulticlassRecall, MulticlassPrecision which accept batch input #188

Open
@jaanli

Description

🐛 Describe the bug

The different behavior of RetrievalRecall and RetrievalPrecision make it difficult to compute standard metrics such as Precision@k or Recall@k for multiclass classification problems.

Would it be possible to have them accept the same shape of input, e.g. inputs of shape batch_size, num_classes and targets of shape batch_size, num_classes?

Example code below:

To install: pip install --pre torcheval-nightly; using '0.0.7'.

import torch
from torch.nn import functional as F
from torcheval.metrics import RetrievalRecall


batch_size = 10
num_classes = 20
# generate random predictions
preds = torch.rand(batch_size, num_classes)
# generate random targets
targets = torch.randint(0, num_classes, (batch_size,))

recall = RetrievalRecall(num_queries=batch_size, k=5)

# first make the targets one hot (RetrievalRecall does not accept num_classes arguments, requires binary targets)
targets_one_hot = F.one_hot(targets.type(torch.long), num_classes)
targets_one_hot.shape

# indexes associate each prediction with a target
indexes = torch.arange(batch_size).repeat(num_classes, 1).T

recall.update(preds.ravel(), targets_one_hot.ravel(), indexes=indexes.ravel())

recall.compute().mean() # -> 0.1


from torcheval.metrics import MulticlassRecall, MulticlassPrecision

recall = MulticlassRecall(num_classes=num_classes)
precision = MulticlassPrecision(num_classes=num_classes)
recall.update(preds, targets)
precision.update(preds, targets)
recall.compute(), precision.compute() # -> 0.1, 0.1

Current workaround:

import torch
from torch.nn import functional as F
from torcheval.metrics import RetrievalRecall


class MulticlassRetrievalRecall(RetrievalRecall):
    def __init__(self, batch_size, num_classes, **kwargs):
        super().__init__(num_queries=batch_size, **kwargs)
        self.num_classes = num_classes
        
    def update(self, input, target):
        target_one_hot = F.one_hot(target.type(torch.long), self.num_classes)
        indexes = torch.arange(len(input)).repeat(self.num_classes, 1).T
        super().update(input.ravel(), target_one_hot.ravel(), indexes=indexes.ravel())

Usage:

recall_multi = MulticlassRetrievalRecall(batch_size, num_classes, k=5)
recall_multi.update(preds, targets)
recall_multi.compute().mean() # -> 0.1

Open to any tips on how best to do this! Thank for this helpful canonical library :)

Versions

python collect_env.py                                                                                       9854  17:14:34  

Collecting environment information...
PyTorch version: 2.1.1
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 13.6.2 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.0.40.1)
CMake version: version 3.22.2
Libc version: N/A

Python version: 3.11.6 (main, Nov  2 2023, 04:39:43) [Clang 14.0.3 (clang-1403.0.22.14.1)] (64-bit runtime)
Python platform: macOS-13.6.2-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Max

Versions of relevant libraries:
[pip3] numpy==1.26.2
[pip3] torch==2.1.1
[pip3] torchaudio==2.1.1
[pip3] torchdata==0.7.1
[pip3] torcheval==0.0.7
[pip3] torcheval-nightly==2023.12.21
[pip3] torchtext==0.16.1
[pip3] torchvision==0.16.1
[conda] numpy                     1.24.3          py310hb93e574_0  
[conda] numpy-base                1.24.3          py310haf87e8b_0  
[conda] torch                     2.0.1                    pypi_0    pypi

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions