diff --git a/app/ai-service/metrics.py b/app/ai-service/metrics.py index 3d29dad..a94e534 100644 --- a/app/ai-service/metrics.py +++ b/app/ai-service/metrics.py @@ -17,6 +17,37 @@ INFERENCE_LATENCY = Histogram('inference_latency_seconds', 'Inference latency in seconds', ['task_type']) PIPELINE_STEP_LATENCY = Histogram('pipeline_step_latency_seconds', 'Pipeline step latency in seconds', ['step_name']) +# Circuit breaker metrics +# State is encoded numerically so it can be plotted over time: +# 0 = CLOSED (healthy), 1 = HALF_OPEN (probing), 2 = OPEN (failing fast). +CIRCUIT_STATE = Gauge( + 'circuit_breaker_state', + 'Circuit breaker state (0=CLOSED, 1=HALF_OPEN, 2=OPEN)', + ['breaker_name'], +) +CIRCUIT_FAILURE_COUNT = Counter( + 'circuit_breaker_failure_count_total', + 'Total failures recorded by the circuit breaker', + ['breaker_name'], +) +CIRCUIT_RECOVERY_TIME = Histogram( + 'circuit_breaker_recovery_time_seconds', + 'Time spent in the OPEN state before transitioning to HALF_OPEN', + ['breaker_name'], +) + +# Circuit-breaker state constants. Exported so callers (and tests) can +# compare against the numeric gauge value without hard-coding literals. +CIRCUIT_STATE_CLOSED = 0 +CIRCUIT_STATE_HALF_OPEN = 1 +CIRCUIT_STATE_OPEN = 2 + + +def set_circuit_state(breaker_name: str, state_value: int) -> None: + """Helper to update the circuit-state gauge from anywhere.""" + CIRCUIT_STATE.labels(breaker_name=breaker_name).set(state_value) + + def check_system_resources(memory_threshold_percent: float = 90.0) -> bool: """ Check if system RAM or VRAM is above threshold. diff --git a/app/ai-service/services/circuit_breaker.py b/app/ai-service/services/circuit_breaker.py index 805b4f7..605f871 100644 --- a/app/ai-service/services/circuit_breaker.py +++ b/app/ai-service/services/circuit_breaker.py @@ -2,29 +2,49 @@ import logging from threading import Lock +from metrics import ( + CIRCUIT_STATE_CLOSED, + CIRCUIT_STATE_HALF_OPEN, + CIRCUIT_STATE_OPEN, + CIRCUIT_FAILURE_COUNT, + CIRCUIT_RECOVERY_TIME, + set_circuit_state, +) + logger = logging.getLogger(__name__) class CircuitBreaker: """ A thread-safe implementation of the Circuit Breaker pattern. - + States: - CLOSED: Normal operation. Requests flow through. - OPEN: Service is failing. Requests fail-fast (return False/raise error). - HALF_OPEN: Recovery window elapsed. Allow a request to test downstream health. + + The breaker publishes Prometheus metrics on every state change: + - CIRCUIT_STATE (Gauge): current state, encoded as 0/1/2. + - CIRCUIT_FAILURE_COUNT (Counter): cumulative failure count. + - CIRCUIT_RECOVERY_TIME (Histogram): time spent OPEN before HALF_OPEN. + Metric updates happen inside the same lock that guards state, so the + exported values can never diverge from the underlying state. """ def __init__(self, name: str, failure_threshold: int = 3, recovery_timeout: float = 30.0): self.name = name self.failure_threshold = failure_threshold self.recovery_timeout = recovery_timeout - + self.state = "CLOSED" # CLOSED, OPEN, HALF_OPEN self.failure_count = 0 self.last_state_change = time.time() self._lock = Lock() + # Publish the initial state so the gauge is always defined for + # every instantiated breaker, even before any traffic flows. + set_circuit_state(self.name, CIRCUIT_STATE_CLOSED) + def allow_request(self) -> bool: """ Check if a request is allowed to proceed. @@ -34,6 +54,9 @@ def allow_request(self) -> bool: now = time.time() if self.state == "OPEN": if now - self.last_state_change >= self.recovery_timeout: + # Capture recovery time BEFORE updating last_state_change, + # so the histogram reflects how long we were actually OPEN. + recovery_seconds = now - self.last_state_change logger.info( "Circuit breaker for provider '%s' transitioning from OPEN to HALF_OPEN " "(recovery timeout %ss elapsed)", @@ -42,6 +65,8 @@ def allow_request(self) -> bool: ) self.state = "HALF_OPEN" self.last_state_change = now + set_circuit_state(self.name, CIRCUIT_STATE_HALF_OPEN) + CIRCUIT_RECOVERY_TIME.labels(breaker_name=self.name).observe(recovery_seconds) return True return False return True @@ -62,6 +87,7 @@ def record_success(self) -> None: self.state = "CLOSED" self.failure_count = 0 self.last_state_change = now + set_circuit_state(self.name, CIRCUIT_STATE_CLOSED) elif self.state == "CLOSED": self.failure_count = 0 @@ -73,6 +99,7 @@ def record_failure(self) -> None: with self._lock: now = time.time() self.failure_count += 1 + CIRCUIT_FAILURE_COUNT.labels(breaker_name=self.name).inc() if self.state == "HALF_OPEN" or self.failure_count >= self.failure_threshold: logger.warning( "Circuit breaker for provider '%s' transitioning from %s to OPEN " @@ -84,3 +111,4 @@ def record_failure(self) -> None: ) self.state = "OPEN" self.last_state_change = now + set_circuit_state(self.name, CIRCUIT_STATE_OPEN) diff --git a/app/ai-service/tests/test_circuit_breaker.py b/app/ai-service/tests/test_circuit_breaker.py index dfb6d70..0de524b 100644 --- a/app/ai-service/tests/test_circuit_breaker.py +++ b/app/ai-service/tests/test_circuit_breaker.py @@ -7,39 +7,40 @@ from services.humanitarian_verification import HumanitarianVerificationService from exceptions import AIServiceError from config import settings +import metrics +from prometheus_client import REGISTRY def test_circuit_breaker_basic_transitions(): # Set a short recovery timeout for fast testing breaker = CircuitBreaker("test-provider", failure_threshold=2, recovery_timeout=0.1) - + # 1. Starts CLOSED assert breaker.state == "CLOSED" assert breaker.allow_request() is True - + # 2. First failure breaker.record_failure() assert breaker.state == "CLOSED" # Not tripped yet assert breaker.allow_request() is True - + # 3. Second failure (reaches threshold) breaker.record_failure() assert breaker.state == "OPEN" assert breaker.allow_request() is False # Tripped - + # 4. Wait for recovery timeout time.sleep(0.12) - + # 5. Transitions to HALF_OPEN on allow_request check assert breaker.allow_request() is True assert breaker.state == "HALF_OPEN" - + # 6. Success closes the circuit breaker.record_success() assert breaker.state == "CLOSED" assert breaker.failure_count == 0 - def test_circuit_breaker_half_open_failure(): breaker = CircuitBreaker("test-provider", failure_threshold=2, recovery_timeout=0.1) @@ -129,3 +130,111 @@ def test_request_timeout_raises_ai_timeout(self, mock_post, monkeypatch): # The breaker for openai should have recorded the failure assert self.service.breakers["openai"].failure_count == 2 # Primary & fallback attempts both failed + + +def _sample(name: str, labels: dict) -> float: + """Read a sample value directly from the Prometheus registry.""" + return REGISTRY.get_sample_value(name, labels) + + +class TestCircuitBreakerMetrics: + """Verify that CircuitBreaker publishes the metrics defined in metrics.py.""" + + def test_initial_state_is_published(self): + # Using a fresh breaker name ensures labels don't collide with other tests. + CircuitBreaker("metrics-initial", failure_threshold=1, recovery_timeout=0.1) + assert ( + _sample("circuit_breaker_state", {"breaker_name": "metrics-initial"}) == 0 + ) + + def test_failure_increments_counter_and_trips_state(self): + breaker = CircuitBreaker( + "metrics-failures", failure_threshold=2, recovery_timeout=0.1 + ) + before = _sample( + "circuit_breaker_failure_count_total", {"breaker_name": "metrics-failures"} + ) or 0.0 + + breaker.record_failure() + after_one = _sample( + "circuit_breaker_failure_count_total", {"breaker_name": "metrics-failures"} + ) + assert after_one == before + 1 + # State stays CLOSED below threshold + assert _sample("circuit_breaker_state", {"breaker_name": "metrics-failures"}) == 0 + + breaker.record_failure() + # Threshold reached -> gauge flips to OPEN (2) + assert _sample("circuit_breaker_state", {"breaker_name": "metrics-failures"}) == 2 + + def test_recovery_updates_histogram_and_state_gauge(self): + breaker = CircuitBreaker( + "metrics-recovery", failure_threshold=1, recovery_timeout=0.05 + ) + breaker.record_failure() # trips immediately + assert _sample("circuit_breaker_state", {"breaker_name": "metrics-recovery"}) == 2 + + time.sleep(0.07) + # allow_request triggers the OPEN -> HALF_OPEN transition + assert breaker.allow_request() is True + + assert _sample("circuit_breaker_state", {"breaker_name": "metrics-recovery"}) == 1 + + sum_value = _sample( + "circuit_breaker_recovery_time_seconds_sum", + {"breaker_name": "metrics-recovery"}, + ) + count_value = _sample( + "circuit_breaker_recovery_time_seconds_count", + {"breaker_name": "metrics-recovery"}, + ) + assert count_value is not None and count_value >= 1 + assert sum_value is not None and sum_value >= 0.05 + + def test_success_closes_circuit_and_resets_state_gauge(self): + breaker = CircuitBreaker( + "metrics-success", failure_threshold=1, recovery_timeout=0.05 + ) + breaker.record_failure() + time.sleep(0.07) + breaker.allow_request() # -> HALF_OPEN + breaker.record_success() + + assert _sample("circuit_breaker_state", {"breaker_name": "metrics-success"}) == 0 + + def test_half_open_failure_increments_counter_and_reopens(self): + """A failure during the HALF_OPEN probe must be counted and reopen the + circuit. Without this, callers would never see the breaker come back + online after a successful probe that subsequently fails.""" + breaker = CircuitBreaker( + "metrics-half-open-failure", + failure_threshold=1, + recovery_timeout=0.05, + ) + breaker.record_failure() # CLOSED -> OPEN + time.sleep(0.07) + breaker.allow_request() # OPEN -> HALF_OPEN + + counter_before_probe = _sample( + "circuit_breaker_failure_count_total", + {"breaker_name": "metrics-half-open-failure"}, + ) + assert ( + _sample("circuit_breaker_state", {"breaker_name": "metrics-half-open-failure"}) + == 1 + ) + + breaker.record_failure() # HALF_OPEN -> OPEN + counter_after_probe = _sample( + "circuit_breaker_failure_count_total", + {"breaker_name": "metrics-half-open-failure"}, + ) + assert counter_after_probe is not None + assert counter_before_probe is not None + assert counter_after_probe == counter_before_probe + 1 + assert ( + _sample("circuit_breaker_state", {"breaker_name": "metrics-half-open-failure"}) + == 2 + ) + # request rejected because we just re-opened + assert breaker.allow_request() is False