Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions app/ai-service/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,37 @@
INFERENCE_LATENCY = Histogram('inference_latency_seconds', 'Inference latency in seconds', ['task_type'])
PIPELINE_STEP_LATENCY = Histogram('pipeline_step_latency_seconds', 'Pipeline step latency in seconds', ['step_name'])

# Circuit breaker metrics
# State is encoded numerically so it can be plotted over time:
# 0 = CLOSED (healthy), 1 = HALF_OPEN (probing), 2 = OPEN (failing fast).
CIRCUIT_STATE = Gauge(
'circuit_breaker_state',
'Circuit breaker state (0=CLOSED, 1=HALF_OPEN, 2=OPEN)',
['breaker_name'],
)
CIRCUIT_FAILURE_COUNT = Counter(
'circuit_breaker_failure_count_total',
'Total failures recorded by the circuit breaker',
['breaker_name'],
)
CIRCUIT_RECOVERY_TIME = Histogram(
'circuit_breaker_recovery_time_seconds',
'Time spent in the OPEN state before transitioning to HALF_OPEN',
['breaker_name'],
)

# Circuit-breaker state constants. Exported so callers (and tests) can
# compare against the numeric gauge value without hard-coding literals.
CIRCUIT_STATE_CLOSED = 0
CIRCUIT_STATE_HALF_OPEN = 1
CIRCUIT_STATE_OPEN = 2


def set_circuit_state(breaker_name: str, state_value: int) -> None:
"""Helper to update the circuit-state gauge from anywhere."""
CIRCUIT_STATE.labels(breaker_name=breaker_name).set(state_value)


def check_system_resources(memory_threshold_percent: float = 90.0) -> bool:
"""
Check if system RAM or VRAM is above threshold.
Expand Down
32 changes: 30 additions & 2 deletions app/ai-service/services/circuit_breaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,49 @@
import logging
from threading import Lock

from metrics import (
CIRCUIT_STATE_CLOSED,
CIRCUIT_STATE_HALF_OPEN,
CIRCUIT_STATE_OPEN,
CIRCUIT_FAILURE_COUNT,
CIRCUIT_RECOVERY_TIME,
set_circuit_state,
)

logger = logging.getLogger(__name__)


class CircuitBreaker:
"""
A thread-safe implementation of the Circuit Breaker pattern.

States:
- CLOSED: Normal operation. Requests flow through.
- OPEN: Service is failing. Requests fail-fast (return False/raise error).
- HALF_OPEN: Recovery window elapsed. Allow a request to test downstream health.

The breaker publishes Prometheus metrics on every state change:
- CIRCUIT_STATE (Gauge): current state, encoded as 0/1/2.
- CIRCUIT_FAILURE_COUNT (Counter): cumulative failure count.
- CIRCUIT_RECOVERY_TIME (Histogram): time spent OPEN before HALF_OPEN.
Metric updates happen inside the same lock that guards state, so the
exported values can never diverge from the underlying state.
"""

def __init__(self, name: str, failure_threshold: int = 3, recovery_timeout: float = 30.0):
self.name = name
self.failure_threshold = failure_threshold
self.recovery_timeout = recovery_timeout

self.state = "CLOSED" # CLOSED, OPEN, HALF_OPEN
self.failure_count = 0
self.last_state_change = time.time()
self._lock = Lock()

# Publish the initial state so the gauge is always defined for
# every instantiated breaker, even before any traffic flows.
set_circuit_state(self.name, CIRCUIT_STATE_CLOSED)

def allow_request(self) -> bool:
"""
Check if a request is allowed to proceed.
Expand All @@ -34,6 +54,9 @@ def allow_request(self) -> bool:
now = time.time()
if self.state == "OPEN":
if now - self.last_state_change >= self.recovery_timeout:
# Capture recovery time BEFORE updating last_state_change,
# so the histogram reflects how long we were actually OPEN.
recovery_seconds = now - self.last_state_change
logger.info(
"Circuit breaker for provider '%s' transitioning from OPEN to HALF_OPEN "
"(recovery timeout %ss elapsed)",
Expand All @@ -42,6 +65,8 @@ def allow_request(self) -> bool:
)
self.state = "HALF_OPEN"
self.last_state_change = now
set_circuit_state(self.name, CIRCUIT_STATE_HALF_OPEN)
CIRCUIT_RECOVERY_TIME.labels(breaker_name=self.name).observe(recovery_seconds)
return True
return False
return True
Expand All @@ -62,6 +87,7 @@ def record_success(self) -> None:
self.state = "CLOSED"
self.failure_count = 0
self.last_state_change = now
set_circuit_state(self.name, CIRCUIT_STATE_CLOSED)
elif self.state == "CLOSED":
self.failure_count = 0

Expand All @@ -73,6 +99,7 @@ def record_failure(self) -> None:
with self._lock:
now = time.time()
self.failure_count += 1
CIRCUIT_FAILURE_COUNT.labels(breaker_name=self.name).inc()
if self.state == "HALF_OPEN" or self.failure_count >= self.failure_threshold:
logger.warning(
"Circuit breaker for provider '%s' transitioning from %s to OPEN "
Expand All @@ -84,3 +111,4 @@ def record_failure(self) -> None:
)
self.state = "OPEN"
self.last_state_change = now
set_circuit_state(self.name, CIRCUIT_STATE_OPEN)
123 changes: 116 additions & 7 deletions app/ai-service/tests/test_circuit_breaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,39 +7,40 @@
from services.humanitarian_verification import HumanitarianVerificationService
from exceptions import AIServiceError
from config import settings
import metrics
from prometheus_client import REGISTRY


def test_circuit_breaker_basic_transitions():
# Set a short recovery timeout for fast testing
breaker = CircuitBreaker("test-provider", failure_threshold=2, recovery_timeout=0.1)

# 1. Starts CLOSED
assert breaker.state == "CLOSED"
assert breaker.allow_request() is True

# 2. First failure
breaker.record_failure()
assert breaker.state == "CLOSED" # Not tripped yet
assert breaker.allow_request() is True

# 3. Second failure (reaches threshold)
breaker.record_failure()
assert breaker.state == "OPEN"
assert breaker.allow_request() is False # Tripped

# 4. Wait for recovery timeout
time.sleep(0.12)

# 5. Transitions to HALF_OPEN on allow_request check
assert breaker.allow_request() is True
assert breaker.state == "HALF_OPEN"

# 6. Success closes the circuit
breaker.record_success()
assert breaker.state == "CLOSED"
assert breaker.failure_count == 0


def test_circuit_breaker_half_open_failure():
breaker = CircuitBreaker("test-provider", failure_threshold=2, recovery_timeout=0.1)

Expand Down Expand Up @@ -129,3 +130,111 @@ def test_request_timeout_raises_ai_timeout(self, mock_post, monkeypatch):

# The breaker for openai should have recorded the failure
assert self.service.breakers["openai"].failure_count == 2 # Primary & fallback attempts both failed


def _sample(name: str, labels: dict) -> float:
"""Read a sample value directly from the Prometheus registry."""
return REGISTRY.get_sample_value(name, labels)


class TestCircuitBreakerMetrics:
"""Verify that CircuitBreaker publishes the metrics defined in metrics.py."""

def test_initial_state_is_published(self):
# Using a fresh breaker name ensures labels don't collide with other tests.
CircuitBreaker("metrics-initial", failure_threshold=1, recovery_timeout=0.1)
assert (
_sample("circuit_breaker_state", {"breaker_name": "metrics-initial"}) == 0
)

def test_failure_increments_counter_and_trips_state(self):
breaker = CircuitBreaker(
"metrics-failures", failure_threshold=2, recovery_timeout=0.1
)
before = _sample(
"circuit_breaker_failure_count_total", {"breaker_name": "metrics-failures"}
) or 0.0

breaker.record_failure()
after_one = _sample(
"circuit_breaker_failure_count_total", {"breaker_name": "metrics-failures"}
)
assert after_one == before + 1
# State stays CLOSED below threshold
assert _sample("circuit_breaker_state", {"breaker_name": "metrics-failures"}) == 0

breaker.record_failure()
# Threshold reached -> gauge flips to OPEN (2)
assert _sample("circuit_breaker_state", {"breaker_name": "metrics-failures"}) == 2

def test_recovery_updates_histogram_and_state_gauge(self):
breaker = CircuitBreaker(
"metrics-recovery", failure_threshold=1, recovery_timeout=0.05
)
breaker.record_failure() # trips immediately
assert _sample("circuit_breaker_state", {"breaker_name": "metrics-recovery"}) == 2

time.sleep(0.07)
# allow_request triggers the OPEN -> HALF_OPEN transition
assert breaker.allow_request() is True

assert _sample("circuit_breaker_state", {"breaker_name": "metrics-recovery"}) == 1

sum_value = _sample(
"circuit_breaker_recovery_time_seconds_sum",
{"breaker_name": "metrics-recovery"},
)
count_value = _sample(
"circuit_breaker_recovery_time_seconds_count",
{"breaker_name": "metrics-recovery"},
)
assert count_value is not None and count_value >= 1
assert sum_value is not None and sum_value >= 0.05

def test_success_closes_circuit_and_resets_state_gauge(self):
breaker = CircuitBreaker(
"metrics-success", failure_threshold=1, recovery_timeout=0.05
)
breaker.record_failure()
time.sleep(0.07)
breaker.allow_request() # -> HALF_OPEN
breaker.record_success()

assert _sample("circuit_breaker_state", {"breaker_name": "metrics-success"}) == 0

def test_half_open_failure_increments_counter_and_reopens(self):
"""A failure during the HALF_OPEN probe must be counted and reopen the
circuit. Without this, callers would never see the breaker come back
online after a successful probe that subsequently fails."""
breaker = CircuitBreaker(
"metrics-half-open-failure",
failure_threshold=1,
recovery_timeout=0.05,
)
breaker.record_failure() # CLOSED -> OPEN
time.sleep(0.07)
breaker.allow_request() # OPEN -> HALF_OPEN

counter_before_probe = _sample(
"circuit_breaker_failure_count_total",
{"breaker_name": "metrics-half-open-failure"},
)
assert (
_sample("circuit_breaker_state", {"breaker_name": "metrics-half-open-failure"})
== 1
)

breaker.record_failure() # HALF_OPEN -> OPEN
counter_after_probe = _sample(
"circuit_breaker_failure_count_total",
{"breaker_name": "metrics-half-open-failure"},
)
assert counter_after_probe is not None
assert counter_before_probe is not None
assert counter_after_probe == counter_before_probe + 1
assert (
_sample("circuit_breaker_state", {"breaker_name": "metrics-half-open-failure"})
== 2
)
# request rejected because we just re-opened
assert breaker.allow_request() is False
Loading