|
2 | 2 | from collections import OrderedDict |
3 | 3 | from typing import Optional, Any, List, Iterable |
4 | 4 | from torch import nn, Tensor |
5 | | -from transformers import BertModel, AutoConfig, BertConfig |
| 5 | +from transformers import BertModel, AutoConfig |
6 | 6 | from medcat.meta_cat import ConfigMetaCAT |
7 | 7 | import logging |
8 | 8 | logger = logging.getLogger(__name__) |
@@ -87,30 +87,35 @@ def forward(self, |
87 | 87 | class BertForMetaAnnotation(nn.Module): |
88 | 88 | _keys_to_ignore_on_load_unexpected: List[str] = [r"pooler"] # type: ignore |
89 | 89 |
|
90 | | - def __init__(self, config, save_dir_path = None): |
| 90 | + def __init__(self, config, save_dir_path=None): |
91 | 91 | super(BertForMetaAnnotation, self).__init__() |
92 | 92 | if save_dir_path: |
93 | 93 | try: |
94 | 94 | _bertconfig = AutoConfig.from_pretrained(save_dir_path + "/bert_config.json", |
95 | 95 | num_hidden_layers=config.model['num_layers']) |
96 | | - except: |
| 96 | + except Exception: |
97 | 97 | _bertconfig = AutoConfig.from_pretrained(config.model.model_variant, |
98 | 98 | num_hidden_layers=config.model['num_layers']) |
| 99 | + logger.info("BERT config not found locally — downloaded successfully from Hugging Face.") |
| 100 | + |
99 | 101 | else: |
100 | 102 | _bertconfig = AutoConfig.from_pretrained(config.model.model_variant,num_hidden_layers=config.model['num_layers']) |
101 | 103 |
|
102 | 104 | if config.model['input_size'] != _bertconfig.hidden_size: |
103 | 105 | logger.warning("Input size for %s model should be %d, provided input size is %d. Input size changed to %d",config.model.model_variant,_bertconfig.hidden_size,config.model['input_size'],_bertconfig.hidden_size) |
104 | 106 |
|
105 | 107 | if config.model['load_bert_pretrained_weights']: |
106 | | - bert = BertModel.from_pretrained(config.model.model_variant, config=_bertconfig) |
| 108 | + try: |
| 109 | + bert = BertModel.from_pretrained(config.model.model_variant, config=_bertconfig) |
| 110 | + except Exception: |
| 111 | + raise Exception("Could not load BERT pretrained weights from Hugging Face. \nIf you're seeing a connection error, set `config.model.load_bert_pretrained_weights=False` and make sure to load the model pack from disk instead.") |
107 | 112 | else: |
108 | 113 | bert = BertModel(_bertconfig) |
109 | 114 |
|
110 | 115 | self.config = config |
111 | 116 | self.config.use_return_dict = False |
112 | 117 | self.bert = bert |
113 | | - self.bert_config: BertConfig = _bertconfig |
| 118 | + self.bert_config = _bertconfig |
114 | 119 | self.num_labels = config.model["nclasses"] |
115 | 120 | for param in self.bert.parameters(): |
116 | 121 | param.requires_grad = not config.model.model_freeze_layers |
|
0 commit comments