diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/config.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/config.py index a855c0b5a..6a58b3a0d 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/config.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/config.py @@ -74,3 +74,9 @@ class EmbeddingLinking(Linking): use_ner_link_candidates: bool = True """Link candidates are provided by some NER steps. This will flag if you want to trust them or not.""" + learning_rate: float = 1e-4 + """Learning rate for training the embedding linker. Only used if + the embedding linker is trainable.""" + weight_decay: float = 0.01 + """Weight decay for training the embedding linker. Only used if + the embedding linker is trainable.""" diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py index f41df7051..a056a4b51 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py @@ -19,7 +19,7 @@ class Linker(AbstractEntityProvidingComponent): - comp_name = "embedding_linker" + name = "embedding_linker" _MODEL_FOLDER_NAME = "embedding_model" _STATE_FILE_NAME = "state.json" @@ -124,13 +124,21 @@ def create_embeddings( ) self.max_length = max_length self.cnf_l.max_token_length = max_length + self.context_model.max_length = max_length - self.context_model.embed_cuis(embedding_model_name) - self.context_model.embed_names(embedding_model_name) + # Route model swaps through linker-level hook so trainable variants can + # refresh optimizer/scaler when underlying params change. + self.load_transformers(embedding_model_name) + self.context_model.embed_cuis() + self.context_model.embed_names() self._names_context_matrix = None self._cui_context_matrix = None + def load_transformers(self, embedding_model_name: str) -> None: + """Pass through to the underlying transformer model for context embedding.""" + self.context_model.load_transformers(embedding_model_name) + def get_type(self) -> CoreComponentType: return CoreComponentType.linking diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/trainable_embedding_linker.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/trainable_embedding_linker.py index b6e9dfadc..c2d8603ec 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/trainable_embedding_linker.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/trainable_embedding_linker.py @@ -7,7 +7,7 @@ from medcat.tokenizing.tokenizers import BaseTokenizer from medcat.tokenizing.tokens import MutableDocument, MutableEntity from medcat.vocab import Vocab -from medcat_embedding_linker.embedding_linker import Linker +from medcat_embedding_linker.embedding_linker import Linker as StaticEmbeddingLinker from medcat.storage.serialisables import AbstractManualSerialisable import logging import torch @@ -17,13 +17,14 @@ logger = logging.getLogger(__name__) -class TrainableEmbeddingLinker(Linker, AbstractManualSerialisable): +class Linker(StaticEmbeddingLinker, AbstractManualSerialisable): """Trainable variant of the embedding linker. This class inherits inference and embedding behavior from Linker and provides method hooks for online/offline training. """ - comp_name = "trainable_embedding_linker" + name = "trainable_embedding_linker" + _MODEL_FOLDER_NAME = "trainable_embedding_model" _MODEL_STATE_FILE_NAME = "model_state.pt" @@ -47,11 +48,39 @@ def __init__(self, cdb: CDB, config: Config) -> None: self.negative_sampling_candidate_pool_size = ( self.cnf_l.negative_sampling_candidate_pool_size ) - self.scaler = torch.amp.GradScaler() # for FP16 training stability + self.reset_optimizer_and_scaler() + + def reset_optimizer_and_scaler( + self, + learning_rate: Optional[float] = None, + weight_decay: Optional[float] = None, + ) -> None: + """Recreate training state bound to the current context model params. + + Optionally update the learning rate and weight decay in the config. + If not provided, the current config values are used. + + Args: + learning_rate: New learning rate. Updates config if provided. + weight_decay: New weight decay. Updates config if provided. + """ + if learning_rate is not None: + self.cnf_l.learning_rate = learning_rate + if weight_decay is not None: + self.cnf_l.weight_decay = weight_decay + # Keep scaler and optimizer aligned with the currently loaded model. + self.scaler = torch.amp.GradScaler() self.optimizer = torch.optim.AdamW( - self.context_model.model.parameters(), lr=1e-4, weight_decay=0.01 + self.context_model.model.parameters(), + lr=self.cnf_l.learning_rate, + weight_decay=self.cnf_l.weight_decay, ) + def load_transformers(self, embedding_model_name: str) -> None: + """Switch embedding model and refresh optimizer/scaler to new params.""" + self.context_model.load_transformers(embedding_model_name) + self.reset_optimizer_and_scaler() + def _generate_negative_samples( self, candidate_indices: Tensor, @@ -366,7 +395,7 @@ def create_new_component( cdb: CDB, vocab: Vocab, model_load_path: Optional[str], - ) -> "TrainableEmbeddingLinker": + ) -> "Linker": return cls(cdb, cdb.config) def serialise_to(self, folder_path: str) -> None: @@ -382,7 +411,7 @@ def serialise_to(self, folder_path: str) -> None: @classmethod def deserialise_from( cls, folder_path: str, **init_kwargs - ) -> "TrainableEmbeddingLinker": + ) -> "Linker": cdb = init_kwargs["cdb"] linker = cls(cdb, cdb.config) diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py index bd058b802..73924e911 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/transformer_context_model.py @@ -315,17 +315,12 @@ def embed( outputs = self.model(**batch_dict) return outputs.half() - def embed_cuis( - self, embedding_model_name: Optional[Union[str, Path]] = None - ) -> None: + def embed_cuis(self) -> None: """Create embeddings for each CUI's longest name and store in CDB. - If ``embedding_model_name`` is provided, switch/load that model first. - Otherwise, reuse the currently loaded model (training-friendly default). + Switch the model first via ``load_transformers`` if needed. """ - target_model = embedding_model_name or self.cnf_l.embedding_model_name self._refresh_cdb_keys() # ensure _cui_keys is up to date before embedding - self.load_transformers(target_model) cui_names = [self.cdb.get_name(cui) for cui in self._cui_keys] total_batches = math.ceil(len(cui_names) / self.cnf_l.embedding_batch_size) @@ -344,17 +339,12 @@ def embed_cuis( self.cdb.addl_info["cui_embeddings"] = all_embeddings_matrix logger.debug("Embedding cui names done, total: %d", len(cui_names)) - def embed_names( - self, embedding_model_name: Optional[Union[str, Path]] = None - ) -> None: + def embed_names(self) -> None: """Create embeddings for all names and store in CDB. - If ``embedding_model_name`` is provided, switch/load that model first. - Otherwise, reuse the currently loaded model (training-friendly default). + Switch the model first via ``load_transformers`` if needed. """ - target_model = embedding_model_name or self.cnf_l.embedding_model_name self._refresh_cdb_keys() # ensure _cui_keys is up to date before embedding - self.load_transformers(target_model) names = self._name_keys total_batches = math.ceil(len(names) / self.cnf_l.embedding_batch_size)