1717from medcat .config import Config
1818from medcat .config_rel_cat import ConfigRelCAT
1919from medcat .pipeline .pipe_runner import PipeRunner
20- from medcat .utils .relation_extraction .tokenizer import BaseTokenizerWrapper , load_tokenizer
20+ from medcat .utils .relation_extraction .base_component import load_base_component , BaseComponent
21+ from medcat .utils .relation_extraction .tokenizer import BaseTokenizerWrapper
2122from spacy .tokens import Doc , Span
2223from typing import Dict , Iterable , Iterator , List , cast
2324from torch .utils .data import DataLoader , Sampler
@@ -91,8 +92,13 @@ class RelCAT(PipeRunner):
9192
9293 log = logging .getLogger (__name__ )
9394
94- def __init__ (self , cdb : CDB , tokenizer : BaseTokenizerWrapper , config : ConfigRelCAT = ConfigRelCAT (), task = "train" , init_model = False ):
95+ def __init__ (self , cdb : CDB ,
96+ base_component : BaseComponent ,
97+ tokenizer : BaseTokenizerWrapper ,
98+ config : ConfigRelCAT = ConfigRelCAT (),
99+ task = "train" , init_model = False ):
95100 self .config = config
101+ self .base_component = base_component
96102 self .tokenizer : BaseTokenizerWrapper = tokenizer
97103 self .cdb = cdb
98104
@@ -154,8 +160,8 @@ def _get_model(self):
154160
155161 """ Used only for model initialisation.
156162 """
157- self .model_config = self .tokenizer .config_from_pretrained ()
158- self .model = self .tokenizer .model_from_pretrained (relcat_config = self .config ,
163+ self .model_config = self .base_component .config_from_pretrained ()
164+ self .model = self .base_component .model_from_pretrained (relcat_config = self .config ,
159165 model_config = self .model_config )
160166
161167 @classmethod
@@ -182,20 +188,22 @@ def load(cls, load_path: str = "./") -> "RelCAT":
182188 if "bert" in config .general .tokenizer_name or "llama" in config .general .tokenizer_name :
183189 tokenizer_path = load_path
184190
185- tokenizer = load_tokenizer (tokenizer_path , config )
191+ base_component = load_base_component (tokenizer_path , config )
192+ tokenizer = base_component .tokenizer
186193
187194 model_config_path = os .path .join (load_path , "model_config.json" )
188195
189196 if os .path .exists (model_config_path ):
190- model_config = tokenizer .config_from_json_file (model_config_path )
197+ model_config = base_component .config_from_json_file (model_config_path )
191198 cls .log .info ("Loaded config from : " + model_config_path )
192199 else :
193200 cls .log .info ("model_config.json not found, using default for the model" )
194- model_config = tokenizer .config_from_pretrained ()
201+ model_config = base_component .config_from_pretrained ()
195202
196203 model_config .vocab_size = tokenizer .get_size ()
197204
198205 rel_cat = cls (cdb = cdb , config = config ,
206+ base_component = base_component ,
199207 tokenizer = tokenizer ,
200208 task = config .general .task )
201209
@@ -209,10 +217,11 @@ def load(cls, load_path: str = "./") -> "RelCAT":
209217
210218 if os .path .exists (os .path .join (load_path , config .general .model_name )):
211219 # NOTE: should it be the joined path? it wasn't previously
212- rel_cat .model = tokenizer .model_from_pretrained (relcat_config = config , model_config = model_config ,
213- pretrained_model_name_or_path = config .general .model_name )
220+ rel_cat .model = base_component .model_from_pretrained (
221+ relcat_config = config , model_config = model_config ,
222+ pretrained_model_name_or_path = config .general .model_name )
214223 else :
215- rel_cat .model = tokenizer .model_from_pretrained (
224+ rel_cat .model = base_component .model_from_pretrained (
216225 pretrained_model_name_or_path = '' ,
217226 relcat_config = config ,
218227 model_config = model_config )
@@ -228,7 +237,7 @@ def load(cls, load_path: str = "./") -> "RelCAT":
228237
229238 cls .log .error ("Failed to load specified HF model, defaulting to 'bert-base-uncased', loading..." )
230239 # NOTE: this won't really work for Llama or ModernBert, I've got a feeling
231- rel_cat .model = tokenizer .model_from_pretrained (
240+ rel_cat .model = base_component .model_from_pretrained (
232241 pretrained_model_name_or_path = "bert-base-uncased" ,
233242 relcat_config = config ,
234243 model_config = model_config )
0 commit comments