Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
15 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion medcat-v2/medcat/components/linking/context_based_linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,9 @@ def predict_entities(self, doc: MutableDocument,
raise ValueError("Need to have NER'ed entities provided")

if cnf_l.train:
linked_entities = self._train_on_doc(doc, ents)
raise ValueError(
"Use the new train_unsupervised method for unsuperivsed training "
"instead of a regular inference call with a changed config entry.")
else:
linked_entities = self._inference(doc, ents)
# evaluating generator here because the `all_ents` list gets
Expand All @@ -227,6 +229,11 @@ def predict_entities(self, doc: MutableDocument,
return filter_linked_annotations(
doc, le, self.config.general.show_nested_entities)

# TrainableComponent

def train_unsupervised(self, doc: MutableDocument) -> None:
list(self._train_on_doc(doc, doc.ner_ents))

def train(self, cui: str,
entity: MutableEntity,
doc: MutableDocument,
Expand Down
11 changes: 11 additions & 0 deletions medcat-v2/medcat/components/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,17 @@ def get_hash(self) -> str:
@runtime_checkable
class TrainableComponent(Protocol):

def train_unsupervised(self, doc: MutableDocument) -> None:
"""Train unsupervised based on the given document.

If this component doesn't support unsupervised training,
this method can be a no-op.

Args:
doc (MutableDocument): The document to train on.
"""
pass

def train(self, cui: str,
entity: MutableEntity,
doc: MutableDocument,
Expand Down
40 changes: 23 additions & 17 deletions medcat-v2/medcat/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
MedCATTrainerExport, MedCATTrainerExportAnnotation, MedCATTrainerExportProject,
MedCATTrainerExportDocument, count_all_annotations, iter_anns)
from medcat.preprocessors.cleaners import prepare_name, NameDescriptor
from medcat.components.types import CoreComponentType, TrainableComponent
from medcat.components.types import TrainableComponent
from medcat.components.addons.addons import AddonComponent
from medcat.pipeline import Pipeline

Expand Down Expand Up @@ -64,10 +64,8 @@ def train_unsupervised(self,
with self.config.meta.prepare_and_report_training(
data_iterator, nepochs, False
) as wrapped_iter:
with temp_changed_config(self.config.components.linking,
'train', True):
self._train_unsupervised(wrapped_iter, nepochs, fine_tune,
progress_print)
self._train_unsupervised(wrapped_iter, nepochs, fine_tune,
progress_print)

def _train_unsupervised(self,
data_iterator: Iterable,
Expand All @@ -91,11 +89,18 @@ def _train_unsupervised(self,
# Convert to string
line = str(line).strip()


# inference run for the document
try:
_ = self.caller(line)
doc = self.caller(line)
except Exception as e:
logger.warning("LINE: '%s...' \t WAS SKIPPED", line[0:100])
logger.warning("BECAUSE OF:", exc_info=e)
continue
for comp in self._pipeline.iter_all_components():
if isinstance(comp, TrainableComponent):
logger.debug("Training on component %s", comp.full_name)
comp.train_unsupervised(doc)
else:
logger.warning("EMPTY LINE WAS DETECTED AND SKIPPED")

Expand Down Expand Up @@ -632,18 +637,15 @@ def add_and_train_concept(self,

if mut_entity is None or mut_doc is None:
return
linker = self._pipeline.get_component(
CoreComponentType.linking)
if not isinstance(linker, TrainableComponent):
logger.warning(
"Linker cannot be trained during add_and_train_concept"
"because it has no train method: %s", linker)
else:
trained_comps = 0
for component in self._pipeline.iter_all_components():
if not isinstance(component, TrainableComponent):
continue
# Train Linking
if isinstance(mut_entity, list):
mut_entity = self._pipeline.entity_from_tokens(mut_entity)
linker.train(cui=cui, entity=mut_entity, doc=mut_doc,
negative=negative, names=names)
component.train(cui=cui, entity=mut_entity, doc=mut_doc,
negative=negative, names=names)
Comment on lines +647 to +648
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this mean we're still doing unsupervised on a per entity basis? I can't think of a case where in an unsupervised manner you would need the entity.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is supervised training still. This is within the add_and_train_concept method.


if not negative and devalue_others:
# Find all cuis
Expand All @@ -658,8 +660,12 @@ def add_and_train_concept(self,
# Add negative training for all other CUIs that link to
# these names
for _cui in cuis:
linker.train(cui=_cui, entity=mut_entity, doc=mut_doc,
negative=True)
component.train(cui=_cui, entity=mut_entity, doc=mut_doc,
negative=True)
if trained_comps == 0:
logger.warning(
"Nothing was trained during add_and_train_concept because "
"no components followed the TrainableComponent protocol")

@property
def _pn_configs(self) -> tuple[General, Preprocessing, CDBMaker]:
Expand Down
164 changes: 164 additions & 0 deletions medcat-v2/medcat/utils/training_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from typing import Callable, Optional
from typing_extensions import Self
from contextlib import contextmanager

from medcat.cat import CAT
from medcat.cdb import CDB
from medcat.components.types import (
CoreComponentType, AbstractEntityProvidingComponent)
from medcat.config.config import ComponentConfig
from medcat.tokenizing.tokenizers import BaseTokenizer
from medcat.tokenizing.tokens import MutableDocument, MutableEntity, MutableToken
from medcat.data.mctexport import (
MedCATTrainerExport, MedCATTrainerExportDocument, count_all_docs, iter_docs)
from medcat.vocab import Vocab


class _CheatingComponent(AbstractEntityProvidingComponent):
name = 'cheating_component'

def __init__(self,
comp_type: CoreComponentType,
predictor: Callable[[MutableDocument], list[MutableEntity]]):
self._comp_type = comp_type
super().__init__(
comp_type == CoreComponentType.linking,
comp_type == CoreComponentType.linking)
self._predictor = predictor

def get_type(self) -> CoreComponentType:
return self._comp_type

def predict_entities(self, doc: MutableDocument,
ents: list[MutableEntity] | None = None
) -> list[MutableEntity]:
return self._predictor(doc)

@classmethod
def create_new_component(
cls, cnf: ComponentConfig, tokenizer: BaseTokenizer,
cdb: CDB, vocab: Vocab, model_load_path: Optional[str]) -> Self:
raise ValueError("Cannot create new component of this type")

@contextmanager
def cheating_component(
cat: CAT,
comp_type: CoreComponentType,
predictor: Callable[[MutableDocument], list[MutableEntity]]):
"""Creates and uses a cheating component within the pipe.

This component will "predict" entities as per the predictor it is given.

Args:
cat (CAT): The model pack.
comp_type (CoreComponentType): The component type (generally NER or linker).
predictor (Callable[[MutableDocument], list[MutableEntity]]):
The predictor to use.
"""
comps_list = cat.pipe._components
# find original index
original_comp = cat.pipe.get_component(comp_type)
replace_index = comps_list.index(original_comp)
# create and replace
cheater = _CheatingComponent(comp_type, predictor)
comps_list[replace_index] = cheater
try:
yield
finally:
# restore original component
comps_list[replace_index] = original_comp


def _identify_document(
doc: MutableDocument,
dataset: MedCATTrainerExport) -> MedCATTrainerExportDocument:
for proj in dataset['projects']:
for ann_doc in proj['documents']:
if ann_doc['text'] == doc.base.text:
return ann_doc
raise ValueError("Unable to identify correct document")


def _create_general_predictor(
dataset: MedCATTrainerExport,
tokens2entity: Callable[[list[MutableToken], MutableDocument], MutableEntity],
set_cui: bool,
) -> Callable[[MutableDocument], list[MutableEntity]]:
def predict(doc: MutableDocument) -> list[MutableEntity]:
anns = _identify_document(doc, dataset)["annotations"]
ents: list[MutableEntity] = []
for ann in anns:
tkns = doc.get_tokens(ann["start"], ann["end"])
# TODO: catch possible exception?
ent = tokens2entity(tkns, doc)
if set_cui:
ent.cui = ann["cui"]
ents.append(ent)
return ents
return predict



def _create_linker_predictor(
dataset: MedCATTrainerExport,
tokens2entity: Callable[[list[MutableToken], MutableDocument], MutableEntity],
) -> Callable[[MutableDocument], list[MutableEntity]]:
return _create_general_predictor(dataset, tokens2entity, True)


def _create_ner_predictor(
dataset: MedCATTrainerExport,
tokens2entity: Callable[[list[MutableToken], MutableDocument], MutableEntity],
) -> Callable[[MutableDocument], list[MutableEntity]]:
return _create_general_predictor(dataset, tokens2entity, False)


def _create_predictor(
component_type: CoreComponentType,
dataset: MedCATTrainerExport,
tokens2entity: Callable[[list[MutableToken], MutableDocument], MutableEntity],
) -> Callable[[MutableDocument], list[MutableEntity]]:
if component_type == CoreComponentType.linking:
return _create_linker_predictor(dataset, tokens2entity)
elif component_type == CoreComponentType.ner:
return _create_ner_predictor(dataset, tokens2entity)
raise ValueError(
f"Unable to create predictor for component {component_type}")


def _check_dataset(dataset: MedCATTrainerExport):
texts = set(
doc['text'] for _, doc in iter_docs(dataset)
)
num_texts = len(texts)
num_docs = count_all_docs(dataset)
if num_texts != num_docs:
raise ValueError(
"Dataset contains documents with identical texts. "
"This means it cannot be used for dataset aware components "
"because the identification of the document is not trivial. "
f"The check found {num_texts} different texts within {num_docs} "
"documents")


@contextmanager
def dataset_aware_component(
cat: CAT,
comp_type: CoreComponentType,
dataset: MedCATTrainerExport):
"""Creates and uses a dataset aware component within the pipe.

This simplfies trainin for and evaluating one component at
a time by swapping out the other component for one that has
perfect performance since it knows the dataset.

Args:
cat (CAT): The model pack.
comp_type (CoreComponentType): The component type.
dataset (MedCATTrainerExport): The dataset in question.
"""
_check_dataset(dataset)
tokens2entity = cat.pipe.tokenizer.entity_from_tokens_in_doc
predictor = _create_predictor(comp_type, dataset, tokens2entity)
with cheating_component(cat, comp_type, predictor):
yield
56 changes: 56 additions & 0 deletions medcat-v2/tests/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,20 @@ class FakeComponent:
pass


class FakeTrainableNERComponent:

full_name = "ner:fake"

def __init__(self):
self.docs_trained_on = []

def train_unsupervised(self, doc: MutableDocument) -> None:
self.docs_trained_on.append(doc)

def train(self, *args, **kwargs) -> None:
return


class FakePipeline:

def tokenizer(self, text: str) -> FakeMutDoc:
Expand All @@ -92,13 +106,25 @@ def tokenizer(self, text: str) -> FakeMutDoc:
def tokenizer_with_tag(self, text: str) -> FakeMutDoc:
return FakeMutDoc(text)

def iter_all_components(self):
return []

def get_component(self, comp_type):
return FakeComponent

def entity_from_tokens_in_doc(self, tkns: list, doc: MutableDocument) -> FakeMutEnt:
return FakeMutEnt(doc, tkns[0].index, tkns[-1].index)


class FakePipelineWithComponents(FakePipeline):

def __init__(self, components: list):
self._components = components

def iter_all_components(self):
return self._components


class TrainerTestsBase(unittest.TestCase):
DATA_CNT = 14
TRAIN_DATA = [
Expand Down Expand Up @@ -173,6 +199,35 @@ def test_training_gets_remembered_multi(self, repeats: int = 3):
exp_total=repeats,
unsup=self.UNSUP)

def test_unsup_training_trains_non_linking_component(self):
ner_component = FakeTrainableNERComponent()
trainer = Trainer(
self.cdb,
self.caller,
FakePipelineWithComponents([ner_component]),
)
trainer.config = self.cnf

trainer.train_unsupervised(self.TRAIN_DATA, nepochs=1)

self.assertEqual(len(ner_component.docs_trained_on), self.DATA_CNT)
self.assertTrue(
all(isinstance(doc, FakeMutDoc) for doc in ner_component.docs_trained_on)
)

def test_unsup_training_skips_non_trainable_components(self):
ner_component = FakeTrainableNERComponent()
trainer = Trainer(
self.cdb,
self.caller,
FakePipelineWithComponents([FakeComponent(), ner_component, object()]),
)
trainer.config = self.cnf

trainer.train_unsupervised(self.TRAIN_DATA, nepochs=1)

self.assertEqual(len(ner_component.docs_trained_on), self.DATA_CNT)


class TrainerSupervisedTests(TrainerUnsupervisedTests):
DATA_CNT = 1
Expand Down Expand Up @@ -337,3 +392,4 @@ def test_has_trained_all(self):
with self.subTest(cui):
info = self.model.cdb.cui2info[cui]
self.assertGreater(info['count_train'], prev_count)

Loading
Loading