Skip to content

Commit e50fa88

Browse files
authored
Add multi-thread unittest for mlflow handler (#5755)
Signed-off-by: binliu <[email protected]> ### Description Add a multi-thread unit test for mlflow handler. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: binliu <[email protected]>
1 parent 35db359 commit e50fa88

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

tests/test_handler_mlflow.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,10 @@
1111

1212
import glob
1313
import os
14+
import shutil
1415
import tempfile
1516
import unittest
17+
from concurrent.futures import ThreadPoolExecutor
1618

1719
import numpy as np
1820
from ignite.engine import Engine, Events
@@ -21,7 +23,38 @@
2123
from monai.utils import path_to_uri
2224

2325

26+
def dummy_train(tracking_folder):
27+
tempdir = tempfile.mkdtemp()
28+
29+
# set up engine
30+
def _train_func(engine, batch):
31+
return [batch + 1.0]
32+
33+
engine = Engine(_train_func)
34+
35+
# set up testing handler
36+
test_path = os.path.join(tempdir, tracking_folder)
37+
handler = MLFlowHandler(
38+
iteration_log=False,
39+
epoch_log=True,
40+
tracking_uri=path_to_uri(test_path),
41+
state_attributes=["test"],
42+
close_on_complete=True,
43+
)
44+
handler.attach(engine)
45+
engine.run(range(3), max_epochs=2)
46+
return test_path
47+
48+
2449
class TestHandlerMLFlow(unittest.TestCase):
50+
def setUp(self):
51+
self.tmpdir_list = []
52+
53+
def tearDown(self):
54+
for tmpdir in self.tmpdir_list:
55+
if tmpdir and os.path.exists(tmpdir):
56+
shutil.rmtree(tmpdir)
57+
2558
def test_metrics_track(self):
2659
experiment_param = {"backbone": "efficientnet_b0"}
2760
with tempfile.TemporaryDirectory() as tempdir:
@@ -61,6 +94,18 @@ def _update_metric(engine):
6194
# check logging output
6295
self.assertTrue(len(glob.glob(test_path)) > 0)
6396

97+
def test_multi_thread(self):
98+
test_uri_list = ["monai_mlflow_test1", "monai_mlflow_test2"]
99+
with ThreadPoolExecutor(2, "Training") as executor:
100+
futures = {}
101+
for t in test_uri_list:
102+
futures[t] = executor.submit(dummy_train, t)
103+
104+
for _, future in futures.items():
105+
res = future.result()
106+
self.tmpdir_list.append(res)
107+
self.assertTrue(len(glob.glob(res)) > 0)
108+
64109

65110
if __name__ == "__main__":
66111
unittest.main()

0 commit comments

Comments
 (0)