From 794d65f401dd70408bd0b336ee54c209ed38eefc Mon Sep 17 00:00:00 2001 From: Yih-Dar <2521628+ydshieh@users.noreply.github.com> Date: Sun, 5 Apr 2026 11:42:50 +0200 Subject: [PATCH] Fix tf32 issue: set `torch.backends.cudnn.conv.fp32_precision` explicitly. (#45248) * empty * fix * fix * fix --------- Co-authored-by: ydshieh --- conftest.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/conftest.py b/conftest.py index bb42a6cdf922..21bffbac2575 100644 --- a/conftest.py +++ b/conftest.py @@ -149,6 +149,22 @@ def check_output(self, want, got, optionflags): # The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True. # We set it to `False` for CI. See https://github.com/pytorch/pytorch/issues/157274#issuecomment-3090791615 enable_tf32(False) + # # torch.backends.fp32_precision does not cascade to torch.backends.cudnn.conv.fp32_precision and torch.backends.cudnn.rnn.fp32_precision + # TODO: Considering move this to `enable_tf32`, or report a bug to `torch`. + import torch + + # In order to set `torch.backends.cudnn.conv.fp32_precision = "ieee"` below (new API), we still need to set this + # (old API) because it defaults to `True` (and not changed automatically when we change `cudnn.conv.fp32_precision`) + # and such inconsistency cause `torch` to complain `RuntimeError: PyTorch is checking whether allow_tf32 is enabled for cuDNN without a specific operator name,but the current flag(s) indica + # te that cuDNN conv and cuDNN RNN have different TF32 flags.This combination indicates that you have used a mix of the legacy and new APIs + # to set the TF32 flags. We suggest only using the new API to set the TF32 flag(s).`. + # TODO: report a bug to `torch` + if hasattr(torch.backends.cudnn, "allow_tf32"): + torch.backends.cudnn.allow_tf32 = False + + # This is necessary to make several `test_batching_equivalence` pass (within the tolerance `1e-5`) + if hasattr(torch.backends.cudnn.conv, "fp32_precision"): + torch.backends.cudnn.conv.fp32_precision = "ieee" # patch `torch.compile`: if `TORCH_COMPILE_FORCE_FULLGRAPH=1` (or values considered as true, e.g. yes, y, etc.), # the patched version will always run with `fullgraph=True`.