|
15 | 15 | import torch.multiprocessing
|
16 | 16 | from azureml._restclient.constants import RunStatus
|
17 | 17 | from azureml.core import Model, Run, model
|
| 18 | +from health_azure import AzureRunInfo |
| 19 | +from health_azure.utils import ENVIRONMENT_VERSION, create_run_recovery_id, is_global_rank_zero |
18 | 20 | from pytorch_lightning import LightningModule, seed_everything
|
19 | 21 | from pytorch_lightning.core.datamodule import LightningDataModule
|
20 | 22 | from torch.utils.data import DataLoader
|
21 | 23 |
|
22 | 24 | from InnerEye.Azure import azure_util
|
23 | 25 | from InnerEye.Azure.azure_config import AzureConfig
|
24 | 26 | from InnerEye.Azure.azure_runner import ENV_OMPI_COMM_WORLD_RANK, get_git_tags
|
25 |
| -from InnerEye.Azure.azure_util import CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, DEFAULT_CROSS_VALIDATION_SPLIT_INDEX, \ |
26 |
| - EFFECTIVE_RANDOM_SEED_KEY_NAME, IS_ENSEMBLE_KEY_NAME, MODEL_ID_KEY_NAME, PARENT_RUN_CONTEXT, \ |
27 |
| - PARENT_RUN_ID_KEY_NAME, RUN_CONTEXT, RUN_RECOVERY_FROM_ID_KEY_NAME, RUN_RECOVERY_ID_KEY_NAME, \ |
28 |
| - get_all_environment_files, is_offline_run_context |
| 27 | +from InnerEye.Azure.azure_util import ( |
| 28 | + CROSS_VALIDATION_SPLIT_INDEX_TAG_KEY, DEFAULT_CROSS_VALIDATION_SPLIT_INDEX, EFFECTIVE_RANDOM_SEED_KEY_NAME, |
| 29 | + IS_ENSEMBLE_KEY_NAME, MODEL_ID_KEY_NAME, PARENT_RUN_CONTEXT, PARENT_RUN_ID_KEY_NAME, RUN_CONTEXT, |
| 30 | + RUN_RECOVERY_FROM_ID_KEY_NAME, RUN_RECOVERY_ID_KEY_NAME, get_all_environment_files, is_offline_run_context |
| 31 | +) |
29 | 32 | from InnerEye.Common import fixed_paths
|
30 |
| -from InnerEye.Common.common_util import (BASELINE_COMPARISONS_FOLDER, BASELINE_WILCOXON_RESULTS_FILE, |
31 |
| - CROSSVAL_RESULTS_FOLDER, ENSEMBLE_SPLIT_NAME, FULL_METRICS_DATAFRAME_FILE, |
32 |
| - METRICS_AGGREGATES_FILE, ModelProcessing, |
33 |
| - OTHER_RUNS_SUBDIR_NAME, SCATTERPLOTS_SUBDIR_NAME, SUBJECT_METRICS_FILE_NAME, |
34 |
| - change_working_directory, get_best_epoch_results_path, is_windows, |
35 |
| - logging_section, print_exception, remove_file_or_directory) |
| 33 | +from InnerEye.Common.common_util import ( |
| 34 | + BASELINE_COMPARISONS_FOLDER, BASELINE_WILCOXON_RESULTS_FILE, CROSSVAL_RESULTS_FOLDER, ENSEMBLE_SPLIT_NAME, |
| 35 | + FULL_METRICS_DATAFRAME_FILE, METRICS_AGGREGATES_FILE, OTHER_RUNS_SUBDIR_NAME, SCATTERPLOTS_SUBDIR_NAME, |
| 36 | + SUBJECT_METRICS_FILE_NAME, ModelProcessing, change_working_directory, get_best_epoch_results_path, |
| 37 | + is_windows, logging_section, merge_conda_files, print_exception, remove_file_or_directory |
| 38 | +) |
36 | 39 | from InnerEye.Common.fixed_paths import INNEREYE_PACKAGE_NAME, PYTHON_ENVIRONMENT_NAME
|
37 | 40 | from InnerEye.Common.type_annotations import PathOrString
|
38 | 41 | from InnerEye.ML.baselines_util import compare_folders_and_run_outputs
|
39 |
| -from InnerEye.ML.common import CHECKPOINT_FOLDER, EXTRA_RUN_SUBFOLDER, FINAL_ENSEMBLE_MODEL_FOLDER, \ |
40 |
| - FINAL_MODEL_FOLDER, \ |
41 |
| - ModelExecutionMode |
| 42 | +from InnerEye.ML.common import ( |
| 43 | + CHECKPOINT_FOLDER, EXTRA_RUN_SUBFOLDER, FINAL_ENSEMBLE_MODEL_FOLDER, FINAL_MODEL_FOLDER, ModelExecutionMode |
| 44 | +) |
42 | 45 | from InnerEye.ML.config import SegmentationModelBase
|
43 |
| -from InnerEye.ML.deep_learning_config import DeepLearningConfig, ModelCategory, MultiprocessingStartMethod, \ |
44 |
| - load_checkpoint |
| 46 | +from InnerEye.ML.deep_learning_config import ( |
| 47 | + DeepLearningConfig, ModelCategory, MultiprocessingStartMethod, load_checkpoint |
| 48 | +) |
45 | 49 | from InnerEye.ML.lightning_base import InnerEyeContainer
|
46 | 50 | from InnerEye.ML.lightning_container import InnerEyeInference, LightningContainer
|
47 | 51 | from InnerEye.ML.lightning_loggers import StoringLogger
|
|
50 | 54 | from InnerEye.ML.model_inference_config import ModelInferenceConfig
|
51 | 55 | from InnerEye.ML.model_testing import model_test
|
52 | 56 | from InnerEye.ML.model_training import create_lightning_trainer, model_train
|
53 |
| -from InnerEye.ML.reports.notebook_report import generate_classification_crossval_notebook, \ |
54 |
| - generate_classification_multilabel_notebook, generate_classification_notebook, generate_segmentation_notebook, \ |
55 |
| - get_ipynb_report_name, reports_folder |
| 57 | +from InnerEye.ML.reports.notebook_report import ( |
| 58 | + generate_classification_crossval_notebook, generate_classification_multilabel_notebook, |
| 59 | + generate_classification_notebook, generate_segmentation_notebook, get_ipynb_report_name, reports_folder |
| 60 | +) |
56 | 61 | from InnerEye.ML.scalar_config import ScalarModelBase
|
57 | 62 | from InnerEye.ML.utils.checkpoint_handling import CheckpointHandler, download_all_checkpoints_from_run
|
58 | 63 | from InnerEye.ML.visualizers import activation_maps
|
59 |
| -from InnerEye.ML.visualizers.plot_cross_validation import \ |
| 64 | +from InnerEye.ML.visualizers.plot_cross_validation import ( |
60 | 65 | get_config_and_results_for_offline_runs, plot_cross_validation_from_files
|
61 |
| -from health_azure import AzureRunInfo |
62 |
| -from health_azure.utils import ENVIRONMENT_VERSION, create_run_recovery_id, is_global_rank_zero, merge_conda_files |
| 66 | +) |
63 | 67 |
|
64 | 68 | ModelDeploymentHookSignature = Callable[[LightningContainer, AzureConfig, Model, ModelProcessing], Any]
|
65 | 69 | PostCrossValidationHookSignature = Callable[[ModelConfigBase, Path], None]
|
@@ -797,8 +801,10 @@ def create_ensemble_model_and_run_inference(self) -> None:
|
797 | 801 | remove_file_or_directory(other_runs_dir)
|
798 | 802 |
|
799 | 803 | def plot_cross_validation_and_upload_results(self) -> Path:
|
800 |
| - from InnerEye.ML.visualizers.plot_cross_validation import crossval_config_from_model_config, \ |
801 |
| - plot_cross_validation, unroll_aggregate_metrics |
| 804 | + from InnerEye.ML.visualizers.plot_cross_validation import ( |
| 805 | + crossval_config_from_model_config, plot_cross_validation, unroll_aggregate_metrics |
| 806 | + ) |
| 807 | + |
802 | 808 | # perform aggregation as cross val splits are now ready
|
803 | 809 | plot_crossval_config = crossval_config_from_model_config(self.innereye_config)
|
804 | 810 | plot_crossval_config.run_recovery_id = PARENT_RUN_CONTEXT.tags[RUN_RECOVERY_ID_KEY_NAME]
|
|
0 commit comments