From be62fbbb6a06cf6879d07c25f55959b1519b91e9 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Thu, 11 Dec 2025 14:45:32 -0500 Subject: [PATCH 1/2] Add support for training on batches with variable data sizes --- src/electrai/configs/MP/config.yaml | 4 ++-- src/electrai/dataloader/mp.py | 7 +++++-- src/electrai/entrypoints/train.py | 12 ++++++++++++ src/electrai/lightning.py | 24 ++++++++++++++++++++---- 4 files changed, 39 insertions(+), 8 deletions(-) diff --git a/src/electrai/configs/MP/config.yaml b/src/electrai/configs/MP/config.yaml index 04730ee..1e6b355 100644 --- a/src/electrai/configs/MP/config.yaml +++ b/src/electrai/configs/MP/config.yaml @@ -27,7 +27,7 @@ model_precision: 32 use_checkpoint: True save_every_epochs: 2 epochs: 10 -nbatch: 1 +nbatch: 4 lr: 0.01 weight_decay: 0.0 warmup_length: 1 @@ -35,4 +35,4 @@ warmup_length: 1 # Weights and biases wandb_mode: offline wb_pname: uniform-data-training -entity: PrinceOA +entity: hanaoli diff --git a/src/electrai/dataloader/mp.py b/src/electrai/dataloader/mp.py index 8684a26..1b0e7ca 100644 --- a/src/electrai/dataloader/mp.py +++ b/src/electrai/dataloader/mp.py @@ -123,8 +123,11 @@ def rotate(d): return [rotate(rotate(rotate(d))) for d in data_lst] def __getitem__(self, idx: int): - data = self.read_data(self.data[idx][0]) - label = self.read_data(self.data[idx][1]) + data_path = self.data[idx][0] + label_path = self.data[idx][1] + + data = self.read_data(data_path) + label = self.read_data(label_path) if self.rho_type == "chgcar": data = data.data["total"] / np.prod(data.data["total"].shape) diff --git a/src/electrai/entrypoints/train.py b/src/electrai/entrypoints/train.py index e5e58d1..2c97ed3 100644 --- a/src/electrai/entrypoints/train.py +++ b/src/electrai/entrypoints/train.py @@ -11,10 +11,20 @@ from src.electrai.dataloader.registry import get_data from src.electrai.lightning import LightningGenerator from torch.utils.data import DataLoader +from torch.utils.data._utils.collate import default_collate torch.backends.cudnn.conv.fp32_precision = "tf32" +def collate_fn(batch): + try: + return default_collate(batch) + except Exception: + # Separate and return as lists of tensors + x, y = zip(*batch, strict=False) + return list(x), list(y) + + def train(args): # ----------------------------- # Load YAML config @@ -35,12 +45,14 @@ def train(args): batch_size=int(cfg.nbatch), shuffle=True, num_workers=cfg.num_workers, + collate_fn=collate_fn, ) test_loader = DataLoader( test_data, batch_size=int(cfg.nbatch), shuffle=False, num_workers=cfg.num_workers, + collate_fn=collate_fn, ) # ----------------------------- diff --git a/src/electrai/lightning.py b/src/electrai/lightning.py index a46c7d5..35d5796 100644 --- a/src/electrai/lightning.py +++ b/src/electrai/lightning.py @@ -27,8 +27,16 @@ def forward(self, x): def training_step(self, batch): x, y = batch - pred = self(x) - loss = self.loss_fn(pred, y) + if isinstance(x, list): + losses = [] + for x_i, y_i in zip(x, y, strict=False): + pred = self(x_i.unsqueeze(0)) + loss = self.loss_fn(pred, y_i.unsqueeze(0)) + losses.append(loss) + loss = torch.stack(losses).mean() + else: + pred = self(x) + loss = self.loss_fn(pred, y) self.log( "train_loss", loss, @@ -41,8 +49,16 @@ def training_step(self, batch): def validation_step(self, batch): x, y = batch - pred = self(x) - loss = self.loss_fn(pred, y) + if isinstance(x, list): + losses = [] + for x_i, y_i in zip(x, y, strict=False): + pred = self(x_i.unsqueeze(0)) + loss = self.loss_fn(pred, y_i.unsqueeze(0)) + losses.append(loss) + loss = torch.stack(losses).mean() + else: + pred = self(x) + loss = self.loss_fn(pred, y) self.log( "val_loss", loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True ) From 4489be96a959acbd0124da0fa56fb875d06528b2 Mon Sep 17 00:00:00 2001 From: Hananeh Oliaei Date: Thu, 11 Dec 2025 15:45:28 -0500 Subject: [PATCH 2/2] Support heterogeneous batches in loss function and model forward pass --- src/electrai/configs/MP/config.yaml | 2 +- src/electrai/model/loss/charge.py | 9 +++++++++ src/electrai/model/srgan_layernorm_pbc.py | 9 +++++++-- 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/electrai/configs/MP/config.yaml b/src/electrai/configs/MP/config.yaml index 1e6b355..0bf5148 100644 --- a/src/electrai/configs/MP/config.yaml +++ b/src/electrai/configs/MP/config.yaml @@ -35,4 +35,4 @@ warmup_length: 1 # Weights and biases wandb_mode: offline wb_pname: uniform-data-training -entity: hanaoli +entity: PrinceOA diff --git a/src/electrai/model/loss/charge.py b/src/electrai/model/loss/charge.py index 09c2554..dfaff47 100644 --- a/src/electrai/model/loss/charge.py +++ b/src/electrai/model/loss/charge.py @@ -9,6 +9,15 @@ def __init__(self): self.mae = torch.nn.L1Loss(reduction="none") def forward(self, output, target): + if isinstance(output, torch.Tensor): + return self._forward(output, target) + + losses = [] + for out, tar in zip(output, target, strict=False): + losses.append(self._forward(out.unsqueeze(0), tar.unsqueeze(0))) + return torch.stack(losses).mean() + + def _forward(self, output, target): mae = self.mae(output, target) nelec = torch.sum(target, axis=(-3, -2, -1)) mae = mae / nelec[..., None, None, None] diff --git a/src/electrai/model/srgan_layernorm_pbc.py b/src/electrai/model/srgan_layernorm_pbc.py index ccd1b13..24f1c09 100644 --- a/src/electrai/model/srgan_layernorm_pbc.py +++ b/src/electrai/model/srgan_layernorm_pbc.py @@ -37,7 +37,6 @@ def __init__(self, in_features, K=3, use_checkpoint=True): def forward(self, x): if self.use_checkpoint and self.training: - # Use gradient checkpointing to save memory during training return x + checkpoint(self.conv_block, x, use_reentrant=False) else: return x + self.conv_block(x) @@ -148,12 +147,18 @@ def __init__( ) def forward(self, x): + if isinstance(x, torch.Tensor): + return self._forward(x) + return [self._forward(xi.unsqueeze(0)).squeeze(0) for xi in x] + + def _forward(self, x): out1 = self.conv1(x) out = self.res_blocks(out1) out2 = self.conv2(out) - out = torch.add(out1, out2) + out = out1 + out2 out = self.upsampling(out) out = self.conv3(out) + if self.normalize: upscale_factor = 8 ** (self.n_upscale_layers) out = out / torch.sum(out, axis=(-3, -2, -1))[..., None, None, None]