Skip to content

Error in masking in the function multiclass_recall #150

Closed
@xiangn95

Description

🐛 Describe the bug

I think I found a bug at around this part of the code in the functional multiclass_recall. When one of the class is missing for both prediction and label, only the num_tp is masked and not the num_labels, which causes a mismatch between the shape of num_tp and num_labels. For example,

import torch
from torcheval.metrics.functional import multiclass_recall

pred = torch.tensor([0,1,2,5])
label = torch.tensor([0,2,1,3])

num_class = 6

multiclass_recall(pred, label, num_classes=num_class, average="macro")

will get an error of

Traceback (most recent call last):
  File "playground.py", line 1205, in <module>
    multiclass_recall(input1, label1, num_classes=num_class, average="macro")
  File "/opt/conda/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/lib/python3.7/site-packages/torcheval/metrics/functional/classification/recall.py", line 151, in multiclass_recall
    return _recall_compute(num_tp, num_labels, num_predictions, average)
  File "/opt/conda/lib/python3.7/site-packages/torcheval/metrics/functional/classification/recall.py", line 193, in _recall_compute
    recall = num_tp / num_labels
RuntimeError: The size of tensor a (5) must match the size of tensor b (6) at non-singleton dimension 0

Versions

Collecting environment information...
PyTorch version: 1.12.1
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.26.3
Libc version: glibc-2.17

Python version: 3.7.13 (default, Mar 29 2022, 02:18:16) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-4.19.91-26.al7.x86_64-x86_64-with-debian-buster-sid
Is CUDA available: True
CUDA runtime version: 11.3.109
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 470.103.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 4
On-line CPU(s) list: 0-3
Thread(s) per core: 2
Core(s) per socket: 2
Socket(s): 1
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Platinum 8163 CPU @ 2.50GHz
Stepping: 4
CPU MHz: 2499.998
BogoMIPS: 4999.99
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 32K
L1i cache: 32K
L2 cache: 1024K
L3 cache: 33792K
NUMA node0 CPU(s): 0-3
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat

Versions of relevant libraries:
[pip3] mypy==1.2.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.21.5
[pip3] torch==1.12.1
[pip3] torcheval==0.0.6
[pip3] torchtext==0.13.1
[pip3] torchtnt==0.0.7
[pip3] torchvision==0.13.1
[pip3] triton==2.0.0.post1
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 ha36c431_9 nvidia
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py37h7f8727e_0
[conda] mkl_fft 1.3.1 py37hd3c417c_0
[conda] mkl_random 1.2.2 py37h51133e4_0
[conda] numpy 1.21.5 py37he7a7128_2
[conda] numpy-base 1.21.5 py37hf524024_2
[conda] pytorch 1.12.1 py3.7_cuda11.3_cudnn8.3.2_0 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torcheval 0.0.6 pypi_0 pypi
[conda] torchtext 0.13.1 py37 pytorch
[conda] torchtnt 0.0.7 pypi_0 pypi
[conda] torchvision 0.13.1 py37_cu113 pytorch
[conda] triton 2.0.0.post1 pypi_0 pypi

Activity

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions