Skip to content

Commit 794d65f

Browse files
authored
Fix tf32 issue: set torch.backends.cudnn.conv.fp32_precision explicitly. (huggingface#45248)
* empty * fix * fix * fix --------- Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
1 parent 499ef1d commit 794d65f

1 file changed

Lines changed: 16 additions & 0 deletions

File tree

conftest.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -149,6 +149,22 @@ def check_output(self, want, got, optionflags):
149149
# The flag below controls whether to allow TF32 on cuDNN. This flag defaults to True.
150150
# We set it to `False` for CI. See https://github.com/pytorch/pytorch/issues/157274#issuecomment-3090791615
151151
enable_tf32(False)
152+
# # torch.backends.fp32_precision does not cascade to torch.backends.cudnn.conv.fp32_precision and torch.backends.cudnn.rnn.fp32_precision
153+
# TODO: Considering move this to `enable_tf32`, or report a bug to `torch`.
154+
import torch
155+
156+
# In order to set `torch.backends.cudnn.conv.fp32_precision = "ieee"` below (new API), we still need to set this
157+
# (old API) because it defaults to `True` (and not changed automatically when we change `cudnn.conv.fp32_precision`)
158+
# and such inconsistency cause `torch` to complain `RuntimeError: PyTorch is checking whether allow_tf32 is enabled for cuDNN without a specific operator name,but the current flag(s) indica
159+
# te that cuDNN conv and cuDNN RNN have different TF32 flags.This combination indicates that you have used a mix of the legacy and new APIs
160+
# to set the TF32 flags. We suggest only using the new API to set the TF32 flag(s).`.
161+
# TODO: report a bug to `torch`
162+
if hasattr(torch.backends.cudnn, "allow_tf32"):
163+
torch.backends.cudnn.allow_tf32 = False
164+
165+
# This is necessary to make several `test_batching_equivalence` pass (within the tolerance `1e-5`)
166+
if hasattr(torch.backends.cudnn.conv, "fp32_precision"):
167+
torch.backends.cudnn.conv.fp32_precision = "ieee"
152168

153169
# patch `torch.compile`: if `TORCH_COMPILE_FORCE_FULLGRAPH=1` (or values considered as true, e.g. yes, y, etc.),
154170
# the patched version will always run with `fullgraph=True`.

0 commit comments

Comments
 (0)