diff --git a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py index a056a4b51..60ac8e6fd 100644 --- a/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py +++ b/medcat-plugins/embedding-linker/src/medcat_embedding_linker/embedding_linker.py @@ -585,7 +585,7 @@ def _pre_inference( def predict_entities( self, doc: MutableDocument, ents: list[MutableEntity] | None = None ) -> list[MutableEntity]: - if self.cnf_l.train and self.comp_name == "embedding_linker": + if self.cnf_l.train and self.name == "embedding_linker": logger.warning( "Attemping to train a static embedding linker. " "This is not possible / required." diff --git a/medcat-plugins/embedding-linker/tests/test_embedding_linker.py b/medcat-plugins/embedding-linker/tests/test_embedding_linker.py index 06dedc031..1b9591c10 100644 --- a/medcat-plugins/embedding-linker/tests/test_embedding_linker.py +++ b/medcat-plugins/embedding-linker/tests/test_embedding_linker.py @@ -81,9 +81,9 @@ class TrainableEmbeddingLinkerTests(unittest.TestCase): cnf = Config() cnf.components.linking = embedding_linker.EmbeddingLinking() cnf.components.linking.comp_name = ( - trainable_embedding_linker.TrainableEmbeddingLinker.name + trainable_embedding_linker.Linker.name ) - linker = trainable_embedding_linker.TrainableEmbeddingLinker(FakeCDB(cnf), cnf) + linker = trainable_embedding_linker.Linker(FakeCDB(cnf), cnf) def test_linker_is_trainable(self): self.assertIsInstance(self.linker, TrainableComponent)