Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 65 additions & 5 deletions src/database/api_key_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import logging
import secrets
import string
import time
from collections import OrderedDict
from datetime import datetime
from threading import RLock
from typing import List, Optional, Dict, Any

from src.config import settings
Expand All @@ -18,6 +21,11 @@
# In-memory fallback
_in_memory_api_keys: Dict[str, Dict[str, Any]] = {}

VALIDATION_CACHE_TTL_SECONDS = 120
VALIDATION_CACHE_MAX_SIZE = 2048
_validation_cache: OrderedDict[str, tuple[float, Dict[str, Any]]] = OrderedDict()
_validation_cache_lock = RLock()


class APIKeyStore:
"""MongoDB-backed storage for API key management with in-memory fallback."""
Expand Down Expand Up @@ -79,6 +87,41 @@ def _hash_key(self, key: str) -> str:
"""Create SHA-256 hash of the API key."""
return hashlib.sha256(key.encode()).hexdigest()

def _clear_validation_cache(self) -> None:
"""Clear cached API key validation results."""
with _validation_cache_lock:
_validation_cache.clear()

def _get_cached_validation(self, key_hash: str) -> Optional[Dict[str, Any]]:
"""Return a cached validation result when it is still fresh."""
with _validation_cache_lock:
cached = _validation_cache.get(key_hash)
if not cached:
return None

expires_at, key_doc = cached
if expires_at <= time.monotonic():
_validation_cache.pop(key_hash, None)
return None

result = dict(key_doc)

result["last_used"] = datetime.utcnow()
return result

def _cache_validation(self, key_hash: str, key_doc: Dict[str, Any]) -> None:
"""Cache a sanitized active API key document."""
with _validation_cache_lock:
if key_hash in _validation_cache:
_validation_cache.pop(key_hash, None)
while len(_validation_cache) >= VALIDATION_CACHE_MAX_SIZE:
_validation_cache.popitem(last=False)

_validation_cache[key_hash] = (
time.monotonic() + VALIDATION_CACHE_TTL_SECONDS,
dict(key_doc),
)

def create_api_key(
self,
user_id: str,
Expand Down Expand Up @@ -174,6 +217,10 @@ def validate_api_key(self, key: str) -> Optional[Dict[str, Any]]:
return result
return None

cached_doc = self._get_cached_validation(key_hash)
if cached_doc:
return cached_doc

try:
key_doc = self.api_keys.find_one({
"key_hash": key_hash,
Expand All @@ -185,10 +232,15 @@ def validate_api_key(self, key: str) -> Optional[Dict[str, Any]]:
{"_id": key_doc["_id"]},
{"$set": {"last_used": now}}
)
key_doc["last_used"] = now
key_doc["id"] = str(key_doc.pop("_id"))
key_doc = {
**key_doc,
"id": str(key_doc["_id"]),
"last_used": now,
}
key_doc.pop("_id", None)
key_doc.pop("key_hash", None)
return key_doc
self._cache_validation(key_hash, key_doc)
return dict(key_doc) if key_doc else None
except Exception as e:
logger.error(f"Database error validating API key: {e}")
return None
Expand All @@ -199,6 +251,7 @@ def revoke_api_key(self, user_id: str, key_id: str) -> bool:
if key_id in _in_memory_api_keys:
if _in_memory_api_keys[key_id].get("user_id") == user_id:
_in_memory_api_keys[key_id]["is_active"] = False
self._clear_validation_cache()
return True
return False

Expand All @@ -208,7 +261,10 @@ def revoke_api_key(self, user_id: str, key_id: str) -> bool:
{"_id": ObjectId(key_id), "user_id": user_id},
{"$set": {"is_active": False}}
)
return result.modified_count > 0
success = result.modified_count > 0
if success:
self._clear_validation_cache()
return success
except Exception as e:
logger.error(f"Failed to revoke API key {key_id}: {e}")
return False
Expand All @@ -224,6 +280,7 @@ def update_api_key_name(
if key_id in _in_memory_api_keys:
if _in_memory_api_keys[key_id].get("user_id") == user_id:
_in_memory_api_keys[key_id]["name"] = new_name
self._clear_validation_cache()
return True
return False

Expand All @@ -233,7 +290,10 @@ def update_api_key_name(
{"_id": ObjectId(key_id), "user_id": user_id},
{"$set": {"name": new_name}}
)
return result.modified_count > 0
success = result.modified_count > 0
if success:
self._clear_validation_cache()
return success
except Exception as e:
logger.error(f"Failed to update API key name {key_id}: {e}")
return False
Expand Down
Loading