Skip to content
This repository was archived by the owner on Jul 28, 2025. It is now read-only.

Commit 15809b2

Browse files
Adding functionality for offline loading
Allow model loading entirely from local files
1 parent 2992f1a commit 15809b2

File tree

3 files changed

+33
-7
lines changed

3 files changed

+33
-7
lines changed

medcat/config_meta_cat.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,12 @@ class Model(MixingConfig, BaseModel):
133133
134134
NB! For these changes to take effect, the pipe would need to be recreated.
135135
"""
136+
load_bert_pretrained_weights: bool = False
137+
"""
138+
Applicable only when using BERT:
139+
Determines if the pretrained weights for BERT are loaded
140+
This should be True if you don't plan on using the model pack weights"""
141+
136142
num_layers: int = 2
137143
"""Number of layers in the model (both LSTM and BERT)
138144
@@ -164,7 +170,9 @@ class Model(MixingConfig, BaseModel):
164170
165171
Paper reference - https://ieeexplore.ieee.org/document/7533053"""
166172
category_undersample: str = ''
167-
"""When using 2 phase learning, this category is used to undersample the data"""
173+
"""When using 2 phase learning, this category is used to undersample the data
174+
The number of samples in the category sets the upper limit for all categories"""
175+
168176
model_architecture_config: Dict = {'fc2': True, 'fc3': False,'lr_scheduler': True}
169177
"""Specifies the architecture for BERT model.
170178

medcat/meta_cat.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,13 @@ class MetaCAT(PipeRunner):
5151
def __init__(self,
5252
tokenizer: Optional[TokenizerWrapperBase] = None,
5353
embeddings: Optional[Union[Tensor, numpy.ndarray]] = None,
54-
config: Optional[ConfigMetaCAT] = None) -> None:
54+
config: Optional[ConfigMetaCAT] = None,
55+
save_dir_path: Optional[str] = None) -> None:
5556
if config is None:
5657
config = ConfigMetaCAT()
5758
self.config = config
5859
set_all_seeds(config.general['seed'])
60+
self.save_dir_path = save_dir_path
5961

6062
if tokenizer is not None:
6163
# Set it in the config
@@ -90,7 +92,7 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module:
9092

9193
elif config.model['model_name'] == 'bert':
9294
from medcat.utils.meta_cat.models import BertForMetaAnnotation
93-
model = BertForMetaAnnotation(config)
95+
model = BertForMetaAnnotation(config,self.save_dir_path)
9496

9597
if not config.model.model_freeze_layers:
9698
peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=16,
@@ -380,6 +382,8 @@ def save(self, save_dir_path: str) -> None:
380382
model_save_path = os.path.join(save_dir_path, 'model.dat')
381383
torch.save(self.model.state_dict(), model_save_path)
382384

385+
model_config_save_path = os.path.join(save_dir_path, 'bert_config.json')
386+
self.model.bert_config.to_json_file(model_config_save_path)
383387
# This is everything we need to save from the class, we do not
384388
# save the class itself.
385389

@@ -416,7 +420,7 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA
416420
tokenizer = TokenizerWrapperBERT.load(save_dir_path, config.model.model_variant)
417421

418422
# Create meta_cat
419-
meta_cat = cls(tokenizer=tokenizer, embeddings=None, config=config)
423+
meta_cat = cls(tokenizer=tokenizer, embeddings=None, config=config,save_dir_path=save_dir_path)
420424

421425
# Load the model
422426
model_save_path = os.path.join(save_dir_path, 'model.dat')

medcat/utils/meta_cat/models.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,16 +87,30 @@ def forward(self,
8787
class BertForMetaAnnotation(nn.Module):
8888
_keys_to_ignore_on_load_unexpected: List[str] = [r"pooler"] # type: ignore
8989

90-
def __init__(self, config):
90+
def __init__(self, config, save_dir_path = None):
9191
super(BertForMetaAnnotation, self).__init__()
92-
_bertconfig = AutoConfig.from_pretrained(config.model.model_variant,num_hidden_layers=config.model['num_layers'])
92+
if save_dir_path:
93+
try:
94+
_bertconfig = AutoConfig.from_pretrained(save_dir_path + "/bert_config.json",
95+
num_hidden_layers=config.model['num_layers'])
96+
except:
97+
_bertconfig = AutoConfig.from_pretrained(config.model.model_variant,
98+
num_hidden_layers=config.model['num_layers'])
99+
else:
100+
_bertconfig = AutoConfig.from_pretrained(config.model.model_variant,num_hidden_layers=config.model['num_layers'])
101+
93102
if config.model['input_size'] != _bertconfig.hidden_size:
94103
logger.warning("Input size for %s model should be %d, provided input size is %d. Input size changed to %d",config.model.model_variant,_bertconfig.hidden_size,config.model['input_size'],_bertconfig.hidden_size)
95104

96-
bert = BertModel.from_pretrained(config.model.model_variant, config=_bertconfig)
105+
if config.model['load_bert_pretrained_weights']:
106+
bert = BertModel.from_pretrained(config.model.model_variant, config=_bertconfig)
107+
else:
108+
bert = BertModel(_bertconfig)
109+
97110
self.config = config
98111
self.config.use_return_dict = False
99112
self.bert = bert
113+
self.bert_config = _bertconfig
100114
self.num_labels = config.model["nclasses"]
101115
for param in self.bert.parameters():
102116
param.requires_grad = not config.model.model_freeze_layers

0 commit comments

Comments
 (0)