Skip to content
This repository was archived by the owner on Jun 30, 2025. It is now read-only.

Commit 14d4cf2

Browse files
authored
CU-8699hj2dx Revamp component initialisation (#95)
* CU-8699hj2dx: Initial changes to remove config-based init args and hardcode it (WIP) * CU-8699hj2dx: Update/fix a registration test * CU-8699hj2dx: Some minor keyword argument renaming * CU-8699hj2dx: Fix RelCAT tests (init) * CU-8699hj2dx: Update Transformers NER to work when loading models * CU-8699hj2dx: Fix DeID deserialising test * CU-8699hj2dx: Fix MeaCAT init * CU-8699hj2dx: Fix RelCAT init/load * CU-8699hj2dx: Remove unused import * CU-8699hj2dx: Add doc string regarding keyword arguments when manually deserialising * CU-8699hj2dx: Update pipeline with notes regarding keyword arguments for manual deserialisation
1 parent 682fe11 commit 14d4cf2

34 files changed

+354
-490
lines changed

medcat/components/addons/addons.py

Lines changed: 34 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,11 @@
1-
from typing import Callable, Protocol, Any, runtime_checkable
1+
from typing import Callable, Protocol, Any, runtime_checkable, Optional
22

33
from medcat.components.types import BaseComponent, MutableEntity
44
from medcat.utils.registry import Registry
55
from medcat.config.config import ComponentConfig
6+
from medcat.cdb import CDB
7+
from medcat.vocab import Vocab
8+
from medcat.tokenizing.tokenizers import BaseTokenizer
69

710

811
@runtime_checkable
@@ -19,9 +22,15 @@ def addon_type(self) -> str:
1922
def is_core(self) -> bool:
2023
return False
2124

25+
@classmethod
26+
def get_folder_name_for_addon_and_name(
27+
cls, addon_type: str, name: str) -> str:
28+
return (cls.NAME_PREFIX + addon_type +
29+
cls.NAME_SPLITTER + name)
30+
2231
def get_folder_name(self) -> str:
23-
return (self.NAME_PREFIX + self.addon_type +
24-
self.NAME_SPLITTER + self.name)
32+
return self.get_folder_name_for_addon_and_name(
33+
self.addon_type, self.name)
2534

2635
@property
2736
def full_name(self) -> str:
@@ -36,51 +45,64 @@ def get_output_key_val(self, ent: MutableEntity
3645
pass
3746

3847

48+
AddonClass = Callable[[ComponentConfig, BaseTokenizer,
49+
CDB, Vocab, Optional[str]], AddonComponent]
50+
51+
3952
_DEFAULT_ADDONS: dict[str, tuple[str, str]] = {
4053
'meta_cat': ('medcat.components.addons.meta_cat.meta_cat',
41-
'MetaCATAddon.create_new'),
54+
'MetaCATAddon.create_new_component'),
4255
'rel_cat': ('medcat.components.addons.relation_extraction.rel_cat',
43-
'RelCATAddon.create_new')
56+
'RelCATAddon.create_new_component')
4457
}
4558

4659
# NOTE: type error due to non-concrete type
4760
_ADDON_REGISTRY = Registry(AddonComponent, _DEFAULT_ADDONS) # type: ignore
4861

4962

5063
def register_addon(addon_name: str,
51-
addon_cls: Callable[..., AddonComponent]) -> None:
64+
addon_cls: AddonClass) -> None:
5265
"""Register a new addon.
5366
5467
Args:
5568
addon_name (str): The addon name.
56-
addon_cls (Callable[..., AddonComponent]): The addon creator.
69+
addon_cls (AddonClass): The addon creator.
5770
"""
5871
_ADDON_REGISTRY.register(addon_name, addon_cls)
5972

6073

61-
def get_addon_creator(addon_name: str) -> Callable[..., AddonComponent]:
74+
def get_addon_creator(addon_name: str) -> AddonClass:
6275
"""Get the creator for an addon.
6376
6477
Args:
6578
addon_name (str): The name of the addonl
6679
6780
Returns:
68-
Callable[..., AddonComponent]: The creator of the addon.
81+
AddonClass: The creator of the addon.
6982
"""
7083
return _ADDON_REGISTRY.get_component(addon_name)
7184

7285

73-
def create_addon(addon_name: str, cnf: ComponentConfig,
74-
*args, **kwargs) -> AddonComponent:
86+
def create_addon(
87+
addon_name: str, cnf: ComponentConfig,
88+
tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
89+
model_load_path: Optional[str]) -> AddonComponent:
7590
"""Create an addon of the specified name with the specified arguments.
7691
7792
All the `*args`, and `**kwrags` are passed to the creator.
7893
7994
Args:
8095
addon_name (str): The name of the addon.
8196
cnf (ComponentConfig): The addon config.
97+
tokenizer (BaseTokenizer): The base tokenizer to be passed to creator.
98+
cdb (CDB): The CDB to be passed to creator.
99+
vocab (Vocab): The Vocab to be passed to creator.
100+
model_load_path (Optional[str]): The optional model load path to be
101+
passed to creator.
102+
82103
83104
Returns:
84105
AddonComponent: The resulting / created addon.
85106
"""
86-
return get_addon_creator(addon_name)(cnf, *args, **kwargs)
107+
return get_addon_creator(addon_name)(
108+
cnf, tokenizer, cdb, vocab, model_load_path)

medcat/components/addons/meta_cat/meta_cat.py

Lines changed: 23 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
from torch import nn, Tensor
1313
from medcat.tokenizing.tokenizers import BaseTokenizer
14+
from medcat.config.config import ComponentConfig
1415
from medcat.config.config_meta_cat import ConfigMetaCAT
1516
from medcat.components.addons.meta_cat.ml_utils import (
1617
predict, train_model, set_all_seeds, eval_model)
@@ -25,6 +26,7 @@
2526
from medcat.tokenizing.tokens import MutableDocument, MutableEntity
2627
from medcat.cdb import CDB
2728
from medcat.vocab import Vocab
29+
from medcat.utils.defaults import COMPONENTS_FOLDER
2830
from peft import get_peft_model, LoraConfig, TaskType
2931

3032
# It should be safe to do this always, as all other multiprocessing
@@ -84,6 +86,23 @@ def create_new(cls, config: ConfigMetaCAT, base_tokenizer: BaseTokenizer,
8486
meta_cat = MetaCAT(tokenizer, embeddings=None, config=config)
8587
return cls(config, base_tokenizer, meta_cat)
8688

89+
@classmethod
90+
def create_new_component(
91+
cls, cnf: ComponentConfig, tokenizer: BaseTokenizer,
92+
cdb: CDB, vocab: Vocab, model_load_path: Optional[str]
93+
) -> 'MetaCATAddon':
94+
if not isinstance(cnf, ConfigMetaCAT):
95+
raise ValueError(f"Incompatible config: {cnf}")
96+
if model_load_path is not None:
97+
components_folder = os.path.join(
98+
model_load_path, COMPONENTS_FOLDER)
99+
folder_name = cls.get_folder_name_for_addon_and_name(
100+
cls.addon_type, str(cnf.general.category_name))
101+
load_path = os.path.join(components_folder, folder_name)
102+
return cls.load_existing(cnf, tokenizer, load_path)
103+
# TODO: tokenizer preprocessing for (e.g) BPE tokenizer (see PR #67)
104+
return cls.create_new(cnf, tokenizer, None)
105+
87106
@classmethod
88107
def load_existing(cls, cnf: ConfigMetaCAT,
89108
base_tokenizer: BaseTokenizer,
@@ -100,18 +119,6 @@ def name(self) -> str:
100119
def __call__(self, doc: MutableDocument) -> MutableDocument:
101120
return self.mc(doc)
102121

103-
@classmethod
104-
def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
105-
model_load_path: Optional[str]) -> list[Any]:
106-
# NOTE: cnf is silent init parameter
107-
return []
108-
109-
@classmethod
110-
def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
111-
model_load_path: Optional[str]) -> dict[str, Any]:
112-
# cls.init_tokenizer(cnf, model_load_path)
113-
return {'base_tokenizer': tokenizer}
114-
115122
def load(self, folder_path: str) -> 'MetaCAT':
116123
mc_path, tokenizer_folder = self._get_meta_cat_and_tokenizer_paths(
117124
folder_path)
@@ -169,8 +176,10 @@ def serialise_to(self, folder_path: str) -> None:
169176
@classmethod
170177
def deserialise_from(cls, folder_path: str, **init_kwargs
171178
) -> 'MetaCATAddon':
172-
# NOTE: model load path sent by kwargs
173-
return cls.load_existing(load_path=folder_path, **init_kwargs)
179+
return cls.load_existing(
180+
load_path=folder_path,
181+
cnf=init_kwargs['cnf'],
182+
base_tokenizer=init_kwargs['tokenizer'])
174183

175184
def get_strategy(self) -> SerialisingStrategy:
176185
return SerialisingStrategy.MANUAL

medcat/components/addons/relation_extraction/rel_cat.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
import os
44
import random
5-
from typing import Optional, Any
5+
from typing import Optional
66

77
from sklearn.utils import compute_class_weight
88
import torch
@@ -18,7 +18,7 @@
1818

1919
from medcat.cdb import CDB
2020
from medcat.vocab import Vocab
21-
from medcat.config import Config
21+
from medcat.config.config import Config, ComponentConfig
2222
from medcat.config.config_rel_cat import ConfigRelCAT
2323
from medcat.storage.serialisers import deserialise
2424
from medcat.storage.serialisables import SerialisingStrategy
@@ -32,6 +32,7 @@
3232
from medcat.components.addons.relation_extraction.rel_dataset import RelData
3333
from medcat.tokenizing.tokenizers import BaseTokenizer, create_tokenizer
3434
from medcat.tokenizing.tokens import MutableDocument
35+
from medcat.utils.defaults import COMPONENTS_FOLDER
3536

3637

3738
logger = logging.getLogger(__name__)
@@ -54,6 +55,20 @@ def create_new(cls, config: ConfigRelCAT, base_tokenizer: BaseTokenizer,
5455
return cls(config,
5556
RelCAT(base_tokenizer, cdb, config=config, init_model=True))
5657

58+
@classmethod
59+
def create_new_component(
60+
cls, cnf: ComponentConfig, tokenizer: BaseTokenizer,
61+
cdb: CDB, vocab: Vocab, model_load_path: Optional[str]
62+
) -> 'RelCATAddon':
63+
if not isinstance(cnf, ConfigRelCAT):
64+
raise ValueError(f"Incompatible config: {cnf}")
65+
config = cnf
66+
if model_load_path is not None:
67+
load_path = os.path.join(model_load_path, COMPONENTS_FOLDER,
68+
cls.NAME_PREFIX + cls.addon_type)
69+
return cls.load_existing(config, tokenizer, cdb, load_path)
70+
return cls.create_new(config, tokenizer, cdb)
71+
5772
@classmethod
5873
def load_existing(cls, cnf: ConfigRelCAT,
5974
base_tokenizer: BaseTokenizer,
@@ -70,21 +85,6 @@ def serialise_to(self, folder_path: str) -> None:
7085
os.mkdir(folder_path)
7186
self._rel_cat.save(folder_path)
7287

73-
@classmethod
74-
def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
75-
model_load_path: Optional[str]) -> list[Any]:
76-
# NOTE: cnf is silent init parameter
77-
return []
78-
79-
@classmethod
80-
def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
81-
model_load_path: Optional[str]) -> dict[str, Any]:
82-
# cls.init_tokenizer(cnf, model_load_path)
83-
return {
84-
'base_tokenizer': tokenizer,
85-
"cdb": cdb
86-
}
87-
8888
@property
8989
def name(self) -> str:
9090
return str(self.addon_type)
@@ -95,7 +95,12 @@ def name(self) -> str:
9595
def deserialise_from(cls, folder_path: str, **init_kwargs
9696
) -> 'RelCATAddon':
9797
# NOTE: model load path sent by kwargs
98-
return cls.load_existing(load_path=folder_path, **init_kwargs)
98+
return cls.load_existing(
99+
load_path=folder_path,
100+
base_tokenizer=init_kwargs['tokenizer'],
101+
cnf=init_kwargs['cnf'],
102+
cdb=init_kwargs['cdb'],
103+
)
99104

100105
def get_strategy(self) -> SerialisingStrategy:
101106
return SerialisingStrategy.MANUAL
@@ -232,7 +237,7 @@ def load(cls, load_path: str = "./") -> "RelCAT":
232237

233238
rel_cat = RelCAT(
234239
# NOTE: this is a throaway tokenizer just for registrations
235-
create_tokenizer(cdb.config.general.nlp.provider),
240+
create_tokenizer(cdb.config.general.nlp.provider, cdb.config),
236241
cdb=cdb, config=component.relcat_config, task=component.task)
237242
rel_cat.device = device
238243
rel_cat.component = component
@@ -883,7 +888,8 @@ def predict_text_with_anns(self, text: str, annotations: list[dict]
883888
Doc: spacy doc with the relations.
884889
"""
885890
# NOTE: This runs not an empty language, but the specified one
886-
base_tokenizer = create_tokenizer(self.cdb.config.general.nlp.provider)
891+
base_tokenizer = create_tokenizer(
892+
self.cdb.config.general.nlp.provider, self.cdb.config)
887893
doc = base_tokenizer(text)
888894

889895
for ann in annotations:

medcat/components/linking/context_based_linker.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,14 @@
11
import random
22
import logging
3-
from typing import Iterator, Optional, Union, Any
3+
from typing import Iterator, Optional, Union
44

55
from medcat.components.types import CoreComponentType, AbstractCoreComponent
66
from medcat.tokenizing.tokens import MutableEntity, MutableDocument
77
from medcat.components.linking.vector_context_model import (
88
ContextModel, PerDocumentTokenCache)
99
from medcat.cdb import CDB
1010
from medcat.vocab import Vocab
11-
from medcat.config import Config
11+
from medcat.config.config import Config, ComponentConfig
1212
from medcat.utils.defaults import StatusTypes as ST
1313
from medcat.utils.postprocessing import create_main_ann
1414
from medcat.tokenizing.tokenizers import BaseTokenizer
@@ -245,11 +245,8 @@ def train(self, cui: str,
245245
cui, entity, doc, per_doc_valid_token_cache, negative, names)
246246

247247
@classmethod
248-
def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
249-
model_load_path: Optional[str]) -> list[Any]:
250-
return [cdb, vocab, cdb.config]
251-
252-
@classmethod
253-
def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
254-
model_load_path: Optional[str]) -> dict[str, Any]:
255-
return {}
248+
def create_new_component(
249+
cls, cnf: ComponentConfig, tokenizer: BaseTokenizer,
250+
cdb: CDB, vocab: Vocab, model_load_path: Optional[str]
251+
) -> 'Linker':
252+
return cls(cdb, vocab, cdb.config)
Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
1-
from typing import Any, Optional
1+
from typing import Optional
22

33
from medcat.components.types import CoreComponentType, AbstractCoreComponent
44
from medcat.tokenizing.tokens import MutableDocument
55
from medcat.tokenizing.tokenizers import BaseTokenizer
66
from medcat.cdb.cdb import CDB
77
from medcat.vocab import Vocab
8+
from medcat.config.config import ComponentConfig
89

910

1011
class NoActionLinker(AbstractCoreComponent):
@@ -17,11 +18,8 @@ def __call__(self, doc: MutableDocument) -> MutableDocument:
1718
return doc
1819

1920
@classmethod
20-
def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
21-
model_load_path: Optional[str]) -> list[Any]:
22-
return []
23-
24-
@classmethod
25-
def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
26-
model_load_path: Optional[str]) -> dict[str, Any]:
27-
return {}
21+
def create_new_component(
22+
cls, cnf: ComponentConfig, tokenizer: BaseTokenizer,
23+
cdb: CDB, vocab: Vocab, model_load_path: Optional[str]
24+
) -> 'NoActionLinker':
25+
return cls()

medcat/components/linking/two_step_context_based_linker.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from medcat.cdb.cdb import CDB
99
from medcat.vocab import Vocab
10-
from medcat.config.config import Config, SerialisableBaseModel
10+
from medcat.config.config import Config, SerialisableBaseModel, ComponentConfig
1111
from medcat.utils.defaults import StatusTypes as ST
1212
from medcat.utils.matutils import sigmoid
1313
from medcat.utils.config_utils import temp_changed_config
@@ -255,14 +255,11 @@ def train(self, cui: str,
255255
per_doc_valid_token_cache=pdc)
256256

257257
@classmethod
258-
def get_init_args(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
259-
model_load_path: Optional[str]) -> list[Any]:
260-
return [cdb, vocab, cdb.config]
261-
262-
@classmethod
263-
def get_init_kwargs(cls, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab,
264-
model_load_path: Optional[str]) -> dict[str, Any]:
265-
return {}
258+
def create_new_component(
259+
cls, cnf: ComponentConfig, tokenizer: BaseTokenizer,
260+
cdb: CDB, vocab: Vocab, model_load_path: Optional[str]
261+
) -> 'TwoStepLinker':
262+
return cls(cdb, vocab, cdb.config)
266263

267264
@property
268265
def two_step_config(self) -> 'TwoStepLinkerConfig':

0 commit comments

Comments
 (0)