Skip to content

Commit a701f25

Browse files
committed
distribute
1 parent a5f28f5 commit a701f25

1 file changed

Lines changed: 10 additions & 1 deletion

File tree

src/mixins.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,16 @@ def _train_batch(self, x, y):
8787
preds = outputs
8888
correct = (preds == labels).sum().item()
8989
acc = correct / len(labels)
90-
loss.backward()
90+
try:
91+
loss.backward()
92+
except Exception as e:
93+
if "out of memory" in str(e):
94+
logger.error(f"[Rank {dist.get_rank()}] CUDA OOM detected")
95+
torch.cuda.empty_cache()
96+
gc.collect()
97+
return 0.0 # skip the batch
98+
else:
99+
raise e
91100
self.optimizer.step()
92101
logger.info(f" Rank {dist.get_rank()} loss {loss.item()}")
93102
mlflow.log_metric("loss", loss.item())

0 commit comments

Comments
 (0)