Skip to content

Commit 649b4fb

Browse files
Update train_ms.py (#225)
在每次epoch结束后回收显存以避免显存溢出的问题
1 parent 6733e63 commit 649b4fb

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

train_ms.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from config import config
1414
import argparse
1515
import datetime
16+
import gc
1617

1718
logging.getLogger("numba").setLevel(logging.WARNING)
1819
import commons
@@ -590,7 +591,8 @@ def train_and_evaluate(
590591
)
591592

592593
global_step += 1
593-
594+
gc.collect()
595+
torch.cuda.empty_cache()
594596
if rank == 0:
595597
logger.info("====> Epoch: {}".format(epoch))
596598

0 commit comments

Comments
 (0)