Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/electrai/configs/MP/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ model_precision: 32
use_checkpoint: True
save_every_epochs: 2
epochs: 10
nbatch: 1
nbatch: 4
lr: 0.01
weight_decay: 0.0
warmup_length: 1
Expand Down
7 changes: 5 additions & 2 deletions src/electrai/dataloader/mp.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,11 @@ def rotate(d):
return [rotate(rotate(rotate(d))) for d in data_lst]

def __getitem__(self, idx: int):
data = self.read_data(self.data[idx][0])
label = self.read_data(self.data[idx][1])
data_path = self.data[idx][0]
label_path = self.data[idx][1]

data = self.read_data(data_path)
label = self.read_data(label_path)

if self.rho_type == "chgcar":
data = data.data["total"] / np.prod(data.data["total"].shape)
Expand Down
12 changes: 12 additions & 0 deletions src/electrai/entrypoints/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,20 @@
from src.electrai.dataloader.registry import get_data
from src.electrai.lightning import LightningGenerator
from torch.utils.data import DataLoader
from torch.utils.data._utils.collate import default_collate
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's the motivation for importing default_collate from torch.utils.data._utils.collate rather than torch.utils.data?


torch.backends.cudnn.conv.fp32_precision = "tf32"


def collate_fn(batch):
try:
return default_collate(batch)
except Exception:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This catches all possible exceptions, which is very broad. Some of them could be real errors that should be thrown. Is there a particular class of exceptions that you want to catch here?

# Separate and return as lists of tensors
x, y = zip(*batch, strict=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting strict=False means that data will be silently dropped if x and y somehow are different lengths. Can you say more about why this is safe and preferred to strict=True?

return list(x), list(y)


def train(args):
# -----------------------------
# Load YAML config
Expand All @@ -35,12 +45,14 @@ def train(args):
batch_size=int(cfg.nbatch),
shuffle=True,
num_workers=cfg.num_workers,
collate_fn=collate_fn,
)
test_loader = DataLoader(
test_data,
batch_size=int(cfg.nbatch),
shuffle=False,
num_workers=cfg.num_workers,
collate_fn=collate_fn,
)

# -----------------------------
Expand Down
24 changes: 20 additions & 4 deletions src/electrai/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,16 @@ def forward(self, x):

def training_step(self, batch):
x, y = batch
pred = self(x)
loss = self.loss_fn(pred, y)
if isinstance(x, list):
losses = []
for x_i, y_i in zip(x, y, strict=False):
pred = self(x_i.unsqueeze(0))
loss = self.loss_fn(pred, y_i.unsqueeze(0))
losses.append(loss)
loss = torch.stack(losses).mean()
else:
pred = self(x)
loss = self.loss_fn(pred, y)
self.log(
"train_loss",
loss,
Expand All @@ -41,8 +49,16 @@ def training_step(self, batch):

def validation_step(self, batch):
x, y = batch
pred = self(x)
loss = self.loss_fn(pred, y)
if isinstance(x, list):
losses = []
for x_i, y_i in zip(x, y, strict=False):
pred = self(x_i.unsqueeze(0))
loss = self.loss_fn(pred, y_i.unsqueeze(0))
losses.append(loss)
loss = torch.stack(losses).mean()
else:
pred = self(x)
loss = self.loss_fn(pred, y)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is a duplicate of that in training_step above. It would be good to create a separate function, e.g. _loss_calculation() that each of these functions call. That way, next time we update this code, we won't accidentally miss an update to one and cause them to drift.

self.log(
"val_loss", loss, prog_bar=True, on_step=True, on_epoch=True, sync_dist=True
)
Expand Down
9 changes: 9 additions & 0 deletions src/electrai/model/loss/charge.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,15 @@ def __init__(self):
self.mae = torch.nn.L1Loss(reduction="none")

def forward(self, output, target):
if isinstance(output, torch.Tensor):
return self._forward(output, target)

losses = []
for out, tar in zip(output, target, strict=False):
losses.append(self._forward(out.unsqueeze(0), tar.unsqueeze(0)))
return torch.stack(losses).mean()

def _forward(self, output, target):
mae = self.mae(output, target)
nelec = torch.sum(target, axis=(-3, -2, -1))
mae = mae / nelec[..., None, None, None]
Expand Down
9 changes: 7 additions & 2 deletions src/electrai/model/srgan_layernorm_pbc.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ def __init__(self, in_features, K=3, use_checkpoint=True):

def forward(self, x):
if self.use_checkpoint and self.training:
# Use gradient checkpointing to save memory during training
return x + checkpoint(self.conv_block, x, use_reentrant=False)
else:
return x + self.conv_block(x)
Expand Down Expand Up @@ -148,12 +147,18 @@ def __init__(
)

def forward(self, x):
if isinstance(x, torch.Tensor):
return self._forward(x)
return [self._forward(xi.unsqueeze(0)).squeeze(0) for xi in x]

def _forward(self, x):
out1 = self.conv1(x)
out = self.res_blocks(out1)
out2 = self.conv2(out)
out = torch.add(out1, out2)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ooc how is this relevant to the variable-sized batch changes?

out = out1 + out2
out = self.upsampling(out)
out = self.conv3(out)

if self.normalize:
upscale_factor = 8 ** (self.n_upscale_layers)
out = out / torch.sum(out, axis=(-3, -2, -1))[..., None, None, None]
Expand Down