Skip to content

Commit 806d505

Browse files
deploy changes
1 parent ed9bdc1 commit 806d505

2 files changed

Lines changed: 16 additions & 11 deletions

File tree

asparagus/pipeline/run/pretrain.py

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,6 @@ def main(cfg: DictConfig) -> None:
113113
log_images_every_n_epoch=cfg.logger.log_images_every_n_epoch,
114114
)
115115

116-
print("Training duration configured as:")
117-
print(f" - Steps: {cfg.training.steps}")
118-
print(f" - Steps per pseudo epoch: {cfg.training.steps_per_epoch}")
119-
print(f" - Validation steps per pseudo epoch: {cfg.training.val_steps_per_epoch}")
120-
print(
121-
f" - Pseudo Epochs: {cfg.training.steps / (cfg.training.steps_per_epoch * cfg.training.accumulate_grad_batches):.1f}"
122-
)
123-
print(f" - Warmup Pseudo Epochs: {cfg.training.warmup_epochs} (ratio {cfg.training.warmup_ratio})")
124-
125116
trainer = instantiate(
126117
cfg.lightning._trainer,
127118
callbacks=callbacks,
@@ -136,6 +127,19 @@ def main(cfg: DictConfig) -> None:
136127
accumulate_grad_batches=cfg.training.accumulate_grad_batches,
137128
)
138129

130+
if trainer.is_global_zero:
131+
print("Training duration configured as:")
132+
print(f" - Steps: {cfg.training.steps}")
133+
print(f" - Global batch size: {cfg.training.global_batch_size}")
134+
print(f" - Steps per pseudo epoch: {cfg.training.steps_per_epoch}")
135+
print(f" - Validation steps per pseudo epoch: {cfg.training.val_steps_per_epoch}")
136+
print(
137+
" - Pseudo Epochs: {:.1f}".format(
138+
cfg.training.steps / (cfg.training.steps_per_epoch * cfg.training.accumulate_grad_batches)
139+
)
140+
)
141+
print(f" - Warmup Pseudo Epochs: {cfg.training.warmup_epochs} (ratio {cfg.training.warmup_ratio})")
142+
139143
trainer.fit(
140144
model=model_module,
141145
datamodule=data_module,

configs/default_pretrain.yaml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,12 @@ training:
3939
# when we increase the number of devices _or_ use a bigger dataset.
4040
steps_per_epoch: 1890 # <--- should be constant ... but note that if gradient accumulation is used, then steps_per_epoch > number of backwards passes.
4141
val_steps_per_epoch: ${eval:"${training.steps_per_epoch} // 100"}
42-
steps: ${eval:"${training.max_samples} // (${hardware.num_devices} * ${hardware.num_nodes} * ${training.batch_size} * ${training.accumulate_grad_batches})"}
42+
global_batch_size: ${eval:"${training.batch_size} * ${hardware.num_devices} * ${hardware.num_nodes} * ${training.accumulate_grad_batches}"}
43+
steps: ${eval:"${training.max_samples} // (${training.global_batch_size})"}
4344
warmup_epochs: ${eval:"max(1, int((${training.steps} // ${training.steps_per_epoch}) * ${training.warmup_ratio}))"}
4445
decoder_warmup_epochs: 0
4546
rec_loss_masked_only: False
46-
check_val_every_n_epoch: 1
47+
check_val_every_n_epoch: 3
4748

4849
# num_devices 4
4950
# num_nodes 2

0 commit comments

Comments
 (0)