diff --git a/medcat-v2/medcat/components/linking/vector_context_model.py b/medcat-v2/medcat/components/linking/vector_context_model.py index 9815c171e..c009d4cee 100644 --- a/medcat-v2/medcat/components/linking/vector_context_model.py +++ b/medcat-v2/medcat/components/linking/vector_context_model.py @@ -58,6 +58,7 @@ def __init__(self, cui2info: dict[str, CUIInfo], def get_context_tokens(self, entity: MutableEntity, doc: MutableDocument, size: int, per_doc_valid_token_cache: 'PerDocumentTokenCache', + fill_centre_tokens: bool = True, ) -> tuple[list[MutableToken], list[MutableToken], list[MutableToken]]: @@ -83,17 +84,21 @@ def get_context_tokens(self, entity: MutableEntity, doc: MutableDocument, per_doc_valid_token_cache[tkn]] # Reverse because the first token should be the one closest to center tokens_left.reverse() - tokens_center: list[MutableToken] = list( - cast(Iterable[MutableToken], entity)) + if fill_centre_tokens: + tokens_center: list[MutableToken] = list( + cast(Iterable[MutableToken], entity)) + else: + tokens_center = [] _right_tokens = doc[end_ind + 1:end_ind + 1 + size] tokens_right = [tkn for tkn in _right_tokens if per_doc_valid_token_cache[tkn]] return tokens_left, tokens_center, tokens_right - def _tokens2vecs(self, tokens: Sequence[Union[MutableToken, str]] - ) -> Iterable[np.ndarray]: - for step, tkn in enumerate(tokens): + def _tokens2vecs(self, tokens: Sequence[Union[MutableToken, str]], + step_start: int = 0 + ) -> Iterable[np.ndarray]: + for step, tkn in enumerate(tokens, start=step_start): lower = tkn.lower() if isinstance(tkn, str) else tkn.base.lower if lower not in self.vocab: continue @@ -137,26 +142,55 @@ def get_context_vectors(self, entity: MutableEntity, """ vectors: dict[str, np.ndarray] = {} - context_vector_sizes = self.config.context_vector_sizes - for context_type, window_size in context_vector_sizes.items(): - tokens_left, tokens_center, tokens_right = self.get_context_tokens( - entity, doc, window_size, per_doc_valid_token_cache) - - values: list[np.ndarray] = [] - # Add left - values.extend(self._tokens2vecs(tokens_left)) - - if not self.config.context_ignore_center_tokens: - # Add center - values.extend( - self._preprocess_center_tokens(cui, tokens_center)) - - # Add right - values.extend(self._tokens2vecs(tokens_right)) - + # Sort ascending so each iteration is a superset of the previous + sorted_contexts = sorted( + self.config.context_vector_sizes.items(), key=lambda x: x[1]) + + prev_left: list[MutableToken] = [] + prev_right: list[MutableToken] = [] + # Accumulated weighted vecs from previous (smaller) windows, + # excluding center (center is the same for all window sizes) + prev_left_vecs: list[np.ndarray] = [] + prev_right_vecs: list[np.ndarray] = [] + + # Center is identical for all window sizes, only compute once + if not self.config.context_ignore_center_tokens: + tokens_center = list( + cast(Iterable[MutableToken], entity)) + center_vecs = list( + self._preprocess_center_tokens(cui, tokens_center)) + else: + center_vecs = [] + + for context_type, window_size in sorted_contexts: + tokens_left, _, tokens_right = self.get_context_tokens( + entity, doc, window_size, per_doc_valid_token_cache, + fill_centre_tokens=False) + + # New outer tokens only — the inner ones were already processed + # NOTE: left hand tokens are in order of closest first, which is why + # we're slicing from the start of the list + new_left = tokens_left[len(prev_left):] + new_right = tokens_right[len(prev_right):] + + # step_start for new left tokens: they are further from centre + # so their step index is + # len(tokens_left) - len(new_left) ... len(tokens_left)-1 + # i.e. the new tokens are the outermost, highest-step ones + new_left_vecs = list(self._tokens2vecs( + new_left, step_start=len(prev_left))) + new_right_vecs = list(self._tokens2vecs( + new_right, step_start=len(prev_right))) + + prev_left_vecs = new_left_vecs + prev_left_vecs + prev_right_vecs = prev_right_vecs + new_right_vecs + prev_left = tokens_left + prev_right = tokens_right + + values = prev_left_vecs + center_vecs + prev_right_vecs if values: - value = np.average(values, axis=0) - vectors[context_type] = value + vectors[context_type] = np.average(values, axis=0) + return vectors def similarity(self, cui: str, entity: MutableEntity, doc: MutableDocument,