Skip to content

Commit 9ee3954

Browse files
committed
Training runs are currently not deterministic
1 parent 5c55a76 commit 9ee3954

File tree

4 files changed

+16
-8
lines changed

4 files changed

+16
-8
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,4 @@ Checkpoints for MagFace can be downloaded at: [MagFace Repository](https://githu
7272
- [x] Add ElasticFace header
7373
- [x] Remove `mxnet`dependency. Therefore it is necessary to convert the datasets.
7474
- [ ] Compare MagFace training to official Magface code
75+
- [ ] Make training runs deterministic at same seed

src/datamodule_hf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def train_dataloader(self):
6262
batch_size=self.hparams.batch_size,
6363
num_workers=self.hparams.num_workers,
6464
shuffle=True,
65+
# generator=torch.Generator().manual_seed(42),
6566
)
6667

6768

src/pl_module.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,16 @@ def training_step(self, batch, batch_idx):
5858
# logits vector describes the probability for each image to belong to one of n_classes
5959
loss = self.criterion(logits, targets)
6060
optimizer_lr = self.optimizers().optimizer.param_groups[0]["lr"]
61-
log_dict = {
62-
# "step": float(self.current_epoch), # Overwrite step to plot epochs on x-axis
63-
"loss": loss,
64-
"optimizer_lr": optimizer_lr,
65-
"max_ampl": max_ampl.item(),
66-
}
67-
self.log_dict(log_dict, on_step=True)
61+
self.log("loss", loss, prog_bar=True)
62+
self.log("optimizer_lr", optimizer_lr)
63+
self.log("max_ampl", max_ampl.item())
64+
# log_dict = {
65+
# # "step": float(self.current_epoch), # Overwrite step to plot epochs on x-axis
66+
# "loss": loss,
67+
# "optimizer_lr": optimizer_lr,
68+
# "max_ampl": max_ampl.item(),
69+
# }
70+
# self.log_dict(log_dict, on_step=True)
6871
return loss
6972

7073
def configure_optimizers(self):

train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,11 @@ def main(
6161
results_dir: str,
6262
version: Optional[int] = None,
6363
):
64-
# 1. Set fixed seed
64+
# 1. Set fixed seed and flags for deterministic behavior
6565
pl.seed_everything(cfg.seed)
66+
torch.use_deterministic_algorithms(True)
67+
torch.backends.cudnn.deterministic = True
68+
torch.backends.cudnn.benchmark = False
6669

6770
# 2. Assign datamodule and pl_module
6871
datamodule = datamodule

0 commit comments

Comments
 (0)