|
| 1 | +import json |
| 2 | +import os |
| 3 | +from core.common.constant import ParadigmType |
| 4 | +from examples.yaoba.singletask_learning_boost.resource.utils.infer_and_error import infer_anno, merge_predict_results, \ |
| 5 | + compute_error, gen_txt_according_json, get_new_train_json |
| 6 | +from examples.yaoba.singletask_learning_boost.resource.utils.transform_unkonwn import aug_image_bboxes |
| 7 | +from .singletask_learning import SingleTaskLearning |
| 8 | +import os.path as osp |
| 9 | + |
| 10 | + |
| 11 | +class SingleTaskLearningACBoost(SingleTaskLearning): |
| 12 | + |
| 13 | + def __init__(self, workspace, **kwargs): |
| 14 | + super(SingleTaskLearningACBoost, self).__init__(workspace, **kwargs) |
| 15 | + |
| 16 | + def run(self): |
| 17 | + job = self.build_paradigm_job(str(ParadigmType.SINGLE_TASK_LEARNING.value)) |
| 18 | + known_dataset_json, unknown_dataset_json, img_path = self._prepare_for_calculate_weights() |
| 19 | + base_config_path = osp.join(job.resource_dir, "base_config.py") |
| 20 | + train_script_path = osp.join(job.resource_dir, "train.py") |
| 21 | + ac_boost_training_json, aug_img_folder = self._calculate_weights_for_training( |
| 22 | + base_config=base_config_path, |
| 23 | + known_json_path=known_dataset_json, |
| 24 | + unknown_json_path=unknown_dataset_json, |
| 25 | + img_path=img_path, |
| 26 | + tmp_path=os.path.join(job.work_dir, "tmp_folder"), |
| 27 | + train_script_path=train_script_path |
| 28 | + ) |
| 29 | + trained_model = self._ac_boost_train(job, ac_boost_training_json, aug_img_folder) |
| 30 | + inference_result = self._inference(job, trained_model) |
| 31 | + self.system_metric_info['use_raw'] = True |
| 32 | + return inference_result, self.system_metric_info |
| 33 | + |
| 34 | + def _ac_boost_train(self, job, training_anno, training_img_folder): |
| 35 | + train_output_model_path = job.train((training_img_folder, training_anno)) |
| 36 | + trained_model_path = job.save(train_output_model_path) |
| 37 | + return trained_model_path |
| 38 | + |
| 39 | + def _inference(self, job, trained_model): |
| 40 | + # Load test set data |
| 41 | + img_prefix = self.dataset.image_folder_url |
| 42 | + ann_file_path = self.dataset.test_url |
| 43 | + ann_file = json.load(open(ann_file_path, mode="r", encoding="utf-8")) |
| 44 | + test_set = [] |
| 45 | + for i in ann_file['images']: |
| 46 | + test_set.append(os.path.join(img_prefix, i['file_name'])) |
| 47 | + |
| 48 | + job.load(trained_model) |
| 49 | + infer_res = job.predict(test_set) |
| 50 | + return infer_res |
| 51 | + |
| 52 | + def _prepare_for_calculate_weights(self): |
| 53 | + known_dataset_json = self.dataset.known_dataset_url |
| 54 | + unknown_dataset_json = self.dataset.unknown_dataset_url |
| 55 | + img_path = self.dataset.image_folder_url |
| 56 | + return known_dataset_json, unknown_dataset_json, img_path |
| 57 | + |
| 58 | + def _calculate_weights_for_training(self, |
| 59 | + base_config, |
| 60 | + known_json_path, |
| 61 | + unknown_json_path, |
| 62 | + img_path, |
| 63 | + tmp_path, |
| 64 | + train_script_path): |
| 65 | + r"""Generate instance weights required for unknown task training. In object detection, |
| 66 | + an instance means a bounding box, i.e., generating training weights for each bounding box. |
| 67 | + Args: |
| 68 | + base_config (str): path of config file for training known/unknown model |
| 69 | + known_json_path (str): path of JSON file for training known model |
| 70 | + unknown_json_path (str): path of JSON file for training unknown model |
| 71 | + img_path (str): image path of training, validation, and test set. |
| 72 | + tmp_path (str): path to save temporary files, including augmented images, training JSON files, etc. |
| 73 | + train_script_path (str): path of mmdet training script |
| 74 | + Return: |
| 75 | + new_training_weight (str): JSON file with instance weights for unknown task training, |
| 76 | + which contains both the known and unknown training sets. |
| 77 | + aug_img_folder (str): the image paths required for training the model using the JSON file with instance weights. |
| 78 | + """ |
| 79 | + if not os.path.exists(tmp_path): |
| 80 | + os.mkdir(tmp_path) |
| 81 | + # Define necessary path |
| 82 | + aug_img_folder = osp.join(tmp_path, "aug_img_folder") # The directory for saving augmented images |
| 83 | + known_model_folder = osp.join(tmp_path, "known_model") # The directory for saving known model training results |
| 84 | + unknown_model_folder = osp.join(tmp_path, "unknown_model") # The directory for saving unknown model training results |
| 85 | + aug_unknown_json = osp.join(tmp_path, 'aug_unknown.json') # The JSON file path of the unknown data after augmentation |
| 86 | + |
| 87 | + # Augmenting the unknown data and returning the paths of the augmented images |
| 88 | + aug_image_bboxes( |
| 89 | + anno=unknown_json_path, |
| 90 | + augs=[('flip', 1), ('brightness', 0.6), ('flip', -1)], |
| 91 | + image_path=img_path, |
| 92 | + out_path=tmp_path |
| 93 | + ) |
| 94 | + |
| 95 | + # Train the known model |
| 96 | + known_model_training_task = f"python {train_script_path} " \ |
| 97 | + f"{base_config} --seed 1 --deterministic --cfg-options " \ |
| 98 | + f"data.train.ann_file={known_json_path} " \ |
| 99 | + f"data.train.img_prefix={img_path} " \ |
| 100 | + f"work_dir={known_model_folder}" |
| 101 | + os.system(known_model_training_task) |
| 102 | + |
| 103 | + # Train the unknown model |
| 104 | + unknown_model_training_task = f"python {train_script_path} " \ |
| 105 | + f"{base_config} --seed 1 --deterministic --cfg-options " \ |
| 106 | + f"data.train.ann_file={aug_unknown_json} " \ |
| 107 | + f"data.train.img_prefix={aug_img_folder} " \ |
| 108 | + f"work_dir={unknown_model_folder}" |
| 109 | + os.system(unknown_model_training_task) |
| 110 | + |
| 111 | + # using above known model to infer unknown data |
| 112 | + infer_anno( |
| 113 | + config_file=base_config, |
| 114 | + checkpoint_file=osp.join(known_model_folder, 'latest.pth'), |
| 115 | + img_path=aug_img_folder, |
| 116 | + anno_path=aug_unknown_json, |
| 117 | + out_path=osp.join(tmp_path, 'unknown_infer_results.json') |
| 118 | + ) |
| 119 | + |
| 120 | + # using above unknown model to infer known data |
| 121 | + infer_anno( |
| 122 | + config_file=base_config, |
| 123 | + checkpoint_file=osp.join(unknown_model_folder, 'latest.pth'), |
| 124 | + img_path=aug_img_folder, |
| 125 | + anno_path=known_json_path, |
| 126 | + out_path=osp.join(tmp_path, 'known_infer_results.json') |
| 127 | + ) |
| 128 | + |
| 129 | + # merging the prediction results and computing error |
| 130 | + merge_predict_results( |
| 131 | + result1=osp.join(tmp_path, 'unknown_infer_results.json'), |
| 132 | + result2=osp.join(tmp_path, 'known_infer_results.json'), |
| 133 | + out_dir=osp.join(tmp_path, "merge_predict_result.json") |
| 134 | + ) |
| 135 | + new_json = compute_error(osp.join(tmp_path, "merge_predict_result.json")) |
| 136 | + |
| 137 | + # generating the weights of the overall training sample based on the prediction error. |
| 138 | + gen_txt_according_json(known_json_path, osp.join(tmp_path, 'known.txt')) |
| 139 | + gen_txt_according_json(aug_unknown_json, osp.join(tmp_path, 'aug_unknown.txt')) |
| 140 | + get_new_train_json( |
| 141 | + new_json, |
| 142 | + aug_img_folder, |
| 143 | + osp.join(tmp_path, 'known.txt'), |
| 144 | + osp.join(tmp_path, 'aug_unknown.txt'), |
| 145 | + out_dir=osp.join(tmp_path, 'new_training_weight.json')) |
| 146 | + |
| 147 | + return osp.join(tmp_path, 'new_training_weight.json'), aug_img_folder |
0 commit comments