Skip to content
This repository was archived by the owner on Jul 28, 2025. It is now read-only.

Commit 9cc52b4

Browse files
committed
Latest fixes for model saving/loading.
1 parent 0f5e61b commit 9cc52b4

File tree

11 files changed

+57
-53
lines changed

11 files changed

+57
-53
lines changed

medcat/rel_cat.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,8 +126,9 @@ def load(cls, load_path: str = "./") -> "RelCAT":
126126

127127
device = torch.device("cuda" if torch.cuda.is_available() and component.relcat_config.general.device != "cpu" else "cpu")
128128

129-
rel_cat = RelCAT(cdb=cdb, config=component.relcat_config, task=component.task, init_model=False)
129+
rel_cat = RelCAT(cdb=cdb, config=component.relcat_config, task=component.task)
130130
rel_cat.device = device
131+
rel_cat.component = component
131132

132133
return rel_cat
133134

@@ -255,7 +256,7 @@ def train(self, export_data_path:str = "", train_csv_path:str = "", test_csv_pat
255256
gamma=self.component.relcat_config.train.multistep_lr_gamma) # type: ignore
256257

257258
self.epoch, self.best_f1 = load_state(
258-
self.component.model, self.component.optimizer, self.component.scheduler, load_best=False, path=checkpoint_path, config=self.component.relcat_config)
259+
self.component.model, self.component.optimizer, self.component.scheduler, load_best=False, path=checkpoint_path, relcat_config=self.component.relcat_config)
259260

260261
self.log.info("Starting training process...")
261262

medcat/utils/relation_extraction/base_component.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -95,11 +95,9 @@ def save(self, save_path: str) -> None:
9595
self.model.hf_model.resize_token_embeddings(self.tokenizer.get_size()) # type: ignore
9696

9797
assert self.model_config is not None
98-
self.model_config.vocab_size = self.tokenizer.get_size()
99-
self.model_config.pad_token_id = self.pad_id
100-
101-
self.model_config.to_json_file(
102-
os.path.join(save_path, "model_config.json"))
98+
self.model_config.hf_model_config.vocab_size = self.tokenizer.get_size()
99+
self.model_config.hf_model_config.pad_token_id = self.pad_id
100+
self.model_config.save(save_path)
103101

104102
save_state(self.model, optimizer=self.optimizer, scheduler=self.scheduler, epoch=self.epoch, best_f1=self.best_f1,
105103
path=save_path, model_name=self.relcat_config.general.model_name,
@@ -116,16 +114,17 @@ def load(cls, pretrained_model_name_or_path: str = "./") -> "BaseComponent_Relat
116114
"""
117115

118116
relcat_config = ConfigRelCAT.load(load_path=pretrained_model_name_or_path)
117+
119118
model_config = BaseConfig_RelationExtraction.load(pretrained_model_name_or_path=pretrained_model_name_or_path,
120119
relcat_config=relcat_config)
121120

121+
tokenizer = BaseTokenizerWrapper_RelationExtraction.load(tokenizer_path=pretrained_model_name_or_path,
122+
relcat_config=relcat_config)
123+
122124
model = BaseModel_RelationExtraction.load(pretrained_model_name_or_path=pretrained_model_name_or_path,
123125
model_config=model_config,
124126
relcat_config=relcat_config)
125127

126-
tokenizer = BaseTokenizerWrapper_RelationExtraction.load(tokenizer_path=pretrained_model_name_or_path,
127-
relcat_config=relcat_config)
128-
129128
model.hf_model.resize_token_embeddings(len(tokenizer.hf_tokenizers)) # type: ignore
130129

131130
optimizer = None # type: ignore
@@ -134,7 +133,7 @@ def load(cls, pretrained_model_name_or_path: str = "./") -> "BaseComponent_Relat
134133
epoch, best_f1 = load_state(model, optimizer, scheduler, path=pretrained_model_name_or_path,
135134
model_name=relcat_config.general.model_name,
136135
file_prefix=relcat_config.general.task,
137-
config=relcat_config)
136+
relcat_config=relcat_config)
138137

139138
component = cls(model=model, tokenizer=tokenizer, model_config=model_config, config=relcat_config)
140139
cls.epoch = epoch

medcat/utils/relation_extraction/bert/config.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@ class BertConfig_RelationExtraction(BaseConfig_RelationExtraction):
1717
@classmethod
1818
def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, **kwargs) -> "BertConfig_RelationExtraction":
1919
model_config = cls(pretrained_model_name_or_path, **kwargs)
20-
model_config_path = os.path.join(pretrained_model_name_or_path, "model_config.json")
2120

22-
if pretrained_model_name_or_path and os.path.exists(model_config_path):
23-
model_config.hf_model_config = BertConfig.from_json_file(model_config_path)
24-
logger.info("Loaded config from file: " + model_config_path)
21+
if pretrained_model_name_or_path and os.path.exists(pretrained_model_name_or_path):
22+
model_config.hf_model_config = BertConfig.from_json_file(pretrained_model_name_or_path)
23+
logger.info("Loaded config from file: " + pretrained_model_name_or_path)
2524
else:
2625
relcat_config.general.model_name = cls.pretrained_model_name_or_path
2726
model_config.hf_model_config = BertConfig.from_pretrained(

medcat/utils/relation_extraction/bert/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ def __init__(self, pretrained_model_name_or_path: str, relcat_config: ConfigRelC
4242
self.model_config: Union[BaseConfig_RelationExtraction, BertConfig_RelationExtraction] = model_config
4343
self.pretrained_model_name_or_path: str = pretrained_model_name_or_path
4444

45-
self.hf_model: PreTrainedModel = PreTrainedModel(model_config) # type: ignore
45+
self.hf_model: Union[BertModel, PreTrainedModel] = BertModel(model_config.hf_model_config) # type: ignore
4646

4747
for param in self.hf_model.parameters(): # type: ignore
4848
if self.relcat_config.model.freeze_layers:

medcat/utils/relation_extraction/config.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,19 +40,20 @@ def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, *
4040
if "modern-bert" in relcat_config.general.tokenizer_name or \
4141
"modern-bert" in relcat_config.general.model_name:
4242
from medcat.utils.relation_extraction.modernbert.config import ModernBertConfig_RelationExtraction
43-
model_config = ModernBertConfig_RelationExtraction.load(model_config_path, **kwargs)
43+
model_config = ModernBertConfig_RelationExtraction.load(model_config_path, relcat_config=relcat_config, **kwargs)
4444
elif "bert" in relcat_config.general.tokenizer_name or \
4545
"bert" in relcat_config.general.model_name:
4646
from medcat.utils.relation_extraction.bert.config import BertConfig_RelationExtraction
47-
model_config = BertConfig_RelationExtraction.load(model_config_path, **kwargs)
47+
model_config = BertConfig_RelationExtraction.load(model_config_path, relcat_config=relcat_config, **kwargs)
4848
elif "llama" in relcat_config.general.tokenizer_name or \
4949
"llama" in relcat_config.general.model_name:
5050
from medcat.utils.relation_extraction.llama.config import LlamaConfig_RelationExtraction
51-
model_config = LlamaConfig_RelationExtraction.load(model_config_path, **kwargs)
51+
model_config = LlamaConfig_RelationExtraction.load(model_config_path, relcat_config=relcat_config, **kwargs)
5252
else:
5353
if pretrained_model_name_or_path:
5454
model_config.hf_model_config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
5555
else:
5656
model_config.hf_model_config = PretrainedConfig.from_pretrained(pretrained_model_name_or_path=relcat_config.general.model_name, **kwargs)
5757
logger.info("Loaded config from : " + model_config_path)
58+
5859
return model_config

medcat/utils/relation_extraction/llama/config.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,13 @@ class LlamaConfig_RelationExtraction(BaseConfig_RelationExtraction):
1717
@classmethod
1818
def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, **kwargs) -> "LlamaConfig_RelationExtraction":
1919
model_config = cls(pretrained_model_name_or_path, **kwargs)
20-
model_config_path = os.path.join(pretrained_model_name_or_path, "model_config.json")
2120

22-
if pretrained_model_name_or_path and os.path.exists(model_config_path):
23-
model_config.model_config = LlamaConfig.from_json_file(model_config_path)
24-
logger.info("Loaded config from file: " + model_config_path)
21+
if pretrained_model_name_or_path and os.path.exists(pretrained_model_name_or_path):
22+
model_config.hf_model_config = LlamaConfig.from_json_file(pretrained_model_name_or_path)
23+
logger.info("Loaded config from file: " + pretrained_model_name_or_path)
2524
else:
2625
relcat_config.general.model_name = cls.pretrained_model_name_or_path
27-
model_config.model_config = LlamaConfig.from_pretrained(
26+
model_config.hf_model_config = LlamaConfig.from_pretrained(
2827
pretrained_model_name_or_path=cls.pretrained_model_name_or_path, **kwargs)
2928
logger.info("Loaded config from pretrained: " + relcat_config.general.model_name)
3029

medcat/utils/relation_extraction/ml_utils.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def save_state(model, optimizer: torch.optim.AdamW, scheduler: torch.optim.lr_sc
105105
If you want to export the model after training set final_export=True and leave is_checkpoint=False.
106106
107107
Args:
108-
model (Base_RelationExtraction): BertModel_RelationExtraction | LlamaModel_RelationExtraction
108+
model (BaseModel_RelationExtraction): BertModel_RelationExtraction | LlamaModel_RelationExtraction etc.
109109
optimizer (torch.optim.AdamW, optional): Defaults to None.
110110
scheduler (torch.optim.lr_scheduler.MultiStepLR, optional): Defaults to None.
111111
epoch (int): Defaults to None.
@@ -136,11 +136,11 @@ def save_state(model, optimizer: torch.optim.AdamW, scheduler: torch.optim.lr_sc
136136
}, os.path.join(path, file_name))
137137

138138

139-
def load_state(model, optimizer, scheduler, path="./", model_name="BERT", file_prefix="train", load_best=False, config: ConfigRelCAT = ConfigRelCAT()) -> Tuple[int, int]:
139+
def load_state(model, optimizer, scheduler, path: str = "./", model_name:str = "BERT", file_prefix:str = "train", load_best: bool = False, relcat_config: ConfigRelCAT = ConfigRelCAT()) -> Tuple[int, int]:
140140
""" Used by RelCAT.load() and RelCAT.train()
141141
142142
Args:
143-
model (Base_RelationExtraction): BertModel_RelationExtraction | LlamaModel_RelationExtraction, it has to be initialized before calling this method via (Bert/Llama)Model_RelationExtraction(...)
143+
model (BaseModel_RelationExtraction): BaseModel_RelationExtraction, it has to be initialized before calling this method via (Bert/Llama)Model_RelationExtraction(...)
144144
optimizer (_type_): optimizer
145145
scheduler (_type_): scheduler
146146
path (str, optional): Defaults to "./".
@@ -153,7 +153,7 @@ def load_state(model, optimizer, scheduler, path="./", model_name="BERT", file_p
153153
Tuple (int, int): last epoch and f1 score.
154154
"""
155155

156-
device: torch.device =torch.device(config.general.device)
156+
device: torch.device =torch.device(relcat_config.general.device)
157157

158158
model_name = model_name.replace("/", "_")
159159
logging.info("Attempting to load RelCAT model on device: " + str(device))
@@ -178,13 +178,13 @@ def load_state(model, optimizer, scheduler, path="./", model_name="BERT", file_p
178178

179179
if optimizer is None:
180180
parameters = filter(lambda p: p.requires_grad, model.parameters())
181-
optimizer = torch.optim.AdamW(params=parameters, lr=config.train.lr, weight_decay=config.train.adam_weight_decay,
182-
betas=config.train.adam_betas, eps=config.train.adam_epsilon)
181+
optimizer = torch.optim.AdamW(params=parameters, lr=relcat_config.train.lr, weight_decay=relcat_config.train.adam_weight_decay,
182+
betas=relcat_config.train.adam_betas, eps=relcat_config.train.adam_epsilon)
183183

184184
if scheduler is None:
185185
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
186-
milestones=config.train.multistep_milestones,
187-
gamma=config.train.multistep_lr_gamma)
186+
milestones=relcat_config.train.multistep_milestones,
187+
gamma=relcat_config.train.multistep_lr_gamma)
188188
optimizer.load_state_dict(checkpoint['optimizer'])
189189
scheduler.load_state_dict(checkpoint['scheduler'])
190190
logging.info("Loaded model and optimizer.")

medcat/utils/relation_extraction/models.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -98,19 +98,26 @@ def __init__(self, relcat_config: ConfigRelCAT,
9898
self.hf_model: Union[ModernBertModel, BertModel, LlamaModel, PreTrainedModel] = PreTrainedModel(config=model_config.hf_model_config) # type: ignore
9999
self.pretrained_model_name_or_path: str = pretrained_model_name_or_path
100100

101+
self._reinitialize_dense_and_frozen_layers(relcat_config=relcat_config)
102+
103+
self.log.info("RelCAT model config: " + str(self.model_config.hf_model_config))
104+
105+
def _reinitialize_dense_and_frozen_layers(self, relcat_config: ConfigRelCAT) -> None:
106+
""" Reinitialize the dense layers of the model
107+
108+
Args:
109+
relcat_config (ConfigRelCAT): relcat config.
110+
"""
111+
112+
self.drop_out = nn.Dropout(relcat_config.model.dropout)
113+
self.fc1, self.fc2, self.fc3 = create_dense_layers(relcat_config)
114+
101115
for param in self.hf_model.parameters(): # type: ignore
102116
if self.relcat_config.model.freeze_layers:
103117
param.requires_grad = False
104118
else:
105119
param.requires_grad = True
106120

107-
self.drop_out = nn.Dropout(self.relcat_config.model.dropout)
108-
109-
# dense layers
110-
self.fc1, self.fc2, self.fc3 = create_dense_layers(self.relcat_config)
111-
112-
self.log.info("RelCAT model config: " + str(self.model_config.hf_model_config))
113-
114121
def forward(self,
115122
input_ids: Optional[torch.Tensor] = None,
116123
attention_mask: Optional[torch.Tensor] = None,
@@ -250,4 +257,7 @@ def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, m
250257
cls.log.info("Loaded model from relcat_config: " + relcat_config.general.model_name)
251258

252259
cls.log.info("Loaded " + str(model.__class__.__name__) + " from pretrained_model_name_or_path: " + pretrained_model_name_or_path)
260+
261+
model._reinitialize_dense_and_frozen_layers(relcat_config=relcat_config)
262+
253263
return model

medcat/utils/relation_extraction/modernbert/config.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,15 +17,13 @@ class ModernBertConfig_RelationExtraction(BaseConfig_RelationExtraction):
1717
@classmethod
1818
def load(cls, pretrained_model_name_or_path: str, relcat_config: ConfigRelCAT, **kwargs) -> "ModernBertConfig_RelationExtraction":
1919
model_config = cls(pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)
20-
model_config_path = os.path.join(pretrained_model_name_or_path, "model_config.json")
2120

22-
23-
if pretrained_model_name_or_path and os.path.exists(model_config_path):
24-
model_config.model_config = ModernBertConfig.from_json_file(model_config_path)
25-
logger.info("Loaded config from file: " + model_config_path)
21+
if pretrained_model_name_or_path and os.path.exists(pretrained_model_name_or_path):
22+
model_config.hf_model_config = ModernBertConfig.from_json_file(pretrained_model_name_or_path)
23+
logger.info("Loaded config from file: " + pretrained_model_name_or_path)
2624
else:
2725
relcat_config.general.model_name = cls.pretrained_model_name_or_path
28-
model_config.model_config = ModernBertConfig.from_pretrained(
26+
model_config.hf_model_config = ModernBertConfig.from_pretrained(
2927
pretrained_model_name_or_path=cls.pretrained_model_name_or_path, **kwargs)
3028
logger.info("Loaded config from pretrained: " + relcat_config.general.model_name)
3129

medcat/utils/relation_extraction/modernbert/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(self, pretrained_model_name_or_path: str, relcat_config: ConfigRelC
4040
self.model_config: Union[BaseConfig_RelationExtraction, ModernBertConfig_RelationExtraction] = model_config
4141
self.pretrained_model_name_or_path: str = pretrained_model_name_or_path
4242

43-
self.hf_model: Union[ModernBertModel, PreTrainedModel] = PreTrainedModel(config=model_config.hf_model_config)
43+
self.hf_model: Union[ModernBertModel, PreTrainedModel] = ModernBertModel(config=model_config.hf_model_config)
4444

4545
for param in self.hf_model.parameters(): # type: ignore
4646
if self.relcat_config.model.freeze_layers:

0 commit comments

Comments
 (0)