Skip to content
This repository was archived by the owner on Jul 28, 2025. It is now read-only.

Commit b7f5efe

Browse files
committed
CU-8697v6qr2 support expansion of transformers ner models to include new concepts
1 parent b8692fe commit b7f5efe

File tree

4 files changed

+91
-0
lines changed

4 files changed

+91
-0
lines changed

medcat/ner/transformers_ner.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import json
33
import logging
44
import datasets
5+
import torch
56
from spacy.tokens import Doc
67
from datetime import datetime
78
from typing import Iterable, Iterator, Optional, Dict, List, cast, Union, Tuple, Callable, Type
@@ -330,6 +331,57 @@ def save(self, save_dir_path: str) -> None:
330331
# This is everything we need to save from the class, we do not
331332
#save the class itself.
332333

334+
def expand_model_with_concepts(self, cui2preferred_name: Dict[str, str], use_avg_init: bool = True) -> None:
335+
"""Expand the model with new concepts and their preferred names, which requires subsequent retraining on the model.
336+
337+
Args:
338+
cui2preferred_name(Dict[str, str]):
339+
Dictionary where each key is the literal ID of the concept to be added and each value is its preferred name.
340+
use_avg_init(bool):
341+
Whether to use the average of existing weights or biases as the initial value for the new concept. Defaults to True.
342+
"""
343+
344+
avg_weight = torch.mean(self.model.classifier.weight, dim=0, keepdim=True)
345+
avg_bias = torch.mean(self.model.classifier.bias, dim=0, keepdim=True)
346+
347+
for label, preferred_name in cui2preferred_name.items():
348+
if label in self.model.config.label2id.keys():
349+
continue
350+
351+
sname = preferred_name.lower().replace(" ", "~")
352+
new_names = {
353+
sname: {
354+
"tokens": [],
355+
"snames": [sname],
356+
"raw_name": preferred_name,
357+
"is_upper": True
358+
}
359+
}
360+
self.cdb.add_names(cui=label, names=new_names, name_status="P", full_build=True)
361+
362+
new_label_id = sorted(self.model.config.label2id.values())[-1] + 1
363+
self.model.config.label2id[label] = new_label_id
364+
self.model.config.id2label[new_label_id] = label
365+
self.tokenizer.label_map[label] = new_label_id
366+
self.tokenizer.cui2name = {k: self.cdb.get_name(k) for k in self.tokenizer.label_map.keys()}
367+
368+
if use_avg_init:
369+
self.model.classifier.weight = torch.nn.Parameter(
370+
torch.cat((self.model.classifier.weight, avg_weight), 0)
371+
)
372+
self.model.classifier.bias = torch.nn.Parameter(
373+
torch.cat((self.model.classifier.bias, avg_bias), 0)
374+
)
375+
else:
376+
self.model.classifier.weight = torch.nn.Parameter(
377+
torch.cat((self.model.classifier.weight, torch.randn(1, self.model.config.hidden_size)), 0)
378+
)
379+
self.model.classifier.bias = torch.nn.Parameter(
380+
torch.cat((self.model.classifier.bias, torch.randn(1)), 0)
381+
)
382+
self.model.num_labels += 1
383+
self.model.classifier.out_features += 1
384+
333385
@classmethod
334386
def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "TransformersNER":
335387
"""Load a meta_cat object.

medcat/utils/ner/model.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,21 @@ def get_entities(self, text: str, *args, **kwargs) -> dict:
7676
"""
7777
return self.cat.get_entities(text, *args, **kwargs)
7878

79+
def add_new_concepts(self,
80+
cui2preferred_name: Dict[str, str],
81+
train_nr: int = 0,
82+
with_random_init: bool = False) -> None:
83+
"""Add new concepts to the model and the concept database.
84+
85+
Invoking this requires subsequent retraining on the model.
86+
87+
Args:
88+
cui2preferred_name(Dict[str, str]): Dictionary where each key is the literal ID of the concept to be added and each value is its preferred name.
89+
train_nr (int): The number of the NER object in cat._addl_train to which new concepts will be added. Defaults to 0.
90+
with_random_init (bool): Whether to use the random init strategy for the new concepts. Defaults to False.
91+
"""
92+
self.cat._addl_ner[train_nr].expand_model_with_concepts(cui2preferred_name, use_avg_init=not with_random_init)
93+
7994
@property
8095
def config(self) -> Config:
8196
return self.cat.config

tests/ner/test_transformers_ner.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,3 +48,20 @@ def on_epoch_end(self, *args, **kwargs) -> None:
4848
assert dataset["train"].num_rows == 48
4949
assert dataset["test"].num_rows == 12
5050
self.assertEqual(tracker.call.call_count, 2)
51+
52+
def test_expand_model_with_concepts(self):
53+
original_num_labels = self.undertest.model.num_labels
54+
original_out_features = self.undertest.model.classifier.out_features
55+
original_label_map_size = len(self.undertest.tokenizer.label_map)
56+
cui2preferred_name = {
57+
"concept_1" : "Preferred Name 1",
58+
"concept_2" : "Preferred Name 2",
59+
}
60+
61+
self.undertest.expand_model_with_concepts(cui2preferred_name)
62+
63+
assert self.undertest.model.num_labels == original_num_labels + len(cui2preferred_name)
64+
assert self.undertest.model.classifier.out_features == original_out_features + len(cui2preferred_name)
65+
assert len(self.undertest.tokenizer.label_map) == original_label_map_size + len(cui2preferred_name)
66+
assert self.undertest.tokenizer.cui2name.get("concept_1") == "Preferred Name 1"
67+
assert self.undertest.tokenizer.cui2name.get("concept_2") == "Preferred Name 2"

tests/utils/ner/test_deid.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,13 @@ def test_training(self):
9090
self.assertIsNotNone(examples)
9191
self.assertIsNotNone(dataset)
9292

93+
def test_add_new_concepts(self):
94+
self.deid_model.add_new_concepts({'CONCEPT': "Concept"}, with_random_init=True)
95+
self.assertTrue("CONCEPT" in self.deid_model.cat.cdb.cui2names)
96+
self.assertEqual(self.deid_model.cat.cdb.cui2names["CONCEPT"], {"concept"})
97+
self.assertTrue("CONCEPT" in self.deid_model.cat._addl_ner[0].model.config.label2id)
98+
self.assertTrue("CONCEPT" in self.deid_model.cat._addl_ner[0].tokenizer.label_map)
99+
self.assertTrue("CONCEPT" in self.deid_model.cat._addl_ner[0].tokenizer.cui2name)
93100

94101
input_text = '''
95102
James Joyce

0 commit comments

Comments
 (0)