Skip to content
This repository was archived by the owner on Jul 28, 2025. It is now read-only.

Commit 761b63f

Browse files
Pushing type fix and adding more information in case of connection error
1 parent 1643c41 commit 761b63f

File tree

2 files changed

+11
-6
lines changed

2 files changed

+11
-6
lines changed

medcat/meta_cat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -384,7 +384,7 @@ def save(self, save_dir_path: str) -> None:
384384

385385
if self.config.model.model_name == 'bert':
386386
model_config_save_path = os.path.join(save_dir_path, 'bert_config.json')
387-
self.model.bert_config.to_json_file(model_config_save_path)
387+
self.model.bert_config.to_json_file(model_config_save_path) # type: ignore
388388
# This is everything we need to save from the class, we do not
389389
# save the class itself.
390390

medcat/utils/meta_cat/models.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from collections import OrderedDict
33
from typing import Optional, Any, List, Iterable
44
from torch import nn, Tensor
5-
from transformers import BertModel, AutoConfig, BertConfig
5+
from transformers import BertModel, AutoConfig
66
from medcat.meta_cat import ConfigMetaCAT
77
import logging
88
logger = logging.getLogger(__name__)
@@ -87,30 +87,35 @@ def forward(self,
8787
class BertForMetaAnnotation(nn.Module):
8888
_keys_to_ignore_on_load_unexpected: List[str] = [r"pooler"] # type: ignore
8989

90-
def __init__(self, config, save_dir_path = None):
90+
def __init__(self, config, save_dir_path=None):
9191
super(BertForMetaAnnotation, self).__init__()
9292
if save_dir_path:
9393
try:
9494
_bertconfig = AutoConfig.from_pretrained(save_dir_path + "/bert_config.json",
9595
num_hidden_layers=config.model['num_layers'])
96-
except:
96+
except Exception:
9797
_bertconfig = AutoConfig.from_pretrained(config.model.model_variant,
9898
num_hidden_layers=config.model['num_layers'])
99+
logger.info("BERT config not found locally — downloaded successfully from Hugging Face.")
100+
99101
else:
100102
_bertconfig = AutoConfig.from_pretrained(config.model.model_variant,num_hidden_layers=config.model['num_layers'])
101103

102104
if config.model['input_size'] != _bertconfig.hidden_size:
103105
logger.warning("Input size for %s model should be %d, provided input size is %d. Input size changed to %d",config.model.model_variant,_bertconfig.hidden_size,config.model['input_size'],_bertconfig.hidden_size)
104106

105107
if config.model['load_bert_pretrained_weights']:
106-
bert = BertModel.from_pretrained(config.model.model_variant, config=_bertconfig)
108+
try:
109+
bert = BertModel.from_pretrained(config.model.model_variant, config=_bertconfig)
110+
except Exception:
111+
raise Exception("Could not load BERT pretrained weights from Hugging Face. \nIf you're seeing a connection error, set `config.model.load_bert_pretrained_weights=False` and make sure to load the model pack from disk instead.")
107112
else:
108113
bert = BertModel(_bertconfig)
109114

110115
self.config = config
111116
self.config.use_return_dict = False
112117
self.bert = bert
113-
self.bert_config: BertConfig = _bertconfig
118+
self.bert_config = _bertconfig
114119
self.num_labels = config.model["nclasses"]
115120
for param in self.bert.parameters():
116121
param.requires_grad = not config.model.model_freeze_layers

0 commit comments

Comments
 (0)