Skip to content

Commit f83227d

Browse files
Update train_ms.py
1 parent 8f9884d commit f83227d

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

train_ms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@
3838

3939
torch.backends.cudnn.benchmark = True
4040
torch.backends.cuda.matmul.allow_tf32 = True
41-
torch.backends.cudnn.allow_tf32 = True
41+
torch.backends.cudnn.allow_tf32 = True # If encontered training problem,please try to disable TF32.
4242
torch.set_float32_matmul_precision('medium')
4343
torch.backends.cuda.sdp_kernel("flash")
4444
torch.backends.cuda.enable_flash_sdp(True)
45-
torch.backends.cuda.enable_mem_efficient_sdp(True)
45+
torch.backends.cuda.enable_mem_efficient_sdp(True) # Not avaliable if torch version is lower than 2.0
4646
torch.backends.cuda.enable_math_sdp(True)
4747
global_step = 0
4848

@@ -165,7 +165,7 @@ def run(rank, n_gpus, hps):
165165

166166
epoch_str = max(epoch_str, 1)
167167
global_step = (epoch_str - 1) * len(train_loader)
168-
except Exception as e:
168+
except Exception as e:
169169
print(e)
170170
epoch_str = 1
171171
global_step = 0

0 commit comments

Comments
 (0)