Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 16 additions & 3 deletions cyteonto/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np # type:ignore
import pandas as pd # type:ignore
from pydantic_ai import Agent
from tqdm.auto import tqdm # type:ignore

from .config import CONFIG
from .llm_config import EMBDModelConfig
Expand Down Expand Up @@ -68,6 +69,8 @@ def __init__(
self.matcher = CyteOntoMatcher(
embeddings_file_path=self.embeddings_file_path,
base_data_path=base_data_path,
base_agent=self.base_agent,
embedding_model=self.embedding_model,
)

# Initialize embedding generator for user queries
Expand Down Expand Up @@ -220,6 +223,7 @@ async def compare_single_pair(
algorithm_labels: list[str],
algorithm_name: str = "algorithm",
study_name: str | None = None,
metric: str = "cosine_kernel",
) -> list[dict]:
"""
Compare a single pair of author vs algorithm labels with detailed results.
Expand Down Expand Up @@ -342,7 +346,11 @@ async def compare_single_pair(

if author_ontology_id and algorithm_ontology_id:
ontology_similarities = self.matcher.compute_ontology_similarity(
[author_ontology_id], [algorithm_ontology_id]
[author_ontology_id],
[algorithm_ontology_id],
[author_embedding_similarity],
[algorithm_embedding_similarity],
metric=metric,
)
ontology_hierarchy_similarity = (
ontology_similarities[0] if ontology_similarities else 0.0
Expand Down Expand Up @@ -387,6 +395,7 @@ async def compare_batch(
author_labels: list[str],
algo_comparison_data: list[tuple[str, list[str]]],
study_name: str | None = None,
metric: str = "cosine_kernel",
) -> pd.DataFrame:
"""
Perform detailed batch comparisons between multiple algorithm results.
Expand Down Expand Up @@ -416,11 +425,15 @@ async def compare_batch(

all_results = []

for algorithm_name, algorithm_labels in algo_comparison_data:
for algorithm_name, algorithm_labels in tqdm(
algo_comparison_data,
total=len(algo_comparison_data),
desc="Comparing Algorithms",
):
logger.info(f"Processing algorithm: {algorithm_name}")

detailed_results = await self.compare_single_pair(
author_labels, algorithm_labels, algorithm_name, study_name
author_labels, algorithm_labels, algorithm_name, study_name, metric
)

# Add algorithm name to each result
Expand Down
48 changes: 44 additions & 4 deletions cyteonto/matcher/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path

import numpy as np
from pydantic_ai import Agent
from sklearn.metrics.pairwise import cosine_similarity # type: ignore

from ..logger_config import logger
Expand All @@ -19,6 +20,8 @@ def __init__(
self,
embeddings_file_path: Path | None = None,
base_data_path: str | None = None,
base_agent: Agent | None = None,
embedding_model: str | None = None,
):
"""
Initialize CyteOnto matcher.
Expand All @@ -37,6 +40,10 @@ def __init__(
self._ontology_extractor: OntologyExtractor | None = None
self._ontology_similarity: OntologySimilarity | None = None

# Model information
self.base_agent = base_agent
self.embedding_model = embedding_model

logger.info(f"Loading ontology embeddings from {self.embeddings_file_path}")
self.embeddings_ready = self._load_ontology_embeddings()
logger.info(f"Loaded success: {self.embeddings_ready}")
Expand Down Expand Up @@ -82,7 +89,11 @@ def _get_ontology_similarity(self) -> OntologySimilarity:
"""Get ontology similarity calculator, creating if needed."""
if self._ontology_similarity is None:
owl_path = self.file_manager.get_ontology_owl_path()
self._ontology_similarity = OntologySimilarity(owl_path)
embedding_path = self.file_manager.get_embedding_file_path(
text_model=self.base_agent.model.model_name, # type: ignore
embedding_model=self.embedding_model, # type: ignore
)
self._ontology_similarity = OntologySimilarity(owl_path, embedding_path)
return self._ontology_similarity

def find_closest_ontology_terms(
Expand Down Expand Up @@ -141,7 +152,12 @@ def find_closest_ontology_terms(
return results

def compute_ontology_similarity(
self, author_ontology_terms: list[str], user_ontology_terms: list[str]
self,
author_ontology_terms: list[str],
user_ontology_terms: list[str],
author_ontology_score: list[float] | None = None,
user_ontology_score: list[float] | None = None,
metric: str = "cosine_kernel",
) -> list[float]:
"""
Compute ontology hierarchy-based similarity between author and user labels.
Expand All @@ -163,16 +179,40 @@ def compute_ontology_similarity(
# Build mappings if needed
_, label_to_ontology = extractor.build_mappings()

if metric == "cosine_ensemble" and (
author_ontology_score is None or user_ontology_score is None
):
logger.error(
"Cosine ensemble metric requires author and user ontology scores"
)
return [0.0] * len(author_ontology_terms)
else:
# make arrays of 1.0 if scores not provided
if author_ontology_score is None:
author_ontology_score = [1.0] * len(author_ontology_terms)
if user_ontology_score is None:
user_ontology_score = [1.0] * len(user_ontology_terms)

similarities = []
for author_label, user_label in zip(author_ontology_terms, user_ontology_terms):
for author_label, user_label, author_score, user_score in zip(
author_ontology_terms,
user_ontology_terms,
author_ontology_score,
user_ontology_score,
):
# Get ontology IDs for labels
author_ontology_id = label_to_ontology.get(author_label, author_label)
user_ontology_id = label_to_ontology.get(user_label, user_label)

# Compute ontology-based similarity
similarity = similarity_calc.compute_ontology_similarity(
author_ontology_id, user_ontology_id
author_ontology_id,
user_ontology_id,
author_score,
user_score,
metric=metric,
)
logger.debug(f"Matcher Similarity: {similarity}")
similarities.append(similarity)

logger.debug(f"Computed ontology similarities for {len(similarities)} pairs")
Expand Down
4 changes: 0 additions & 4 deletions cyteonto/ontology/extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,6 @@ def build_mappings(self) -> tuple[dict[str, list[str]], dict[str, str]]:

self._ontology_to_labels = ontology_to_labels
self._label_to_ontology = label_to_ontology

logger.info(
f"Built mappings: {len(ontology_to_labels)} ontology terms, {len(label_to_ontology)} labels"
)
return ontology_to_labels, label_to_ontology

def get_ontology_id_for_label(self, label: str) -> str | None:
Expand Down
Loading