11import functools
22import logging
3- import math
43import time
54import warnings
65import weakref
@@ -766,14 +765,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
766765
767766 @staticmethod
768767 def _is_done (state : State ) -> bool :
769- is_done_iters = state .max_iters is not None and state .iteration >= state .max_iters
770768 is_done_count = (
771769 state .epoch_length is not None
772770 and state .max_epochs is not None
773771 and state .iteration >= state .epoch_length * state .max_epochs
774772 )
775773 is_done_epochs = state .max_epochs is not None and state .epoch >= state .max_epochs
776- return is_done_iters or is_done_count or is_done_epochs
774+ return is_done_count or is_done_epochs
777775
778776 def set_data (self , data : Union [Iterable , DataLoader ]) -> None :
779777 """Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -815,14 +813,13 @@ def run(
815813 self ,
816814 data : Optional [Iterable ] = None ,
817815 max_epochs : Optional [int ] = None ,
818- max_iters : Optional [int ] = None ,
819816 epoch_length : Optional [int ] = None ,
820817 ) -> State :
821818 """Runs the ``process_function`` over the passed data.
822819
823820 Engine has a state and the following logic is applied in this function:
824821
825- - At the first call, new state is defined by `max_epochs`, `max_iters`, ` epoch_length`, if provided.
822+ - At the first call, new state is defined by `max_epochs`, `epoch_length`, if provided.
826823 A timer for total and per-epoch time is initialized when Events.STARTED is handled.
827824 - If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
828825 provided, state is kept and used in the function.
@@ -840,9 +837,6 @@ def run(
840837 `len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
841838 determined as the iteration on which data iterator raises `StopIteration`.
842839 This argument should not change if run is resuming from a state.
843- max_iters: Number of iterations to run for.
844- `max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
845-
846840 Returns:
847841 State: output state.
848842
@@ -893,6 +887,8 @@ def switch_batch(engine):
893887
894888 if self .state .max_epochs is None or (self ._is_done (self .state ) and self ._internal_run_generator is None ):
895889 # Create new state
890+ if max_epochs is None :
891+ max_epochs = 1
896892 if epoch_length is None :
897893 if data is None :
898894 raise ValueError ("epoch_length should be provided if data is None" )
@@ -901,22 +897,9 @@ def switch_batch(engine):
901897 if epoch_length is not None and epoch_length < 1 :
902898 raise ValueError ("Input data has zero size. Please provide non-empty data" )
903899
904- if max_iters is None :
905- if max_epochs is None :
906- max_epochs = 1
907- else :
908- if max_epochs is not None :
909- raise ValueError (
910- "Arguments max_iters and max_epochs are mutually exclusive."
911- "Please provide only max_epochs or max_iters."
912- )
913- if epoch_length is not None :
914- max_epochs = math .ceil (max_iters / epoch_length )
915-
916900 self .state .iteration = 0
917901 self .state .epoch = 0
918902 self .state .max_epochs = max_epochs
919- self .state .max_iters = max_iters
920903 self .state .epoch_length = epoch_length
921904 # Reset generator if previously used
922905 self ._internal_run_generator = None
@@ -1117,18 +1100,12 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
11171100 if self .state .epoch_length is None :
11181101 # Define epoch length and stop the epoch
11191102 self .state .epoch_length = iter_counter
1120- if self .state .max_iters is not None :
1121- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
11221103 break
11231104
11241105 # Should exit while loop if we can not iterate
11251106 if should_exit :
1126- if not self ._is_done (self .state ):
1127- total_iters = (
1128- self .state .epoch_length * self .state .max_epochs
1129- if self .state .max_epochs is not None
1130- else self .state .max_iters
1131- )
1107+ if not self ._is_done (self .state ) and self .state .max_epochs is not None :
1108+ total_iters = self .state .epoch_length * self .state .max_epochs
11321109
11331110 warnings .warn (
11341111 "Data iterator can not provide data anymore but required total number of "
@@ -1159,10 +1136,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
11591136 if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
11601137 break
11611138
1162- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
1163- self .should_terminate = True
1164- raise _EngineTerminateException ()
1165-
11661139 except _EngineTerminateSingleEpochException :
11671140 self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
11681141 self ._setup_dataloader_iter ()
@@ -1304,18 +1277,12 @@ def _run_once_on_dataset_legacy(self) -> float:
13041277 if self .state .epoch_length is None :
13051278 # Define epoch length and stop the epoch
13061279 self .state .epoch_length = iter_counter
1307- if self .state .max_iters is not None :
1308- self .state .max_epochs = math .ceil (self .state .max_iters / self .state .epoch_length )
13091280 break
13101281
13111282 # Should exit while loop if we can not iterate
13121283 if should_exit :
1313- if not self ._is_done (self .state ):
1314- total_iters = (
1315- self .state .epoch_length * self .state .max_epochs
1316- if self .state .max_epochs is not None
1317- else self .state .max_iters
1318- )
1284+ if not self ._is_done (self .state ) and self .state .max_epochs is not None :
1285+ total_iters = self .state .epoch_length * self .state .max_epochs
13191286
13201287 warnings .warn (
13211288 "Data iterator can not provide data anymore but required total number of "
@@ -1346,10 +1313,6 @@ def _run_once_on_dataset_legacy(self) -> float:
13461313 if self .state .epoch_length is not None and iter_counter == self .state .epoch_length :
13471314 break
13481315
1349- if self .state .max_iters is not None and self .state .iteration == self .state .max_iters :
1350- self .should_terminate = True
1351- raise _EngineTerminateException ()
1352-
13531316 except _EngineTerminateSingleEpochException :
13541317 self ._fire_event (Events .TERMINATE_SINGLE_EPOCH , iter_counter = iter_counter )
13551318 self ._setup_dataloader_iter ()
0 commit comments