Skip to content

Conversation

@dakhare-creator
Copy link

PhysicsNeMo Pull Request

Description

Checklist

  • I am familiar with the Contributing Guidelines.
  • New or existing tests cover these changes.
  • The documentation is up to date with these changes.
  • The CHANGELOG.md is up to date with these changes.
  • An issue is linked to this pull request.

Dependencies

Review Process

All PRs are reviewed by the PhysicsNeMo team before merging.

Depending on which files are changed, GitHub may automatically assign a maintainer for review.

We are also testing AI-based code review tools (e.g., Greptile), which may add automated comments with a confidence score.
This score reflects the AI’s assessment of merge readiness and is not a qualitative judgment of your work, nor is
it an indication that the PR will be accepted / rejected.

AI-generated feedback should be reviewed critically for usefulness.
You are not required to respond to every AI comment, but they are intended to help both authors and reviewers.
Please react to Greptile comments with 👍 or 👎 to provide feedback on their accuracy.

Dakhare crash

- Validation function to track model performance on test dataset is added in physicsnemo/examples/structural_mechanics/crash/train.py

- validate_every_n_epochs, save_ckpt_every_n_epochs added in config/training/default.yaml to assign frequency for calling validation function and saking checkpoint
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Greptile Overview

Greptile Summary

This PR adds validation functionality to the structural mechanics crash simulation training example. The main changes include: (1) adding validation dataset creation and distributed sampling in train.py, (2) implementing a validation loop that computes time-step-wise MSE loss and aggregates results across distributed ranks, (3) adding validation configuration parameters to control validation frequency and checkpoint saving, and (4) refactoring the inference code to use a unified sample object interface instead of passing individual graph components separately.

The validation implementation follows distributed training best practices by properly handling data sampling, metric aggregation, and logging only on rank 0. The changes integrate cleanly with the existing training pipeline and tensorboard logging infrastructure, providing essential model monitoring capabilities for the crash simulation example.

PR Description Notes:

  • The PR description is largely empty with only unchecked checklist items
  • No standalone description of changes provided
  • No linked issues or changelog updates mentioned
  • Missing information about new dependencies or testing coverage

Important Files Changed

Filename Score Overview
examples/structural_mechanics/crash/train.py 4/5 Added comprehensive validation functionality with distributed sampling, MSE computation, and tensorboard logging
examples/structural_mechanics/crash/conf/training/default.yaml 5/5 Added validation configuration parameters for sample count, validation frequency, and checkpoint saving
examples/structural_mechanics/crash/inference.py 4/5 Refactored model forward pass to use unified sample object interface instead of separate graph components

3 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines 114 to 115
self.num_validation_replicas = min(self.dist.world_size, cfg.training.num_validation_samples)
self.num_validation_samples = cfg.training.num_validation_samples // self.num_validation_replicas * self.num_validation_replicas
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: potential division by zero if num_validation_replicas is 0

Suggested change
self.num_validation_replicas = min(self.dist.world_size, cfg.training.num_validation_samples)
self.num_validation_samples = cfg.training.num_validation_samples // self.num_validation_replicas * self.num_validation_replicas
self.num_validation_replicas = min(self.dist.world_size, cfg.training.num_validation_samples)
if self.num_validation_replicas > 0:
self.num_validation_samples = cfg.training.num_validation_samples // self.num_validation_replicas * self.num_validation_replicas
else:
self.num_validation_samples = 0

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following code
self.num_validation_replicas = min(self.dist.world_size, cfg.training.num_validation_samples)
enforce num_validation_replicas >= 0 as self.dist.world_size >= 1

Comment on lines 261 to 268
exact_seq = None
if sample.node_target is not None:
N = sample.node_target.size(0)
Fo = 3 # output features per node
assert sample.node_target.size(1) == T * Fo, (
f"target dim {sample.node_target.size(1)} != {T * Fo}"
)
exact_seq = sample.node_target.view(N, T, Fo).transpose(0, 1).contiguous() # [T,N,Fo]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: validation will fail if sample.node_target is None - should handle missing targets gracefully

Comment on lines +275 to +278
# Sum errors across all ranks
if self.dist.world_size > 1:
torch.distributed.all_reduce(MSE, op=torch.distributed.ReduceOp.SUM)
torch.distributed.all_reduce(MSE_w_time, op=torch.distributed.ReduceOp.SUM)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: all_reduce sums errors from participating ranks but denominator uses total validation samples - this double-counts if not all ranks participate. Should the denominator be adjusted for the actual number of participating ranks rather than total validation samples?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following code added at line 114-121 adjust the number of validation samples for all possible cases-
self.num_validation_replicas = min(
self.dist.world_size, cfg.training.num_validation_samples
)
self.num_validation_samples = (
cfg.training.num_validation_samples
// self.num_validation_replicas

  • self.num_validation_replicas
    )

Eg: if user provide: world_size = 3, num_validation_samples = 8,
then num_validation_samples = 6. Each 3 node will load 2 sample.
if user provide: world_size = 24, num_validation_samples = 8,
then num_validation_samples = 8. Only 8 node will load 1 sample and rest will have no data for validation.

start_lr: 0.0001
end_lr: 0.0000003
epochs: 10000
validate_every_n_epochs: 10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe change this to validation_freq?

end_lr: 0.0000003
epochs: 10000
validate_every_n_epochs: 10
save_ckpt_every_n_epochs: 10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And change this to save_chckpoint_freq?

# Create a validation dataset
val_cfg = self.cfg.datapipe
with open_dict(val_cfg): # or open_dict(cfg) to open the whole tree
val_cfg.data_dir = self.cfg.inference.raw_data_dir_test
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should not use the test samples for validation. We should have three splits: train, validation, and test

val_cfg,
name="crash_test",
reader=reader,
split="test",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

validation

torch.distributed.all_reduce(MSE_w_time, op=torch.distributed.ReduceOp.SUM)

val_stats = {
"MSE_w_time": MSE_w_time / self.num_validation_samples,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See greptile's comment. We should divide by the actual number of validation samples, especially because you have drop_last=True.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Following code added at line 114-121 computes actual number of validation samples for all possible cases-
self.num_validation_replicas = min(
self.dist.world_size, cfg.training.num_validation_samples
)
self.num_validation_samples = (
cfg.training.num_validation_samples
// self.num_validation_replicas
* self.num_validation_replicas
)

Eg: if user provide: world_size = 3, num_validation_samples = 8,
then num_validation_samples = 6. Each 3 node will load 2 sample.
if user provide: world_size = 24, num_validation_samples = 8,
then num_validation_samples = 8. Only 8 node will load 1 sample and rest will have no data for validation.

@mnabian
Copy link
Collaborator

mnabian commented Nov 3, 2025

/blossom-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants