-
Notifications
You must be signed in to change notification settings - Fork 12
feat(medcat):CU-869cy3xa0 Improve training #414
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
mart-r
wants to merge
15
commits into
main
Choose a base branch
from
feat/medcat/CU-869cy3xa0-specify-unsupervised-training-in-trainable-component-protocol
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
d307dab
CU-869cy3yz9: Add unsupervised training method to trainable component…
github-actions[bot] acf623b
CU-869cy3yz9: Follow the intercace for unsupervised trianing in conte…
github-actions[bot] f204535
CU-869cy3xa0: Use new interface for self-supervised training
github-actions[bot] b1cc70a
CU-869cy3yz9: Fix fake pipe in tests
github-actions[bot] a76fde4
CU-869cy3yz9: Fix issue with unrealised generator
github-actions[bot] 7d9adf5
CU-869cy3z45: Allow any component to be trained in an unsupervised ma…
github-actions[bot] 4a71c1c
CU-869cy3z45: Remove unused import
github-actions[bot] 4c25292
CU-869cy3yz9: Add a few more tests for trainable components
github-actions[bot] 9620d9b
CU-869cy3zb0: Add utilities to create a dataset-aware NER or linker c…
github-actions[bot] 7463586
CU-869cy3zb0: Fix minor issues with new utilities
github-actions[bot] d782e0e
CU-869cy3zb0: Fix minor order of operations issue
github-actions[bot] 3e3ae16
CU-869cy3zb0: Add a few tests for training utilities
github-actions[bot] 1e7e8cd
CU-869cy3zb0: Add a few missing doc strings
github-actions[bot] 9185ae7
CU-869cy3zb0: Add a few supervised training based tests
github-actions[bot] 5df2870
CU-869cy3zb0: Fix import of Self (from typing extensions)
github-actions[bot] File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,164 @@ | ||
| from typing import Callable, Optional | ||
| from typing_extensions import Self | ||
| from contextlib import contextmanager | ||
|
|
||
| from medcat.cat import CAT | ||
| from medcat.cdb import CDB | ||
| from medcat.components.types import ( | ||
| CoreComponentType, AbstractEntityProvidingComponent) | ||
| from medcat.config.config import ComponentConfig | ||
| from medcat.tokenizing.tokenizers import BaseTokenizer | ||
| from medcat.tokenizing.tokens import MutableDocument, MutableEntity, MutableToken | ||
| from medcat.data.mctexport import ( | ||
| MedCATTrainerExport, MedCATTrainerExportDocument, count_all_docs, iter_docs) | ||
| from medcat.vocab import Vocab | ||
|
|
||
|
|
||
| class _CheatingComponent(AbstractEntityProvidingComponent): | ||
| name = 'cheating_component' | ||
|
|
||
| def __init__(self, | ||
| comp_type: CoreComponentType, | ||
| predictor: Callable[[MutableDocument], list[MutableEntity]]): | ||
| self._comp_type = comp_type | ||
| super().__init__( | ||
| comp_type == CoreComponentType.linking, | ||
| comp_type == CoreComponentType.linking) | ||
| self._predictor = predictor | ||
|
|
||
| def get_type(self) -> CoreComponentType: | ||
| return self._comp_type | ||
|
|
||
| def predict_entities(self, doc: MutableDocument, | ||
| ents: list[MutableEntity] | None = None | ||
| ) -> list[MutableEntity]: | ||
| return self._predictor(doc) | ||
|
|
||
| @classmethod | ||
| def create_new_component( | ||
| cls, cnf: ComponentConfig, tokenizer: BaseTokenizer, | ||
| cdb: CDB, vocab: Vocab, model_load_path: Optional[str]) -> Self: | ||
| raise ValueError("Cannot create new component of this type") | ||
|
|
||
| @contextmanager | ||
| def cheating_component( | ||
| cat: CAT, | ||
| comp_type: CoreComponentType, | ||
| predictor: Callable[[MutableDocument], list[MutableEntity]]): | ||
| """Creates and uses a cheating component within the pipe. | ||
|
|
||
| This component will "predict" entities as per the predictor it is given. | ||
|
|
||
| Args: | ||
| cat (CAT): The model pack. | ||
| comp_type (CoreComponentType): The component type (generally NER or linker). | ||
| predictor (Callable[[MutableDocument], list[MutableEntity]]): | ||
| The predictor to use. | ||
| """ | ||
| comps_list = cat.pipe._components | ||
| # find original index | ||
| original_comp = cat.pipe.get_component(comp_type) | ||
| replace_index = comps_list.index(original_comp) | ||
| # create and replace | ||
| cheater = _CheatingComponent(comp_type, predictor) | ||
| comps_list[replace_index] = cheater | ||
| try: | ||
| yield | ||
| finally: | ||
| # restore original component | ||
| comps_list[replace_index] = original_comp | ||
|
|
||
|
|
||
| def _identify_document( | ||
| doc: MutableDocument, | ||
| dataset: MedCATTrainerExport) -> MedCATTrainerExportDocument: | ||
| for proj in dataset['projects']: | ||
| for ann_doc in proj['documents']: | ||
| if ann_doc['text'] == doc.base.text: | ||
| return ann_doc | ||
| raise ValueError("Unable to identify correct document") | ||
|
|
||
|
|
||
| def _create_general_predictor( | ||
| dataset: MedCATTrainerExport, | ||
| tokens2entity: Callable[[list[MutableToken], MutableDocument], MutableEntity], | ||
| set_cui: bool, | ||
| ) -> Callable[[MutableDocument], list[MutableEntity]]: | ||
| def predict(doc: MutableDocument) -> list[MutableEntity]: | ||
| anns = _identify_document(doc, dataset)["annotations"] | ||
| ents: list[MutableEntity] = [] | ||
| for ann in anns: | ||
| tkns = doc.get_tokens(ann["start"], ann["end"]) | ||
| # TODO: catch possible exception? | ||
| ent = tokens2entity(tkns, doc) | ||
| if set_cui: | ||
| ent.cui = ann["cui"] | ||
| ents.append(ent) | ||
| return ents | ||
| return predict | ||
|
|
||
|
|
||
|
|
||
| def _create_linker_predictor( | ||
| dataset: MedCATTrainerExport, | ||
| tokens2entity: Callable[[list[MutableToken], MutableDocument], MutableEntity], | ||
| ) -> Callable[[MutableDocument], list[MutableEntity]]: | ||
| return _create_general_predictor(dataset, tokens2entity, True) | ||
|
|
||
|
|
||
| def _create_ner_predictor( | ||
| dataset: MedCATTrainerExport, | ||
| tokens2entity: Callable[[list[MutableToken], MutableDocument], MutableEntity], | ||
| ) -> Callable[[MutableDocument], list[MutableEntity]]: | ||
| return _create_general_predictor(dataset, tokens2entity, False) | ||
|
|
||
|
|
||
| def _create_predictor( | ||
| component_type: CoreComponentType, | ||
| dataset: MedCATTrainerExport, | ||
| tokens2entity: Callable[[list[MutableToken], MutableDocument], MutableEntity], | ||
| ) -> Callable[[MutableDocument], list[MutableEntity]]: | ||
| if component_type == CoreComponentType.linking: | ||
| return _create_linker_predictor(dataset, tokens2entity) | ||
| elif component_type == CoreComponentType.ner: | ||
| return _create_ner_predictor(dataset, tokens2entity) | ||
| raise ValueError( | ||
| f"Unable to create predictor for component {component_type}") | ||
|
|
||
|
|
||
| def _check_dataset(dataset: MedCATTrainerExport): | ||
| texts = set( | ||
| doc['text'] for _, doc in iter_docs(dataset) | ||
| ) | ||
| num_texts = len(texts) | ||
| num_docs = count_all_docs(dataset) | ||
| if num_texts != num_docs: | ||
| raise ValueError( | ||
| "Dataset contains documents with identical texts. " | ||
| "This means it cannot be used for dataset aware components " | ||
| "because the identification of the document is not trivial. " | ||
| f"The check found {num_texts} different texts within {num_docs} " | ||
| "documents") | ||
|
|
||
|
|
||
| @contextmanager | ||
| def dataset_aware_component( | ||
| cat: CAT, | ||
| comp_type: CoreComponentType, | ||
| dataset: MedCATTrainerExport): | ||
| """Creates and uses a dataset aware component within the pipe. | ||
|
|
||
| This simplfies trainin for and evaluating one component at | ||
| a time by swapping out the other component for one that has | ||
| perfect performance since it knows the dataset. | ||
|
|
||
| Args: | ||
| cat (CAT): The model pack. | ||
| comp_type (CoreComponentType): The component type. | ||
| dataset (MedCATTrainerExport): The dataset in question. | ||
| """ | ||
| _check_dataset(dataset) | ||
| tokens2entity = cat.pipe.tokenizer.entity_from_tokens_in_doc | ||
| predictor = _create_predictor(comp_type, dataset, tokens2entity) | ||
| with cheating_component(cat, comp_type, predictor): | ||
| yield |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Does this mean we're still doing unsupervised on a per entity basis? I can't think of a case where in an unsupervised manner you would need the entity.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is supervised training still. This is within the
add_and_train_conceptmethod.