Skip to content
This repository was archived by the owner on Jul 28, 2025. It is now read-only.

Commit c40f9e6

Browse files
committed
Improve abstraction by moving model and config loading outside of tokenizer into a base component class
1 parent 255334d commit c40f9e6

File tree

11 files changed

+192
-99
lines changed

11 files changed

+192
-99
lines changed

medcat/rel_cat.py

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
from medcat.config import Config
1818
from medcat.config_rel_cat import ConfigRelCAT
1919
from medcat.pipeline.pipe_runner import PipeRunner
20-
from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper, load_tokenizer
20+
from medcat.utils.relation_extraction.base_component import load_base_component, BaseComponent
21+
from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper
2122
from spacy.tokens import Doc, Span
2223
from typing import Dict, Iterable, Iterator, List, cast
2324
from torch.utils.data import DataLoader, Sampler
@@ -91,8 +92,13 @@ class RelCAT(PipeRunner):
9192

9293
log = logging.getLogger(__name__)
9394

94-
def __init__(self, cdb: CDB, tokenizer: BaseTokenizerWrapper, config: ConfigRelCAT = ConfigRelCAT(), task="train", init_model=False):
95+
def __init__(self, cdb: CDB,
96+
base_component: BaseComponent,
97+
tokenizer: BaseTokenizerWrapper,
98+
config: ConfigRelCAT = ConfigRelCAT(),
99+
task="train", init_model=False):
95100
self.config = config
101+
self.base_component = base_component
96102
self.tokenizer: BaseTokenizerWrapper = tokenizer
97103
self.cdb = cdb
98104

@@ -154,8 +160,8 @@ def _get_model(self):
154160

155161
""" Used only for model initialisation.
156162
"""
157-
self.model_config = self.tokenizer.config_from_pretrained()
158-
self.model = self.tokenizer.model_from_pretrained(relcat_config=self.config,
163+
self.model_config = self.base_component.config_from_pretrained()
164+
self.model = self.base_component.model_from_pretrained(relcat_config=self.config,
159165
model_config=self.model_config)
160166

161167
@classmethod
@@ -182,20 +188,22 @@ def load(cls, load_path: str = "./") -> "RelCAT":
182188
if "bert" in config.general.tokenizer_name or "llama" in config.general.tokenizer_name:
183189
tokenizer_path = load_path
184190

185-
tokenizer = load_tokenizer(tokenizer_path, config)
191+
base_component = load_base_component(tokenizer_path, config)
192+
tokenizer = base_component.tokenizer
186193

187194
model_config_path = os.path.join(load_path, "model_config.json")
188195

189196
if os.path.exists(model_config_path):
190-
model_config = tokenizer.config_from_json_file(model_config_path)
197+
model_config = base_component.config_from_json_file(model_config_path)
191198
cls.log.info("Loaded config from : " + model_config_path)
192199
else:
193200
cls.log.info("model_config.json not found, using default for the model")
194-
model_config = tokenizer.config_from_pretrained()
201+
model_config = base_component.config_from_pretrained()
195202

196203
model_config.vocab_size = tokenizer.get_size()
197204

198205
rel_cat = cls(cdb=cdb, config=config,
206+
base_component=base_component,
199207
tokenizer=tokenizer,
200208
task=config.general.task)
201209

@@ -209,10 +217,11 @@ def load(cls, load_path: str = "./") -> "RelCAT":
209217

210218
if os.path.exists(os.path.join(load_path, config.general.model_name)):
211219
# NOTE: should it be the joined path? it wasn't previously
212-
rel_cat.model = tokenizer.model_from_pretrained(relcat_config=config, model_config=model_config,
213-
pretrained_model_name_or_path=config.general.model_name)
220+
rel_cat.model = base_component.model_from_pretrained(
221+
relcat_config=config, model_config=model_config,
222+
pretrained_model_name_or_path=config.general.model_name)
214223
else:
215-
rel_cat.model = tokenizer.model_from_pretrained(
224+
rel_cat.model = base_component.model_from_pretrained(
216225
pretrained_model_name_or_path='',
217226
relcat_config=config,
218227
model_config=model_config)
@@ -228,7 +237,7 @@ def load(cls, load_path: str = "./") -> "RelCAT":
228237

229238
cls.log.error("Failed to load specified HF model, defaulting to 'bert-base-uncased', loading...")
230239
# NOTE: this won't really work for Llama or ModernBert, I've got a feeling
231-
rel_cat.model = tokenizer.model_from_pretrained(
240+
rel_cat.model = base_component.model_from_pretrained(
232241
pretrained_model_name_or_path="bert-base-uncased",
233242
relcat_config=config,
234243
model_config=model_config)
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
from abc import ABC, abstractmethod
2+
import logging
3+
4+
from transformers import PretrainedConfig
5+
6+
from medcat.config_rel_cat import ConfigRelCAT
7+
from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper
8+
from medcat.utils.relation_extraction.models import Base_RelationExtraction
9+
10+
11+
logger = logging.getLogger(__name__)
12+
13+
14+
class BaseComponent(ABC):
15+
16+
@property
17+
@abstractmethod
18+
def tokenizer(self) -> BaseTokenizerWrapper:
19+
pass
20+
21+
@abstractmethod
22+
def config_from_pretrained(self) -> PretrainedConfig:
23+
pass # perhaps some doc string
24+
25+
@abstractmethod
26+
def config_from_json_file(self, file_path: str) -> PretrainedConfig:
27+
pass # perhaps some doc string
28+
29+
@abstractmethod
30+
def model_from_pretrained(self, relcat_config: ConfigRelCAT, model_config: PretrainedConfig,
31+
pretrained_model_name_or_path: str = 'default') -> Base_RelationExtraction:
32+
pass # perhaps some doc string
33+
34+
35+
def load_base_component(tokenizer_path: str, config: ConfigRelCAT) -> BaseComponent:
36+
if "modern-bert-tokenizer" in config.general.tokenizer_name:
37+
from medcat.utils.relation_extraction.modernbert.component import ModernBertComponent
38+
return ModernBertComponent(tokenizer_path, config)
39+
elif "bert" in config.general.tokenizer_name:
40+
from medcat.utils.relation_extraction.bert.component import BertComponent
41+
return BertComponent(tokenizer_path, config)
42+
elif "llama" in config.general.tokenizer_name:
43+
from medcat.utils.relation_extraction.llama.component import LlamaComponent
44+
return LlamaComponent(tokenizer_path, config)
45+
raise ValueError(f"Could not find matching base component for {config.general.tokenizer_name}")
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import os
2+
from typing import Optional
3+
4+
from transformers import PretrainedConfig, BertConfig
5+
6+
from medcat.config_rel_cat import ConfigRelCAT
7+
from medcat.utils.relation_extraction.base_component import BaseComponent
8+
from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper, load_default_tokenizer
9+
from medcat.utils.relation_extraction.models import Base_RelationExtraction
10+
from medcat.utils.relation_extraction.bert.tokenizer import TokenizerWrapperBERT
11+
from medcat.utils.relation_extraction.bert.model import BertModel_RelationExtraction
12+
13+
14+
class BertComponent(BaseComponent):
15+
pretrained_model_name_or_path = "bert-base-uncased"
16+
17+
def __init__(self, tokenizer_path: str, config: ConfigRelCAT,
18+
tokenizer: Optional[BaseTokenizerWrapper] = None):
19+
if tokenizer is not None:
20+
self._tokenizer = tokenizer
21+
elif os.path.exists(tokenizer_path):
22+
self._tokenizer = TokenizerWrapperBERT.load(tokenizer_path)
23+
else:
24+
self._tokenizer = load_default_tokenizer(tokenizer_path, config)
25+
26+
@property
27+
def tokenizer(self) -> BaseTokenizerWrapper:
28+
return self._tokenizer
29+
30+
def config_from_pretrained(self) -> PretrainedConfig:
31+
return BertConfig.from_pretrained(self.pretrained_model_name_or_path)
32+
33+
def config_from_json_file(self, file_path: str) -> PretrainedConfig:
34+
return BertConfig.from_json_file(file_path)
35+
36+
def model_from_pretrained(self, relcat_config: ConfigRelCAT, model_config: PretrainedConfig,
37+
pretrained_model_name_or_path: str = 'default') -> Base_RelationExtraction:
38+
if pretrained_model_name_or_path == 'default':
39+
pretrained_model_name_or_path = self.pretrained_model_name_or_path
40+
return BertModel_RelationExtraction(pretrained_model_name_or_path, relcat_config, model_config)

medcat/utils/relation_extraction/bert/tokenizer.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
import os
2-
from transformers import PretrainedConfig
3-
from transformers import BertConfig
42
from transformers.models.bert.tokenization_bert_fast import BertTokenizerFast
53
import logging
64

7-
from medcat.config_rel_cat import ConfigRelCAT
85
from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper
9-
from medcat.utils.relation_extraction.models import Base_RelationExtraction
10-
from medcat.utils.relation_extraction.bert.model import BertModel_RelationExtraction
116

127

138
logger = logging.getLogger(__name__)
@@ -22,19 +17,6 @@ class TokenizerWrapperBERT(BaseTokenizerWrapper):
2217
A huggingface Fast BERT.
2318
'''
2419
name = 'bert-tokenizer'
25-
pretrained_model_name_or_path = "bert-base-uncased"
26-
27-
def config_from_pretrained(self) -> PretrainedConfig:
28-
return BertConfig.from_pretrained(self.pretrained_model_name_or_path)
29-
30-
def config_from_json_file(self, file_path: str) -> PretrainedConfig:
31-
return BertConfig.from_json_file(file_path)
32-
33-
def model_from_pretrained(self, relcat_config: ConfigRelCAT, model_config: PretrainedConfig,
34-
pretrained_model_name_or_path: str = 'default') -> Base_RelationExtraction:
35-
if pretrained_model_name_or_path == 'default':
36-
pretrained_model_name_or_path = self.pretrained_model_name_or_path
37-
return BertModel_RelationExtraction(pretrained_model_name_or_path, relcat_config, model_config)
3820

3921
@classmethod
4022
def load(cls, dir_path, **kwargs):
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
import os
2+
3+
from transformers import PretrainedConfig
4+
from transformers.models.llama import LlamaConfig
5+
6+
from medcat.config_rel_cat import ConfigRelCAT
7+
from medcat.utils.relation_extraction.base_component import BaseComponent
8+
from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper, load_default_tokenizer
9+
from medcat.utils.relation_extraction.models import Base_RelationExtraction
10+
from medcat.utils.relation_extraction.llama.tokenizer import TokenizerWrapperLlama
11+
from medcat.utils.relation_extraction.llama.model import LlamaModel_RelationExtraction
12+
13+
14+
class LlamaComponent(BaseComponent):
15+
pretrained_model_name_or_path = "meta-llama/Llama-3.1-8B"
16+
17+
def __init__(self, tokenizer_path: str, config: ConfigRelCAT):
18+
if os.path.exists(tokenizer_path):
19+
self._tokenizer = TokenizerWrapperLlama.load(tokenizer_path)
20+
else:
21+
self._tokenizer = load_default_tokenizer(tokenizer_path, config)
22+
23+
@property
24+
def tokenizer(self) -> BaseTokenizerWrapper:
25+
return self._tokenizer
26+
27+
def config_from_pretrained(self) -> PretrainedConfig:
28+
pass # perhaps some doc string
29+
30+
def config_from_json_file(self, file_path: str) -> PretrainedConfig:
31+
return LlamaConfig.from_json_file(file_path)
32+
33+
def model_from_pretrained(self, relcat_config: ConfigRelCAT, model_config: PretrainedConfig,
34+
pretrained_model_name_or_path: str = 'default') -> Base_RelationExtraction:
35+
if pretrained_model_name_or_path == 'default':
36+
pretrained_model_name_or_path = self.pretrained_model_name_or_path
37+
return LlamaModel_RelationExtraction(pretrained_model_name_or_path, relcat_config, model_config)

medcat/utils/relation_extraction/llama/tokenizer.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,9 @@
11
import os
22
from typing import Optional
3-
from transformers import PretrainedConfig
4-
from transformers.models.llama import LlamaConfig
53
from transformers import LlamaTokenizerFast
64
import logging
75

8-
from medcat.config_rel_cat import ConfigRelCAT
96
from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper
10-
from medcat.utils.relation_extraction.models import Base_RelationExtraction
11-
from medcat.utils.relation_extraction.llama.model import LlamaModel_RelationExtraction
127

138

149
logger = logging.getLogger(__name__)
@@ -23,19 +18,6 @@ class TokenizerWrapperLlama(BaseTokenizerWrapper):
2318
A huggingface Fast Llama.
2419
'''
2520
name = 'llama-tokenizer'
26-
pretrained_model_name_or_path = "meta-llama/Llama-3.1-8B"
27-
28-
def config_from_pretrained(self) -> PretrainedConfig:
29-
pass # perhaps some doc string
30-
31-
def config_from_json_file(self, file_path: str) -> PretrainedConfig:
32-
return LlamaConfig.from_json_file(file_path)
33-
34-
def model_from_pretrained(self, relcat_config: ConfigRelCAT, model_config: PretrainedConfig,
35-
pretrained_model_name_or_path: str = 'default') -> Base_RelationExtraction:
36-
if pretrained_model_name_or_path == 'default':
37-
pretrained_model_name_or_path = self.pretrained_model_name_or_path
38-
return LlamaModel_RelationExtraction(pretrained_model_name_or_path, relcat_config, model_config)
3921

4022
def __init__(self, hf_tokenizers=None, max_seq_length: Optional[int] = None, add_special_tokens: Optional[bool] = False):
4123
self.hf_tokenizers = hf_tokenizers
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import os
2+
3+
from transformers import PretrainedConfig, ModernBertConfig
4+
5+
from medcat.config_rel_cat import ConfigRelCAT
6+
from medcat.utils.relation_extraction.base_component import BaseComponent
7+
from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper, load_default_tokenizer
8+
from medcat.utils.relation_extraction.models import Base_RelationExtraction
9+
from medcat.utils.relation_extraction.modernbert.tokenizer import TokenizerWrapperModernBERT
10+
from medcat.utils.relation_extraction.modernbert.model import ModernBertModel_RelationExtraction
11+
12+
13+
class ModernBertComponent(BaseComponent):
14+
pretrained_model_name_or_path = "answerdotai/ModernBERT-base"
15+
16+
def __init__(self, tokenizer_path: str, config: ConfigRelCAT):
17+
if os.path.exists(tokenizer_path):
18+
self._tokenizer = TokenizerWrapperModernBERT.load(tokenizer_path)
19+
else:
20+
self._tokenizer = load_default_tokenizer(tokenizer_path, config)
21+
22+
@property
23+
def tokenizer(self) -> BaseTokenizerWrapper:
24+
return self._tokenizer
25+
26+
def config_from_pretrained(self) -> PretrainedConfig:
27+
return ModernBertConfig.from_pretrained(self.pretrained_model_name_or_path)
28+
29+
def config_from_json_file(self, file_path: str) -> PretrainedConfig:
30+
return ModernBertConfig.from_json_file(file_path)
31+
32+
def model_from_pretrained(self, relcat_config: ConfigRelCAT, model_config: PretrainedConfig,
33+
pretrained_model_name_or_path: str = 'default') -> Base_RelationExtraction:
34+
if pretrained_model_name_or_path == 'default':
35+
pretrained_model_name_or_path = self.pretrained_model_name_or_path
36+
return ModernBertModel_RelationExtraction(pretrained_model_name_or_path, relcat_config, model_config)

medcat/utils/relation_extraction/modernbert/tokenizer.py

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,8 @@
11
import os
2-
from transformers import PretrainedConfig
3-
from transformers import ModernBertConfig
42
from transformers import PreTrainedTokenizerFast
53
import logging
64

7-
from medcat.config_rel_cat import ConfigRelCAT
85
from medcat.utils.relation_extraction.tokenizer import BaseTokenizerWrapper
9-
from medcat.utils.relation_extraction.models import Base_RelationExtraction
10-
from medcat.utils.relation_extraction.modernbert.model import ModernBertModel_RelationExtraction
116

127

138
logger = logging.getLogger(__name__)
@@ -22,19 +17,6 @@ class TokenizerWrapperModernBERT(BaseTokenizerWrapper):
2217
A huggingface Fast tokenizer.
2318
'''
2419
name = 'modern-bert-tokenizer'
25-
pretrained_model_name_or_path = "answerdotai/ModernBERT-base"
26-
27-
def config_from_pretrained(self) -> PretrainedConfig:
28-
return ModernBertConfig.from_pretrained(self.pretrained_model_name_or_path)
29-
30-
def config_from_json_file(self, file_path: str) -> PretrainedConfig:
31-
return ModernBertConfig.from_json_file(file_path)
32-
33-
def model_from_pretrained(self, relcat_config: ConfigRelCAT, model_config: PretrainedConfig,
34-
pretrained_model_name_or_path: str = 'default') -> Base_RelationExtraction:
35-
if pretrained_model_name_or_path == 'default':
36-
pretrained_model_name_or_path = self.pretrained_model_name_or_path
37-
return ModernBertModel_RelationExtraction(pretrained_model_name_or_path, relcat_config, model_config)
3820

3921
@classmethod
4022
def load(cls, dir_path, **kwargs):

0 commit comments

Comments
 (0)