Description
🐛 Describe the bug
I think I found a bug at around this part of the code in the functional multiclass_recall. When one of the class is missing for both prediction and label, only the num_tp
is masked and not the num_labels
, which causes a mismatch between the shape of num_tp
and num_labels
. For example,
import torch
from torcheval.metrics.functional import multiclass_recall
pred = torch.tensor([0,1,2,5])
label = torch.tensor([0,2,1,3])
num_class = 6
multiclass_recall(pred, label, num_classes=num_class, average="macro")
will get an error of
Traceback (most recent call last):
File "playground.py", line 1205, in <module>
multiclass_recall(input1, label1, num_classes=num_class, average="macro")
File "/opt/conda/lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
return func(*args, **kwargs)
File "/opt/conda/lib/python3.7/site-packages/torcheval/metrics/functional/classification/recall.py", line 151, in multiclass_recall
return _recall_compute(num_tp, num_labels, num_predictions, average)
File "/opt/conda/lib/python3.7/site-packages/torcheval/metrics/functional/classification/recall.py", line 193, in _recall_compute
recall = num_tp / num_labels
RuntimeError: The size of tensor a (5) must match the size of tensor b (6) at non-singleton dimension 0
Versions
Collecting environment information...
PyTorch version: 1.12.1
Is debug build: False
CUDA used to build PyTorch: 11.3
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: Could not collect
CMake version: version 3.26.3
Libc version: glibc-2.17
Python version: 3.7.13 (default, Mar 29 2022, 02:18:16) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-4.19.91-26.al7.x86_64-x86_64-with-debian-buster-sid
Is CUDA available: True
CUDA runtime version: 11.3.109
CUDA_MODULE_LOADING set to:
GPU models and configuration: GPU 0: Tesla T4
Nvidia driver version: 470.103.01
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.2.0
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.2.0
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Byte Order: Little Endian
CPU(s): 4
On-line CPU(s) list: 0-3
Thread(s) per core: 2
Core(s) per socket: 2
Socket(s): 1
NUMA node(s): 1
Vendor ID: GenuineIntel
CPU family: 6
Model: 85
Model name: Intel(R) Xeon(R) Platinum 8163 CPU @ 2.50GHz
Stepping: 4
CPU MHz: 2499.998
BogoMIPS: 4999.99
Hypervisor vendor: KVM
Virtualization type: full
L1d cache: 32K
L1i cache: 32K
L2 cache: 1024K
L3 cache: 33792K
NUMA node0 CPU(s): 0-3
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ss ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopology nonstop_tsc cpuid tsc_known_freq pni pclmulqdq monitor ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch invpcid_single pti fsgsbase tsc_adjust bmi1 hle avx2 smep bmi2 erms invpcid rtm mpx avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves arat
Versions of relevant libraries:
[pip3] mypy==1.2.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==1.21.5
[pip3] torch==1.12.1
[pip3] torcheval==0.0.6
[pip3] torchtext==0.13.1
[pip3] torchtnt==0.0.7
[pip3] torchvision==0.13.1
[pip3] triton==2.0.0.post1
[conda] blas 1.0 mkl
[conda] cudatoolkit 11.3.1 ha36c431_9 nvidia
[conda] ffmpeg 4.3 hf484d3e_0 pytorch
[conda] mkl 2021.4.0 h06a4308_640
[conda] mkl-service 2.4.0 py37h7f8727e_0
[conda] mkl_fft 1.3.1 py37hd3c417c_0
[conda] mkl_random 1.2.2 py37h51133e4_0
[conda] numpy 1.21.5 py37he7a7128_2
[conda] numpy-base 1.21.5 py37hf524024_2
[conda] pytorch 1.12.1 py3.7_cuda11.3_cudnn8.3.2_0 pytorch
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] torcheval 0.0.6 pypi_0 pypi
[conda] torchtext 0.13.1 py37 pytorch
[conda] torchtnt 0.0.7 pypi_0 pypi
[conda] torchvision 0.13.1 py37_cu113 pytorch
[conda] triton 2.0.0.post1 pypi_0 pypi
Activity