diff --git a/medcat/cat.py b/medcat/cat.py index 13042acd0..ce63c07a6 100644 --- a/medcat/cat.py +++ b/medcat/cat.py @@ -40,6 +40,7 @@ from medcat.utils.saving.serializer import SPECIALITY_NAMES, ONE2MANY from medcat.utils.saving.envsnapshot import get_environment_info, ENV_SNAPSHOT_FILE_NAME from medcat.stats.stats import get_stats +from medcat.stats.mctexport import count_all_annotations, iter_anns from medcat.utils.filters import set_project_filters from medcat.utils.usage_monitoring import UsageMonitor @@ -808,7 +809,8 @@ def train_supervised_from_json(self, retain_extra_cui_filter: bool = False, checkpoint: Optional[Checkpoint] = None, retain_filters: bool = False, - is_resumed: bool = False) -> Tuple: + is_resumed: bool = False, + train_meta_cats: bool = False) -> Tuple: """ Run supervised training on a dataset from MedCATtrainer in JSON format. @@ -825,7 +827,7 @@ def train_supervised_from_json(self, devalue_others, use_groups, never_terminate, train_from_false_positives, extra_cui_filter, retain_extra_cui_filter, checkpoint, - retain_filters, is_resumed) + retain_filters, is_resumed, train_meta_cats) def train_supervised_raw(self, data: Dict[str, List[Dict[str, dict]]], @@ -845,7 +847,8 @@ def train_supervised_raw(self, retain_extra_cui_filter: bool = False, checkpoint: Optional[Checkpoint] = None, retain_filters: bool = False, - is_resumed: bool = False) -> Tuple: + is_resumed: bool = False, + train_meta_cats: bool = False) -> Tuple: """Train supervised based on the raw data provided. The raw data is expected in the following format: @@ -922,6 +925,8 @@ def train_supervised_raw(self, a ValueError is raised. The merging is done in the first epoch. is_resumed (bool): If True resume the previous training; If False, start a fresh new training. + train_meta_cats (bool): + If True, also trains the appropriate MetaCATs. Raises: ValueError: If attempting to retain filters with while training over multiple projects. @@ -1081,6 +1086,21 @@ def train_supervised_raw(self, use_overlaps=use_overlaps, use_groups=use_groups, extra_cui_filter=extra_cui_filter) + if (train_meta_cats and + # NOTE if no annnotaitons, no point + count_all_annotations(data) > 0): # type: ignore + # NOTE: if there + logger.info("Training MetaCATs within train_supervised_raw") + _, _, ann0 = next(iter_anns(data)) # type: ignore + for meta_cat in self._meta_cats: + # only consider meta-cats that have been defined for the category + if 'meta_anns' in ann0: + ann_names = ann0['meta_anns'].keys() # type: ignore + # adapt to alternative names if applicable + cat_name = meta_cat.config.general.get_applicable_category_name(ann_names) + if cat_name in ann_names: + logger.debug("Training MetaCAT %s", meta_cat.config.general.category_name) + meta_cat.train_raw(data) # reset the state of filters self.config.linking.filters = orig_filters diff --git a/tests/test_cat.py b/tests/test_cat.py index 17cdd2819..432073822 100644 --- a/tests/test_cat.py +++ b/tests/test_cat.py @@ -383,6 +383,27 @@ def test_train_supervised_can_retain_MCT_filters(self, extra_cui_filter=None, re with self.subTest(f'CUI: {filtered_cui}'): self.assertTrue(filtered_cui in self.undertest.config.linking.filters.cuis) + def _test_train_sup_with_meta_cat(self, train_meta_cats: bool): + # def side_effect(doc, *args, **kwargs): + # raise ValueError() + # # return doc + meta_cat = _get_meta_cat(self.meta_cat_dir) + cat = CAT(cdb=self.cdb, config=self.cdb.config, vocab=self.vocab, meta_cats=[meta_cat]) + with patch.object(MetaCAT, "train_raw") as mock_train: + with patch.object(MetaCAT, "__call__", side_effect=lambda doc: doc): + cat.train_supervised_raw(get_fixed_meta_cat_data(), never_terminate=True, + train_meta_cats=train_meta_cats) + if train_meta_cats: + mock_train.assert_called() + else: + mock_train.assert_not_called() + + def test_train_supervised_does_not_train_meta_cat_by_default(self): + self._test_train_sup_with_meta_cat(False) + + def test_train_supervised_can_train_meta_cats(self): + self._test_train_sup_with_meta_cat(True) + def test_train_supervised_no_leak_extra_cui_filters(self): self.test_train_supervised_does_not_retain_MCT_filters_default(extra_cui_filter={'C123', 'C111'}) @@ -799,6 +820,9 @@ def test_loading_model_pack_without_any_config_raises_exception(self): CAT.load_model_pack(self.temp_dir.name) +META_CAT_JSON_PATH = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "mct_export_for_meta_cat_test.json") + + def _get_meta_cat(meta_cat_dir): config = ConfigMetaCAT() config.general["category_name"] = "Status" @@ -808,11 +832,31 @@ def _get_meta_cat(meta_cat_dir): embeddings=None, config=config) os.makedirs(meta_cat_dir, exist_ok=True) - json_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "resources", "mct_export_for_meta_cat_test.json") + json_path = META_CAT_JSON_PATH meta_cat.train_from_json(json_path, save_dir_path=meta_cat_dir) return meta_cat +def get_fixed_meta_cat_data(path: str = META_CAT_JSON_PATH): + with open(path) as f: + data = json.load(f) + for proj_num, project in enumerate(data['projects']): + if 'name' not in project: + project['name'] = f"Proj_{proj_num}" + if 'cuis' not in project: + project['cuis'] = '' + if 'id' not in project: + project['id'] = f'P{proj_num}' + for doc in project['documents']: + if 'entities' in doc and 'annotations' not in doc: + ents = doc.pop("entities") + doc['annotations'] = list(ents.values()) + for ann in doc['annotations']: + if 'pretty_name' in ann and 'value' not in ann: + ann['value'] = ann.pop('pretty_name') + return data + + class TestLoadingOldWeights(unittest.TestCase): cdb_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "..", "examples", "cdb_old_broken_weights_in_config.dat")