Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions cyteonto/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,7 @@ async def compare_single_pair(
algorithm_name: str = "algorithm",
study_name: str | None = None,
metric: str = "cosine_kernel",
metric_params: dict | None = None,
) -> list[dict]:
"""
Compare a single pair of author vs algorithm labels with detailed results.
Expand All @@ -233,6 +234,8 @@ async def compare_single_pair(
algorithm_labels: Algorithm (predicted) cell type labels from cell ontology
algorithm_name: Name of the algorithm (for file naming and caching)
study_name: Name of the study for organizing files (optional)
metric: Metric to use for similarity calculation (e.g., "cosine_kernel")
metric_params: Additional parameters for the similarity metric

Returns:
List of detailed comparison dictionaries
Expand Down Expand Up @@ -351,6 +354,7 @@ async def compare_single_pair(
[author_embedding_similarity],
[algorithm_embedding_similarity],
metric=metric,
metric_params=metric_params,
)
ontology_hierarchy_similarity = (
ontology_similarities[0] if ontology_similarities else 0.0
Expand Down Expand Up @@ -396,6 +400,7 @@ async def compare_batch(
algo_comparison_data: list[tuple[str, list[str]]],
study_name: str | None = None,
metric: str = "cosine_kernel",
metric_params: dict | None = None,
) -> pd.DataFrame:
"""
Perform detailed batch comparisons between multiple algorithm results.
Expand All @@ -404,6 +409,8 @@ async def compare_batch(
author_labels: Author (reference) cell type labels from cell ontology
algo_comparison_data: List of (algorithm_name, algorithm_labels) tuples
study_name: Name of the study for organizing files (optional)
metric: Metric to use for similarity calculation (e.g., "cosine_kernel")
metric_params: Additional parameters for the similarity metric

Returns:
DataFrame with detailed comparison results including:
Expand Down Expand Up @@ -433,7 +440,12 @@ async def compare_batch(
logger.info(f"Processing algorithm: {algorithm_name}")

detailed_results = await self.compare_single_pair(
author_labels, algorithm_labels, algorithm_name, study_name, metric
author_labels,
algorithm_labels,
algorithm_name,
study_name,
metric,
metric_params,
)

# Add algorithm name to each result
Expand Down Expand Up @@ -479,6 +491,7 @@ async def compare_anndata_objects(
target_columns: list[str],
author_column: str,
algorithm_names: list[str] | None = None,
metric_params: dict | None = None,
) -> pd.DataFrame:
"""
Compare cell type annotations across AnnData objects.
Expand Down Expand Up @@ -523,4 +536,6 @@ async def compare_anndata_objects(
all_comparison_data.append((algo_name, algorithm_labels))

# Perform batch comparison
return await self.compare_batch(author_labels, all_comparison_data)
return await self.compare_batch(
author_labels, all_comparison_data, metric_params=metric_params
)
2 changes: 2 additions & 0 deletions cyteonto/matcher/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,7 @@ def compute_ontology_similarity(
author_ontology_score: list[float] | None = None,
user_ontology_score: list[float] | None = None,
metric: str = "cosine_kernel",
metric_params: dict | None = None,
) -> list[float]:
"""
Compute ontology hierarchy-based similarity between author and user labels.
Expand Down Expand Up @@ -211,6 +212,7 @@ def compute_ontology_similarity(
author_score,
user_score,
metric=metric,
metric_params=metric_params,
)
logger.debug(f"Matcher Similarity: {similarity}")
similarities.append(similarity)
Expand Down
17 changes: 16 additions & 1 deletion cyteonto/ontology/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,7 @@ def compute_ontology_similarity(
ontology_score1: float = 1.0,
ontology_score2: float = 1.0,
metric: str = "cosine_kernel",
metric_params: dict | None = None,
) -> float:
"""
Compute similarity between two ontology terms using a specified metric.
Expand All @@ -346,6 +347,11 @@ def compute_ontology_similarity(
- Ensemble: 'cosine_kernel' (DEFAULT; embedding with Gaussian hill)
- Combined: 'final'

metric_params: Optional dictionary of parameters for the metric.
For 'cosine_kernel', supported keys are:
- center: Center of the Gaussian (default: 1)
- width: Width of the Gaussian (default: 0.25)
- amplitude: Amplitude of the Gaussian (default: 1)

Returns:
Similarity score between 0 and 1.
Expand Down Expand Up @@ -397,7 +403,16 @@ def compute_ontology_similarity(
logger.debug(f"Embeddings: {embd1[:5]}, {embd2[:5]}")
d3 = self._cosine_similarity(embd1, embd2) # type:ignore
logger.debug(f"Embedding Cosine: {d3}")
d3_hill = self.gaussian_hill(d3, center=1, width=0.25, amplitude=1)

# Get parameters with defaults
params = metric_params or {}
center = params.get("center", 1)
width = params.get("width", 0.25)
amplitude = params.get("amplitude", 1)

d3_hill = self.gaussian_hill(
d3, center=center, width=width, amplitude=amplitude
)
logger.debug(f"Embedding Cosine Hill: {d3_hill}")
return d3_hill

Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,17 @@ dependencies = [
"anndata>=0.12.2",
"asyncio>=4.0.0",
"fastapi>=0.116.1",
"iprogress>=0.4",
"ipykernel>=6.30.1",
"ipython>=9.4.0",
"jupyter>=1.1.1",
"logfire>=4.3.3",
"loguru>=0.7.3",
"mypy>=1.17.1",
"numpy>=2.3.2",
"owlready2>=0.48",
"pandas>=2.3.1",
"pydantic>=2.11.7",
"pydantic-ai>=0.7.4",
"pytest>=8.4.1",
"pytest-asyncio>=1.1.0",
Expand Down
6 changes: 6 additions & 0 deletions scripts/download_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,12 @@
BASE_URL = "https://pub-d8bf3af01ebe421abded39c4cb33d88a.r2.dev/cyteonto"

AVAILABLE_MODELS = {
"llama-3.3-70b-versatile": {
"display_name": "LLaMA 3.3 70B Versatile (descriptions) + Qwen3-Embedding-8B (embeddings)",
"description_file": "descriptions_llama-3.3-70b-versatile.json",
"embedding_file": "embeddings_llama-3.3-70b-versatile_Qwen-Qwen3-Embedding-8B.npz",
"recommended": False,
},
"moonshot-ai_kimi-k2": {
"display_name": "Moonshot AI Kimi-K2 (descriptions) + Qwen3-Embedding-8B (embeddings)",
"description_file": "descriptions_moonshotai-Kimi-K2-Instruct.json",
Expand Down
6 changes: 6 additions & 0 deletions scripts/show_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from pathlib import Path

AVAILABLE_MODELS = {
"llama-3.3-70b-versatile": {
"display_name": "LLaMA 3.3 70B Versatile (descriptions) + Qwen3-Embedding-8B (embeddings)",
"description_file": "descriptions_llama-3.3-70b-versatile.json",
"embedding_file": "embeddings_llama-3.3-70b-versatile_Qwen-Qwen3-Embedding-8B.npz",
"recommended": False,
},
"moonshot-ai_kimi-k2": {
"display_name": "Moonshot AI Kimi-K2 (descriptions) + Qwen3-Embedding-8B (embeddings)",
"description_file": "descriptions_moonshotai-Kimi-K2-Instruct.json",
Expand Down
102 changes: 102 additions & 0 deletions tests/test_cosine_kernel_params.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from unittest.mock import Mock, patch

import numpy as np
import pytest

from cyteonto.matcher.matcher import CyteOntoMatcher
from cyteonto.ontology.similarity import OntologySimilarity


class TestCosineKernelParams:
"""Test exposure of cosine_kernel parameters."""

@pytest.fixture
def mock_similarity(self):
"""Create a mock OntologySimilarity instance."""
with patch(
"cyteonto.ontology.similarity.OntologySimilarity._load_ontology",
return_value=True,
):
sim = OntologySimilarity()
sim._ontology = Mock()
# Mock embedding map
sim.embedding_map = {
"CL:0000001": np.array([1.0, 0.0]),
"CL:0000002": np.array([0.0, 1.0]), # Orthogonal, cosine sim = 0
"CL:0000003": np.array([0.707, 0.707]), # 45 deg, cosine sim ~ 0.707
}
# Mock find class to return something so it doesn't fail
sim._find_class_cached = Mock(return_value=Mock())
return sim

def test_compute_ontology_similarity_default_params(self, mock_similarity):
"""Test with default parameters."""
# Cosine similarity of CL:0000001 and CL:0000002 is 0.0
# Default params: center=1, width=0.25, amplitude=1
# Gaussian(0, 1, 0.25, 1) = exp(-((0-1)**2)/(2*0.25**2)) = exp(-1/0.125) = exp(-8) ~= 0.000335

score = mock_similarity.compute_ontology_similarity(
"CL:0000001", "CL:0000002", metric="cosine_kernel"
)
# We don't need exact match, just ensure it runs and returns a value
assert isinstance(score, float)
assert 0 <= score <= 1

def test_compute_ontology_similarity_custom_params(self, mock_similarity):
"""Test with custom parameters."""
# Cosine similarity is 0.0
# Custom params: center=0, width=1, amplitude=1
# Gaussian(0, 0, 1, 1) = exp(-((0-0)**2)/(2*1**2)) = exp(0) = 1.0

params = {"center": 0, "width": 1, "amplitude": 1}
score = mock_similarity.compute_ontology_similarity(
"CL:0000001", "CL:0000002", metric="cosine_kernel", metric_params=params
)
assert score == pytest.approx(1.0, rel=1e-5)

def test_params_affect_result(self, mock_similarity):
"""Verify that changing parameters changes the result."""
term1 = "CL:0000001"
term2 = "CL:0000003" # Cosine sim ~ 0.707

# Default params
score_default = mock_similarity.compute_ontology_similarity(
term1, term2, metric="cosine_kernel"
)

# Custom params: wider width should increase score for non-perfect match if center=1
# Sim is ~0.7. Center=1.
# Default width=0.25. Dist=0.3. 0.3 > width. Score low.
# Custom width=1.0. Dist=0.3. 0.3 < width. Score higher.

params = {"width": 1.0}
score_custom = mock_similarity.compute_ontology_similarity(
term1, term2, metric="cosine_kernel", metric_params=params
)

assert score_custom > score_default

@patch("cyteonto.matcher.matcher.CyteOntoMatcher._get_ontology_similarity")
@patch("cyteonto.matcher.matcher.CyteOntoMatcher._get_ontology_extractor")
def test_matcher_passes_params(self, mock_get_extractor, mock_get_similarity):
"""Test that CyteOntoMatcher passes parameters to OntologySimilarity."""
# Setup mocks
mock_extractor = Mock()
mock_extractor.build_mappings.return_value = (None, {})
mock_get_extractor.return_value = mock_extractor

mock_sim_calc = Mock()
mock_sim_calc.compute_ontology_similarity.return_value = 0.5
mock_get_similarity.return_value = mock_sim_calc

matcher = CyteOntoMatcher(embeddings_file_path=Mock())

params = {"center": 0.5}
matcher.compute_ontology_similarity(
["T cell"], ["B cell"], metric="cosine_kernel", metric_params=params
)

# Verify call args
call_args = mock_sim_calc.compute_ontology_similarity.call_args
assert call_args is not None
assert call_args.kwargs.get("metric_params") == params
Loading