Skip to content
This repository was archived by the owner on Jul 28, 2025. It is now read-only.

Commit 64a7a6a

Browse files
committed
Merge branch 'master' of https://github.com/CogStack/MedCAT into relation_extraction_llama
2 parents 2d6b2f3 + 65f7c5e commit 64a7a6a

File tree

5 files changed

+274
-14
lines changed

5 files changed

+274
-14
lines changed

medcat/meta_cat.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module:
9595
if not config.model.model_freeze_layers:
9696
peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=16,
9797
target_modules=["query", "value"], lora_dropout=0.2)
98-
99-
model = get_peft_model(model, peft_config)
98+
# Not sure what changed between transformers 4.50.3 and 4.50.1 that made this
99+
# fail for mypy. But as best as I Can tell, it still works just the same
100+
model = get_peft_model(model, peft_config) # type: ignore
100101
# model.print_trainable_parameters()
101102

102103
logger.info("BERT model used for classification")
@@ -412,7 +413,7 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA
412413
tokenizer = TokenizerWrapperBPE.load(save_dir_path)
413414
elif config.general['tokenizer_name'] == 'bert-tokenizer':
414415
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT
415-
tokenizer = TokenizerWrapperBERT.load(save_dir_path, config.model['model_variant'])
416+
tokenizer = TokenizerWrapperBERT.load(save_dir_path, config.model.model_variant)
416417

417418
# Create meta_cat
418419
meta_cat = cls(tokenizer=tokenizer, embeddings=None, config=config)

medcat/ner/transformers_ner.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def __init__(self, cdb, config: Optional[ConfigTransformersNER] = None,
7070
eval_accumulation_steps=1,
7171
gradient_accumulation_steps=4, # We want to get to bs=4
7272
do_eval=True,
73-
evaluation_strategy='epoch', # type: ignore
73+
# eval_strategy over evaluation_strategy since trf==4.46 (apperently)
74+
eval_strategy='epoch', # type: ignore
7475
logging_strategy='epoch', # type: ignore
7576
save_strategy='epoch', # type: ignore
7677
metric_for_best_model='eval_recall', # Can be changed if our preference is not recall but precision or f1
@@ -176,7 +177,7 @@ def train(self,
176177
ignore_extra_labels=False,
177178
dataset=None,
178179
meta_requirements=None,
179-
trainer_callbacks: Optional[List[TrainerCallback]]=None) -> Tuple:
180+
trainer_callbacks: Optional[List[Callable[[Trainer], TrainerCallback]]] = None) -> Tuple:
180181
"""Train or continue training a model give a json_path containing a MedCATtrainer export. It will
181182
continue training if an existing model is loaded or start new training if the model is blank/new.
182183
@@ -188,10 +189,13 @@ def train(self,
188189
labels that did not exist in the old model.
189190
dataset: Defaults to None.
190191
meta_requirements: Defaults to None
191-
trainer_callbacks (List[TrainerCallback]):
192+
trainer_callbacks (List[Callable[[Trainer], TrainerCallback]]]):
192193
A list of trainer callbacks for collecting metrics during the training at the client side. The
193194
transformers Trainer object will be passed in when each callback is called.
194195
196+
Raises:
197+
ValueError: If something went wrong with model save path.
198+
195199
Returns:
196200
Tuple: The dataframe, examples, and the dataset
197201
"""
@@ -227,7 +231,9 @@ def train(self,
227231
if self.model.num_labels != len(self.tokenizer.label_map):
228232
logger.warning("The dataset contains labels we've not seen before, model is being reinitialized")
229233
logger.warning("Model: {} vs Dataset: {}".format(self.model.num_labels, len(self.tokenizer.label_map)))
230-
self.model = AutoModelForTokenClassification.from_pretrained(self.config.general['model_name'], num_labels=len(self.tokenizer.label_map))
234+
self.model = AutoModelForTokenClassification.from_pretrained(self.config.general['model_name'],
235+
num_labels=len(self.tokenizer.label_map),
236+
ignore_mismatched_sizes=True)
231237
self.tokenizer.cui2name = {k:self.cdb.get_name(k) for k in self.tokenizer.label_map.keys()}
232238

233239
self.model.config.id2label = {v:k for k,v in self.tokenizer.label_map.items()}
@@ -252,15 +258,21 @@ def train(self,
252258
tokenizer=None)
253259
if trainer_callbacks:
254260
for callback in trainer_callbacks:
255-
trainer.add_callback(callback(trainer))
261+
# No idea why mypy isn't picking up the method.
262+
# It most certainly does exist
263+
trainer.add_callback(callback(trainer)) # type: ignore
256264

257265
trainer.train() # type: ignore
258266

259267
# Save the training time
260268
self.config.general.last_train_on = datetime.now().timestamp() # type: ignore
261269

262270
# Save everything
263-
self.save(save_dir_path=os.path.join(self.training_arguments.output_dir, 'final_model'))
271+
output_dir = self.training_arguments.output_dir
272+
if output_dir is None:
273+
# NOTE: this shouldn't really happen, but we'll do this for type safety
274+
raise ValueError("Output path should not be None!")
275+
self.save(save_dir_path=os.path.join(output_dir, 'final_model'))
264276

265277
# Run an eval step and return metrics
266278
p = trainer.predict(encoded_dataset['test']) # type: ignore

medcat/tokenizers/meta_cat_tokenizers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,13 @@ def load(cls, dir_path: str, model_variant: Optional[str] = '', **kwargs) -> "To
193193
try:
194194
tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(path, **kwargs)
195195
except Exception as e:
196-
logging.warning("Could not load tokenizer from path due to error: {}. Loading from library for model variant: {}".format(e,model_variant))
197-
tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(model_variant)
196+
# So that this is a string - it should be as it's only used in MetaCAT.load method
197+
# with `config.model.model_variant` which is a `str` rathern than None
198+
# NOTE: The reason the type in method signature is Optional[str] is because supertype defines it as such
199+
variant = str(model_variant)
200+
logging.warning("Could not load tokenizer from path due to error: %s. Loading from library for model variant: %s",
201+
e, variant)
202+
tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(variant)
198203

199204
return tokenizer
200205

medcat/utils/relation_extraction/tokenizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class BaseTokenizerWrapper_RelationExtraction(PreTrainedTokenizerFast, ABC):
1616
def __init__(self, hf_tokenizers=None, max_seq_length: Optional[int] = None, add_special_tokens: Optional[bool] = False):
1717
self.hf_tokenizers = hf_tokenizers
1818
self.max_seq_length = max_seq_length
19-
self.add_special_tokens = add_special_tokens
19+
self._add_special_tokens = add_special_tokens
2020

2121
def get_size(self):
2222
return len(self.hf_tokenizers.vocab)
@@ -30,7 +30,7 @@ def get_pad_id(self):
3030
def __call__(self, text, truncation: Optional[bool] = True):
3131
if isinstance(text, str):
3232
result = self.hf_tokenizers.encode_plus(text, return_offsets_mapping=True, return_length=True, return_token_type_ids=True, return_attention_mask=True,
33-
add_special_tokens=self.add_special_tokens, max_length=self.max_seq_length, padding="longest", truncation=truncation)
33+
add_special_tokens=self._add_special_tokens, max_length=self.max_seq_length, padding="longest", truncation=truncation)
3434

3535
return {'offset_mapping': result['offset_mapping'],
3636
'input_ids': result['input_ids'],
@@ -41,7 +41,7 @@ def __call__(self, text, truncation: Optional[bool] = True):
4141
}
4242
elif isinstance(text, list):
4343
results = self.hf_tokenizers._batch_encode_plus(text, return_offsets_mapping=True, return_length=True, return_token_type_ids=True,
44-
add_special_tokens=self.add_special_tokens, max_length=self.max_seq_length,truncation=truncation)
44+
add_special_tokens=self._add_special_tokens, max_length=self.max_seq_length,truncation=truncation)
4545
output = []
4646
for ind in range(len(results['input_ids'])):
4747
output.append({

tests/test_transformers_ner.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
import unittest
2+
import tempfile
3+
import json
4+
import os
5+
import shutil
6+
from medcat.cdb import CDB
7+
from medcat.ner.transformers_ner import TransformersNER
8+
from medcat.config_transformers_ner import ConfigTransformersNER
9+
10+
class TestTransformersNER(unittest.TestCase):
11+
def setUp(self):
12+
# Create a temporary directory for the test
13+
self.tmp_dir = tempfile.TemporaryDirectory()
14+
# Create results dir for training outputs
15+
self.results_dir = './results'
16+
os.makedirs(self.results_dir, exist_ok=True)
17+
18+
# Create a minimal CDB
19+
self.cdb = CDB()
20+
21+
# Create initial training data with 2 labels and multiple examples
22+
self.initial_data = {
23+
"projects": [{
24+
"documents": [
25+
{
26+
"text": "Patient has diabetes and hypertension.",
27+
"annotations": [
28+
{
29+
"cui": "C0011849", # Diabetes
30+
"start": 14,
31+
"end": 22,
32+
"value": "diabetes"
33+
},
34+
{
35+
"cui": "C0020538", # Hypertension
36+
"start": 27,
37+
"end": 39,
38+
"value": "hypertension"
39+
}
40+
]
41+
},
42+
{
43+
"text": "History of diabetes with hypertension.",
44+
"annotations": [
45+
{
46+
"cui": "C0011849", # Diabetes
47+
"start": 12,
48+
"end": 20,
49+
"value": "diabetes"
50+
},
51+
{
52+
"cui": "C0020538", # Hypertension
53+
"start": 26,
54+
"end": 38,
55+
"value": "hypertension"
56+
}
57+
]
58+
},
59+
{
60+
"text": "Diagnosed with hypertension and diabetes.",
61+
"annotations": [
62+
{
63+
"cui": "C0020538", # Hypertension
64+
"start": 15,
65+
"end": 27,
66+
"value": "hypertension"
67+
},
68+
{
69+
"cui": "C0011849", # Diabetes
70+
"start": 32,
71+
"end": 40,
72+
"value": "diabetes"
73+
}
74+
]
75+
}
76+
]
77+
}]
78+
}
79+
80+
# Create new training data with an extra label
81+
self.new_data = {
82+
"projects": [{
83+
"documents": [
84+
{
85+
"text": "Patient has diabetes, hypertension, and asthma.",
86+
"annotations": [
87+
{
88+
"cui": "C0011849", # Diabetes
89+
"start": 14,
90+
"end": 22,
91+
"value": "diabetes"
92+
},
93+
{
94+
"cui": "C0020538", # Hypertension
95+
"start": 24,
96+
"end": 36,
97+
"value": "hypertension"
98+
},
99+
{
100+
"cui": "C0004096", # Asthma
101+
"start": 42,
102+
"end": 48,
103+
"value": "asthma"
104+
}
105+
]
106+
},
107+
{
108+
"text": "History of asthma with diabetes and hypertension.",
109+
"annotations": [
110+
{
111+
"cui": "C0004096", # Asthma
112+
"start": 12,
113+
"end": 18,
114+
"value": "asthma"
115+
},
116+
{
117+
"cui": "C0011849", # Diabetes
118+
"start": 24,
119+
"end": 32,
120+
"value": "diabetes"
121+
},
122+
{
123+
"cui": "C0020538", # Hypertension
124+
"start": 37,
125+
"end": 49,
126+
"value": "hypertension"
127+
}
128+
]
129+
},
130+
{
131+
"text": "Diagnosed with asthma, diabetes, and hypertension.",
132+
"annotations": [
133+
{
134+
"cui": "C0004096", # Asthma
135+
"start": 15,
136+
"end": 21,
137+
"value": "asthma"
138+
},
139+
{
140+
"cui": "C0011849", # Diabetes
141+
"start": 23,
142+
"end": 31,
143+
"value": "diabetes"
144+
},
145+
{
146+
"cui": "C0020538", # Hypertension
147+
"start": 37,
148+
"end": 49,
149+
"value": "hypertension"
150+
}
151+
]
152+
}
153+
]
154+
}]
155+
}
156+
157+
# Save initial training data
158+
self.initial_data_path = os.path.join(self.tmp_dir.name, 'initial_data.json')
159+
with open(self.initial_data_path, 'w') as f:
160+
json.dump(self.initial_data, f)
161+
162+
# Save new training data
163+
self.new_data_path = os.path.join(self.tmp_dir.name, 'new_data.json')
164+
with open(self.new_data_path, 'w') as f:
165+
json.dump(self.new_data, f)
166+
167+
def tearDown(self):
168+
# Clean up the temporary directory
169+
self.tmp_dir.cleanup()
170+
# Clean up results directory if it exists
171+
if os.path.exists(self.results_dir):
172+
shutil.rmtree(self.results_dir)
173+
# Clean up logs directory if it exists
174+
if os.path.exists('./logs'):
175+
shutil.rmtree('./logs')
176+
177+
def test_ignore_extra_labels(self):
178+
# Create and train initial model with tiny BERT
179+
config = ConfigTransformersNER()
180+
config.general['model_name'] = 'prajjwal1/bert-tiny'
181+
# Set to single epoch and small test size for faster testing
182+
config.general['num_train_epochs'] = 1
183+
config.general['test_size'] = 0.1
184+
185+
# Create training arguments with reduced epochs
186+
from transformers import TrainingArguments
187+
training_args = TrainingArguments(
188+
output_dir=self.results_dir, # Use the class results_dir
189+
num_train_epochs=1
190+
)
191+
192+
ner = TransformersNER(self.cdb, config=config, training_arguments=training_args)
193+
ner.train(self.initial_data_path)
194+
195+
# Save the model
196+
model_path = os.path.join(self.tmp_dir.name, 'model')
197+
ner.save(model_path)
198+
199+
# Load the saved model
200+
loaded_ner = TransformersNER.load(model_path)
201+
202+
# Get initial number of labels
203+
initial_num_labels = len(loaded_ner.tokenizer.label_map)
204+
205+
# Train with ignore_extra_labels=True
206+
loaded_ner.train(self.new_data_path, ignore_extra_labels=True)
207+
208+
# Verify number of labels hasn't changed
209+
self.assertEqual(
210+
len(loaded_ner.tokenizer.label_map),
211+
initial_num_labels,
212+
"Number of labels changed despite ignore_extra_labels=True"
213+
)
214+
215+
# Verify only original labels are present (including special tokens)
216+
expected_labels = {"C0011849", "C0020538", "O", "X"}
217+
self.assertEqual(
218+
set(loaded_ner.tokenizer.label_map.keys()),
219+
expected_labels,
220+
"Label map contains unexpected labels"
221+
)
222+
223+
# Train with ignore_extra_labels=False
224+
loaded_ner.train(self.new_data_path, ignore_extra_labels=False)
225+
226+
# Verify new label was added
227+
self.assertEqual(
228+
len(loaded_ner.tokenizer.label_map),
229+
initial_num_labels + 1,
230+
"New label was not added when ignore_extra_labels=False"
231+
)
232+
233+
# Verify all labels are present (including special tokens)
234+
expected_labels = {"C0011849", "C0020538", "C0004096", "O", "X"}
235+
self.assertEqual(
236+
set(loaded_ner.tokenizer.label_map.keys()),
237+
expected_labels,
238+
"Label map missing expected labels"
239+
)
240+
241+
if __name__ == '__main__':
242+
unittest.main()

0 commit comments

Comments
 (0)