Skip to content

Commit 39a200b

Browse files
authored
Merge pull request #7 from Quantum-Accelerators/hanaol/mp-dataloader
Model is running without issues. next step is to work on input density.
2 parents 3ba2042 + ce3ff13 commit 39a200b

File tree

4 files changed

+5
-10
lines changed

4 files changed

+5
-10
lines changed

src/electrai/chk.pth

32.7 KB
Binary file not shown.

src/electrai/chk_1.pth

32.7 KB
Binary file not shown.

src/electrai/dataloader/dataset.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,14 +59,12 @@ def rand_rotate(self, data_lst):
5959
return [rotate(rotate(rotate(d))) for d in data_lst]
6060

6161
def __getitem__(self, idx):
62-
rho1 = torch.tensor(
63-
np.load(self.data[idx]), dtype=torch.float32)
64-
size = np.loadtxt(self.data_gs[idx], dtype=int)
62+
rho1 = torch.tensor(self.data[idx], dtype=torch.float32)
63+
size = torch.tensor(self.data_gs[idx], dtype=int)
6564
rho1 = rho1.reshape(1, *size)
6665

67-
rho2 = torch.tensor(
68-
np.load(self.label[idx]), dtype=torch.float32)
69-
size = np.loadtxt(self.label_gs[idx], dtype=int)
66+
rho2 = torch.tensor(self.label[idx], dtype=torch.float32)
67+
size = torch.tensor(self.label_gs[idx], dtype=int)
7068
rho2 = rho2.reshape(1, *size)
7169

7270
if self.da:

src/electrai/scripts/train.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,8 @@ def loss_fn_sum(output, target):
105105
return loss
106106

107107
optimizer.zero_grad()
108-
# print(dataloader)
109-
# for batch, (X, y) in enumerate(dataloader):
110-
for batch, cont in enumerate(dataloader):
108+
for batch, (X, y) in enumerate(dataloader):
111109
print('batch: ', batch)
112-
print('cont: ', cont)
113110
# print('X: ', X.shape)
114111
# print('y: ', y.shape)
115112
X, y = X.to(cfg.device), y.to(cfg.device)

0 commit comments

Comments
 (0)