Skip to content

Commit 3ba2042

Browse files
authored
Merge pull request #6 from Quantum-Accelerators/hanaol/mp-dataloader
Comment about the inefficiency of the dataloading procedure.
2 parents e3f2999 + c458ddd commit 3ba2042

File tree

3 files changed

+11
-4
lines changed

3 files changed

+11
-4
lines changed

src/electrai/dataloader/chgcar_read.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def __init__(
2626
train_fraction: fraction of the data used for training (0 to 1).
2727
'''
2828
self.data_dir = Path(data_dir)
29-
print(self.data_dir)
3029
self.label_dir = Path(label_dir)
3130
self.map_dir = Path(map_dir)
3231
self.rho_type = rho_type

src/electrai/dataloader/dataset.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
from torch.utils.data import DataLoader, Dataset
33
import numpy as np
44

5+
# !!! Reading in all the data at once is probably not a good idea, we better just read in the filenames, and have __getitem__ read each index in instead.
6+
# The could should be updates as such.
57
class RhoData(Dataset):
68
def __init__(self, list_data, list_label, list_data_gridsizes, list_label_gridsizes, data_augmentation=True, downsample_data=1, downsample_label=1):
79
'''

src/electrai/scripts/train.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
downsample_label=cfg.downsample_label,
4848
data_augmentation=False)
4949

50+
print('train_data: ', train_data)
51+
5052
train_loader = DataLoader(train_data, batch_size=int(cfg.nbatch), shuffle=True)
5153
test_loader = DataLoader(test_data, batch_size=int(cfg.nbatch), shuffle=False)
5254

@@ -62,8 +64,6 @@
6264
normalize=not cfg.normalize_label
6365
).to(cfg.device)
6466

65-
print("train chckpt")
66-
6767
optimizer = torch.optim.Adam(model.parameters(), lr=float(cfg.lr), weight_decay=float(cfg.weight_decay))
6868

6969
# Linear + Cosine scheduler
@@ -105,7 +105,13 @@ def loss_fn_sum(output, target):
105105
return loss
106106

107107
optimizer.zero_grad()
108-
for batch, (X, y) in enumerate(dataloader):
108+
# print(dataloader)
109+
# for batch, (X, y) in enumerate(dataloader):
110+
for batch, cont in enumerate(dataloader):
111+
print('batch: ', batch)
112+
print('cont: ', cont)
113+
# print('X: ', X.shape)
114+
# print('y: ', y.shape)
109115
X, y = X.to(cfg.device), y.to(cfg.device)
110116
pred = model(X)
111117

0 commit comments

Comments
 (0)