-
Notifications
You must be signed in to change notification settings - Fork 0
Enable training on variable-sized batches #42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
| def collate_fn(batch): | ||
| try: | ||
| return default_collate(batch) | ||
| except Exception: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This catches all possible exceptions, which is very broad. Some of them could be real errors that should be thrown. Is there a particular class of exceptions that you want to catch here?
| out1 = self.conv1(x) | ||
| out = self.res_blocks(out1) | ||
| out2 = self.conv2(out) | ||
| out = torch.add(out1, out2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ooc how is this relevant to the variable-sized batch changes?
| loss = torch.stack(losses).mean() | ||
| else: | ||
| pred = self(x) | ||
| loss = self.loss_fn(pred, y) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic is a duplicate of that in training_step above. It would be good to create a separate function, e.g. _loss_calculation() that each of these functions call. That way, next time we update this code, we won't accidentally miss an update to one and cause them to drift.
| return default_collate(batch) | ||
| except Exception: | ||
| # Separate and return as lists of tensors | ||
| x, y = zip(*batch, strict=False) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Setting strict=False means that data will be silently dropped if x and y somehow are different lengths. Can you say more about why this is safe and preferred to strict=True?
| from src.electrai.dataloader.registry import get_data | ||
| from src.electrai.lightning import LightningGenerator | ||
| from torch.utils.data import DataLoader | ||
| from torch.utils.data._utils.collate import default_collate |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the motivation for importing default_collate from torch.utils.data._utils.collate rather than torch.utils.data?
Motivation
The current training pipeline assumes either
This restricts the ability to work with datasets that naturally contain heterogeneous shapes (e.g., grids at different resolutions), forcing users to manually pad or preprocess data.
Solution
This PR introduces support for training on heterogeneous batches by:
These updates allow the training loop to handle variable-sized input data seamlessly, reducing the need for intrusive or manual preprocessing.
Notes
Handling heterogeneous batches may introduce additional overhead compared to fully vectorized uniform-data training. However, enabling this flexibility is valuable for datasets where variable-sized samples are inherent to the problem rather than avoidable.