-
Notifications
You must be signed in to change notification settings - Fork 0
Enable training on variable-sized batches #42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,10 +11,20 @@ | |
| from src.electrai.dataloader.registry import get_data | ||
| from src.electrai.lightning import LightningGenerator | ||
| from torch.utils.data import DataLoader | ||
| from torch.utils.data._utils.collate import default_collate | ||
|
|
||
| torch.backends.cudnn.conv.fp32_precision = "tf32" | ||
|
|
||
|
|
||
| def collate_fn(batch): | ||
| try: | ||
| return default_collate(batch) | ||
| except Exception: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This catches all possible exceptions, which is very broad. Some of them could be real errors that should be thrown. Is there a particular class of exceptions that you want to catch here? |
||
| # Separate and return as lists of tensors | ||
| x, y = zip(*batch, strict=False) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Setting |
||
| return list(x), list(y) | ||
|
|
||
|
|
||
| def train(args): | ||
| # ----------------------------- | ||
| # Load YAML config | ||
|
|
@@ -35,12 +45,14 @@ def train(args): | |
| batch_size=int(cfg.nbatch), | ||
| shuffle=True, | ||
| num_workers=cfg.num_workers, | ||
| collate_fn=collate_fn, | ||
| ) | ||
| test_loader = DataLoader( | ||
| test_data, | ||
| batch_size=int(cfg.nbatch), | ||
| shuffle=False, | ||
| num_workers=cfg.num_workers, | ||
| collate_fn=collate_fn, | ||
| ) | ||
|
|
||
| # ----------------------------- | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -27,8 +27,16 @@ def forward(self, x): | |
|
|
||
| def training_step(self, batch): | ||
| x, y = batch | ||
| pred = self(x) | ||
| loss = self.loss_fn(pred, y) | ||
| if isinstance(x, list): | ||
| losses = [] | ||
| for x_i, y_i in zip(x, y, strict=False): | ||
| pred = self(x_i.unsqueeze(0)) | ||
| loss = self.loss_fn(pred, y_i.unsqueeze(0)) | ||
| losses.append(loss) | ||
| loss = torch.stack(losses).mean() | ||
| else: | ||
| pred = self(x) | ||
| loss = self.loss_fn(pred, y) | ||
| self.log( | ||
| "train_loss", | ||
| loss, | ||
|
|
@@ -41,8 +49,16 @@ def training_step(self, batch): | |
|
|
||
| def validation_step(self, batch): | ||
| x, y = batch | ||
| pred = self(x) | ||
| loss = self.loss_fn(pred, y) | ||
| if isinstance(x, list): | ||
| losses = [] | ||
| for x_i, y_i in zip(x, y, strict=False): | ||
| pred = self(x_i.unsqueeze(0)) | ||
| loss = self.loss_fn(pred, y_i.unsqueeze(0)) | ||
| losses.append(loss) | ||
| loss = torch.stack(losses).mean() | ||
| else: | ||
| pred = self(x) | ||
| loss = self.loss_fn(pred, y) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This logic is a duplicate of that in |
||
| self.log( | ||
| "val_loss", loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True | ||
| ) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -37,7 +37,6 @@ def __init__(self, in_features, K=3, use_checkpoint=True): | |
|
|
||
| def forward(self, x): | ||
| if self.use_checkpoint and self.training: | ||
| # Use gradient checkpointing to save memory during training | ||
| return x + checkpoint(self.conv_block, x, use_reentrant=False) | ||
| else: | ||
| return x + self.conv_block(x) | ||
|
|
@@ -148,12 +147,18 @@ def __init__( | |
| ) | ||
|
|
||
| def forward(self, x): | ||
| if isinstance(x, torch.Tensor): | ||
| return self._forward(x) | ||
| return [self._forward(xi.unsqueeze(0)).squeeze(0) for xi in x] | ||
|
|
||
| def _forward(self, x): | ||
| out1 = self.conv1(x) | ||
| out = self.res_blocks(out1) | ||
| out2 = self.conv2(out) | ||
| out = torch.add(out1, out2) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ooc how is this relevant to the variable-sized batch changes? |
||
| out = out1 + out2 | ||
| out = self.upsampling(out) | ||
| out = self.conv3(out) | ||
|
|
||
| if self.normalize: | ||
| upscale_factor = 8 ** (self.n_upscale_layers) | ||
| out = out / torch.sum(out, axis=(-3, -2, -1))[..., None, None, None] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What's the motivation for importing
default_collatefromtorch.utils.data._utils.collaterather thantorch.utils.data?