Skip to content
This repository was archived by the owner on Jul 28, 2025. It is now read-only.

Commit 65f7c5e

Browse files
authored
CU-8698mqu96 Transformers update (4.51.0) fix (#531)
* CU-8698mqu96: Update special tokens lengths attribute * CU-8698mqu96: Update MetaCAT usage of BertTokenizer.from_pretrained for type safety * CU-8698mqu96: Ignore typing where mypy is wrong + add note in code * CU-8698mqu96: Ignore typing where mypy may be wrong + add comment * CU-8698mqu96: Fix tokenizer wrapper import for rel cat * CU-8698mqu96: Rename evaluation strategy keyword argument in line with changes * CU-8698mqu96: Type-ignore method where mypy says it does not exist * CU-8698mqu96: Fix TRF-NER output dir typing issue * CU-8698mqu96: Update a doc string for darglint * CU-8698mqu96: Fix typing issue for TrfNER trainer callback
1 parent 65faa7a commit 65f7c5e

File tree

6 files changed

+32
-15
lines changed

6 files changed

+32
-15
lines changed

medcat/meta_cat.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,9 @@ def get_model(self, embeddings: Optional[Tensor]) -> nn.Module:
9595
if not config.model.model_freeze_layers:
9696
peft_config = LoraConfig(task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=16,
9797
target_modules=["query", "value"], lora_dropout=0.2)
98-
99-
model = get_peft_model(model, peft_config)
98+
# Not sure what changed between transformers 4.50.3 and 4.50.1 that made this
99+
# fail for mypy. But as best as I Can tell, it still works just the same
100+
model = get_peft_model(model, peft_config) # type: ignore
100101
# model.print_trainable_parameters()
101102

102103
logger.info("BERT model used for classification")
@@ -412,7 +413,7 @@ def load(cls, save_dir_path: str, config_dict: Optional[Dict] = None) -> "MetaCA
412413
tokenizer = TokenizerWrapperBPE.load(save_dir_path)
413414
elif config.general['tokenizer_name'] == 'bert-tokenizer':
414415
from medcat.tokenizers.meta_cat_tokenizers import TokenizerWrapperBERT
415-
tokenizer = TokenizerWrapperBERT.load(save_dir_path, config.model['model_variant'])
416+
tokenizer = TokenizerWrapperBERT.load(save_dir_path, config.model.model_variant)
416417

417418
# Create meta_cat
418419
meta_cat = cls(tokenizer=tokenizer, embeddings=None, config=config)

medcat/ner/transformers_ner.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,8 @@ def __init__(self, cdb, config: Optional[ConfigTransformersNER] = None,
7070
eval_accumulation_steps=1,
7171
gradient_accumulation_steps=4, # We want to get to bs=4
7272
do_eval=True,
73-
evaluation_strategy='epoch', # type: ignore
73+
# eval_strategy over evaluation_strategy since trf==4.46 (apperently)
74+
eval_strategy='epoch', # type: ignore
7475
logging_strategy='epoch', # type: ignore
7576
save_strategy='epoch', # type: ignore
7677
metric_for_best_model='eval_recall', # Can be changed if our preference is not recall but precision or f1
@@ -176,7 +177,7 @@ def train(self,
176177
ignore_extra_labels=False,
177178
dataset=None,
178179
meta_requirements=None,
179-
trainer_callbacks: Optional[List[TrainerCallback]]=None) -> Tuple:
180+
trainer_callbacks: Optional[List[Callable[[Trainer], TrainerCallback]]] = None) -> Tuple:
180181
"""Train or continue training a model give a json_path containing a MedCATtrainer export. It will
181182
continue training if an existing model is loaded or start new training if the model is blank/new.
182183
@@ -188,10 +189,13 @@ def train(self,
188189
labels that did not exist in the old model.
189190
dataset: Defaults to None.
190191
meta_requirements: Defaults to None
191-
trainer_callbacks (List[TrainerCallback]):
192+
trainer_callbacks (List[Callable[[Trainer], TrainerCallback]]]):
192193
A list of trainer callbacks for collecting metrics during the training at the client side. The
193194
transformers Trainer object will be passed in when each callback is called.
194195
196+
Raises:
197+
ValueError: If something went wrong with model save path.
198+
195199
Returns:
196200
Tuple: The dataframe, examples, and the dataset
197201
"""
@@ -254,15 +258,21 @@ def train(self,
254258
tokenizer=None)
255259
if trainer_callbacks:
256260
for callback in trainer_callbacks:
257-
trainer.add_callback(callback(trainer))
261+
# No idea why mypy isn't picking up the method.
262+
# It most certainly does exist
263+
trainer.add_callback(callback(trainer)) # type: ignore
258264

259265
trainer.train() # type: ignore
260266

261267
# Save the training time
262268
self.config.general.last_train_on = datetime.now().timestamp() # type: ignore
263269

264270
# Save everything
265-
self.save(save_dir_path=os.path.join(self.training_arguments.output_dir, 'final_model'))
271+
output_dir = self.training_arguments.output_dir
272+
if output_dir is None:
273+
# NOTE: this shouldn't really happen, but we'll do this for type safety
274+
raise ValueError("Output path should not be None!")
275+
self.save(save_dir_path=os.path.join(output_dir, 'final_model'))
266276

267277
# Run an eval step and return metrics
268278
p = trainer.predict(encoded_dataset['test']) # type: ignore

medcat/tokenizers/meta_cat_tokenizers.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -193,8 +193,13 @@ def load(cls, dir_path: str, model_variant: Optional[str] = '', **kwargs) -> "To
193193
try:
194194
tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(path, **kwargs)
195195
except Exception as e:
196-
logging.warning("Could not load tokenizer from path due to error: {}. Loading from library for model variant: {}".format(e,model_variant))
197-
tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(model_variant)
196+
# So that this is a string - it should be as it's only used in MetaCAT.load method
197+
# with `config.model.model_variant` which is a `str` rathern than None
198+
# NOTE: The reason the type in method signature is Optional[str] is because supertype defines it as such
199+
variant = str(model_variant)
200+
logging.warning("Could not load tokenizer from path due to error: %s. Loading from library for model variant: %s",
201+
e, variant)
202+
tokenizer.hf_tokenizers = BertTokenizerFast.from_pretrained(variant)
198203

199204
return tokenizer
200205

medcat/utils/relation_extraction/models.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,8 @@ def forward(self,
202202
encoder_attention_mask = encoder_attention_mask.to(
203203
self.relcat_config.general.device)
204204

205-
self.bert_model = self.bert_model.to(self.relcat_config.general.device)
205+
# NOTE: no idea why, but mypy doesn't understand that there's an implicit `self` argument here...
206+
self.bert_model = self.bert_model.to(device=self.relcat_config.general.device) # type: ignore
206207

207208
model_output = self.bert_model(input_ids=input_ids, attention_mask=attention_mask,
208209
token_type_ids=token_type_ids,

medcat/utils/relation_extraction/tokenizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ class TokenizerWrapperBERT(BertTokenizerFast):
1616
def __init__(self, hf_tokenizers=None, max_seq_length: Optional[int] = None, add_special_tokens: Optional[bool] = False):
1717
self.hf_tokenizers = hf_tokenizers
1818
self.max_seq_length = max_seq_length
19-
self.add_special_tokens = add_special_tokens
19+
self._add_special_tokens = add_special_tokens
2020

2121
def __call__(self, text, truncation: Optional[bool] = True):
2222
if isinstance(text, str):
2323
result = self.hf_tokenizers.encode_plus(text, return_offsets_mapping=True, return_length=True, return_token_type_ids=True, return_attention_mask=True,
24-
add_special_tokens=self.add_special_tokens, max_length=self.max_seq_length, padding="longest", truncation=truncation)
24+
add_special_tokens=self._add_special_tokens, max_length=self.max_seq_length, padding="longest", truncation=truncation)
2525

2626
return {'offset_mapping': result['offset_mapping'],
2727
'input_ids': result['input_ids'],
@@ -32,7 +32,7 @@ def __call__(self, text, truncation: Optional[bool] = True):
3232
}
3333
elif isinstance(text, list):
3434
results = self.hf_tokenizers._batch_encode_plus(text, return_offsets_mapping=True, return_length=True, return_token_type_ids=True,
35-
add_special_tokens=self.add_special_tokens, max_length=self.max_seq_length,truncation=truncation)
35+
add_special_tokens=self._add_special_tokens, max_length=self.max_seq_length,truncation=truncation)
3636
output = []
3737
for ind in range(len(results['input_ids'])):
3838
output.append({

medcat/utils/relation_extraction/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from pandas.core.series import Series
1010
from medcat.config_rel_cat import ConfigRelCAT
1111

12-
from medcat.preprocessing.tokenizers import TokenizerWrapperBERT
12+
from medcat.utils.relation_extraction.tokenizer import TokenizerWrapperBERT
1313
from medcat.utils.relation_extraction.models import BertModel_RelationExtraction
1414

1515

0 commit comments

Comments
 (0)