|
11 | 11 |
|
12 | 12 | import glob
|
13 | 13 | import os
|
| 14 | +import shutil |
14 | 15 | import tempfile
|
15 | 16 | import unittest
|
| 17 | +from concurrent.futures import ThreadPoolExecutor |
16 | 18 |
|
17 | 19 | import numpy as np
|
18 | 20 | from ignite.engine import Engine, Events
|
|
21 | 23 | from monai.utils import path_to_uri
|
22 | 24 |
|
23 | 25 |
|
| 26 | +def dummy_train(tracking_folder): |
| 27 | + tempdir = tempfile.mkdtemp() |
| 28 | + |
| 29 | + # set up engine |
| 30 | + def _train_func(engine, batch): |
| 31 | + return [batch + 1.0] |
| 32 | + |
| 33 | + engine = Engine(_train_func) |
| 34 | + |
| 35 | + # set up testing handler |
| 36 | + test_path = os.path.join(tempdir, tracking_folder) |
| 37 | + handler = MLFlowHandler( |
| 38 | + iteration_log=False, |
| 39 | + epoch_log=True, |
| 40 | + tracking_uri=path_to_uri(test_path), |
| 41 | + state_attributes=["test"], |
| 42 | + close_on_complete=True, |
| 43 | + ) |
| 44 | + handler.attach(engine) |
| 45 | + engine.run(range(3), max_epochs=2) |
| 46 | + return test_path |
| 47 | + |
| 48 | + |
24 | 49 | class TestHandlerMLFlow(unittest.TestCase):
|
| 50 | + def setUp(self): |
| 51 | + self.tmpdir_list = [] |
| 52 | + |
| 53 | + def tearDown(self): |
| 54 | + for tmpdir in self.tmpdir_list: |
| 55 | + if tmpdir and os.path.exists(tmpdir): |
| 56 | + shutil.rmtree(tmpdir) |
| 57 | + |
25 | 58 | def test_metrics_track(self):
|
26 | 59 | experiment_param = {"backbone": "efficientnet_b0"}
|
27 | 60 | with tempfile.TemporaryDirectory() as tempdir:
|
@@ -61,6 +94,18 @@ def _update_metric(engine):
|
61 | 94 | # check logging output
|
62 | 95 | self.assertTrue(len(glob.glob(test_path)) > 0)
|
63 | 96 |
|
| 97 | + def test_multi_thread(self): |
| 98 | + test_uri_list = ["monai_mlflow_test1", "monai_mlflow_test2"] |
| 99 | + with ThreadPoolExecutor(2, "Training") as executor: |
| 100 | + futures = {} |
| 101 | + for t in test_uri_list: |
| 102 | + futures[t] = executor.submit(dummy_train, t) |
| 103 | + |
| 104 | + for _, future in futures.items(): |
| 105 | + res = future.result() |
| 106 | + self.tmpdir_list.append(res) |
| 107 | + self.assertTrue(len(glob.glob(res)) > 0) |
| 108 | + |
64 | 109 |
|
65 | 110 | if __name__ == "__main__":
|
66 | 111 | unittest.main()
|
0 commit comments