-
Notifications
You must be signed in to change notification settings - Fork 473
Validation fu added to examples/structural_mechanics/crash/train.py #1204
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Dakhare crash - Validation function to track model performance on test dataset is added in physicsnemo/examples/structural_mechanics/crash/train.py - validate_every_n_epochs, save_ckpt_every_n_epochs added in config/training/default.yaml to assign frequency for calling validation function and saking checkpoint
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Greptile Overview
Greptile Summary
This PR adds validation functionality to the structural mechanics crash simulation training example. The main changes include: (1) adding validation dataset creation and distributed sampling in train.py, (2) implementing a validation loop that computes time-step-wise MSE loss and aggregates results across distributed ranks, (3) adding validation configuration parameters to control validation frequency and checkpoint saving, and (4) refactoring the inference code to use a unified sample object interface instead of passing individual graph components separately.
The validation implementation follows distributed training best practices by properly handling data sampling, metric aggregation, and logging only on rank 0. The changes integrate cleanly with the existing training pipeline and tensorboard logging infrastructure, providing essential model monitoring capabilities for the crash simulation example.
PR Description Notes:
- The PR description is largely empty with only unchecked checklist items
- No standalone description of changes provided
- No linked issues or changelog updates mentioned
- Missing information about new dependencies or testing coverage
Important Files Changed
| Filename | Score | Overview |
|---|---|---|
| examples/structural_mechanics/crash/train.py | 4/5 | Added comprehensive validation functionality with distributed sampling, MSE computation, and tensorboard logging |
| examples/structural_mechanics/crash/conf/training/default.yaml | 5/5 | Added validation configuration parameters for sample count, validation frequency, and checkpoint saving |
| examples/structural_mechanics/crash/inference.py | 4/5 | Refactored model forward pass to use unified sample object interface instead of separate graph components |
3 files reviewed, 3 comments
| self.num_validation_replicas = min(self.dist.world_size, cfg.training.num_validation_samples) | ||
| self.num_validation_samples = cfg.training.num_validation_samples // self.num_validation_replicas * self.num_validation_replicas |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: potential division by zero if num_validation_replicas is 0
| self.num_validation_replicas = min(self.dist.world_size, cfg.training.num_validation_samples) | |
| self.num_validation_samples = cfg.training.num_validation_samples // self.num_validation_replicas * self.num_validation_replicas | |
| self.num_validation_replicas = min(self.dist.world_size, cfg.training.num_validation_samples) | |
| if self.num_validation_replicas > 0: | |
| self.num_validation_samples = cfg.training.num_validation_samples // self.num_validation_replicas * self.num_validation_replicas | |
| else: | |
| self.num_validation_samples = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following code
self.num_validation_replicas = min(self.dist.world_size, cfg.training.num_validation_samples)
enforce num_validation_replicas >= 0 as self.dist.world_size >= 1
| exact_seq = None | ||
| if sample.node_target is not None: | ||
| N = sample.node_target.size(0) | ||
| Fo = 3 # output features per node | ||
| assert sample.node_target.size(1) == T * Fo, ( | ||
| f"target dim {sample.node_target.size(1)} != {T * Fo}" | ||
| ) | ||
| exact_seq = sample.node_target.view(N, T, Fo).transpose(0, 1).contiguous() # [T,N,Fo] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: validation will fail if sample.node_target is None - should handle missing targets gracefully
| # Sum errors across all ranks | ||
| if self.dist.world_size > 1: | ||
| torch.distributed.all_reduce(MSE, op=torch.distributed.ReduceOp.SUM) | ||
| torch.distributed.all_reduce(MSE_w_time, op=torch.distributed.ReduceOp.SUM) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
logic: all_reduce sums errors from participating ranks but denominator uses total validation samples - this double-counts if not all ranks participate. Should the denominator be adjusted for the actual number of participating ranks rather than total validation samples?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following code added at line 114-121 adjust the number of validation samples for all possible cases-
self.num_validation_replicas = min(
self.dist.world_size, cfg.training.num_validation_samples
)
self.num_validation_samples = (
cfg.training.num_validation_samples
// self.num_validation_replicas
- self.num_validation_replicas
)
Eg: if user provide: world_size = 3, num_validation_samples = 8,
then num_validation_samples = 6. Each 3 node will load 2 sample.
if user provide: world_size = 24, num_validation_samples = 8,
then num_validation_samples = 8. Only 8 node will load 1 sample and rest will have no data for validation.
getting updates from NVIDIA/physicsnemo
updating crash branch
Dakhare crash
| start_lr: 0.0001 | ||
| end_lr: 0.0000003 | ||
| epochs: 10000 | ||
| validate_every_n_epochs: 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe change this to validation_freq?
| end_lr: 0.0000003 | ||
| epochs: 10000 | ||
| validate_every_n_epochs: 10 | ||
| save_ckpt_every_n_epochs: 10 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
And change this to save_chckpoint_freq?
| # Create a validation dataset | ||
| val_cfg = self.cfg.datapipe | ||
| with open_dict(val_cfg): # or open_dict(cfg) to open the whole tree | ||
| val_cfg.data_dir = self.cfg.inference.raw_data_dir_test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You should not use the test samples for validation. We should have three splits: train, validation, and test
| val_cfg, | ||
| name="crash_test", | ||
| reader=reader, | ||
| split="test", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
validation
| torch.distributed.all_reduce(MSE_w_time, op=torch.distributed.ReduceOp.SUM) | ||
|
|
||
| val_stats = { | ||
| "MSE_w_time": MSE_w_time / self.num_validation_samples, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See greptile's comment. We should divide by the actual number of validation samples, especially because you have drop_last=True.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Following code added at line 114-121 computes actual number of validation samples for all possible cases-
self.num_validation_replicas = min(
self.dist.world_size, cfg.training.num_validation_samples
)
self.num_validation_samples = (
cfg.training.num_validation_samples
// self.num_validation_replicas
* self.num_validation_replicas
)
Eg: if user provide: world_size = 3, num_validation_samples = 8,
then num_validation_samples = 6. Each 3 node will load 2 sample.
if user provide: world_size = 24, num_validation_samples = 8,
then num_validation_samples = 8. Only 8 node will load 1 sample and rest will have no data for validation.
|
/blossom-ci |
val path added and args name changed
PhysicsNeMo Pull Request
Description
Checklist
Dependencies
Review Process
All PRs are reviewed by the PhysicsNeMo team before merging.
Depending on which files are changed, GitHub may automatically assign a maintainer for review.
We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.
AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.