1+ import unittest
2+ import tempfile
3+ import json
4+ import os
5+ import shutil
6+ from medcat .cdb import CDB
7+ from medcat .ner .transformers_ner import TransformersNER
8+ from medcat .config_transformers_ner import ConfigTransformersNER
9+
10+ class TestTransformersNER (unittest .TestCase ):
11+ def setUp (self ):
12+ # Create a temporary directory for the test
13+ self .tmp_dir = tempfile .TemporaryDirectory ()
14+ # Create results dir for training outputs
15+ self .results_dir = './results'
16+ os .makedirs (self .results_dir , exist_ok = True )
17+
18+ # Create a minimal CDB
19+ self .cdb = CDB ()
20+
21+ # Create initial training data with 2 labels and multiple examples
22+ self .initial_data = {
23+ "projects" : [{
24+ "documents" : [
25+ {
26+ "text" : "Patient has diabetes and hypertension." ,
27+ "annotations" : [
28+ {
29+ "cui" : "C0011849" , # Diabetes
30+ "start" : 14 ,
31+ "end" : 22 ,
32+ "value" : "diabetes"
33+ },
34+ {
35+ "cui" : "C0020538" , # Hypertension
36+ "start" : 27 ,
37+ "end" : 39 ,
38+ "value" : "hypertension"
39+ }
40+ ]
41+ },
42+ {
43+ "text" : "History of diabetes with hypertension." ,
44+ "annotations" : [
45+ {
46+ "cui" : "C0011849" , # Diabetes
47+ "start" : 12 ,
48+ "end" : 20 ,
49+ "value" : "diabetes"
50+ },
51+ {
52+ "cui" : "C0020538" , # Hypertension
53+ "start" : 26 ,
54+ "end" : 38 ,
55+ "value" : "hypertension"
56+ }
57+ ]
58+ },
59+ {
60+ "text" : "Diagnosed with hypertension and diabetes." ,
61+ "annotations" : [
62+ {
63+ "cui" : "C0020538" , # Hypertension
64+ "start" : 15 ,
65+ "end" : 27 ,
66+ "value" : "hypertension"
67+ },
68+ {
69+ "cui" : "C0011849" , # Diabetes
70+ "start" : 32 ,
71+ "end" : 40 ,
72+ "value" : "diabetes"
73+ }
74+ ]
75+ }
76+ ]
77+ }]
78+ }
79+
80+ # Create new training data with an extra label
81+ self .new_data = {
82+ "projects" : [{
83+ "documents" : [
84+ {
85+ "text" : "Patient has diabetes, hypertension, and asthma." ,
86+ "annotations" : [
87+ {
88+ "cui" : "C0011849" , # Diabetes
89+ "start" : 14 ,
90+ "end" : 22 ,
91+ "value" : "diabetes"
92+ },
93+ {
94+ "cui" : "C0020538" , # Hypertension
95+ "start" : 24 ,
96+ "end" : 36 ,
97+ "value" : "hypertension"
98+ },
99+ {
100+ "cui" : "C0004096" , # Asthma
101+ "start" : 42 ,
102+ "end" : 48 ,
103+ "value" : "asthma"
104+ }
105+ ]
106+ },
107+ {
108+ "text" : "History of asthma with diabetes and hypertension." ,
109+ "annotations" : [
110+ {
111+ "cui" : "C0004096" , # Asthma
112+ "start" : 12 ,
113+ "end" : 18 ,
114+ "value" : "asthma"
115+ },
116+ {
117+ "cui" : "C0011849" , # Diabetes
118+ "start" : 24 ,
119+ "end" : 32 ,
120+ "value" : "diabetes"
121+ },
122+ {
123+ "cui" : "C0020538" , # Hypertension
124+ "start" : 37 ,
125+ "end" : 49 ,
126+ "value" : "hypertension"
127+ }
128+ ]
129+ },
130+ {
131+ "text" : "Diagnosed with asthma, diabetes, and hypertension." ,
132+ "annotations" : [
133+ {
134+ "cui" : "C0004096" , # Asthma
135+ "start" : 15 ,
136+ "end" : 21 ,
137+ "value" : "asthma"
138+ },
139+ {
140+ "cui" : "C0011849" , # Diabetes
141+ "start" : 23 ,
142+ "end" : 31 ,
143+ "value" : "diabetes"
144+ },
145+ {
146+ "cui" : "C0020538" , # Hypertension
147+ "start" : 37 ,
148+ "end" : 49 ,
149+ "value" : "hypertension"
150+ }
151+ ]
152+ }
153+ ]
154+ }]
155+ }
156+
157+ # Save initial training data
158+ self .initial_data_path = os .path .join (self .tmp_dir .name , 'initial_data.json' )
159+ with open (self .initial_data_path , 'w' ) as f :
160+ json .dump (self .initial_data , f )
161+
162+ # Save new training data
163+ self .new_data_path = os .path .join (self .tmp_dir .name , 'new_data.json' )
164+ with open (self .new_data_path , 'w' ) as f :
165+ json .dump (self .new_data , f )
166+
167+ def tearDown (self ):
168+ # Clean up the temporary directory
169+ self .tmp_dir .cleanup ()
170+ # Clean up results directory if it exists
171+ if os .path .exists (self .results_dir ):
172+ shutil .rmtree (self .results_dir )
173+ # Clean up logs directory if it exists
174+ if os .path .exists ('./logs' ):
175+ shutil .rmtree ('./logs' )
176+
177+ def test_ignore_extra_labels (self ):
178+ # Create and train initial model with tiny BERT
179+ config = ConfigTransformersNER ()
180+ config .general ['model_name' ] = 'prajjwal1/bert-tiny'
181+ # Set to single epoch and small test size for faster testing
182+ config .general ['num_train_epochs' ] = 1
183+ config .general ['test_size' ] = 0.1
184+
185+ # Create training arguments with reduced epochs
186+ from transformers import TrainingArguments
187+ training_args = TrainingArguments (
188+ output_dir = self .results_dir , # Use the class results_dir
189+ num_train_epochs = 1
190+ )
191+
192+ ner = TransformersNER (self .cdb , config = config , training_arguments = training_args )
193+ ner .train (self .initial_data_path )
194+
195+ # Save the model
196+ model_path = os .path .join (self .tmp_dir .name , 'model' )
197+ ner .save (model_path )
198+
199+ # Load the saved model
200+ loaded_ner = TransformersNER .load (model_path )
201+
202+ # Get initial number of labels
203+ initial_num_labels = len (loaded_ner .tokenizer .label_map )
204+
205+ # Train with ignore_extra_labels=True
206+ loaded_ner .train (self .new_data_path , ignore_extra_labels = True )
207+
208+ # Verify number of labels hasn't changed
209+ self .assertEqual (
210+ len (loaded_ner .tokenizer .label_map ),
211+ initial_num_labels ,
212+ "Number of labels changed despite ignore_extra_labels=True"
213+ )
214+
215+ # Verify only original labels are present (including special tokens)
216+ expected_labels = {"C0011849" , "C0020538" , "O" , "X" }
217+ self .assertEqual (
218+ set (loaded_ner .tokenizer .label_map .keys ()),
219+ expected_labels ,
220+ "Label map contains unexpected labels"
221+ )
222+
223+ # Train with ignore_extra_labels=False
224+ loaded_ner .train (self .new_data_path , ignore_extra_labels = False )
225+
226+ # Verify new label was added
227+ self .assertEqual (
228+ len (loaded_ner .tokenizer .label_map ),
229+ initial_num_labels + 1 ,
230+ "New label was not added when ignore_extra_labels=False"
231+ )
232+
233+ # Verify all labels are present (including special tokens)
234+ expected_labels = {"C0011849" , "C0020538" , "C0004096" , "O" , "X" }
235+ self .assertEqual (
236+ set (loaded_ner .tokenizer .label_map .keys ()),
237+ expected_labels ,
238+ "Label map missing expected labels"
239+ )
240+
241+ if __name__ == '__main__' :
242+ unittest .main ()
0 commit comments