Skip to content
This repository was archived by the owner on Jul 28, 2025. It is now read-only.

Commit 184feea

Browse files
committed
RelCAT: testing updates.
1 parent c8d4aaa commit 184feea

File tree

1 file changed

+31
-3
lines changed

1 file changed

+31
-3
lines changed

tests/test_rel_cat.py

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,20 @@
22
import shutil
33
import unittest
44
import json
5+
import logging
56

67
from medcat.cdb import CDB
78
from medcat.config_rel_cat import ConfigRelCAT
89
from medcat.rel_cat import RelCAT
910
from medcat.utils.relation_extraction.tokenizer import TokenizerWrapperBERT
11+
from medcat.utils.relation_extraction.rel_dataset import RelData
1012

1113
from transformers.models.auto.tokenization_auto import AutoTokenizer
1214

1315
import spacy
1416
from spacy.tokens import Span, Doc
1517

18+
1619
class RelCATTests(unittest.TestCase):
1720

1821
@classmethod
@@ -24,6 +27,7 @@ def setUpClass(cls) -> None:
2427
config.train.nclasses = 3
2528
config.model.hidden_size= 256
2629
config.model.model_size = 2304
30+
config.general.log_level = logging.DEBUG
2731

2832
tokenizer = TokenizerWrapperBERT(AutoTokenizer.from_pretrained(
2933
pretrained_model_name_or_path=config.general.model_name,
@@ -58,13 +62,37 @@ def setUpClass(cls) -> None:
5862
cls.finished = False
5963
cls.tokenizer = tokenizer
6064

65+
def test_dataset_relation_parser(self) -> None:
66+
67+
samples = [
68+
"The [s1]45-year-old male[e1] was diagnosed with [s2]hypertension[e2] during his routine check-up.",
69+
"The patient’s [s1]chest pain[e1] was associated with [s2]shortness of breath[e2].",
70+
"[s1]Blood pressure[e1] readings of [s2]160/90 mmHg[e2] indicated possible hypertension.",
71+
"His elevated [s1]blood glucose[e1] level of [s2]220 mg/dL[e2] raised concerns about his diabetes management.",
72+
"The doctor recommended a [s1]cardiac enzyme test[e1] to assess the risk of [s2]myocardial infarction[e2].",
73+
"The patient’s [s1]ECG[e1] showed signs of [s2]ischemia[e2]",
74+
"To manage his [s1]hypertension[e1], the patient was advised to [s2]reduce salt intake[e2].",
75+
"[s1]Increased physical activity[e1][s2]type 2 diabetes[e2]."
76+
]
77+
78+
rel_dataset = RelData(cdb=self.rel_cat.cdb, config=self.config_rel_cat, tokenizer=self.tokenizer)
79+
80+
rels = []
81+
82+
for idx in range(len(samples)):
83+
tkns = self.tokenizer(samples[idx])["tokens"]
84+
ent1_ent2_tokens_start_pos = (tkns.index("[s1]"), tkns.index("[s2]"))
85+
rels.append(rel_dataset.create_base_relations_from_doc(samples[idx], idx,
86+
ent1_ent2_tokens_start_pos=ent1_ent2_tokens_start_pos))
87+
88+
assert len(rels) == len(samples)
89+
90+
6191
def test_train_csv_no_tags(self) -> None:
6292
self.rel_cat.config.train.epochs = 2
6393
self.rel_cat.train(train_csv_path=self.medcat_rels_csv_path_train, test_csv_path=self.medcat_rels_csv_path_test, checkpoint_path=self.tmp_dir)
6494
self.rel_cat.save(self.save_model_path)
6595

66-
67-
6896
def test_train_mctrainer(self) -> None:
6997
self.rel_cat = RelCAT.load(self.save_model_path)
7098
self.rel_cat.config.general.create_addl_rels = True
@@ -77,7 +105,6 @@ def test_train_mctrainer(self) -> None:
77105
self.rel_cat.train(export_data_path=self.medcat_export_with_rels_path, checkpoint_path=self.tmp_dir)
78106

79107

80-
81108
def test_train_predict(self) -> None:
82109
Span.set_extension('id', default=0, force=True)
83110
Span.set_extension('cui', default=None, force=True)
@@ -103,6 +130,7 @@ def test_train_predict(self) -> None:
103130

104131
assert len(doc._.relations) > 0
105132

133+
106134
def tearDown(self) -> None:
107135
if self.finished:
108136
if os.path.exists(self.tmp_dir):

0 commit comments

Comments
 (0)