Open
Description
🐛 Describe the bug
The following tests are failing on A100 with numerical differences, full logs at https://gist.githubusercontent.com/davidberard98/74009082166d2ac37d65c481b300376f/raw/b5160f8abfc14490f985d19cfb95696a10251e99/torchvision_0513.txt.
FAILED test/test_models.py::test_classification_model[cuda-resnet101] - Asser...
FAILED test/test_models.py::test_segmentation_model[cuda-fcn_resnet101] - Ass...
FAILED test/test_models.py::test_detection_model[cuda-fasterrcnn_resnet50_fpn]
FAILED test/test_models.py::test_detection_model[cuda-maskrcnn_resnet50_fpn]
FAILED test/test_models.py::test_detection_model[cuda-maskrcnn_resnet50_fpn_v2]
By disabling tf32 the failures go away for all except cuda-fasterrcnn_resnet50_fpn
diff --git a/test/test_models.py b/test/test_models.py
index c0afe9f10..8f6dda357 100644
--- a/test/test_models.py
+++ b/test/test_models.py
@@ -20,6 +20,12 @@ from torchvision import models
ACCEPT = os.getenv("EXPECTTEST_ACCEPT", "0") == "1"
SKIP_BIG_MODEL = os.getenv("SKIP_BIG_MODEL", "1") == "1"
+# The flag below controls whether to allow TF32 on matmul. This flag defaults to True.
+torch.backends.cuda.matmul.allow_tf32 = False
+
+# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
+torch.backends.cudnn.allow_tf32 = False
+
def get_models_from_module(module):
# TODO add a registration mechanism to torchvision.models
cuda-fasterrcnn_resnet50_fpn
failures remain even after disabling tf32.
Versions
Collecting environment information...
PyTorch version: 1.12.0a0+git98a20eb
Is debug build: True
CUDA used to build PyTorch: 11.6
ROCM used to build PyTorch: N/A
OS: Ubuntu 18.04.6 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.22.3
Libc version: glibc-2.27
Python version: 3.9.12 (main, Apr 5 2022, 06:56:58) [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.0-1069-aws-x86_64-with-glibc2.27
Is CUDA available: True
CUDA runtime version: 11.6.112
GPU models and configuration:
GPU 0: NVIDIA A100-SXM4-40GB
GPU 1: NVIDIA A100-SXM4-40GB
GPU 2: NVIDIA A100-SXM4-40GB
GPU 3: NVIDIA A100-SXM4-40GB
Nvidia driver version: 510.47.03
cuDNN version: Probably one of the following:
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5
/usr/local/cuda-11.0/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.0.5
/usr/local/cuda-11.1/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.0.5
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_adv_infer.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_adv_train.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_cnn_infer.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_cnn_train.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_ops_infer.so.8.1.1
/usr/local/cuda-11.2/targets/x86_64-linux/lib/libcudnn_ops_train.so.8.1.1
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
Versions of relevant libraries:
[pip3] numpy==1.22.3
[pip3] torch==1.12.0a0+git98a20eb
[pip3] torchvision==0.13.0a0+edb7bbb
[conda] numpy 1.22.3 pypi_0 pypi
[conda] torch 1.12.0a0+git98a20eb dev_0 <develop>
[conda] torchvision 0.13.0a0+edb7bbb dev_0 <develop>
cc @jjsjann123
Activity