-
Notifications
You must be signed in to change notification settings - Fork 473
Validation fu added to examples/structural_mechanics/crash/train.py #1204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
f442134
87ad160
f833402
fe70df9
50aecf3
6682e0f
db8f6df
cc2add3
68a3131
887c714
72e4dc8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,9 +27,12 @@ max_workers_preprocessing: 64 # Maximum parallel workers | |
|
|
||
| num_time_steps: 14 | ||
| num_training_samples: 8 | ||
| num_validation_samples: 8 | ||
| start_lr: 0.0001 | ||
| end_lr: 0.0000003 | ||
| epochs: 10000 | ||
| validate_every_n_epochs: 10 | ||
| save_ckpt_every_n_epochs: 10 | ||
|
||
|
|
||
| # ┌───────────────────────────────────────────┐ | ||
| # │ Performance Optimization │ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,6 +38,7 @@ | |
|
|
||
| # Import unified datapipe | ||
| from datapipe import SimSample, simsample_collate | ||
| from omegaconf import open_dict | ||
|
|
||
|
|
||
| class Trainer: | ||
|
|
@@ -109,6 +110,58 @@ def __init__(self, cfg: DictConfig, logger0: RankZeroLoggingWrapper): | |
| ) | ||
| self.sampler = sampler | ||
|
|
||
| if cfg.training.num_validation_samples > 0: | ||
| self.num_validation_replicas = min( | ||
| self.dist.world_size, cfg.training.num_validation_samples | ||
| ) | ||
| self.num_validation_samples = ( | ||
| cfg.training.num_validation_samples | ||
| // self.num_validation_replicas | ||
| * self.num_validation_replicas | ||
| ) | ||
| logger0.info(f"Number of validation samples: {self.num_validation_samples}") | ||
|
|
||
| # Create a validation dataset | ||
| val_cfg = self.cfg.datapipe | ||
| with open_dict(val_cfg): # or open_dict(cfg) to open the whole tree | ||
| val_cfg.data_dir = self.cfg.inference.raw_data_dir_test | ||
|
||
| val_cfg.num_samples = self.num_validation_samples | ||
| val_dataset = instantiate( | ||
| val_cfg, | ||
| name="crash_test", | ||
| reader=reader, | ||
| split="test", | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. validation |
||
| logger=logger0, | ||
| ) | ||
|
|
||
| if self.dist.rank < self.num_validation_replicas: | ||
| # Sampler | ||
| if self.dist.world_size > 1: | ||
| sampler = DistributedSampler( | ||
| val_dataset, | ||
| num_replicas=self.num_validation_replicas, | ||
| rank=self.dist.rank, | ||
| shuffle=False, | ||
| drop_last=True, | ||
| ) | ||
| else: | ||
| sampler = None | ||
|
|
||
| self.val_dataloader = torch.utils.data.DataLoader( | ||
| val_dataset, | ||
| batch_size=1, # variable N per sample | ||
| shuffle=(sampler is None), | ||
| drop_last=True, | ||
| pin_memory=True, | ||
| num_workers=cfg.training.num_dataloader_workers, | ||
| sampler=sampler, | ||
| collate_fn=simsample_collate, | ||
| ) | ||
| else: | ||
| self.val_dataloader = torch.utils.data.DataLoader( | ||
| torch.utils.data.Subset(val_dataset, []), batch_size=1 | ||
| ) | ||
|
|
||
| # Model | ||
| self.model = instantiate(cfg.model) | ||
| logging.getLogger().setLevel(logging.INFO) | ||
|
|
@@ -199,6 +252,50 @@ def backward(self, loss): | |
| loss.backward() | ||
| self.optimizer.step() | ||
|
|
||
| @torch.no_grad() | ||
| def validate(self, epoch): | ||
| """Run validation error computation""" | ||
| self.model.eval() | ||
|
|
||
| MSE = torch.zeros(1, device=self.dist.device) | ||
| MSE_w_time = torch.zeros(self.rollout_steps, device=self.dist.device) | ||
| for idx, sample in enumerate(self.val_dataloader): | ||
| sample = sample[0].to(self.dist.device) # SimSample .to() | ||
| T = self.rollout_steps | ||
|
|
||
| # Model forward | ||
| pred_seq = self.model(sample=sample, data_stats=self.data_stats) | ||
|
|
||
| # Exact sequence (if provided) | ||
| exact_seq = None | ||
| if sample.node_target is not None: | ||
| N = sample.node_target.size(0) | ||
| Fo = 3 # output features per node | ||
| assert sample.node_target.size(1) == T * Fo, ( | ||
| f"target dim {sample.node_target.size(1)} != {T * Fo}" | ||
| ) | ||
| exact_seq = ( | ||
| sample.node_target.view(N, T, Fo).transpose(0, 1).contiguous() | ||
| ) # [T,N,Fo] | ||
|
|
||
| # Compute and add error | ||
| SqError = torch.square(pred_seq - exact_seq) | ||
| MSE_w_time += torch.mean(SqError, dim=(1, 2)) | ||
| MSE += torch.mean(SqError) | ||
|
|
||
| # Sum errors across all ranks | ||
| if self.dist.world_size > 1: | ||
| torch.distributed.all_reduce(MSE, op=torch.distributed.ReduceOp.SUM) | ||
| torch.distributed.all_reduce(MSE_w_time, op=torch.distributed.ReduceOp.SUM) | ||
|
Comment on lines
+284
to
+287
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. logic: all_reduce sums errors from participating ranks but denominator uses total validation samples - this double-counts if not all ranks participate. Should the denominator be adjusted for the actual number of participating ranks rather than total validation samples? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Following code added at line 114-121 adjust the number of validation samples for all possible cases-
Eg: if user provide: world_size = 3, num_validation_samples = 8, |
||
|
|
||
| val_stats = { | ||
| "MSE_w_time": MSE_w_time / self.num_validation_samples, | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. See greptile's comment. We should divide by the actual number of validation samples, especially because you have There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Following code added at line 114-121 computes actual number of validation samples for all possible cases- Eg: if user provide: world_size = 3, num_validation_samples = 8, |
||
| "MSE": MSE / self.num_validation_samples, | ||
| } | ||
|
|
||
| self.model.train() # Switch back to training mode | ||
| return val_stats | ||
|
|
||
|
|
||
| @hydra.main(version_base="1.3", config_path="conf", config_name="config") | ||
| def main(cfg: DictConfig) -> None: | ||
|
|
@@ -243,7 +340,8 @@ def main(cfg: DictConfig) -> None: | |
|
|
||
| if dist.world_size > 1: | ||
| torch.distributed.barrier() | ||
| if dist.rank == 0: | ||
|
|
||
| if dist.rank == 0 and (epoch + 1) % cfg.training.save_ckpt_every_n_epochs == 0: | ||
| save_checkpoint( | ||
| cfg.training.ckpt_path, | ||
| models=trainer.model, | ||
|
|
@@ -254,6 +352,31 @@ def main(cfg: DictConfig) -> None: | |
| ) | ||
| logger.info(f"Saved model on rank {dist.rank}") | ||
|
|
||
| # Validation | ||
| if ( | ||
| cfg.training.num_validation_samples > 0 | ||
| and (epoch + 1) % cfg.training.validate_every_n_epochs == 0 | ||
| ): | ||
| # logger0.info(f"Validation started...") | ||
| val_stats = trainer.validate(epoch) | ||
|
|
||
| # Log detailed validation statistics | ||
| logger0.info( | ||
| f"Validation epoch {epoch + 1}: MSE: {val_stats['MSE'].item():.3e}, " | ||
| ) | ||
|
|
||
| if dist.rank == 0: | ||
| # Log to tensorboard | ||
| trainer.writer.add_scalar("val/MSE", val_stats["MSE"].item(), epoch) | ||
|
|
||
| # Log individual timestep relative errors | ||
| for i in range(len(val_stats["MSE_w_time"])): | ||
| trainer.writer.add_scalar( | ||
| f"val/timestep_{i}_MSE", | ||
| val_stats["MSE_w_time"][i].item(), | ||
| epoch, | ||
| ) | ||
|
|
||
| logger0.info("Training completed!") | ||
| if dist.rank == 0: | ||
| trainer.writer.close() | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe change this to validation_freq?