Skip to content
This repository was archived by the owner on Jul 28, 2025. It is now read-only.

Commit eab6632

Browse files
committed
CU-8698jzjj3: pass in extra param if ignore_extra_labels is set, and test
1 parent 7507a18 commit eab6632

File tree

2 files changed

+245
-1
lines changed

2 files changed

+245
-1
lines changed

medcat/ner/transformers_ner.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -227,7 +227,9 @@ def train(self,
227227
if self.model.num_labels != len(self.tokenizer.label_map):
228228
logger.warning("The dataset contains labels we've not seen before, model is being reinitialized")
229229
logger.warning("Model: {} vs Dataset: {}".format(self.model.num_labels, len(self.tokenizer.label_map)))
230-
self.model = AutoModelForTokenClassification.from_pretrained(self.config.general['model_name'], num_labels=len(self.tokenizer.label_map))
230+
self.model = AutoModelForTokenClassification.from_pretrained(self.config.general['model_name'],
231+
num_labels=len(self.tokenizer.label_map),
232+
ignore_mismatched_sizes=True)
231233
self.tokenizer.cui2name = {k:self.cdb.get_name(k) for k in self.tokenizer.label_map.keys()}
232234

233235
self.model.config.id2label = {v:k for k,v in self.tokenizer.label_map.items()}

tests/test_transformers_ner.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
import unittest
2+
import tempfile
3+
import json
4+
import os
5+
import shutil
6+
from medcat.cdb import CDB
7+
from medcat.ner.transformers_ner import TransformersNER
8+
from medcat.config_transformers_ner import ConfigTransformersNER
9+
10+
class TestTransformersNER(unittest.TestCase):
11+
def setUp(self):
12+
# Create a temporary directory for the test
13+
self.tmp_dir = tempfile.TemporaryDirectory()
14+
# Create results dir for training outputs
15+
self.results_dir = './results'
16+
os.makedirs(self.results_dir, exist_ok=True)
17+
18+
# Create a minimal CDB
19+
self.cdb = CDB()
20+
21+
# Create initial training data with 2 labels and multiple examples
22+
self.initial_data = {
23+
"projects": [{
24+
"documents": [
25+
{
26+
"text": "Patient has diabetes and hypertension.",
27+
"annotations": [
28+
{
29+
"cui": "C0011849", # Diabetes
30+
"start": 14,
31+
"end": 22,
32+
"value": "diabetes"
33+
},
34+
{
35+
"cui": "C0020538", # Hypertension
36+
"start": 27,
37+
"end": 39,
38+
"value": "hypertension"
39+
}
40+
]
41+
},
42+
{
43+
"text": "History of diabetes with hypertension.",
44+
"annotations": [
45+
{
46+
"cui": "C0011849", # Diabetes
47+
"start": 12,
48+
"end": 20,
49+
"value": "diabetes"
50+
},
51+
{
52+
"cui": "C0020538", # Hypertension
53+
"start": 26,
54+
"end": 38,
55+
"value": "hypertension"
56+
}
57+
]
58+
},
59+
{
60+
"text": "Diagnosed with hypertension and diabetes.",
61+
"annotations": [
62+
{
63+
"cui": "C0020538", # Hypertension
64+
"start": 15,
65+
"end": 27,
66+
"value": "hypertension"
67+
},
68+
{
69+
"cui": "C0011849", # Diabetes
70+
"start": 32,
71+
"end": 40,
72+
"value": "diabetes"
73+
}
74+
]
75+
}
76+
]
77+
}]
78+
}
79+
80+
# Create new training data with an extra label
81+
self.new_data = {
82+
"projects": [{
83+
"documents": [
84+
{
85+
"text": "Patient has diabetes, hypertension, and asthma.",
86+
"annotations": [
87+
{
88+
"cui": "C0011849", # Diabetes
89+
"start": 14,
90+
"end": 22,
91+
"value": "diabetes"
92+
},
93+
{
94+
"cui": "C0020538", # Hypertension
95+
"start": 24,
96+
"end": 36,
97+
"value": "hypertension"
98+
},
99+
{
100+
"cui": "C0004096", # Asthma
101+
"start": 42,
102+
"end": 48,
103+
"value": "asthma"
104+
}
105+
]
106+
},
107+
{
108+
"text": "History of asthma with diabetes and hypertension.",
109+
"annotations": [
110+
{
111+
"cui": "C0004096", # Asthma
112+
"start": 12,
113+
"end": 18,
114+
"value": "asthma"
115+
},
116+
{
117+
"cui": "C0011849", # Diabetes
118+
"start": 24,
119+
"end": 32,
120+
"value": "diabetes"
121+
},
122+
{
123+
"cui": "C0020538", # Hypertension
124+
"start": 37,
125+
"end": 49,
126+
"value": "hypertension"
127+
}
128+
]
129+
},
130+
{
131+
"text": "Diagnosed with asthma, diabetes, and hypertension.",
132+
"annotations": [
133+
{
134+
"cui": "C0004096", # Asthma
135+
"start": 15,
136+
"end": 21,
137+
"value": "asthma"
138+
},
139+
{
140+
"cui": "C0011849", # Diabetes
141+
"start": 23,
142+
"end": 31,
143+
"value": "diabetes"
144+
},
145+
{
146+
"cui": "C0020538", # Hypertension
147+
"start": 37,
148+
"end": 49,
149+
"value": "hypertension"
150+
}
151+
]
152+
}
153+
]
154+
}]
155+
}
156+
157+
# Save initial training data
158+
self.initial_data_path = os.path.join(self.tmp_dir.name, 'initial_data.json')
159+
with open(self.initial_data_path, 'w') as f:
160+
json.dump(self.initial_data, f)
161+
162+
# Save new training data
163+
self.new_data_path = os.path.join(self.tmp_dir.name, 'new_data.json')
164+
with open(self.new_data_path, 'w') as f:
165+
json.dump(self.new_data, f)
166+
167+
def tearDown(self):
168+
# Clean up the temporary directory
169+
self.tmp_dir.cleanup()
170+
# Clean up results directory if it exists
171+
if os.path.exists(self.results_dir):
172+
shutil.rmtree(self.results_dir)
173+
# Clean up logs directory if it exists
174+
if os.path.exists('./logs'):
175+
shutil.rmtree('./logs')
176+
177+
def test_ignore_extra_labels(self):
178+
# Create and train initial model with tiny BERT
179+
config = ConfigTransformersNER()
180+
config.general['model_name'] = 'prajjwal1/bert-tiny'
181+
# Set to single epoch and small test size for faster testing
182+
config.general['num_train_epochs'] = 1
183+
config.general['test_size'] = 0.1
184+
185+
# Create training arguments with reduced epochs
186+
from transformers import TrainingArguments
187+
training_args = TrainingArguments(
188+
output_dir=self.results_dir, # Use the class results_dir
189+
num_train_epochs=1
190+
)
191+
192+
ner = TransformersNER(self.cdb, config=config, training_arguments=training_args)
193+
ner.train(self.initial_data_path)
194+
195+
# Save the model
196+
model_path = os.path.join(self.tmp_dir.name, 'model')
197+
ner.save(model_path)
198+
199+
# Load the saved model
200+
loaded_ner = TransformersNER.load(model_path)
201+
202+
# Get initial number of labels
203+
initial_num_labels = len(loaded_ner.tokenizer.label_map)
204+
205+
# Train with ignore_extra_labels=True
206+
loaded_ner.train(self.new_data_path, ignore_extra_labels=True)
207+
208+
# Verify number of labels hasn't changed
209+
self.assertEqual(
210+
len(loaded_ner.tokenizer.label_map),
211+
initial_num_labels,
212+
"Number of labels changed despite ignore_extra_labels=True"
213+
)
214+
215+
# Verify only original labels are present (including special tokens)
216+
expected_labels = {"C0011849", "C0020538", "O", "X"}
217+
self.assertEqual(
218+
set(loaded_ner.tokenizer.label_map.keys()),
219+
expected_labels,
220+
"Label map contains unexpected labels"
221+
)
222+
223+
# Train with ignore_extra_labels=False
224+
loaded_ner.train(self.new_data_path, ignore_extra_labels=False)
225+
226+
# Verify new label was added
227+
self.assertEqual(
228+
len(loaded_ner.tokenizer.label_map),
229+
initial_num_labels + 1,
230+
"New label was not added when ignore_extra_labels=False"
231+
)
232+
233+
# Verify all labels are present (including special tokens)
234+
expected_labels = {"C0011849", "C0020538", "C0004096", "O", "X"}
235+
self.assertEqual(
236+
set(loaded_ner.tokenizer.label_map.keys()),
237+
expected_labels,
238+
"Label map missing expected labels"
239+
)
240+
241+
if __name__ == '__main__':
242+
unittest.main()

0 commit comments

Comments
 (0)