Skip to content

Commit b914795

Browse files
committed
distribute
1 parent d952e32 commit b914795

1 file changed

Lines changed: 4 additions & 1 deletion

File tree

src/mixins.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def _train_batch(self, x, y):
100100
else:
101101
raise e
102102
self.optimizer.step()
103-
logger.info(f" Rank {dist.get_rank()} loss {loss.item()}")
103+
logger.info(f" Rank {dist.get_rank()} loss {loss.item()} acc {acc}")
104104
mlflow.log_metric("loss", loss.item())
105105
mlflow.log_metric("acc", acc)
106106
mlflow.log_metric("time", time.time())
@@ -131,6 +131,9 @@ def train(self):
131131
gc.collect()
132132
if torch.backends.mps.is_available():
133133
torch.mps.empty_cache()
134+
if torch.cuda.is_available():
135+
torch.cuda.empty_cache()
136+
torch.cuda.ipc_collect()
134137
time.sleep(0.2)
135138
self.save()
136139
self.load_checkpoint()

0 commit comments

Comments
 (0)