@@ -310,7 +310,7 @@ def execute_something():
310310 except ValueError :
311311 _check_signature (handler , "handler" , * (event_args + args ), ** kwargs )
312312 self ._event_handlers [event_name ].append ((handler , args , kwargs ))
313- self .logger .debug (f"added handler for event { event_name } " )
313+ self .logger .debug (f"Added handler for event { event_name } " )
314314
315315 return RemovableEventHandle (event_name , handler , self )
316316
@@ -406,7 +406,7 @@ def _fire_event(self, event_name: Any, *event_args: Any, **event_kwargs: Any) ->
406406 **event_kwargs: optional keyword args to be passed to all handlers.
407407
408408 """
409- self .logger .debug (f"firing handlers for event { event_name } " )
409+ self .logger .debug (f"{ self . state . epoch } | { self . state . iteration } , Firing handlers for event { event_name } " )
410410 self .last_event_name = event_name
411411 for func , args , kwargs in self ._event_handlers [event_name ]:
412412 kwargs .update (event_kwargs )
@@ -720,6 +720,11 @@ def switch_batch(engine):
720720 if self .state .epoch_length is None and data is None :
721721 raise ValueError ("epoch_length should be provided if data is None" )
722722
723+ if self .should_terminate :
724+ # If engine was terminated and now is resuming from terminated state
725+ # we need to initialize iter_counter as 0
726+ self ._init_iter .append (0 )
727+
723728 self .state .dataloader = data
724729 return self ._internal_run ()
725730
@@ -750,12 +755,13 @@ def _setup_dataloader_iter(self) -> None:
750755
751756 def _setup_engine (self ) -> None :
752757 self ._setup_dataloader_iter ()
753- iteration = self .state .iteration
754758
755- # Below we define initial counter value for _run_once_on_dataset to measure a single epoch
756- if self .state .epoch_length is not None :
757- iteration %= self .state .epoch_length
758- self ._init_iter .append (iteration )
759+ if len (self ._init_iter ) == 0 :
760+ iteration = self .state .iteration
761+ # Below we define initial counter value for _run_once_on_dataset to measure a single epoch
762+ if self .state .epoch_length is not None :
763+ iteration %= self .state .epoch_length
764+ self ._init_iter .append (iteration )
759765
760766 def _internal_run (self ) -> State :
761767 self .should_terminate = self .should_terminate_single_epoch = False
@@ -826,6 +832,11 @@ def _run_once_on_dataset(self) -> float:
826832 start_time = time .time ()
827833
828834 # We need to setup iter_counter > 0 if we resume from an iteration
835+ if len (self ._init_iter ) > 1 :
836+ raise RuntimeError (
837+ "Internal error, len(self._init_iter) should 0 or 1, "
838+ f"but got: { len (self ._init_iter )} , { self ._init_iter } "
839+ )
829840 iter_counter = self ._init_iter .pop () if len (self ._init_iter ) > 0 else 0
830841 should_exit = False
831842 try :
0 commit comments