Skip to content

Conversation

@hanaol
Copy link
Collaborator

@hanaol hanaol commented Dec 11, 2025

Motivation

The current training pipeline assumes either

  • uniformly sized data across all samples in a batch, or
  • single-sample processing when dimensions vary

This restricts the ability to work with datasets that naturally contain heterogeneous shapes (e.g., grids at different resolutions), forcing users to manually pad or preprocess data.

Solution

This PR introduces support for training on heterogeneous batches by:

  • Implementing a dynamic collate mechanism that groups and prepares samples of varying sizes at runtime.
  • Updating the dataloader and relevant training components (e.g., model forward pass or loss function, where applicable) to handle variable-shaped inputs safely.
  • Ensuring training remains stable and consistent even when batch elements differ in their spatial dimensions.

These updates allow the training loop to handle variable-sized input data seamlessly, reducing the need for intrusive or manual preprocessing.

Notes

Handling heterogeneous batches may introduce additional overhead compared to fully vectorized uniform-data training. However, enabling this flexibility is valuable for datasets where variable-sized samples are inherent to the problem rather than avoidable.

@hanaol hanaol requested a review from forklady42 December 11, 2025 20:53
def collate_fn(batch):
try:
return default_collate(batch)
except Exception:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This catches all possible exceptions, which is very broad. Some of them could be real errors that should be thrown. Is there a particular class of exceptions that you want to catch here?

out1 = self.conv1(x)
out = self.res_blocks(out1)
out2 = self.conv2(out)
out = torch.add(out1, out2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooc how is this relevant to the variable-sized batch changes?

loss = torch.stack(losses).mean()
else:
pred = self(x)
loss = self.loss_fn(pred, y)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is a duplicate of that in training_step above. It would be good to create a separate function, e.g. _loss_calculation() that each of these functions call. That way, next time we update this code, we won't accidentally miss an update to one and cause them to drift.

return default_collate(batch)
except Exception:
# Separate and return as lists of tensors
x, y = zip(*batch, strict=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting strict=False means that data will be silently dropped if x and y somehow are different lengths. Can you say more about why this is safe and preferred to strict=True?

from src.electrai.dataloader.registry import get_data
from src.electrai.lightning import LightningGenerator
from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_collate
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the motivation for importing default_collate from torch.utils.data._utils.collate rather than torch.utils.data?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants