Skip to content

Commit e36982b

Browse files
authored
update validation epoch (#7121)
- this allows for validation epoch at the very beginning of training - fixes #7122 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li <[email protected]>
1 parent 4743945 commit e36982b

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

Diff for: monai/apps/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,8 @@ def download_url(
203203
if urlparse(url).netloc == "drive.google.com":
204204
if not has_gdown:
205205
raise RuntimeError("To download files from Google Drive, please install the gdown dependency.")
206+
if "fuzzy" not in gdown_kwargs:
207+
gdown_kwargs["fuzzy"] = True # default to true for flexible url
206208
gdown.download(url, f"{tmp_name}", quiet=not progress, **gdown_kwargs)
207209
elif urlparse(url).netloc == "cloud-api.yandex.net":
208210
with urlopen(url) as response:

Diff for: monai/engines/evaluator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ def run(self, global_epoch: int = 1) -> None: # type: ignore[override]
142142
143143
"""
144144
# init env value for current validation process
145-
self.state.max_epochs = global_epoch
145+
self.state.max_epochs = max(global_epoch, 1) # at least one epoch of validation
146146
self.state.epoch = global_epoch - 1
147147
self.state.iteration = 0
148148
super().run()

Diff for: tests/test_handler_validation.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@
2323

2424
class TestEvaluator(Evaluator):
2525
def _iteration(self, engine, batchdata):
26-
pass
26+
engine.state.output = "called"
27+
return engine.state.output
2728

2829

2930
class TestHandlerValidation(unittest.TestCase):
@@ -42,8 +43,9 @@ def _train_func(engine, batch):
4243
ValidationHandler(interval=2, validator=evaluator, exec_at_start=True).attach(engine)
4344
# test execution at start
4445
engine.run(data, max_epochs=1)
45-
self.assertEqual(evaluator.state.max_epochs, 0)
46+
self.assertEqual(evaluator.state.max_epochs, 1)
4647
self.assertEqual(evaluator.state.epoch_length, 8)
48+
self.assertEqual(evaluator.state.output, "called")
4749

4850
engine.run(data, max_epochs=5)
4951
self.assertEqual(evaluator.state.max_epochs, 4)

0 commit comments

Comments
 (0)