Skip to content

Commit 29ef26e

Browse files
vladd-bitshubham-s-agarwalmart-r
authored
Relation extraction llama (CogStack/MedCAT#522)
* Added files. * More additions to rel extraction. * Rel base. * Update. * Updates. * Dependency parsing. * Updates. * Added pre-training steps. * Added training & model utils. * Cleanup & fixes. * Update. * Evaluation updates for pretraining. * Removed duplicate relation storage. * Moved RE model file location. * Structure revisions. * Added custom config for RE. * Implemented custom dataset loader for RE. * More changes. * Small fix. * Latest additions to RelCAT (pipe + predictions) * Setup.py fix. * RE utils update. * rel model update. * rel dataset + tokenizer improvements. * RelCAT updates. * RelCAT saving/loading improvements. * RelCAT saving/loading improvements. * RelCAT model fixes. * Attempted gpu learning fix. Dataset label generation fixes. * Minor train dataset gen fix. * Minor train dataset gen fix No.2. * Config updates. * Gpu support fixes. Added label stats. * Evaluation stat fixes. * Cleaned stat output mode during training. * Build fix. * removed unused dependencies and fixed code formatting * Mypy compliance. * Fixed linting. * More Gpu mode train fixes. * Fixed model saving/loading issues when using other baes models. * More fixes to stat evaluation. Added proper CAT integration of RelCAT. * Setup.py typo fix. * RelCAT loading fix. * RelCAT Config changes. * Type fix. Minor additions to RelCAT model. * Type fixes. * Type corrections. * RelCAT update. * Type fixes. * Fixed type issue. * RelCATConfig: added seed param. * Adaptations to the new codebase + type fixes.. * Doc/type fixes. * Fixed input size issue for model. * Fixed issue(s) with model size and config. * RelCAT: updated configs to new style. * RelCAT: removed old refs to logging. * Fixed GPU training + added extra stat print for train set. * Type fixes. * Updated dev requirements. * Linting. * Fixed pin_memory issue when training on CPU. * Updated RelCAT dataset get + default config. * Updated RelDS generator + default config * Linting. * Updated RelDatset + config. * Pushing updates to model Made changes to: 1) Extracting given number of context tokens left and right of the entities 2) Extracting hidden state from bert for all the tokens of the entities and performing max pooling on them * Fixing formatting * Update rel_dataset.py * Update rel_dataset.py * Update rel_dataset.py * RelCAT: added test resource files. * RelCAT: Fixed model load/checkpointing. * RelCAT: updated to pipe spacy doc call. * RelCAT: added tests. * Fixed lint/type issues & added rel tag to test DS. * Fixed ann id to token issue. * RelCAT: updated test dataset + tests. * RelCAT: updates to requested changes + dataset improvements. * RelCAT: updated docs/logs according to commends. * RelCAT: type fix. * RelCAT: mct export dataset updates. * RelCAT: test updates + requested changes p2. * RelCAT: log for MCT export train. * Updated docs + split train_test & dataset for benchmarks. * type fixes. * RelCAT: Initial Llama integration. * RelCAT: updates to Llama impl. * RelCAT: model typo fix. * RelCAT: label_id /sample no. mixup fix. * Updated cleaned up Relataset, added new ways to create relations via anno types (doc/export only for now). * Added option to predict any text /w annotations via RelCAT. MCT export train fixes. * RelCAT: added sample limiter / class, more logging info. * RelCAT: test/train ds shuffle update. * RelCAT: added option to keep original text when using reldataset class. * Pushing change for stratified batching Implement stratified batching for improved class representation and balanced training * RelCAT: fixed doc processing issue + class weights. * RelCAT: class weights addtions to cfg + param. * RelCAT: added config params for Adam optimizer. * RelCAT updated default config. * RelCAT: config update + optimizer change. * RelCAT: fixed model freeze flags. * RelCAT: model optimizer save/load fix. * RelCAT: added export ent tag check. * Fixed issues when saving/loading model for class weights + inference device cast. * RelCAT: bug fix for ents that are @ EoS. * Rel Dataset updates. * Rel Dataset updates. * Pushing change for ModernBERT * Bumped transformers version. * Updated rel dataset generation from fake Spacy Docs. * ModernBert updates. * Updated RelCAT model-load/save. * Minor relCAT updates, code format. * Type check updates. * Fixed inference issue. * RelCAT: testing updates. * Type fixes. * Type fixes. * Type fixes. * Type fixes IV. * Type fixes python 3.9. * RelCAT: flake8 fixes. * RelCAT: flake8 fixes. * RelCAT: Updates (fixed model loading after save). * Fixed test. * Update RelCAT stuff for improved abstraction * Move separate model implementations to separate packages * Some minor abstraction changes * Remove accidentally copied abstract method decorator * Fix import in test * Fix RelCAT impport in pipe tests * Update base relcat model implementation to include config * Latest RelCAT module updates. * Type fixes + run issues. * Type fixes. * Fixed Llama tokenizer. * Type fixes. * Type fixes: Python3.10 adjustements. * Linting. * Fix base flake8 lint issues * Fix doc string in ConfigRelCAT.load * Fix base component init doc string * Fixed BaseComponent.load method doc string * Fix doc strings in rel_cat ml_utils * Fix doc strings in rel_cat models module * Fix rel-cat test time import * Fix type casting * Align pipe tests with rel cat changes * Fix property paths in rel cat tests * Updates. * Fixed tests. * Fixed relCAT config save. * Latest fixes for model saving/loading. * Lint fix. * RelCAT cfg load test fix. * Remove install requirements from gitignore --------- Co-authored-by: Shubham Agarwal <66172189+shubham-s-agarwal@users.noreply.github.com> Co-authored-by: mart-r <mart.ratas@gmail.com>
1 parent 6726789 commit 29ef26e

27 files changed

+2038
-1036
lines changed

medcat-v1/.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,4 @@ tests/model_creator/output/*
5555
docs/auto/
5656
docs/_build
5757

58+
models/

medcat-v1/install_requires.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
'gensim>=4.3.0,<5.0.0' # 5.3.0 is first to support 3.11; avoid major version bump
44
'spacy>=3.6.0,<4.0.0' # avoid major bump
55
'scipy>=1.9.2,<1.14.0' # 1.9.2 is first to support 3.11; 1.14.0 does not support 3.9
6-
'transformers>=4.34.0,<5.0.0' # avoid major version bump
6+
'transformers>=4.48.1,<5.0.0' # avoid major version bump
77
'accelerate>=0.23.0' # required by Trainer class in de-id
88
'torch>=2.4.0,<3.0.0' # 2.4.0 is first to support 3.12; avoid major 3.0.0 for now
99
'tqdm>=4.27'

medcat-v1/medcat/cat.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -143,7 +143,7 @@ def _create_pipeline(self, config: Config):
143143
self.pipe.add_meta_cat(meta_cat, meta_cat.config.general.category_name)
144144

145145
for rel_cat in self._rel_cats:
146-
self.pipe.add_rel_cat(rel_cat, "_".join(list(rel_cat.config.general["labels2idx"].keys())))
146+
self.pipe.add_rel_cat(rel_cat, "_".join(list(rel_cat.component.relcat_config.general["labels2idx"].keys())))
147147

148148
# Set max document length
149149
self.pipe.spacy_nlp.max_length = config.preprocessing.max_document_length

medcat-v1/medcat/config_rel_cat.py

Lines changed: 85 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1+
import os
12
import logging
2-
from typing import Dict, Any, List
3+
from typing import Any, Dict, List, Tuple, Union, cast
34
from medcat.config import MixingConfig, BaseModel, Optional
45

56

@@ -21,10 +22,14 @@ class General(MixingConfig, BaseModel):
2122
window_size: int = 300
2223
"""Max acceptable dinstance between entities (in characters), care when using this as it can produce sentences that are over 512 tokens (limit is given by tokenizer)"""
2324

24-
mct_export_max_non_rel_sample_size:int = 200
25+
limit_samples_per_class: int = -1
26+
"""Number of samples per class, this limit is applied for train samples, so if train samples are 100 then test would be 20."""
27+
addl_rels_max_sample_size:int = 200
2528
"""Limit the number of 'Other' samples selected for training/test. This is applied per encountered medcat project, sample_size/num_projects. """
26-
mct_export_create_addl_rels: bool = False
27-
"""When processing relations from a MedCAT export, relations labeled as 'Other' are created from all the annotations pairs available"""
29+
create_addl_rels: bool = False
30+
"""When processing relations from a MedCAT export/docs, relations labeled as 'Other' are created from all the annotations pairs available"""
31+
create_addl_rels_by_type: bool = False
32+
"""When creating the 'Other' relation class, actually split this class into subclasses based on concept types"""
2833

2934
tokenizer_name: str = "bert"
3035
"""The name of the tokenizer user.
@@ -46,21 +51,47 @@ class General(MixingConfig, BaseModel):
4651
"""Tokenizer.
4752
4853
NB! For these changes to take effect, the pipe would need to be recreated."""
49-
annotation_schema_tag_ids: List = []
54+
annotation_schema_tag_ids: List = [30522, 30523, 30524, 30525]
5055
"""If a foreign non-MCAT trainer dataset is used, you can insert your own Rel entity token delimiters into the tokenizer, \
51-
copy those token IDs here, and also resize your tokenizer embeddings and adjust the hidden_size of the model, this will depend on the number of tokens you introduce"""
52-
labels2idx: Dict = {}
53-
idx2labels: Dict = {}
56+
copy those token IDs here, and also resize your tokenizer embeddings and adjust the hidden_size of the model, this will depend on the number of tokens you introduce
57+
for example: 30522 - [s1], 30523 - [e1], 30524 - [s2], 30525 - [e2], 30526 - [BLANK], 30527 - [ENT1], 30528 - [ENT2], 30529 - [/ENT1], 30530 - [/ENT2]
58+
Please note that the tokenizer special tokens are supposed to be in pairs of two for example [s1] and [e1], [s2] and [e2], the [BLANK] is just an example placeholder token
59+
If you have more than four tokens here then you need to make sure they are present in the text,
60+
otherwise the pipeline will throw an error in the get_annotation_schema_tag() function.
61+
"""
62+
63+
tokenizer_relation_annotation_special_tokens_tags: List[str] = ["[s1]", "[e1]", "[s2]", "[e2]"]
64+
65+
tokenizer_other_special_tokens: Dict[str, str] = {"pad_token": "[PAD]"}
66+
"""
67+
The special tokens used by the tokenizer. The {PAD} is for Lllama tokenizer."""
68+
69+
labels2idx: Dict[str, int] = {}
70+
idx2labels: Dict[int, str] = {}
71+
5472
pin_memory: bool = True
73+
"""If True the data loader will copy the tensors to the GPU pinned memory"""
74+
5575
seed: int = 13
5676
"""The seed for random number generation.
5777
58-
NOTE: If used along MetaCAT or additional NER, only one of the seeds will take effect
5978
NB! For these changes to take effect, the pipe would need to be recreated."""
6079
task: str = "train"
61-
"""The task for RelCAT.
80+
"""The task for RelCAT."""
6281

63-
NB! For these changes to take effect, the pipe would need to be recreated."""
82+
language: str = "en"
83+
"""Used for Spacy lang setting"""
84+
85+
@classmethod
86+
def convert_keys_to_int(cls, value):
87+
if isinstance(value, dict):
88+
return {int(k): v for k, v in value.items()}
89+
return value
90+
91+
def __setattr__(self, key: str, value: Any):
92+
if key == "idx2labels" and isinstance(value, dict):
93+
value = self.convert_keys_to_int(value) # Ensure conversion
94+
super().__setattr__(key, value)
6495

6596

6697
class Model(MixingConfig, BaseModel):
@@ -82,12 +113,18 @@ class Model(MixingConfig, BaseModel):
82113
num_directions: int = 2
83114
"""2 - bidirectional model, 1 - unidirectional"""
84115

116+
freeze_layers: bool = True
117+
"""If we update the weights during training"""
118+
85119
padding_idx: int = -1
86120
emb_grad: bool = True
87121
"""If True the embeddings will also be trained"""
88122
ignore_cpos: bool = False
89123
"""If set to True center positions will be ignored when calculating representation"""
90124

125+
llama_use_pooled_output: bool = False
126+
"""If set to True, used only in Llama model, it will add the extra tensor formed from selecting the max of the last hidden layer"""
127+
91128
class Config:
92129
extra = 'allow'
93130
validate_assignment = True
@@ -98,9 +135,24 @@ class Train(MixingConfig, BaseModel):
98135
nclasses: int = 2
99136
"""Number of classes that this model will output"""
100137
batch_size: int = 25
138+
"""batch size"""
101139
nepochs: int = 1
140+
"""Epochs"""
102141
lr: float = 1e-4
103-
adam_epsilon: float = 1e-4
142+
"""Learning rate"""
143+
stratified_batching: bool = False
144+
"""Train the model with stratified batching"""
145+
batching_samples_per_class: list = []
146+
"""Number of samples per class in each batch
147+
example for batch size 64: [6,6,6,8,8,8,6,8,8]"""
148+
batching_minority_limit: Union[List[int], int] = 0
149+
"""Maximum number of samples the minority class can have.
150+
Since the minority class elements need to be repeated, this is used to facilitate that
151+
example: batching_samples_per_class - [6,6,6,8,8,8,6,8,8]
152+
batching_minority_limit - 6"""
153+
adam_betas: Tuple[float, float] = (0.9, 0.999)
154+
adam_weight_decay: float = 0
155+
adam_epsilon: float = 1e-8
104156
test_size: float = 0.2
105157
gradient_acc_steps: int = 1
106158
multistep_milestones: List[int] = [
@@ -109,7 +161,8 @@ class Train(MixingConfig, BaseModel):
109161
max_grad_norm: float = 1.0
110162
shuffle_data: bool = True
111163
"""Used only during training, if set the dataset will be shuffled before train/test split"""
112-
class_weights: Optional[Any] = None
164+
class_weights: Union[List[float], None] = None
165+
enable_class_weights: bool = False
113166
score_average: str = "weighted"
114167
"""What to use for averaging F1/P/R across labels"""
115168
auto_save_model: bool = True
@@ -129,3 +182,22 @@ class ConfigRelCAT(MixingConfig, BaseModel):
129182
class Config:
130183
extra = 'allow'
131184
validate_assignment = True
185+
186+
@classmethod
187+
def load(cls, load_path: str = "./") -> "ConfigRelCAT":
188+
"""Load the config from a file.
189+
190+
Args:
191+
load_path (str): Path to RelCAT config. Defaults to "./".
192+
193+
Returns:
194+
ConfigRelCAT: The loaded config.
195+
"""
196+
config = cls()
197+
if os.path.exists(load_path):
198+
if "config.json" not in load_path:
199+
load_path = os.path.join(load_path, "config.json")
200+
config = cast(ConfigRelCAT, super().load(load_path))
201+
logging.info("Loaded config.json")
202+
203+
return config

0 commit comments

Comments
 (0)