11import functools
22import logging
3- import math
43import time
54import warnings
65import weakref
@@ -764,14 +763,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
764763
765764 @staticmethod
766765 def _is_done (state : State ) -> bool :
767- is_done_iters = state .max_iters is not None and state .iteration >= state .max_iters
768766 is_done_count = (
769767 state .epoch_length is not None
770768 and state .max_epochs is not None
771769 and state .iteration >= state .epoch_length * state .max_epochs
772770 )
773771 is_done_epochs = state .max_epochs is not None and state .epoch >= state .max_epochs
774- return is_done_iters or is_done_count or is_done_epochs
772+ return is_done_count or is_done_epochs
775773
776774 def set_data (self , data : Iterable | DataLoader ) -> None :
777775 """Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -813,14 +811,13 @@ def run(
813811 self ,
814812 data : Iterable | None = None ,
815813 max_epochs : int | None = None ,
816- max_iters : int | None = None ,
817814 epoch_length : int | None = None ,
818815 ) -> State :
819816 """Runs the ``process_function`` over the passed data.
820817
821818 Engine has a state and the following logic is applied in this function:
822819
823- - At the first call, new state is defined by `max_epochs`, `max_iters`, ` epoch_length`, if provided.
820+ - At the first call, new state is defined by `max_epochs`, `epoch_length`, if provided.
824821 A timer for total and per-epoch time is initialized when Events.STARTED is handled.
825822 - If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
826823 provided, state is kept and used in the function.
@@ -838,9 +835,6 @@ def run(
838835 `len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
839836 determined as the iteration on which data iterator raises `StopIteration`.
840837 This argument should not change if run is resuming from a state.
841- max_iters: Number of iterations to run for.
842- `max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
843-
844838 Returns:
845839 State: output state.
846840
@@ -891,6 +885,8 @@ def switch_batch(engine):
891885
892886 if self .state .max_epochs is None or (self ._is_done (self .state ) and self ._internal_run_generator is None ):
893887 # Create new state
888+ if max_epochs is None :
889+ max_epochs = 1
894890 if epoch_length is None :
895891 if data is None :
896892 raise ValueError ("epoch_length should be provided if data is None" )
@@ -899,22 +895,9 @@ def switch_batch(engine):
899895 if epoch_length is not None and epoch_length < 1 :
900896 raise ValueError ("Input data has zero size. Please provide non-empty data" )
901897
902- if max_iters is None :
903- if max_epochs is None :
904- max_epochs = 1
905- else :
906- if max_epochs is not None :
907- raise ValueError (
908- "Arguments max_iters and max_epochs are mutually exclusive."
909- "Please provide only max_epochs or max_iters."
910- )
911- if epoch_length is not None :
912- max_epochs = math .ceil (max_iters / epoch_length )
913-
914898 self .state .iteration = 0
915899 self .state .epoch = 0
916900 self .state .max_epochs = max_epochs
917- self .state .max_iters = max_iters
918901 self .state .epoch_length = epoch_length
919902 # Reset generator if previously used
920903 self ._internal_run_generator = None
@@ -1114,18 +1097,12 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
11141097 if self .state .epoch_length is None :
11151098 # Define epoch length and stop the epoch
11161099 self .state .epoch_length = iter_counter
1117- if self .state .max_iters is not None :
1118- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
11191100 break
11201101
11211102 # Should exit while loop if we can not iterate
11221103 if should_exit :
1123- if not self ._is_done (self .state ):
1124- total_iters = (
1125- self .state .epoch_length * self .state .max_epochs
1126- if self .state .max_epochs is not None
1127- else self .state .max_iters
1128- )
1104+ if not self ._is_done (self .state ) and self .state .max_epochs is not None :
1105+ total_iters = self .state .epoch_length * self .state .max_epochs
11291106
11301107 warnings .warn (
11311108 "Data iterator can not provide data anymore but required total number of "
@@ -1156,10 +1133,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
11561133 if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
11571134 break
11581135
1159- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
1160- self .should_terminate = True
1161- raise _EngineTerminateException ()
1162-
11631136 except _EngineTerminateSingleEpochException :
11641137 self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
11651138 self ._setup_dataloader_iter ()
@@ -1300,18 +1273,12 @@ def _run_once_on_dataset_legacy(self) -> float:
13001273 if self .state .epoch_length is None :
13011274 # Define epoch length and stop the epoch
13021275 self .state .epoch_length = iter_counter
1303- if self .state .max_iters is not None :
1304- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
13051276 break
13061277
13071278 # Should exit while loop if we can not iterate
13081279 if should_exit :
1309- if not self ._is_done (self .state ):
1310- total_iters = (
1311- self .state .epoch_length * self .state .max_epochs
1312- if self .state .max_epochs is not None
1313- else self .state .max_iters
1314- )
1280+ if not self ._is_done (self .state ) and self .state .max_epochs is not None :
1281+ total_iters = self .state .epoch_length * self .state .max_epochs
13151282
13161283 warnings .warn (
13171284 "Data iterator can not provide data anymore but required total number of "
@@ -1342,10 +1309,6 @@ def _run_once_on_dataset_legacy(self) -> float:
13421309 if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
13431310 break
13441311
1345- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
1346- self .should_terminate = True
1347- raise _EngineTerminateException ()
1348-
13491312 except _EngineTerminateSingleEpochException :
13501313 self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
13511314 self ._setup_dataloader_iter ()
0 commit comments