From d307dab3cebc03894df685a8d1b4c25a1b539dd9 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 16 Apr 2026 15:24:21 +0100 Subject: [PATCH 01/15] CU-869cy3yz9: Add unsupervised training method to trainable component protocol --- medcat-v2/medcat/components/types.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/medcat-v2/medcat/components/types.py b/medcat-v2/medcat/components/types.py index e883d4a20..fe26bb1e0 100644 --- a/medcat-v2/medcat/components/types.py +++ b/medcat-v2/medcat/components/types.py @@ -168,6 +168,17 @@ def get_hash(self) -> str: @runtime_checkable class TrainableComponent(Protocol): + def train_unsupervised(self, doc: MutableDocument) -> None: + """Train unsupervised based on the given document. + + If this component doesn't support unsupervised training, + this method can be a no-op. + + Args: + doc (MutableDocument): The document to train on. + """ + pass + def train(self, cui: str, entity: MutableEntity, doc: MutableDocument, From acf623be09995afac10b20463c54ed561ffa5b61 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 16 Apr 2026 15:27:24 +0100 Subject: [PATCH 02/15] CU-869cy3yz9: Follow the intercace for unsupervised trianing in context based linker --- .../medcat/components/linking/context_based_linker.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/medcat-v2/medcat/components/linking/context_based_linker.py b/medcat-v2/medcat/components/linking/context_based_linker.py index ce389864f..1451e2ef6 100644 --- a/medcat-v2/medcat/components/linking/context_based_linker.py +++ b/medcat-v2/medcat/components/linking/context_based_linker.py @@ -210,7 +210,9 @@ def predict_entities(self, doc: MutableDocument, raise ValueError("Need to have NER'ed entities provided") if cnf_l.train: - linked_entities = self._train_on_doc(doc, ents) + raise ValueError( + "Use the new train_unsupervised method for unsuperivsed training " + "instead of a regular inference call with a changed config entry.") else: linked_entities = self._inference(doc, ents) # evaluating generator here because the `all_ents` list gets @@ -227,6 +229,11 @@ def predict_entities(self, doc: MutableDocument, return filter_linked_annotations( doc, le, self.config.general.show_nested_entities) + # TrainableComponent + + def train_unsupervised(self, doc: MutableDocument) -> None: + self._train_on_doc(doc, doc.ner_ents) + def train(self, cui: str, entity: MutableEntity, doc: MutableDocument, From f204535cab00ffad609ebee9c74368125a19b77a Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Thu, 16 Apr 2026 15:31:23 +0100 Subject: [PATCH 03/15] CU-869cy3xa0: Use new interface for self-supervised training --- medcat-v2/medcat/trainer.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/medcat-v2/medcat/trainer.py b/medcat-v2/medcat/trainer.py index a48691a4a..f5c7b6723 100644 --- a/medcat-v2/medcat/trainer.py +++ b/medcat-v2/medcat/trainer.py @@ -64,10 +64,8 @@ def train_unsupervised(self, with self.config.meta.prepare_and_report_training( data_iterator, nepochs, False ) as wrapped_iter: - with temp_changed_config(self.config.components.linking, - 'train', True): - self._train_unsupervised(wrapped_iter, nepochs, fine_tune, - progress_print) + self._train_unsupervised(wrapped_iter, nepochs, fine_tune, + progress_print) def _train_unsupervised(self, data_iterator: Iterable, @@ -91,11 +89,18 @@ def _train_unsupervised(self, # Convert to string line = str(line).strip() + + # inference run for the document try: - _ = self.caller(line) + doc = self.caller(line) except Exception as e: logger.warning("LINE: '%s...' \t WAS SKIPPED", line[0:100]) logger.warning("BECAUSE OF:", exc_info=e) + continue + for comp in self._pipeline.iter_all_components(): + if isinstance(comp, TrainableComponent): + logger.debug("Training on component %s", comp.full_name) + comp.train_unsupervised(doc) else: logger.warning("EMPTY LINE WAS DETECTED AND SKIPPED") From b1cc70ab3ffac07a0793406ed2dded3fc657ed2d Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 21 Apr 2026 11:30:35 +0100 Subject: [PATCH 04/15] CU-869cy3yz9: Fix fake pipe in tests --- medcat-v2/tests/test_trainer.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/medcat-v2/tests/test_trainer.py b/medcat-v2/tests/test_trainer.py index 84703713d..55f0f6310 100644 --- a/medcat-v2/tests/test_trainer.py +++ b/medcat-v2/tests/test_trainer.py @@ -92,6 +92,9 @@ def tokenizer(self, text: str) -> FakeMutDoc: def tokenizer_with_tag(self, text: str) -> FakeMutDoc: return FakeMutDoc(text) + def iter_all_components(self): + return [] + def get_component(self, comp_type): return FakeComponent From a76fde4af3075e56aabcdcab92e305e44129f30a Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 21 Apr 2026 13:41:00 +0100 Subject: [PATCH 05/15] CU-869cy3yz9: Fix issue with unrealised generator --- medcat-v2/medcat/components/linking/context_based_linker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/medcat-v2/medcat/components/linking/context_based_linker.py b/medcat-v2/medcat/components/linking/context_based_linker.py index 1451e2ef6..2813d1744 100644 --- a/medcat-v2/medcat/components/linking/context_based_linker.py +++ b/medcat-v2/medcat/components/linking/context_based_linker.py @@ -232,7 +232,7 @@ def predict_entities(self, doc: MutableDocument, # TrainableComponent def train_unsupervised(self, doc: MutableDocument) -> None: - self._train_on_doc(doc, doc.ner_ents) + list(self._train_on_doc(doc, doc.ner_ents)) def train(self, cui: str, entity: MutableEntity, From 7d9adf5e10082716564a1a2da87ef1b8c04a3065 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 21 Apr 2026 14:07:20 +0100 Subject: [PATCH 06/15] CU-869cy3z45: Allow any component to be trained in an unsupervised manner --- medcat-v2/medcat/trainer.py | 23 ++++++++++++----------- 1 file changed, 12 insertions(+), 11 deletions(-) diff --git a/medcat-v2/medcat/trainer.py b/medcat-v2/medcat/trainer.py index f5c7b6723..7c932011c 100644 --- a/medcat-v2/medcat/trainer.py +++ b/medcat-v2/medcat/trainer.py @@ -637,18 +637,15 @@ def add_and_train_concept(self, if mut_entity is None or mut_doc is None: return - linker = self._pipeline.get_component( - CoreComponentType.linking) - if not isinstance(linker, TrainableComponent): - logger.warning( - "Linker cannot be trained during add_and_train_concept" - "because it has no train method: %s", linker) - else: + trained_comps = 0 + for component in self._pipeline.iter_all_components(): + if not isinstance(component, TrainableComponent): + continue # Train Linking if isinstance(mut_entity, list): mut_entity = self._pipeline.entity_from_tokens(mut_entity) - linker.train(cui=cui, entity=mut_entity, doc=mut_doc, - negative=negative, names=names) + component.train(cui=cui, entity=mut_entity, doc=mut_doc, + negative=negative, names=names) if not negative and devalue_others: # Find all cuis @@ -663,8 +660,12 @@ def add_and_train_concept(self, # Add negative training for all other CUIs that link to # these names for _cui in cuis: - linker.train(cui=_cui, entity=mut_entity, doc=mut_doc, - negative=True) + component.train(cui=_cui, entity=mut_entity, doc=mut_doc, + negative=True) + if trained_comps == 0: + logger.warning( + "Nothing was trained during add_and_train_concept because " + "no components followed the TrainableComponent protocol") @property def _pn_configs(self) -> tuple[General, Preprocessing, CDBMaker]: From 4a71c1cbda68152fc03da4d1d7eb7997f57ab804 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 21 Apr 2026 14:07:50 +0100 Subject: [PATCH 07/15] CU-869cy3z45: Remove unused import --- medcat-v2/medcat/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/medcat-v2/medcat/trainer.py b/medcat-v2/medcat/trainer.py index 7c932011c..f1c65eeca 100644 --- a/medcat-v2/medcat/trainer.py +++ b/medcat-v2/medcat/trainer.py @@ -15,7 +15,7 @@ MedCATTrainerExport, MedCATTrainerExportAnnotation, MedCATTrainerExportProject, MedCATTrainerExportDocument, count_all_annotations, iter_anns) from medcat.preprocessors.cleaners import prepare_name, NameDescriptor -from medcat.components.types import CoreComponentType, TrainableComponent +from medcat.components.types import TrainableComponent from medcat.components.addons.addons import AddonComponent from medcat.pipeline import Pipeline From 4c25292c61ea5daf6af0ae75af34afdde8d789bb Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 21 Apr 2026 14:15:34 +0100 Subject: [PATCH 08/15] CU-869cy3yz9: Add a few more tests for trainable components --- medcat-v2/tests/test_trainer.py | 53 +++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/medcat-v2/tests/test_trainer.py b/medcat-v2/tests/test_trainer.py index 55f0f6310..9e3808e5d 100644 --- a/medcat-v2/tests/test_trainer.py +++ b/medcat-v2/tests/test_trainer.py @@ -84,6 +84,20 @@ class FakeComponent: pass +class FakeTrainableNERComponent: + + full_name = "ner:fake" + + def __init__(self): + self.docs_trained_on = [] + + def train_unsupervised(self, doc: MutableDocument) -> None: + self.docs_trained_on.append(doc) + + def train(self, *args, **kwargs) -> None: + return + + class FakePipeline: def tokenizer(self, text: str) -> FakeMutDoc: @@ -102,6 +116,15 @@ def entity_from_tokens_in_doc(self, tkns: list, doc: MutableDocument) -> FakeMut return FakeMutEnt(doc, tkns[0].index, tkns[-1].index) +class FakePipelineWithComponents(FakePipeline): + + def __init__(self, components: list): + self._components = components + + def iter_all_components(self): + return self._components + + class TrainerTestsBase(unittest.TestCase): DATA_CNT = 14 TRAIN_DATA = [ @@ -176,6 +199,35 @@ def test_training_gets_remembered_multi(self, repeats: int = 3): exp_total=repeats, unsup=self.UNSUP) + def test_unsup_training_trains_non_linking_component(self): + ner_component = FakeTrainableNERComponent() + trainer = Trainer( + self.cdb, + self.caller, + FakePipelineWithComponents([ner_component]), + ) + trainer.config = self.cnf + + trainer.train_unsupervised(self.TRAIN_DATA, nepochs=1) + + self.assertEqual(len(ner_component.docs_trained_on), self.DATA_CNT) + self.assertTrue( + all(isinstance(doc, FakeMutDoc) for doc in ner_component.docs_trained_on) + ) + + def test_unsup_training_skips_non_trainable_components(self): + ner_component = FakeTrainableNERComponent() + trainer = Trainer( + self.cdb, + self.caller, + FakePipelineWithComponents([FakeComponent(), ner_component, object()]), + ) + trainer.config = self.cnf + + trainer.train_unsupervised(self.TRAIN_DATA, nepochs=1) + + self.assertEqual(len(ner_component.docs_trained_on), self.DATA_CNT) + class TrainerSupervisedTests(TrainerUnsupervisedTests): DATA_CNT = 1 @@ -340,3 +392,4 @@ def test_has_trained_all(self): with self.subTest(cui): info = self.model.cdb.cui2info[cui] self.assertGreater(info['count_train'], prev_count) + From 9620d9bcc739c467780d2aee62a5af632fbdcbbf Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 21 Apr 2026 15:55:06 +0100 Subject: [PATCH 09/15] CU-869cy3zb0: Add utilities to create a dataset-aware NER or linker component --- medcat-v2/medcat/utils/training_utils.py | 140 +++++++++++++++++++++++ 1 file changed, 140 insertions(+) create mode 100644 medcat-v2/medcat/utils/training_utils.py diff --git a/medcat-v2/medcat/utils/training_utils.py b/medcat-v2/medcat/utils/training_utils.py new file mode 100644 index 000000000..6a5cbaf45 --- /dev/null +++ b/medcat-v2/medcat/utils/training_utils.py @@ -0,0 +1,140 @@ +from typing import Callable, Optional, Self +from contextlib import contextmanager + +from medcat.cat import CAT +from medcat.cdb import CDB +from medcat.components.types import ( + CoreComponentType, AbstractEntityProvidingComponent) +from medcat.config.config import ComponentConfig +from medcat.tokenizing.tokenizers import BaseTokenizer +from medcat.tokenizing.tokens import MutableDocument, MutableEntity, MutableToken +from medcat.data.mctexport import ( + MedCATTrainerExport, MedCATTrainerExportDocument, count_all_docs, iter_docs) +from medcat.vocab import Vocab + + +class _CheatingComponent(AbstractEntityProvidingComponent): + name = 'cheating_component' + + def __init__(self, + comp_type: CoreComponentType, + predictor: Callable[[MutableDocument], list[MutableEntity]]): + super().__init__( + comp_type == CoreComponentType.linking, + comp_type == CoreComponentType.linking) + self._comp_type = comp_type + self._predictor = predictor + + def get_type(self) -> CoreComponentType: + return self._comp_type + + def predict_entities(self, doc: MutableDocument, + ents: list[MutableEntity] | None = None + ) -> list[MutableEntity]: + return self._predictor(doc) + + @classmethod + def create_new_component( + cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, + cdb: CDB, vocab: Vocab, model_load_path: Optional[str]) -> Self: + raise ValueError("Cannot create new compnoent of this type") + +@contextmanager +def cheating_component( + cat: CAT, + comp_type: CoreComponentType, + predictor: Callable[[MutableDocument], list[MutableEntity]]): + comps_list = cat.pipe._components + # find original index + original_comp = cat.pipe.get_component(comp_type) + replace_index = comps_list.index(original_comp) + # create and replace + cheater = _CheatingComponent(comp_type, predictor) + comps_list[replace_index] = cheater + try: + yield + finally: + # restore original component + comps_list[replace_index] = original_comp + + +def _identify_document( + doc: MutableDocument, + dataset: MedCATTrainerExport) -> MedCATTrainerExportDocument: + for proj in dataset['projects']: + for ann_doc in proj['documents']: + if ann_doc['text'] == doc.base.text: + return ann_doc + raise ValueError("Unable to identify correct document") + + +def _create_general_predictor( + dataset: MedCATTrainerExport, + tokens2entity: Callable[[list[MutableToken], MutableDocument], MutableEntity], + set_cui: bool, + ) -> Callable[[MutableDocument], list[MutableEntity]]: + def predict(doc: MutableDocument) -> list[MutableEntity]: + anns = _identify_document(doc, dataset)["annotations"] + ents: list[MutableEntity] = [] + for ann in anns: + tkns = doc.get_tokens(ann["start"], ann["end"]) + # TODO: catch possible exception? + ent = tokens2entity(tkns, doc) + if set_cui: + ent.cui = ann["cui"] + return ents + return predict + + + +def _create_linker_predictor( + dataset: MedCATTrainerExport, + tokens2entity: Callable[[list[MutableToken], MutableDocument], MutableEntity], + ) -> Callable[[MutableDocument], list[MutableEntity]]: + return _create_general_predictor(dataset, tokens2entity, True) + + +def _create_ner_predictor( + dataset: MedCATTrainerExport, + tokens2entity: Callable[[list[MutableToken], MutableDocument], MutableEntity], + ) -> Callable[[MutableDocument], list[MutableEntity]]: + return _create_general_predictor(dataset, tokens2entity, False) + + +def _create_predictor( + component_type: CoreComponentType, + dataset: MedCATTrainerExport, + tokens2entity: Callable[[list[MutableToken], MutableDocument], MutableEntity], + ) -> Callable[[MutableDocument], list[MutableEntity]]: + if component_type == CoreComponentType.linking: + return _create_linker_predictor(dataset, tokens2entity) + elif component_type == CoreComponentType.ner: + return _create_ner_predictor(dataset, tokens2entity) + raise ValueError( + f"Unable to create predictor for component {component_type}") + + +def _check_dataset(dataset: MedCATTrainerExport): + texts = set( + doc['text'] for _, doc in iter_docs(dataset) + ) + num_texts = len(texts) + num_docs = count_all_docs(dataset) + if num_texts != num_docs: + raise ValueError( + "Dataset contains documents with identical texts. " + "This means it cannot be used for dataset aware components " + "because the identification of the document isn't trivial. " + "The check found %d different texts within %d documents", + num_texts, num_docs) + + +@contextmanager +def dataset_aware_component( + cat: CAT, + comp_type: CoreComponentType, + dataset: MedCATTrainerExport): + _check_dataset(dataset) + tokens2entity = cat.pipe.tokenizer.entity_from_tokens_in_doc + predictor = _create_predictor(comp_type, dataset, tokens2entity) + yield cheating_component(cat, comp_type, predictor) From 7463586814ed396589db3725d4412cbf15744ab1 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 21 Apr 2026 16:18:05 +0100 Subject: [PATCH 10/15] CU-869cy3zb0: Fix minor issues with new utilities --- medcat-v2/medcat/utils/training_utils.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/medcat-v2/medcat/utils/training_utils.py b/medcat-v2/medcat/utils/training_utils.py index 6a5cbaf45..9be8c8b72 100644 --- a/medcat-v2/medcat/utils/training_utils.py +++ b/medcat-v2/medcat/utils/training_utils.py @@ -37,7 +37,7 @@ def predict_entities(self, doc: MutableDocument, def create_new_component( cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, cdb: CDB, vocab: Vocab, model_load_path: Optional[str]) -> Self: - raise ValueError("Cannot create new compnoent of this type") + raise ValueError("Cannot create new component of this type") @contextmanager def cheating_component( @@ -82,6 +82,7 @@ def predict(doc: MutableDocument) -> list[MutableEntity]: ent = tokens2entity(tkns, doc) if set_cui: ent.cui = ann["cui"] + ents.append(ent) return ents return predict @@ -124,9 +125,9 @@ def _check_dataset(dataset: MedCATTrainerExport): raise ValueError( "Dataset contains documents with identical texts. " "This means it cannot be used for dataset aware components " - "because the identification of the document isn't trivial. " - "The check found %d different texts within %d documents", - num_texts, num_docs) + "because the identification of the document is not trivial. " + f"The check found {num_texts} different texts within {num_docs} " + "documents") @contextmanager @@ -137,4 +138,5 @@ def dataset_aware_component( _check_dataset(dataset) tokens2entity = cat.pipe.tokenizer.entity_from_tokens_in_doc predictor = _create_predictor(comp_type, dataset, tokens2entity) - yield cheating_component(cat, comp_type, predictor) + with cheating_component(cat, comp_type, predictor): + yield From d782e0e711c00baae4996c97f3c4d545338a8aa3 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 21 Apr 2026 16:21:30 +0100 Subject: [PATCH 11/15] CU-869cy3zb0: Fix minor order of operations issue --- medcat-v2/medcat/utils/training_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/medcat-v2/medcat/utils/training_utils.py b/medcat-v2/medcat/utils/training_utils.py index 9be8c8b72..55cdd0ea4 100644 --- a/medcat-v2/medcat/utils/training_utils.py +++ b/medcat-v2/medcat/utils/training_utils.py @@ -19,10 +19,10 @@ class _CheatingComponent(AbstractEntityProvidingComponent): def __init__(self, comp_type: CoreComponentType, predictor: Callable[[MutableDocument], list[MutableEntity]]): + self._comp_type = comp_type super().__init__( comp_type == CoreComponentType.linking, comp_type == CoreComponentType.linking) - self._comp_type = comp_type self._predictor = predictor def get_type(self) -> CoreComponentType: From 3e3ae1666d4155bd0b9cdf93d2d1d7bca381599c Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 21 Apr 2026 16:24:01 +0100 Subject: [PATCH 12/15] CU-869cy3zb0: Add a few tests for training utilities --- medcat-v2/tests/utils/test_training_utils.py | 240 +++++++++++++++++++ 1 file changed, 240 insertions(+) create mode 100644 medcat-v2/tests/utils/test_training_utils.py diff --git a/medcat-v2/tests/utils/test_training_utils.py b/medcat-v2/tests/utils/test_training_utils.py new file mode 100644 index 000000000..8eab33966 --- /dev/null +++ b/medcat-v2/tests/utils/test_training_utils.py @@ -0,0 +1,240 @@ +import unittest + +from medcat.config import Config +from medcat.components.types import CoreComponentType, AbstractEntityProvidingComponent +from medcat.stats.stats import get_stats +from medcat.trainer import Trainer +from medcat.utils.training_utils import dataset_aware_component + + +class _FakeEntityBase: + + def __init__(self, start: int, end: int, text: str): + self.start_char_index = start + self.end_char_index = end + self.text = text + + +class _FakeEntity: + + def __init__(self, start: int, end: int, text: str, cui: str): + self.base = _FakeEntityBase(start, end, text) + self.cui = cui + self.context_similarity = 1.0 + + +class _FakeToken: + + def __init__(self, start: int, end: int): + self.start_char_index = start + self.end_char_index = end + + +class _FakeDocBase: + + def __init__(self, text: str): + self.text = text + + +class _FakeDoc: + + def __init__(self, text: str, cui_by_start: dict[int, str]): + self.text = text + self.base = _FakeDocBase(text) + self.cui_by_start = cui_by_start + self.ner_ents = [] + self.linked_ents = [] + + def get_tokens(self, start: int, end: int): + return [_FakeToken(start, end)] + + +class _FakeTokenizer: + + def entity_from_tokens_in_doc(self, tkns, doc: _FakeDoc): + start = tkns[0].start_char_index + end = tkns[-1].end_char_index + text = doc.text[start:end] + cui = doc.cui_by_start.get(start, "C_WRONG") + return _FakeEntity(start, end, text, cui) + + +class _EmptyNER(AbstractEntityProvidingComponent): + name = "empty_ner" + + def __init__(self): + super().__init__(False, False) + + def get_type(self) -> CoreComponentType: + return CoreComponentType.ner + + def predict_entities(self, doc, ents=None): + return [] + + +class _PassThroughLinker(AbstractEntityProvidingComponent): + name = "pass_linker" + + def __init__(self): + super().__init__(True, True) + + def get_type(self) -> CoreComponentType: + return CoreComponentType.linking + + def predict_entities(self, doc, ents=None): + return list(doc.ner_ents) + + +class _TrainablePassThroughLinker(_PassThroughLinker): + + full_name = "linking:pass_linker" + + def __init__(self): + super().__init__() + self.unsup_train_calls = 0 + + def train_unsupervised(self, doc): + self.unsup_train_calls += 1 + + def train(self, *args, **kwargs): + return + + +class _TrainableNER(_EmptyNER): + + full_name = "ner:trainable_ner" + + def __init__(self): + super().__init__() + self.unsup_train_calls = 0 + + def train_unsupervised(self, doc): + self.unsup_train_calls += 1 + + def train(self, *args, **kwargs): + return + + +class _FakeFilters: + + def __init__(self) -> None: + self.cuis = [] + self.exclude_cuis = [] + + def check_filters(self, cui: str) -> bool: + return True + + +class _FakePipeline: + + def __init__(self, components): + self._components = components + self.tokenizer = _FakeTokenizer() + + def get_component(self, comp_type): + for comp in self._components: + if comp.get_type() == comp_type: + return comp + raise KeyError(comp_type) + + def iter_all_components(self): + return self._components + + def __call__(self, doc): + for comp in self._components: + doc = comp(doc) + return doc + + +class _FakeCDB: + + def __init__(self, config): + self.config = config + self.addl_info = {} + self.cui2info = {} + + def reset_training(self): + return + + +class _FakeCat: + + def __init__(self, dataset, components): + self.config = Config() + self.config.components.linking.filters = _FakeFilters() + self.cdb = _FakeCDB(self.config) + self.pipe = _FakePipeline(components) + self._dataset = dataset + + def __call__(self, text): + by_text = {doc["text"]: doc for project in self._dataset["projects"] + for doc in project["documents"]} + ann_doc = by_text[text] + cui_by_start = {ann["start"]: ann["cui"] for ann in ann_doc["annotations"]} + doc = _FakeDoc(text, cui_by_start) + return self.pipe(doc) + + +class TrainingUtilsTests(unittest.TestCase): + + DATASET = { + "projects": [{ + "id": "P1", + "name": "P1", + "cuis": "", + "tuis": "", + "documents": [{ + "id": "D1", + "name": "D1", + "text": "abc def", + "annotations": [{"start": 0, "end": 3, "cui": "C1", "value": "abc"}], + }] + }] + } + + def test_get_stats_can_be_perfect_when_ner_and_linker_are_dataset_aware(self): + cat = _FakeCat(self.DATASET, [_EmptyNER(), _PassThroughLinker()]) + + with dataset_aware_component(cat, CoreComponentType.ner, self.DATASET): + with dataset_aware_component(cat, CoreComponentType.linking, self.DATASET): + _, fns, tps, _, _, cui_f1, _, _ = get_stats( + cat, self.DATASET, do_print=False) + + self.assertEqual(fns, {}) + self.assertEqual(tps.get("C1"), 1) + self.assertEqual(cui_f1.get("C1"), 1.0) + + def test_get_stats_can_isolate_ner_quality_by_cheating_ner_only(self): + cat = _FakeCat(self.DATASET, [_EmptyNER(), _PassThroughLinker()]) + + with dataset_aware_component(cat, CoreComponentType.ner, self.DATASET): + _, fns, tps, _, _, cui_f1, _, _ = get_stats( + cat, self.DATASET, do_print=False) + + self.assertEqual(fns, {}) + self.assertEqual(tps.get("C1"), 1) + self.assertEqual(cui_f1.get("C1"), 1.0) + + def test_train_unsupervised_can_train_only_linker_when_ner_is_cheating(self): + ner = _TrainableNER() + linker = _TrainablePassThroughLinker() + cat = _FakeCat(self.DATASET, [ner, linker]) + trainer = Trainer(cat.cdb, cat.__call__, cat.pipe) + + with dataset_aware_component(cat, CoreComponentType.ner, self.DATASET): + trainer.train_unsupervised(["abc def"], nepochs=1) + + self.assertEqual(ner.unsup_train_calls, 0) + self.assertEqual(linker.unsup_train_calls, 1) + + def test_train_unsupervised_can_train_only_ner_when_linker_is_cheating(self): + ner = _TrainableNER() + linker = _TrainablePassThroughLinker() + cat = _FakeCat(self.DATASET, [ner, linker]) + trainer = Trainer(cat.cdb, cat.__call__, cat.pipe) + + with dataset_aware_component(cat, CoreComponentType.linking, self.DATASET): + trainer.train_unsupervised(["abc def"], nepochs=1) + + self.assertEqual(ner.unsup_train_calls, 1) + self.assertEqual(linker.unsup_train_calls, 0) From 1e7e8cd67e66c1791bce1170b7b86151b112abb1 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 21 Apr 2026 16:30:11 +0100 Subject: [PATCH 13/15] CU-869cy3zb0: Add a few missing doc strings --- medcat-v2/medcat/utils/training_utils.py | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/medcat-v2/medcat/utils/training_utils.py b/medcat-v2/medcat/utils/training_utils.py index 55cdd0ea4..65972a0c6 100644 --- a/medcat-v2/medcat/utils/training_utils.py +++ b/medcat-v2/medcat/utils/training_utils.py @@ -44,6 +44,16 @@ def cheating_component( cat: CAT, comp_type: CoreComponentType, predictor: Callable[[MutableDocument], list[MutableEntity]]): + """Creates and uses a cheating component within the pipe. + + This component will "predict" entities as per the predictor it is given. + + Args: + cat (CAT): The model pack. + comp_type (CoreComponentType): The component type (generally NER or linker). + predictor (Callable[[MutableDocument], list[MutableEntity]]): + The predictor to use. + """ comps_list = cat.pipe._components # find original index original_comp = cat.pipe.get_component(comp_type) @@ -135,6 +145,17 @@ def dataset_aware_component( cat: CAT, comp_type: CoreComponentType, dataset: MedCATTrainerExport): + """Creates and uses a dataset aware component within the pipe. + + This simplfies trainin for and evaluating one component at + a time by swapping out the other component for one that has + perfect performance since it knows the dataset. + + Args: + cat (CAT): The model pack. + comp_type (CoreComponentType): The component type. + dataset (MedCATTrainerExport): The dataset in question. + """ _check_dataset(dataset) tokens2entity = cat.pipe.tokenizer.entity_from_tokens_in_doc predictor = _create_predictor(comp_type, dataset, tokens2entity) From 9185ae7c04e843ebb4ce600c1c355ecde0b30902 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Tue, 21 Apr 2026 16:57:32 +0100 Subject: [PATCH 14/15] CU-869cy3zb0: Add a few supervised training based tests --- medcat-v2/tests/utils/test_training_utils.py | 42 +++++++++++++++++++- 1 file changed, 40 insertions(+), 2 deletions(-) diff --git a/medcat-v2/tests/utils/test_training_utils.py b/medcat-v2/tests/utils/test_training_utils.py index 8eab33966..6e2549897 100644 --- a/medcat-v2/tests/utils/test_training_utils.py +++ b/medcat-v2/tests/utils/test_training_utils.py @@ -1,4 +1,5 @@ import unittest +import unittest.mock from medcat.config import Config from medcat.components.types import CoreComponentType, AbstractEntityProvidingComponent @@ -92,12 +93,13 @@ class _TrainablePassThroughLinker(_PassThroughLinker): def __init__(self): super().__init__() self.unsup_train_calls = 0 + self.sup_train_calls = 0 def train_unsupervised(self, doc): self.unsup_train_calls += 1 def train(self, *args, **kwargs): - return + self.sup_train_calls += 1 class _TrainableNER(_EmptyNER): @@ -107,12 +109,13 @@ class _TrainableNER(_EmptyNER): def __init__(self): super().__init__() self.unsup_train_calls = 0 + self.sup_train_calls = 0 def train_unsupervised(self, doc): self.unsup_train_calls += 1 def train(self, *args, **kwargs): - return + self.sup_train_calls += 1 class _FakeFilters: @@ -140,6 +143,12 @@ def get_component(self, comp_type): def iter_all_components(self): return self._components + def entity_from_tokens_in_doc(self, tkns, doc): + return self.tokenizer.entity_from_tokens_in_doc(tkns, doc) + + def tokenizer_with_tag(self, text): + return _FakeDoc(text, {}) + def __call__(self, doc): for comp in self._components: doc = comp(doc) @@ -156,6 +165,9 @@ def __init__(self, config): def reset_training(self): return + def _add_concept(self, *args, **kwargs): + return + class _FakeCat: @@ -238,3 +250,29 @@ def test_train_unsupervised_can_train_only_ner_when_linker_is_cheating(self): self.assertEqual(ner.unsup_train_calls, 1) self.assertEqual(linker.unsup_train_calls, 0) + + def test_train_supervised_can_train_only_linker_when_ner_is_cheating(self): + ner = _TrainableNER() + linker = _TrainablePassThroughLinker() + cat = _FakeCat(self.DATASET, [ner, linker]) + trainer = Trainer(cat.cdb, cat.__call__, cat.pipe) + + with unittest.mock.patch("medcat.trainer.prepare_name", return_value={"abc": {}}): + with dataset_aware_component(cat, CoreComponentType.ner, self.DATASET): + trainer.train_supervised_raw(self.DATASET, disable_progress=True) + + self.assertEqual(ner.sup_train_calls, 0) + self.assertEqual(linker.sup_train_calls, 1) + + def test_train_supervised_can_train_only_ner_when_linker_is_cheating(self): + ner = _TrainableNER() + linker = _TrainablePassThroughLinker() + cat = _FakeCat(self.DATASET, [ner, linker]) + trainer = Trainer(cat.cdb, cat.__call__, cat.pipe) + + with unittest.mock.patch("medcat.trainer.prepare_name", return_value={"abc": {}}): + with dataset_aware_component(cat, CoreComponentType.linking, self.DATASET): + trainer.train_supervised_raw(self.DATASET, disable_progress=True) + + self.assertEqual(ner.sup_train_calls, 1) + self.assertEqual(linker.sup_train_calls, 0) From 5df28709e503fcd56983eccf6ac5476e555b3311 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" Date: Wed, 22 Apr 2026 09:25:32 +0100 Subject: [PATCH 15/15] CU-869cy3zb0: Fix import of Self (from typing extensions) --- medcat-v2/medcat/utils/training_utils.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/medcat-v2/medcat/utils/training_utils.py b/medcat-v2/medcat/utils/training_utils.py index 65972a0c6..1a79efea8 100644 --- a/medcat-v2/medcat/utils/training_utils.py +++ b/medcat-v2/medcat/utils/training_utils.py @@ -1,4 +1,5 @@ -from typing import Callable, Optional, Self +from typing import Callable, Optional +from typing_extensions import Self from contextlib import contextmanager from medcat.cat import CAT