Skip to content
This repository was archived by the owner on Jun 30, 2025. It is now read-only.

Commit d446f07

Browse files
authored
CU-86999tnz7 resync with v1 (#74)
* CU-86999tnz7: Update usage of constants for train/test splitting * CU-86999tnz7: Add alternative category names * CU-86999tnz7: Fix hashes due to config changes * CU-86999tnz7: Allow (optionally) training addons (e.g MetaCAT) during supervised training * CU-86999tnz7: Add optional change description when saving model * CU-86999tnz7: Add a test for description upon save * CU-86999tnz7: Fix typing during tests * CU-86999tnz7: Fix issues with extra labels, add relevant tests * CU-86999tnz7: Ported DeID improvments * CU-86999tnz7: Add missing resource files * CU-86999tnz7: Add preprocessors for UMLS and Snomed * CU-86999tnz7: Add usage monitoring * CU-86999tnz7: Use promise of a hash for usage monitoring * CU-86999tnz7: Allow 15 minutes for tests within main workflow * CU-86999tnz7: Add README for release scripts * CU-86999tnz7: Allowing conversion of beta namespaces to proper ones during deserialisation * CU-86999tnz7: Allow clearing unpacked data when saving model pack * CU-86999tnz7: Make sure model pack path refers to existing file/folder * CU-86999tnz7: Add base backwards compatibility stuff. That is, added things to create a fake model and a script to create + check model, as well as run a simple test on vocab. * CU-86999tnz7: Run model regression during workflow * CU-86999tnz7: Fix vocab data path during regression vocab test * CU-86999tnz7: Fix typo in script * CU-86999tnz7: Add hash to custom names unless explicitly disabled * CU-86999tnz7: Add backwards compatibiltiy script * CU-86999tnz7: Run backwards compatibility as part of workflow * CU-86999tnz7: Avoid runtime warnings due to config namespaces * CU-86999tnz7: Add initial multiprocessing option * CU-86999tnz7: Add minor tests for batching * CU-86999tnz7: Allow text index to be a string. Add doc string to multiprocessing method * CU-86999tnz7: Allow batching on a per-character basis * CU-86999tnz7: Add a few tests for a per-character batching * CU-86999tnz7: Fix issue with resulting text indices for multiprocessing * CU-86999tnz7: Add minor multiprocessing test * CU-86999tnz7: Allow an extra 5 minutes for workflow /tests
1 parent 03ca0e6 commit d446f07

34 files changed

+2929
-58
lines changed

.github/workflows/main.yml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,4 +34,10 @@ jobs:
3434
uv run ruff check medcat2 --preview
3535
- name: Test
3636
run: |
37-
timeout 10m uv run python -m unittest discover
37+
timeout 20m uv run python -m unittest discover
38+
- name: Model regression
39+
run: |
40+
uv run bash tests/backwards_compatibility/run_current.sh
41+
- name: Backwards compatibility
42+
run: |
43+
uv run bash tests/backwards_compatibility/check_backwards_compatibility.sh

.release/README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Releases
2+
3+
The scripts within here are designed to help preparing for and dealing with releases.
4+
5+
The main idea is to use the `prepare_release.sh` script from within the root of the project and it will delegate either to `prepare_minor_release.sh` or `prepare_patch_release.sh` as necessary.
6+
The workflow within the scripts is as follows:
7+
- Create or check out release branch (`release/v<major>.<minor>`)
8+
- Update version in `pyproject.toml`
9+
- Create a tag based on the version
10+
- Push both the branch as well as the tag to `origin`
11+
12+
The general usage for a minor release based on the `main` branch from within the **root of the project** is simply:
13+
```
14+
bash .release/prepare_release.sh <major>.<minor>.0
15+
```
16+
and the usage for a patch release (from within the **root of the project**) is in the format
17+
```
18+
bash .release/prepare_release.sh <major>.<minor>.<patch> <hash 1> <hash 2> ...
19+
```
20+
where `hash 1` and `hash 2` (and so on) refer to the commit hashes that need to be included / cherry-picked in the patch release.
21+

medcat2/cat.py

Lines changed: 223 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
1-
from typing import Optional, Union, Any, overload, Literal
1+
from typing import Optional, Union, Any, overload, Literal, Iterable, Iterator
2+
from typing import cast
23
import os
34
import json
5+
from datetime import date
6+
from concurrent.futures import ProcessPoolExecutor, as_completed, Future
7+
import itertools
48

59
import shutil
610
import logging
@@ -23,6 +27,7 @@
2327
from medcat2.components.addons.addons import AddonComponent
2428
from medcat2.utils.legacy.identifier import is_legacy_model_pack
2529
from medcat2.utils.defaults import AVOID_LEGACY_CONVERSION_ENVIRON
30+
from medcat2.utils.usage_monitoring import UsageMonitor
2631

2732

2833
logger = logging.getLogger(__name__)
@@ -51,6 +56,8 @@ def __init__(self,
5156

5257
self._trainer: Optional[Trainer] = None
5358
self._pipeline = self._recreate_pipe(model_load_path)
59+
self.usage_monitor = UsageMonitor(
60+
self._get_hash, self.config.general.usage_monitor)
5461

5562
def _recreate_pipe(self, model_load_path: Optional[str] = None
5663
) -> Pipeline:
@@ -75,7 +82,10 @@ def ignore_attrs(cls) -> list[str]:
7582
]
7683

7784
def __call__(self, text: str) -> Optional[MutableDocument]:
78-
return self._pipeline.get_doc(text)
85+
doc = self._pipeline.get_doc(text)
86+
if self.usage_monitor.should_monitor:
87+
self.usage_monitor.log_inference(len(text), len(doc.final_ents))
88+
return doc
7989

8090
def _ensure_not_training(self) -> None:
8191
"""Method to ensure config is not set to train.
@@ -139,6 +149,188 @@ def get_entities(self,
139149
return {}
140150
return self._doc_to_out(doc, only_cui=only_cui)
141151

152+
def _mp_worker_func(
153+
self,
154+
texts_and_indices: list[tuple[str, str, bool]]
155+
) -> list[tuple[str, str, Union[dict, Entities, OnlyCUIEntities]]]:
156+
return [
157+
(text, text_index, self.get_entities(text, only_cui=only_cui))
158+
for text, text_index, only_cui in texts_and_indices]
159+
160+
def _generate_batches_by_char_length(
161+
self,
162+
text_iter: Union[Iterator[str], Iterator[tuple[str, str]]],
163+
batch_size_chars: int,
164+
only_cui: bool,
165+
) -> Iterator[list[tuple[str, str, bool]]]:
166+
docs: list[tuple[str, str, bool]] = []
167+
char_count = 0
168+
for i, _doc in enumerate(text_iter):
169+
# NOTE: not sure why mypy is complaining here
170+
doc = cast(
171+
str, _doc[1] if isinstance(_doc, tuple) else _doc)
172+
doc_index: str = _doc[0] if isinstance(_doc, tuple) else str(i)
173+
clen = len(doc)
174+
char_count += clen
175+
if char_count > batch_size_chars:
176+
yield docs
177+
docs = []
178+
char_count = clen
179+
docs.append((doc_index, doc, only_cui))
180+
181+
if len(docs) > 0:
182+
yield docs
183+
184+
def _generate_batches(
185+
self,
186+
text_iter: Union[Iterator[str], Iterator[tuple[str, str]]],
187+
batch_size: int,
188+
batch_size_chars: int,
189+
only_cui: bool,
190+
) -> Iterator[list[tuple[str, str, bool]]]:
191+
if batch_size_chars < 1 and batch_size < 1:
192+
raise ValueError("Either `batch_size` or `batch_size_chars` "
193+
"must be greater than 0.")
194+
if batch_size > 0 and batch_size_chars > 0:
195+
raise ValueError(
196+
"Cannot specify both `batch_size` and `batch_size_chars`. "
197+
"Please use one of them.")
198+
if batch_size_chars > 0:
199+
return self._generate_batches_by_char_length(
200+
text_iter, batch_size_chars, only_cui)
201+
else:
202+
return self._generate_simple_batches(
203+
text_iter, batch_size, only_cui)
204+
205+
def _generate_simple_batches(
206+
self,
207+
text_iter: Union[Iterator[str], Iterator[tuple[str, str]]],
208+
batch_size: int,
209+
only_cui: bool,
210+
) -> Iterator[list[tuple[str, str, bool]]]:
211+
text_index = 0
212+
while True:
213+
# Take a small batch from the iterator
214+
batch = list(itertools.islice(text_iter, batch_size))
215+
if not batch:
216+
break
217+
# NOTE: typing is correct:
218+
# - if str, then (str, int, bool)
219+
# - if tuple, then (str, int, bool)
220+
# but for some reason mypy complains
221+
yield [
222+
(text, str(text_index + i), only_cui) # type: ignore
223+
if isinstance(text, str) else
224+
(text[1], text[0], only_cui)
225+
for i, text in enumerate(batch)
226+
]
227+
text_index += len(batch)
228+
229+
def _mp_one_batch_per_process(
230+
self,
231+
executor: ProcessPoolExecutor,
232+
batch_iter: Iterator[list[tuple[str, str, bool]]],
233+
external_processes: int
234+
) -> Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]:
235+
futures: list[Future] = []
236+
# submit batches, one for each external processes
237+
for _ in range(external_processes):
238+
try:
239+
batch = next(batch_iter)
240+
futures.append(
241+
executor.submit(self._mp_worker_func, batch))
242+
except StopIteration:
243+
break
244+
# Main process works on next batch while workers are busy
245+
main_batch: Optional[list[tuple[str, str, bool]]]
246+
try:
247+
main_batch = next(batch_iter)
248+
main_results = self._mp_worker_func(main_batch)
249+
250+
# Yield main process results immediately
251+
for result in main_results:
252+
yield result[1], result[2]
253+
254+
except StopIteration:
255+
main_batch = None
256+
# since the main process did around the same amount of work
257+
# we would expect all subprocess to have finished by now
258+
# so we're going to wait for them to finish, yield their results,
259+
# and subsequently submit the next batch to keep them busy
260+
for _ in range(external_processes):
261+
# Wait for any future to complete
262+
done_future = next(as_completed(futures))
263+
futures.remove(done_future)
264+
265+
# Yield all results from this batch
266+
for result in done_future.result():
267+
yield result[1], result[2]
268+
269+
# Submit next batch to keep workers busy
270+
try:
271+
batch = next(batch_iter)
272+
futures.append(
273+
executor.submit(self._mp_worker_func, batch))
274+
except StopIteration:
275+
# NOTE: if there's nothing to batch, we've got nothing
276+
# to submit in terms of new work to the workers,
277+
# but we may still have some futures to wait for
278+
pass
279+
280+
def get_entities_multi_texts(
281+
self,
282+
texts: Union[Iterable[str], Iterable[tuple[str, str]]],
283+
only_cui: bool = False,
284+
n_process: int = 1,
285+
batch_size: int = -1,
286+
batch_size_chars: int = 1_000_000,
287+
) -> Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]:
288+
"""Get entities from multiple texts (potentially in parallel).
289+
290+
If `n_process` > 1, `n_process - 1` new processes will be created
291+
and data will be processed on those as well as the main process in
292+
parallel.
293+
294+
Args:
295+
texts (Union[Iterable[str], Iterable[tuple[str, str]]]):
296+
The input text. Either an iterable of raw text or one
297+
with in the format of `(text_index, text)`.
298+
only_cui (bool):
299+
Whether to only return CUIs rather than other information
300+
like start/end and annotated value. Defaults to False.
301+
n_process (int):
302+
Number of processes to use. Defaults to 1.
303+
batch_size (int):
304+
The number of texts to batch at a time. A batch of the
305+
specified size will be given to each worker process.
306+
Defaults to -1 and in this case the character count will
307+
be used instead.
308+
batch_size_chars (int):
309+
The maximum number of characters to process in a batch.
310+
Each process will be given batch of texts with a total
311+
number of characters not exceeding this value. Defaults
312+
to 1,000,000 characters. Set to -1 to disable.
313+
314+
Yields:
315+
Iterator[tuple[str, Union[dict, Entities, OnlyCUIEntities]]]:
316+
The results in the format of (text_index, entities).
317+
"""
318+
text_iter = cast(
319+
Union[Iterator[str], Iterator[tuple[str, str]]], iter(texts))
320+
batch_iter = self._generate_batches(
321+
text_iter, batch_size, batch_size_chars, only_cui)
322+
if n_process == 1:
323+
# just do in series
324+
for batch in batch_iter:
325+
for text_index, _, result in self._mp_worker_func(batch):
326+
yield text_index, result
327+
return
328+
329+
external_processes = n_process - 1
330+
with ProcessPoolExecutor(max_workers=external_processes) as executor:
331+
yield from self._mp_one_batch_per_process(
332+
executor, batch_iter, external_processes)
333+
142334
def _get_entity(self, ent: MutableEntity,
143335
doc_tokens: list[str],
144336
cui: str) -> Entity:
@@ -253,6 +445,9 @@ def save_model_pack(
253445
self, target_folder: str, pack_name: str = DEFAULT_PACK_NAME,
254446
serialiser_type: Union[str, AvailableSerialisers] = 'dill',
255447
make_archive: bool = True,
448+
only_archive: bool = False,
449+
add_hash_to_pack_name: bool = True,
450+
change_description: Optional[str] = None,
256451
) -> str:
257452
"""Save model pack.
258453
@@ -268,14 +463,22 @@ def save_model_pack(
268463
The serialiser type. Defaults to 'dill'.
269464
make_archive (bool):
270465
Whether to make the arhive /.zip file. Defaults to True.
466+
only_archive (bool):
467+
Whether to clear the non-compressed folder. Defaults to False.
468+
add_hash_to_pack_name (bool):
469+
Whether to add the hash to the pack name. This is only relevant
470+
if pack_name is specified. Defaults to True.
471+
change_description (Optional[str]):
472+
If provided, this the description will be added to the
473+
model description. Defaults to None.
271474
272475
Returns:
273476
str: The final model pack path.
274477
"""
275478
self.config.meta.mark_saved_now()
276479
# figure out the location/folder of the saved files
277-
hex_hash = self._versioning()
278-
if pack_name == DEFAULT_PACK_NAME:
480+
hex_hash = self._versioning(change_description)
481+
if pack_name == DEFAULT_PACK_NAME or add_hash_to_pack_name:
279482
pack_name = f"{pack_name}_{hex_hash}"
280483
model_pack_path = os.path.join(target_folder, pack_name)
281484
# ensure target folder and model pack folder exist
@@ -294,9 +497,16 @@ def save_model_pack(
294497
if make_archive:
295498
shutil.make_archive(model_pack_path, 'zip',
296499
root_dir=model_pack_path)
500+
if only_archive:
501+
logger.info("Removing the non-archived model pack folder: %s",
502+
model_pack_path)
503+
shutil.rmtree(model_pack_path, ignore_errors=True)
504+
# change the model pack path to the zip file so that we
505+
# refer to an existing file
506+
model_pack_path += ".zip"
297507
return model_pack_path
298508

299-
def _versioning(self) -> str:
509+
def _get_hash(self) -> str:
300510
hasher = Hasher()
301511
logger.debug("Hashing the CDB")
302512
hasher.update(self.cdb.get_hash())
@@ -306,6 +516,14 @@ def _versioning(self) -> str:
306516
type(component).__name__)
307517
hasher.update(component.get_hash())
308518
hex_hash = self.config.meta.hash = hasher.hexdigest()
519+
return hex_hash
520+
521+
def _versioning(self, change_description: Optional[str]) -> str:
522+
date_today = date.today().strftime("%d %B %Y")
523+
if change_description is not None:
524+
self.config.meta.description += (
525+
f"\n[{date_today}] {change_description}")
526+
hex_hash = self._get_hash()
309527
history = self.config.meta.history
310528
if not history or history[-1] != hex_hash:
311529
history.append(hex_hash)

0 commit comments

Comments
 (0)