22import shutil
33import unittest
44import json
5+ import logging
56
67from medcat .cdb import CDB
78from medcat .config_rel_cat import ConfigRelCAT
89from medcat .rel_cat import RelCAT
910from medcat .utils .relation_extraction .tokenizer import TokenizerWrapperBERT
11+ from medcat .utils .relation_extraction .rel_dataset import RelData
1012
1113from transformers .models .auto .tokenization_auto import AutoTokenizer
1214
1315import spacy
1416from spacy .tokens import Span , Doc
1517
18+
1619class RelCATTests (unittest .TestCase ):
1720
1821 @classmethod
@@ -24,6 +27,7 @@ def setUpClass(cls) -> None:
2427 config .train .nclasses = 3
2528 config .model .hidden_size = 256
2629 config .model .model_size = 2304
30+ config .general .log_level = logging .DEBUG
2731
2832 tokenizer = TokenizerWrapperBERT (AutoTokenizer .from_pretrained (
2933 pretrained_model_name_or_path = config .general .model_name ,
@@ -58,13 +62,37 @@ def setUpClass(cls) -> None:
5862 cls .finished = False
5963 cls .tokenizer = tokenizer
6064
65+ def test_dataset_relation_parser (self ) -> None :
66+
67+ samples = [
68+ "The [s1]45-year-old male[e1] was diagnosed with [s2]hypertension[e2] during his routine check-up." ,
69+ "The patient’s [s1]chest pain[e1] was associated with [s2]shortness of breath[e2]." ,
70+ "[s1]Blood pressure[e1] readings of [s2]160/90 mmHg[e2] indicated possible hypertension." ,
71+ "His elevated [s1]blood glucose[e1] level of [s2]220 mg/dL[e2] raised concerns about his diabetes management." ,
72+ "The doctor recommended a [s1]cardiac enzyme test[e1] to assess the risk of [s2]myocardial infarction[e2]." ,
73+ "The patient’s [s1]ECG[e1] showed signs of [s2]ischemia[e2]" ,
74+ "To manage his [s1]hypertension[e1], the patient was advised to [s2]reduce salt intake[e2]." ,
75+ "[s1]Increased physical activity[e1][s2]type 2 diabetes[e2]."
76+ ]
77+
78+ rel_dataset = RelData (cdb = self .rel_cat .cdb , config = self .config_rel_cat , tokenizer = self .tokenizer )
79+
80+ rels = []
81+
82+ for idx in range (len (samples )):
83+ tkns = self .tokenizer (samples [idx ])["tokens" ]
84+ ent1_ent2_tokens_start_pos = (tkns .index ("[s1]" ), tkns .index ("[s2]" ))
85+ rels .append (rel_dataset .create_base_relations_from_doc (samples [idx ], idx ,
86+ ent1_ent2_tokens_start_pos = ent1_ent2_tokens_start_pos ))
87+
88+ assert len (rels ) == len (samples )
89+
90+
6191 def test_train_csv_no_tags (self ) -> None :
6292 self .rel_cat .config .train .epochs = 2
6393 self .rel_cat .train (train_csv_path = self .medcat_rels_csv_path_train , test_csv_path = self .medcat_rels_csv_path_test , checkpoint_path = self .tmp_dir )
6494 self .rel_cat .save (self .save_model_path )
6595
66-
67-
6896 def test_train_mctrainer (self ) -> None :
6997 self .rel_cat = RelCAT .load (self .save_model_path )
7098 self .rel_cat .config .general .create_addl_rels = True
@@ -77,7 +105,6 @@ def test_train_mctrainer(self) -> None:
77105 self .rel_cat .train (export_data_path = self .medcat_export_with_rels_path , checkpoint_path = self .tmp_dir )
78106
79107
80-
81108 def test_train_predict (self ) -> None :
82109 Span .set_extension ('id' , default = 0 , force = True )
83110 Span .set_extension ('cui' , default = None , force = True )
@@ -103,6 +130,7 @@ def test_train_predict(self) -> None:
103130
104131 assert len (doc ._ .relations ) > 0
105132
133+
106134 def tearDown (self ) -> None :
107135 if self .finished :
108136 if os .path .exists (self .tmp_dir ):
0 commit comments