Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 58 additions & 24 deletions medcat-v2/medcat/components/linking/vector_context_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def __init__(self, cui2info: dict[str, CUIInfo],
def get_context_tokens(self, entity: MutableEntity, doc: MutableDocument,
size: int,
per_doc_valid_token_cache: 'PerDocumentTokenCache',
fill_centre_tokens: bool = True,
) -> tuple[list[MutableToken],
list[MutableToken],
list[MutableToken]]:
Expand All @@ -83,17 +84,21 @@ def get_context_tokens(self, entity: MutableEntity, doc: MutableDocument,
per_doc_valid_token_cache[tkn]]
# Reverse because the first token should be the one closest to center
tokens_left.reverse()
tokens_center: list[MutableToken] = list(
cast(Iterable[MutableToken], entity))
if fill_centre_tokens:
tokens_center: list[MutableToken] = list(
cast(Iterable[MutableToken], entity))
else:
tokens_center = []
_right_tokens = doc[end_ind + 1:end_ind + 1 + size]
tokens_right = [tkn for tkn in _right_tokens if
per_doc_valid_token_cache[tkn]]

return tokens_left, tokens_center, tokens_right

def _tokens2vecs(self, tokens: Sequence[Union[MutableToken, str]]
) -> Iterable[np.ndarray]:
for step, tkn in enumerate(tokens):
def _tokens2vecs(self, tokens: Sequence[Union[MutableToken, str]],
step_start: int = 0
) -> Iterable[np.ndarray]:
for step, tkn in enumerate(tokens, start=step_start):
lower = tkn.lower() if isinstance(tkn, str) else tkn.base.lower
if lower not in self.vocab:
continue
Expand Down Expand Up @@ -137,26 +142,55 @@ def get_context_vectors(self, entity: MutableEntity,
"""
vectors: dict[str, np.ndarray] = {}

context_vector_sizes = self.config.context_vector_sizes
for context_type, window_size in context_vector_sizes.items():
tokens_left, tokens_center, tokens_right = self.get_context_tokens(
entity, doc, window_size, per_doc_valid_token_cache)

values: list[np.ndarray] = []
# Add left
values.extend(self._tokens2vecs(tokens_left))

if not self.config.context_ignore_center_tokens:
# Add center
values.extend(
self._preprocess_center_tokens(cui, tokens_center))

# Add right
values.extend(self._tokens2vecs(tokens_right))

# Sort ascending so each iteration is a superset of the previous
sorted_contexts = sorted(
self.config.context_vector_sizes.items(), key=lambda x: x[1])

prev_left: list[MutableToken] = []
prev_right: list[MutableToken] = []
# Accumulated weighted vecs from previous (smaller) windows,
# excluding center (center is the same for all window sizes)
prev_left_vecs: list[np.ndarray] = []
prev_right_vecs: list[np.ndarray] = []

# Center is identical for all window sizes, only compute once
if not self.config.context_ignore_center_tokens:
tokens_center = list(
cast(Iterable[MutableToken], entity))
center_vecs = list(
self._preprocess_center_tokens(cui, tokens_center))
else:
center_vecs = []

for context_type, window_size in sorted_contexts:
tokens_left, _, tokens_right = self.get_context_tokens(
entity, doc, window_size, per_doc_valid_token_cache,
fill_centre_tokens=False)

# New outer tokens only — the inner ones were already processed
# NOTE: left hand tokens are in order of closest first, which is why
# we're slicing from the start of the list
new_left = tokens_left[len(prev_left):]
new_right = tokens_right[len(prev_right):]

# step_start for new left tokens: they are further from centre
# so their step index is
# len(tokens_left) - len(new_left) ... len(tokens_left)-1
# i.e. the new tokens are the outermost, highest-step ones
new_left_vecs = list(self._tokens2vecs(
new_left, step_start=len(prev_left)))
new_right_vecs = list(self._tokens2vecs(
new_right, step_start=len(prev_right)))

prev_left_vecs = new_left_vecs + prev_left_vecs
prev_right_vecs = prev_right_vecs + new_right_vecs
prev_left = tokens_left
prev_right = tokens_right

values = prev_left_vecs + center_vecs + prev_right_vecs
if values:
value = np.average(values, axis=0)
vectors[context_type] = value
vectors[context_type] = np.average(values, axis=0)

return vectors

def similarity(self, cui: str, entity: MutableEntity, doc: MutableDocument,
Expand Down
Loading