Skip to content

Commit 8a630ba

Browse files
authored
perf: fix low-hanging performance issues in MetaCAT and linking (#400)
* perf(metacat): scope max_seq_len and batch slice to current batch create_batch_piped_data was computing max_seq_len over the entire dataset on every batch call, and slicing data[start_ind:end_ind] three times. Scope both to a single batch slice — reduces padding overhead and eliminates redundant iteration. * perf(linking): update similarities in-place during disambiguation Replace list copy + clear + rebuild with a simple in-place loop. Eliminates three intermediate list allocations in the disambiguation hot path. * perf(metacat): replace O(n) dict values scan with O(1) key lookup undersample_data and encode_category_values both checked membership against category_value2id.values() (linear scan) on every iteration. Since label_data dicts are keyed by the same IDs, check membership against the dict itself (O(1) hash lookup). * perf(metacat): use append instead of list concatenation in eval dict.get(k, []) + [item] allocates a new list on every iteration, making example collection O(n*k). Use setdefault + append for O(1) amortized per insertion.
1 parent 823151c commit 8a630ba

3 files changed

Lines changed: 17 additions & 18 deletions

File tree

medcat-v2/medcat/components/addons/meta_cat/data_utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -319,10 +319,9 @@ def undersample_data(data: list, category_value2id: dict, label_data_,
319319
label_data_counter[sample[-1]] += 1
320320

321321
label_data = {v: 0 for v in category_value2id.values()}
322-
for i in range(len(data_undersampled)):
323-
if data_undersampled[i][2] in category_value2id.values():
324-
label_data[data_undersampled[i][2]] = (
325-
label_data[data_undersampled[i][2]] + 1)
322+
for sample in data_undersampled:
323+
if sample[2] in label_data:
324+
label_data[sample[2]] += 1
326325
logger.info("Updated number of samples per label (for 2-phase learning):"
327326
" %s", label_data)
328327
return data_undersampled
@@ -414,9 +413,9 @@ def encode_category_values(data: list[tuple[list, list, str]],
414413

415414
# Creating dict with labels and its number of samples
416415
label_data_ = {v: 0 for v in category_value2id.values()}
417-
for i in range(len(data)):
418-
if data[i][2] in category_value2id.values():
419-
label_data_[data[i][2]] = label_data_[data[i][2]] + 1
416+
for sample in data:
417+
if sample[2] in label_data_:
418+
label_data_[sample[2]] += 1
420419

421420
logger.info("Original number of samples per label: %s", label_data_)
422421

medcat-v2/medcat/components/addons/meta_cat/ml_utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,15 @@ def create_batch_piped_data(data: list[tuple[list[int], int, Optional[int]]],
6363
y (Optional[torch.Tensor]):
6464
class label of the data
6565
"""
66-
max_seq_len = max([len(x[0]) for x in data])
66+
batch = data[start_ind:end_ind]
67+
max_seq_len = max(len(x[0]) for x in batch)
6768
x = [x[0][0:max_seq_len] + [pad_id] * max(0, max_seq_len - len(x[0]))
68-
for x in data[start_ind:end_ind]]
69-
cpos = [x[1] for x in data[start_ind:end_ind]]
69+
for x in batch]
70+
cpos = [x[1] for x in batch]
7071
y = None
7172
if len(data[0]) == 3:
7273
# Means we have the y column
73-
y = torch.tensor([x[2] for x in data[start_ind:end_ind]],
74+
y = torch.tensor([x[2] for x in batch],
7475
dtype=torch.long).to(device)
7576

7677
x2 = torch.tensor(x, dtype=torch.long).to(device)
@@ -511,10 +512,10 @@ def _eval_predictions(
511512
info = "Predicted: {}, True: {}".format(pred, y)
512513
if pred != y:
513514
# We made a mistake
514-
examples['FN'][y] = examples['FN'].get(y, []) + [(info, text)]
515-
examples['FP'][pred] = examples['FP'].get(pred, []) + [(info, text)]
515+
examples['FN'].setdefault(y, []).append((info, text))
516+
examples['FP'].setdefault(pred, []).append((info, text))
516517
else:
517-
examples['TP'][y] = examples['TP'].get(y, []) + [(info, text)]
518+
examples['TP'].setdefault(y, []).append((info, text))
518519

519520
return {'precision': precision, 'recall': recall, 'f1': f1,
520521
'examples': examples, 'confusion matrix': confusion}

medcat-v2/medcat/components/linking/vector_context_model.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -231,10 +231,9 @@ def _preprocess_disamb_similarities(self, entity: MutableEntity,
231231
pref_freq = self.config.prefer_frequent_concepts
232232
scales = [np.log10(cnt / m) * pref_freq if cnt > 10 else 0
233233
for cnt in cnts]
234-
old_sims = list(similarities)
235-
similarities.clear()
236-
similarities += [float(min(0.99, sim + sim * scale))
237-
for sim, scale in zip(old_sims, scales)]
234+
for i, scale in enumerate(scales):
235+
similarities[i] = float(min(0.99,
236+
similarities[i] + similarities[i] * scale))
238237

239238
def get_all_similarities(self, cuis: list[str], entity: MutableEntity,
240239
name: str, doc: MutableDocument,

0 commit comments

Comments
 (0)