Skip to content

Commit 6586a75

Browse files
committed
testing fixes
1 parent 2f6b6e9 commit 6586a75

File tree

3 files changed

+29
-11
lines changed

3 files changed

+29
-11
lines changed

.vscode/settings.json

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
11
{
22
"cSpell.words": [
3+
"firstname",
34
"isready",
5+
"lastname",
6+
"mypy",
47
"postgresql",
8+
"pytest",
59
"PYTHONPATH"
610
]
711
}

train.py

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
"""Train ResNet based model."""
22

3-
from typing import Optional, Tuple, List, Any, TypedDict, Literal
3+
from typing import Optional, Tuple, List, Any, TypedDict, Literal, Callable
44
from collections import OrderedDict
55
import sys
66
import os
@@ -176,6 +176,7 @@ def __init__(self):
176176
Block(256, 512, first_stride=2),
177177
Block(512, 512),
178178
Block(512, 512),
179+
Block(512, 512),
179180
]
180181
)
181182

@@ -342,7 +343,7 @@ def one_epoch(
342343
dataset: DataLoader,
343344
opt: torch.optim.Optimizer,
344345
model: Model,
345-
loss_fn: nn.modules.loss._Loss,
346+
loss_fn: Callable,
346347
scheduler: torch.optim.lr_scheduler.LRScheduler,
347348
):
348349
"""Run one epoch on data"""
@@ -360,9 +361,10 @@ def one_epoch(
360361
if DEBUG:
361362
print("%smodel output: %s%s" % (Colors.YELLOWFG, str(output), Colors.ENDC))
362363

363-
loss = loss_fn(
364-
output.log(), label
365-
) # the loss function requires our data to be in log format
364+
# loss = loss_fn(
365+
# output.log(), label
366+
# ) # the loss function requires our data to be in log format
367+
loss = loss_fn(output, label)
366368

367369
loss.backward()
368370

@@ -393,6 +395,16 @@ def get_model(device: Device, state: Optional[StateDict] = None) -> Model:
393395

394396
return model
395397

398+
def loss_fn(output:torch.Tensor, label:torch.Tensor) -> torch.Tensor:
399+
"""
400+
Apply a loss function.
401+
402+
Given that the output and the label are vectors. We can find the distance between them.
403+
"""
404+
405+
distance = torch.sqrt(torch.sum((label - output)**2))
406+
return distance
407+
396408

397409
def train(
398410
device: Device,
@@ -447,7 +459,7 @@ def train(
447459
if checkpoint.was_provided():
448460
history.load()
449461

450-
loss_fn = nn.KLDivLoss(reduction="batchmean")
462+
#loss_fn = nn.KLDivLoss(reduction="batchmean")
451463
lr = 0.1
452464
opt = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9, weight_decay=5e-4)
453465
scheduler = torch.optim.lr_scheduler.OneCycleLR(
@@ -489,7 +501,7 @@ def train(
489501
outputs.append(output)
490502
targets.append(torch.Tensor(label))
491503

492-
loss = loss_fn(output.log(), label)
504+
loss = loss_fn(output, label)
493505
running_loss += loss
494506

495507
avg_loss = running_loss / max_test
@@ -527,7 +539,7 @@ def train(
527539
outputs.append(output)
528540
targets.append(torch.Tensor(label))
529541

530-
loss = loss_fn(output.log(), label)
542+
loss = loss_fn(output, label)
531543
eval_loss += loss
532544

533545
avg_loss = eval_loss / max_eval
@@ -562,7 +574,7 @@ def get_device() -> Device:
562574
def get_RMSE(targets: List[torch.Tensor], outputs: List[torch.Tensor]):
563575
"""Root Mean Squared Error."""
564576
diff_sum = sum(
565-
[(torch.sum(target - output)) ** 2 for target, output in zip(targets, outputs)]
577+
[torch.sum((target - output) ** 2) for target, output in zip(targets, outputs)]
566578
)
567579
n = len(targets)
568580
return torch.sqrt((1 / n) * diff_sum)
@@ -572,6 +584,9 @@ def run_debug_experiemnt(args: Arguments):
572584
"""
573585
Run Manual tests.
574586
"""
587+
588+
print("%sRunning in DEBUG mode%s"%(Colors.YELLOWFG, Colors.ENDC))
589+
575590
device = get_device()
576591

577592
checkpoint = Checkpoint(args.checkpoint)
@@ -604,7 +619,6 @@ def setup_and_run_training(args: Arguments):
604619
checkpoint = Checkpoint(args.checkpoint)
605620
checkpoint.load()
606621

607-
# TODO: USE PYTORCH SCRIPTING (JIT)
608622
model = train(
609623
device,
610624
checkpoint,

utils/constants.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
DEBUG = os.environ.get("DEBUG") or False
66

7-
DEFAULT_EPOCHS = 80
7+
DEFAULT_EPOCHS = 100
88

99
DEFAULT_SHOTS = 1000
1010
DEFAULT_NUM_QUBITS = 5

0 commit comments

Comments
 (0)