Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions tests/unit/test_analytics_queries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
"""Unit tests for analytics query user-agent fallback behavior."""

from datetime import datetime

import pandas as pd
import pytest

import wikidatasearch.services.logger.analytics_queries as analytics_queries
from wikidatasearch.services.logger.analytics_queries import AnalyticsQueryService


class _DummyConnection:
"""No-op connection context manager for stubbing SQLAlchemy engine."""

def __enter__(self):
return object()

def __exit__(self, exc_type, exc, tb):
return False


class _DummyEngine:
"""No-op engine that returns a dummy context manager."""

def connect(self):
return _DummyConnection()


def _stub_read_sql(monkeypatch, df: pd.DataFrame) -> None:
"""Patch read_sql and engine to return a predefined DataFrame."""

def _fake_read_sql(*_args, **_kwargs):
return df.copy()

monkeypatch.setattr(analytics_queries, "engine", _DummyEngine())
monkeypatch.setattr(analytics_queries.pd, "read_sql", _fake_read_sql)


def _capture_sql(monkeypatch, df: pd.DataFrame) -> dict[str, str]:
"""Patch read_sql and capture the SQL text used by the query."""
captured: dict[str, str] = {"query": ""}

def _fake_read_sql(sql, *_args, **_kwargs):
captured["query"] = str(sql)
return df.copy()

monkeypatch.setattr(analytics_queries, "engine", _DummyEngine())
monkeypatch.setattr(analytics_queries.pd, "read_sql", _fake_read_sql)
return captured


def _assert_vector_routes_and_status_filter(sql_text: str) -> None:
"""Assert vector-route queries exclude 400 and 422 statuses."""
assert f"route IN {AnalyticsQueryService.VECTOR_QUERY_ROUTES_SQL}" in sql_text
assert "status NOT IN (400, 422)" in sql_text
assert "status <> 422" not in sql_text


def test_get_total_user_agents_prefers_original_user_agent(monkeypatch):
"""Return original user agents when available, otherwise fallback to hash."""
df = pd.DataFrame(
[
{"client": "browser", "user_agent_hash": "hash_browser", "user_agent_value": "Mozilla/5.0 X"},
{"client": "api", "user_agent_hash": "hash_api_a", "user_agent_value": "WikiBot/1.0"},
{"client": "api", "user_agent_hash": "hash_api_b", "user_agent_value": "hash_api_b"},
]
)
_stub_read_sql(monkeypatch, df)

out = AnalyticsQueryService.get_total_user_agents(
datetime(2026, 4, 1),
datetime(2026, 4, 23),
include_user_agents=True,
)

assert out["browser"] == 1
assert out["api"] == 2
assert out["total"] == 3
assert out["user_agents"] == ["Mozilla/5.0 X", "WikiBot/1.0", "hash_api_b"]


def test_get_new_user_agents_prefers_original_user_agent(monkeypatch):
"""Return original user agents in new-user-agents when available."""
df = pd.DataFrame(
[
{"user_agent_hash": "hash_a", "user_agent_value": "CustomAgent/2.0"},
{"user_agent_hash": "hash_b", "user_agent_value": "hash_b"},
]
)
_stub_read_sql(monkeypatch, df)

out = AnalyticsQueryService.get_new_user_agents(
datetime(2026, 4, 1),
datetime(2026, 4, 23),
include_user_agents=True,
)

assert out["total"] == 2
assert out["user_agents"] == ["CustomAgent/2.0", "hash_b"]


def test_get_new_user_agents_count_only_uses_total_query(monkeypatch):
"""Return only total count when include_user_agents is False."""
df = pd.DataFrame([{"total": 7}])
_stub_read_sql(monkeypatch, df)

out = AnalyticsQueryService.get_new_user_agents(
datetime(2026, 4, 1),
datetime(2026, 4, 23),
include_user_agents=False,
)

assert out == {"total": 7}


def test_get_consistent_user_agents_prefers_original_user_agent(monkeypatch):
"""Return original user agents in consistent-user-agents when available."""
df = pd.DataFrame(
[
{"user_agent_hash": "hash_a", "user_agent_value": "AgentA/1.2"},
{"user_agent_hash": "hash_b", "user_agent_value": "hash_b"},
{"user_agent_hash": "hash_c", "user_agent_value": "AgentC/3.4"},
]
)
_stub_read_sql(monkeypatch, df)

out = AnalyticsQueryService.get_consistent_user_agents(
datetime(2026, 4, 1),
datetime(2026, 4, 23),
include_user_agents=True,
)

assert out["total"] == 3
assert out["user_agents"] == ["AgentA/1.2", "AgentC/3.4", "hash_b"]


def test_get_consistent_user_agents_count_only_uses_total_query(monkeypatch):
"""Return only total count when include_user_agents is False."""
df = pd.DataFrame([{"total": 4}])
_stub_read_sql(monkeypatch, df)

out = AnalyticsQueryService.get_consistent_user_agents(
datetime(2026, 4, 1),
datetime(2026, 4, 23),
include_user_agents=False,
)

assert out == {"total": 4}


@pytest.mark.parametrize(
"call",
[
lambda: AnalyticsQueryService.get_total_user_agents(
datetime(2026, 4, 1),
datetime(2026, 4, 23),
include_user_agents=True,
),
lambda: AnalyticsQueryService.get_total_user_agents(
datetime(2026, 4, 1),
datetime(2026, 4, 23),
include_user_agents=False,
),
lambda: AnalyticsQueryService.get_total_requests(
datetime(2026, 4, 1),
datetime(2026, 4, 23),
),
lambda: AnalyticsQueryService.get_total_requests_by_lang(
datetime(2026, 4, 1),
datetime(2026, 4, 23),
),
lambda: AnalyticsQueryService.get_new_user_agents(
datetime(2026, 4, 1),
datetime(2026, 4, 23),
include_user_agents=True,
),
lambda: AnalyticsQueryService.get_new_user_agents(
datetime(2026, 4, 1),
datetime(2026, 4, 23),
include_user_agents=False,
),
lambda: AnalyticsQueryService.get_consistent_user_agents(
datetime(2026, 4, 1),
datetime(2026, 4, 23),
include_user_agents=True,
),
lambda: AnalyticsQueryService.get_consistent_user_agents(
datetime(2026, 4, 1),
datetime(2026, 4, 23),
include_user_agents=False,
),
],
)
def test_vector_route_queries_exclude_400_and_422(monkeypatch, call):
"""Ensure all vector-route analytics queries exclude 400 and 422."""
captured = _capture_sql(monkeypatch, pd.DataFrame())
call()
_assert_vector_routes_and_status_filter(captured["query"])
5 changes: 3 additions & 2 deletions tests/unit/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from fastapi import BackgroundTasks, HTTPException


def test_languages_route_returns_split_languages(test_ctx, run_async):
def test_languages_route_returns_split_languages(test_ctx, run_async, make_request):
"""Validate languages route returns split languages."""
frontend = test_ctx["frontend"]
data = run_async(frontend.languages())
req = make_request("/languages")
data = run_async(frontend.languages(req))
assert data["vectordb_langs"] == ["en", "fr"]
assert "de" in data["other_langs"]
assert "ar" in data["other_langs"]
Expand Down
30 changes: 30 additions & 0 deletions wikidatasearch/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
"""Dependencies for the FastAPI application."""

import base64
import binascii
import time

from fastapi import FastAPI, HTTPException, Request
from slowapi import Limiter, _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded

from .config import settings
from .services.logger import Logger


Expand All @@ -27,6 +30,32 @@ def user_agent_key(request: Request) -> str:
limiter = Limiter(key_func=user_agent_key)


def verify_admin_auth(request: Request) -> str:
"""Verify HTTP Basic auth against for admin page."""
expected = settings.ANALYTICS_API_SECRET
if not expected:
raise HTTPException(status_code=404, detail="Not found")

authorization = request.headers.get("authorization", "")
if not authorization.startswith("Basic "):
decoded = None
else:
token = authorization[6:].strip()
try:
payload = base64.b64decode(token).decode("utf-8")
except (binascii.Error, UnicodeDecodeError):
payload = ""
decoded = payload.split(":", 1) if ":" in payload else None

if not decoded or decoded[1] != expected:
raise HTTPException(
status_code=401,
detail="Incorrect admin credentials",
headers={"WWW-Authenticate": "Basic"},
)
return decoded[0] or "admin"


def require_descriptive_user_agent(request: Request) -> None:
"""Enforce a descriptive User-Agent and blocks generic HTTP clients."""
ua = request.headers.get("user-agent", "").strip()
Expand All @@ -37,6 +66,7 @@ def require_descriptive_user_agent(request: Request) -> None:


def _logged_rate_limit_exceeded_handler(request: Request, exc: Exception):
"""Custom handler for rate limit breaches that logs the event."""
error = str(exc) or "Rate limit exceeded"
Logger.add_request(request, 429, time.time(), error=error)
return _rate_limit_exceeded_handler(request, exc)
Expand Down
15 changes: 12 additions & 3 deletions wikidatasearch/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@
from gradio.routes import mount_gradio_app

from .config import settings
from .dependencies import register_rate_limit
from .dependencies import register_rate_limit, verify_admin_auth
from .routes import frontend, health, item, property, similarity
from .services.analytics import build_analytics_app
from .routes.admin import analytics_api_router, build_analytics_app
from .services.logger.database import initialize_database

app = FastAPI(
title="Wikidata Vector Search",
Expand Down Expand Up @@ -37,6 +38,7 @@
@app.on_event("startup")
async def startup_event():
"""Initialize the FastAPI cache at startup."""
initialize_database()
FastAPICache.init(InMemoryBackend(), prefix="wikidata-cache")


Expand All @@ -50,4 +52,11 @@ async def startup_event():
frontend.mount_static(app)

if settings.ANALYTICS_API_SECRET:
mount_gradio_app(app, build_analytics_app(), path=f"/admin/{settings.ANALYTICS_API_SECRET}")
app.include_router(analytics_api_router)
mount_gradio_app(
app,
build_analytics_app(),
path="/admin",
auth_dependency=verify_admin_auth,
auth_message="Provide HTTP Basic auth using ANALYTICS_API_SECRET as password.",
)
6 changes: 6 additions & 0 deletions wikidatasearch/routes/admin/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""Admin route modules exposed by the API package."""

from .analytics_api import router as analytics_api_router
from .analytics_ui import build_analytics_app

__all__ = ["analytics_api_router", "build_analytics_app"]
Loading
Loading