11"""Train ResNet based model."""
22
3- from typing import Optional , Tuple , List , Any , TypedDict , Literal
3+ from typing import Optional , Tuple , List , Any , TypedDict , Literal , Callable
44from collections import OrderedDict
55import sys
66import os
@@ -176,6 +176,7 @@ def __init__(self):
176176 Block (256 , 512 , first_stride = 2 ),
177177 Block (512 , 512 ),
178178 Block (512 , 512 ),
179+ Block (512 , 512 ),
179180 ]
180181 )
181182
@@ -342,7 +343,7 @@ def one_epoch(
342343 dataset : DataLoader ,
343344 opt : torch .optim .Optimizer ,
344345 model : Model ,
345- loss_fn : nn . modules . loss . _Loss ,
346+ loss_fn : Callable ,
346347 scheduler : torch .optim .lr_scheduler .LRScheduler ,
347348):
348349 """Run one epoch on data"""
@@ -360,9 +361,10 @@ def one_epoch(
360361 if DEBUG :
361362 print ("%smodel output: %s%s" % (Colors .YELLOWFG , str (output ), Colors .ENDC ))
362363
363- loss = loss_fn (
364- output .log (), label
365- ) # the loss function requires our data to be in log format
364+ # loss = loss_fn(
365+ # output.log(), label
366+ # ) # the loss function requires our data to be in log format
367+ loss = loss_fn (output , label )
366368
367369 loss .backward ()
368370
@@ -393,6 +395,16 @@ def get_model(device: Device, state: Optional[StateDict] = None) -> Model:
393395
394396 return model
395397
398+ def loss_fn (output :torch .Tensor , label :torch .Tensor ) -> torch .Tensor :
399+ """
400+ Apply a loss function.
401+
402+ Given that the output and the label are vectors. We can find the distance between them.
403+ """
404+
405+ distance = torch .sqrt (torch .sum ((label - output )** 2 ))
406+ return distance
407+
396408
397409def train (
398410 device : Device ,
@@ -447,7 +459,7 @@ def train(
447459 if checkpoint .was_provided ():
448460 history .load ()
449461
450- loss_fn = nn .KLDivLoss (reduction = "batchmean" )
462+ # loss_fn = nn.KLDivLoss(reduction="batchmean")
451463 lr = 0.1
452464 opt = torch .optim .SGD (model .parameters (), lr = lr , momentum = 0.9 , weight_decay = 5e-4 )
453465 scheduler = torch .optim .lr_scheduler .OneCycleLR (
@@ -489,7 +501,7 @@ def train(
489501 outputs .append (output )
490502 targets .append (torch .Tensor (label ))
491503
492- loss = loss_fn (output . log () , label )
504+ loss = loss_fn (output , label )
493505 running_loss += loss
494506
495507 avg_loss = running_loss / max_test
@@ -527,7 +539,7 @@ def train(
527539 outputs .append (output )
528540 targets .append (torch .Tensor (label ))
529541
530- loss = loss_fn (output . log () , label )
542+ loss = loss_fn (output , label )
531543 eval_loss += loss
532544
533545 avg_loss = eval_loss / max_eval
@@ -562,7 +574,7 @@ def get_device() -> Device:
562574def get_RMSE (targets : List [torch .Tensor ], outputs : List [torch .Tensor ]):
563575 """Root Mean Squared Error."""
564576 diff_sum = sum (
565- [( torch .sum (target - output )) ** 2 for target , output in zip (targets , outputs )]
577+ [torch .sum (( target - output ) ** 2 ) for target , output in zip (targets , outputs )]
566578 )
567579 n = len (targets )
568580 return torch .sqrt ((1 / n ) * diff_sum )
@@ -572,6 +584,9 @@ def run_debug_experiemnt(args: Arguments):
572584 """
573585 Run Manual tests.
574586 """
587+
588+ print ("%sRunning in DEBUG mode%s" % (Colors .YELLOWFG , Colors .ENDC ))
589+
575590 device = get_device ()
576591
577592 checkpoint = Checkpoint (args .checkpoint )
@@ -604,7 +619,6 @@ def setup_and_run_training(args: Arguments):
604619 checkpoint = Checkpoint (args .checkpoint )
605620 checkpoint .load ()
606621
607- # TODO: USE PYTORCH SCRIPTING (JIT)
608622 model = train (
609623 device ,
610624 checkpoint ,
0 commit comments