Skip to content

Commit d3f05b7

Browse files
committed
Revert "Adding max_iters as an optional arg in Engine run (#1381)"
1 parent 2782a00 commit d3f05b7

5 files changed

Lines changed: 9 additions & 109 deletions

File tree

ignite/engine/engine.py

Lines changed: 8 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
import functools
22
import logging
3-
import math
43
import time
54
import warnings
65
import weakref
@@ -764,14 +763,13 @@ def load_state_dict(self, state_dict: Mapping) -> None:
764763

765764
@staticmethod
766765
def _is_done(state: State) -> bool:
767-
is_done_iters = state.max_iters is not None and state.iteration >= state.max_iters
768766
is_done_count = (
769767
state.epoch_length is not None
770768
and state.max_epochs is not None
771769
and state.iteration >= state.epoch_length * state.max_epochs
772770
)
773771
is_done_epochs = state.max_epochs is not None and state.epoch >= state.max_epochs
774-
return is_done_iters or is_done_count or is_done_epochs
772+
return is_done_count or is_done_epochs
775773

776774
def set_data(self, data: Iterable | DataLoader) -> None:
777775
"""Method to set data. After calling the method the next batch passed to `processing_function` is
@@ -813,14 +811,13 @@ def run(
813811
self,
814812
data: Iterable | None = None,
815813
max_epochs: int | None = None,
816-
max_iters: int | None = None,
817814
epoch_length: int | None = None,
818815
) -> State:
819816
"""Runs the ``process_function`` over the passed data.
820817
821818
Engine has a state and the following logic is applied in this function:
822819
823-
- At the first call, new state is defined by `max_epochs`, `max_iters`, `epoch_length`, if provided.
820+
- At the first call, new state is defined by `max_epochs`, `epoch_length`, if provided.
824821
A timer for total and per-epoch time is initialized when Events.STARTED is handled.
825822
- If state is already defined such that there are iterations to run until `max_epochs` and no input arguments
826823
provided, state is kept and used in the function.
@@ -838,9 +835,6 @@ def run(
838835
`len(data)`. If `data` is an iterator and `epoch_length` is not set, then it will be automatically
839836
determined as the iteration on which data iterator raises `StopIteration`.
840837
This argument should not change if run is resuming from a state.
841-
max_iters: Number of iterations to run for.
842-
`max_iters` and `max_epochs` are mutually exclusive; only one of the two arguments should be provided.
843-
844838
Returns:
845839
State: output state.
846840
@@ -891,6 +885,8 @@ def switch_batch(engine):
891885

892886
if self.state.max_epochs is None or (self._is_done(self.state) and self._internal_run_generator is None):
893887
# Create new state
888+
if max_epochs is None:
889+
max_epochs = 1
894890
if epoch_length is None:
895891
if data is None:
896892
raise ValueError("epoch_length should be provided if data is None")
@@ -899,22 +895,9 @@ def switch_batch(engine):
899895
if epoch_length is not None and epoch_length < 1:
900896
raise ValueError("Input data has zero size. Please provide non-empty data")
901897

902-
if max_iters is None:
903-
if max_epochs is None:
904-
max_epochs = 1
905-
else:
906-
if max_epochs is not None:
907-
raise ValueError(
908-
"Arguments max_iters and max_epochs are mutually exclusive."
909-
"Please provide only max_epochs or max_iters."
910-
)
911-
if epoch_length is not None:
912-
max_epochs = math.ceil(max_iters / epoch_length)
913-
914898
self.state.iteration = 0
915899
self.state.epoch = 0
916900
self.state.max_epochs = max_epochs
917-
self.state.max_iters = max_iters
918901
self.state.epoch_length = epoch_length
919902
# Reset generator if previously used
920903
self._internal_run_generator = None
@@ -1114,18 +1097,12 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
11141097
if self.state.epoch_length is None:
11151098
# Define epoch length and stop the epoch
11161099
self.state.epoch_length = iter_counter
1117-
if self.state.max_iters is not None:
1118-
self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length)
11191100
break
11201101

11211102
# Should exit while loop if we can not iterate
11221103
if should_exit:
1123-
if not self._is_done(self.state):
1124-
total_iters = (
1125-
self.state.epoch_length * self.state.max_epochs
1126-
if self.state.max_epochs is not None
1127-
else self.state.max_iters
1128-
)
1104+
if not self._is_done(self.state) and self.state.max_epochs is not None:
1105+
total_iters = self.state.epoch_length * self.state.max_epochs
11291106

11301107
warnings.warn(
11311108
"Data iterator can not provide data anymore but required total number of "
@@ -1156,10 +1133,6 @@ def _run_once_on_dataset_as_gen(self) -> Generator[State, None, float]:
11561133
if self.state.epoch_length is not None and iter_counter == self.state.epoch_length:
11571134
break
11581135

1159-
if self.state.max_iters is not None and self.state.iteration == self.state.max_iters:
1160-
self.should_terminate = True
1161-
raise _EngineTerminateException()
1162-
11631136
except _EngineTerminateSingleEpochException:
11641137
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
11651138
self._setup_dataloader_iter()
@@ -1300,18 +1273,12 @@ def _run_once_on_dataset_legacy(self) -> float:
13001273
if self.state.epoch_length is None:
13011274
# Define epoch length and stop the epoch
13021275
self.state.epoch_length = iter_counter
1303-
if self.state.max_iters is not None:
1304-
self.state.max_epochs = math.ceil(self.state.max_iters / self.state.epoch_length)
13051276
break
13061277

13071278
# Should exit while loop if we can not iterate
13081279
if should_exit:
1309-
if not self._is_done(self.state):
1310-
total_iters = (
1311-
self.state.epoch_length * self.state.max_epochs
1312-
if self.state.max_epochs is not None
1313-
else self.state.max_iters
1314-
)
1280+
if not self._is_done(self.state) and self.state.max_epochs is not None:
1281+
total_iters = self.state.epoch_length * self.state.max_epochs
13151282

13161283
warnings.warn(
13171284
"Data iterator can not provide data anymore but required total number of "
@@ -1342,10 +1309,6 @@ def _run_once_on_dataset_legacy(self) -> float:
13421309
if self.state.epoch_length is not None and iter_counter == self.state.epoch_length:
13431310
break
13441311

1345-
if self.state.max_iters is not None and self.state.iteration == self.state.max_iters:
1346-
self.should_terminate = True
1347-
raise _EngineTerminateException()
1348-
13491312
except _EngineTerminateSingleEpochException:
13501313
self._fire_event(Events.TERMINATE_SINGLE_EPOCH, iter_counter=iter_counter)
13511314
self._setup_dataloader_iter()

ignite/engine/events.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,6 @@ class State:
461461
state.dataloader # data passed to engine
462462
state.epoch_length # optional length of an epoch
463463
state.max_epochs # number of epochs to run
464-
state.max_iters # number of iterations to run
465464
state.batch # batch passed to `process_function`
466465
state.output # output of `process_function` after a single iteration
467466
state.metrics # dictionary with defined metrics if any
@@ -488,7 +487,6 @@ def __init__(self, **kwargs: Any) -> None:
488487
self.epoch = 0
489488
self.epoch_length: int | None = None
490489
self.max_epochs: int | None = None
491-
self.max_iters: int | None = None
492490
self.output: int | None = None
493491
self.batch: int | None = None
494492
self.metrics: dict[str, Any] = {}

ignite/handlers/lr_finder.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,6 @@ def _run(
109109
max_iter = trainer.state.epoch_length * trainer.state.max_epochs
110110
if max_iter < num_iter:
111111
max_iter = num_iter
112-
trainer.state.max_iters = num_iter
113112
trainer.state.max_epochs = ceil(num_iter / trainer.state.epoch_length)
114113

115114
if not trainer.has_event_handler(self._reached_num_iterations):

tests/ignite/engine/test_engine.py

Lines changed: 0 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
import math
21
import os
32
import time
43
from unittest.mock import call, MagicMock, Mock
@@ -1079,47 +1078,6 @@ def switch_dataloader():
10791078

10801079
trainer.run(data1, max_epochs=10)
10811080

1082-
def test_run_with_max_iters(self):
1083-
max_iters = 8
1084-
engine = Engine(lambda e, b: 1)
1085-
engine.run([0] * 20, max_iters=max_iters)
1086-
assert engine.state.iteration == max_iters
1087-
assert engine.state.max_iters == max_iters
1088-
1089-
def test_run_with_max_iters_greater_than_epoch_length(self):
1090-
max_iters = 73
1091-
engine = Engine(lambda e, b: 1)
1092-
engine.run([0] * 20, max_iters=max_iters)
1093-
assert engine.state.iteration == max_iters
1094-
1095-
def test_run_with_invalid_max_iters_and_max_epoch(self):
1096-
max_iters = 12
1097-
max_epochs = 2
1098-
engine = Engine(lambda e, b: 1)
1099-
with pytest.raises(
1100-
ValueError,
1101-
match=r"Arguments max_iters and max_epochs are mutually exclusive."
1102-
"Please provide only max_epochs or max_iters.",
1103-
):
1104-
engine.run([0] * 20, max_iters=max_iters, max_epochs=max_epochs)
1105-
1106-
def test_epoch_events_fired_max_iters(self):
1107-
max_iters = 32
1108-
engine = Engine(lambda e, b: 1)
1109-
1110-
@engine.on(Events.EPOCH_COMPLETED)
1111-
def fired_event(engine):
1112-
assert engine.state.iteration % engine.state.epoch_length == 0
1113-
1114-
engine.run([0] * 10, max_iters=max_iters)
1115-
1116-
def test_is_done_with_max_iters(self):
1117-
state = State(iteration=100, epoch=1, max_epochs=3, epoch_length=100, max_iters=250)
1118-
assert not Engine._is_done(state)
1119-
1120-
state = State(iteration=250, epoch=1, max_epochs=3, epoch_length=100, max_iters=250)
1121-
assert Engine._is_done(state)
1122-
11231081
@pytest.mark.skipif(torch.cuda.device_count() < 1, reason="Skip if no GPU")
11241082
def test_batch_is_released_before_new_one_is_loaded_on_cuda(self):
11251083
torch.cuda.empty_cache()
@@ -1346,24 +1304,6 @@ def __iter__(self):
13461304
assert state.iteration == 3
13471305
assert state.epoch_length == 3
13481306

1349-
def test_max_epochs_calculated_from_max_iters_unknown_epoch_length(self):
1350-
# covers: max_epochs auto-calculated as ceil(max_iters / epoch_length)
1351-
# when data is an iterator of unknown length and max_iters is provided
1352-
1353-
def data_iter():
1354-
for i in range(5):
1355-
yield i
1356-
1357-
engine = Engine(lambda e, b: None)
1358-
1359-
@engine.on(Events.DATALOADER_STOP_ITERATION)
1360-
def restart():
1361-
engine.state.dataloader = data_iter()
1362-
1363-
engine.run(data_iter(), max_iters=7)
1364-
assert engine.state.max_epochs == math.ceil(7 / engine.state.epoch_length)
1365-
assert engine.state.iteration == 7
1366-
13671307
def test_run_resume_raises_on_epoch_length_mismatch(self):
13681308
# covers: ValueError raised when resuming with a different epoch_length than the one stored in state
13691309
engine = Engine(lambda e, b: None)

tests/ignite/handlers/test_lr_finder.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -357,7 +357,7 @@ def test_num_iter_is_not_enough(lr_finder, to_save, dummy_engine, dataloader):
357357
trainer_with_finder.run(dataloader)
358358
assert_output_sizes(lr_finder, dummy_engine)
359359
assert dummy_engine.state.iteration != len(dataloader)
360-
assert dummy_engine.state.iteration == 150
360+
assert dummy_engine.state.iteration == 150 + 1
361361

362362

363363
def test_detach_terminates(lr_finder, to_save, dummy_engine, dataloader):

0 commit comments

Comments
 (0)