diff --git a/app/ai-service/README.md b/app/ai-service/README.md index b0bdd1b..a788d83 100644 --- a/app/ai-service/README.md +++ b/app/ai-service/README.md @@ -30,6 +30,8 @@ The service starts at `http://localhost:8000`. Interactive API documentation is | `LOG_LEVEL` | `INFO` | Logging verbosity | | `REDIS_URL` | `redis://localhost:6379/0` | Redis connection for task queue | | `BACKEND_WEBHOOK_URL` | `http://localhost:3001/ai/webhook` | Backend notification endpoint | +| `MAX_REQUEST_BODY_BYTES` | `10485760` (10 MiB) | Maximum HTTP request body size; oversized requests are rejected with HTTP 413 to prevent memory-exhaustion DoS. Set to `0` to disable (not recommended in production). | +| `REQUEST_BODY_BYPASS_PATHS` | _(empty)_ | Comma-separated path entries that bypass body-size limiting. Entries without a trailing `'/'` must match the path exactly; entries with a trailing `'/'` (e.g. `/hooks/`) match any path with that prefix. The default bypass list (`/health`, `/`, `/ai/metrics`, `/docs`, `/redoc`, `/openapi.json`) is always merged in. | ## Core services diff --git a/app/ai-service/config.py b/app/ai-service/config.py index ae45562..67367e4 100644 --- a/app/ai-service/config.py +++ b/app/ai-service/config.py @@ -33,6 +33,16 @@ class Settings(BaseSettings): BACKEND_WEBHOOK_URL: Webhook URL to notify NestJS backend when tasks complete PROOF_OF_LIFE_CONFIDENCE_THRESHOLD: Default threshold for liveness verification PROOF_OF_LIFE_MIN_FACE_SIZE: Minimum detected face size in pixels + MAX_REQUEST_BODY_BYTES: Maximum allowed HTTP request body size in bytes. + Oversized payloads are rejected with HTTP 413 before the body is + read into memory, mitigating memory-exhaustion DoS attacks. + Default: 10485760 (10 MiB). Set to 0 to disable (not recommended). + REQUEST_BODY_BYPASS_PATHS: Comma-separated list that exempts paths + from body-size limiting. Entries without a trailing '/' must + match the path exactly; entries with a trailing '/' (e.g. + '/hooks/') match any path with that prefix. The built-in + infrastructure defaults (/health, /, /ai/metrics, /docs, + /redoc, /openapi.json) are always merged in. """ # API Keys @@ -67,6 +77,16 @@ class Settings(BaseSettings): proof_of_life_confidence_threshold: float = 0.65 proof_of_life_min_face_size: int = 80 + # Request body size protection (DoS mitigation). Default is 10 MiB. + # Set to 0 or negative to disable the limit (not recommended in + # production). + max_request_body_bytes: int = 10 * 1024 * 1024 + + # Paths that bypass body-size checks. Comma-separated prefix list. + # Health probes, metrics scrape, and OpenAPI/docs endpoints are + # always appended so operators cannot accidentally expose themselves. + request_body_bypass_paths: str = "" + # Verification artifact access settings verification_artifacts_dir: str = "./artifacts/verification" verification_artifact_url_ttl_seconds: int = 300 diff --git a/app/ai-service/main.py b/app/ai-service/main.py index 3ae4ebb..b067e43 100644 --- a/app/ai-service/main.py +++ b/app/ai-service/main.py @@ -7,6 +7,7 @@ from contextlib import asynccontextmanager from pydantic import BaseModel, Field from typing import Any, Dict, List, Optional +import json import logging from fastapi import FastAPI, HTTPException, BackgroundTasks, Request @@ -39,6 +40,199 @@ ) from services.humanitarian_verification import HumanitarianVerificationService +class HTTPBodyTooLarge(Exception): + """Internal signal raised when an incoming request body exceeds the + configured `max_request_body_bytes` limit. Caught and converted to a + 413 response by :class:`MaxRequestBodySizeMiddleware`.""" + + def __init__(self, limit: int, observed: int): + super().__init__( + f"Request body of {observed} bytes exceeds limit of {limit} bytes" + ) + self.limit = limit + self.observed = observed + + +class MaxRequestBodySizeMiddleware: + """Reject HTTP requests whose body would exceed ``max_bytes``. + + The middleware sits at the outer edge of the ASGI stack so that oversized + requests are rejected *before* any other middleware (redirects, + observability, rate limiting) or the application itself buffers the body. + It is DoS-grade protection: clients can trip the limit either by sending a + ``Content-Length`` header that exceeds the cap, or by streaming more bytes + than the cap via chunked transfer encoding. + + The middleware intentionally wraps the raw ASGI ``receive`` callable rather + than using Starlette's ``BaseHTTPMiddleware`` — ``BaseHTTPMiddleware`` + buffers the body in-memory which defeats the point of the limit. + """ + + METHODS_WITH_BODY = ("POST", "PUT", "PATCH") + + def __init__(self, app, max_bytes: int, bypass_prefixes: Optional[List[str]] = None): + self.app = app + # Treat non-positive values as "disabled" — useful for tests that + # don't want the limit to interfere. + self.max_bytes = max_bytes if max_bytes and max_bytes > 0 else None + # Always skip health/metrics/docs endpoints to match the pattern used + # by monitor_requests. Allow additional prefixes via settings. + default_bypass = [ + "/health", + "/", + "/ai/metrics", + "/docs", + "/redoc", + "/openapi.json", + ] + self.bypass_prefixes = tuple({*(default_bypass), *(bypass_prefixes or [])}) + + def _is_bypassed(self, path: str) -> bool: + if path in self.bypass_prefixes: + return True + # Prefix matching only applies to entries that explicitly opt in + # via a trailing '/'. The root '/' is intentionally excluded: + # otherwise every HTTP path (which all begin with '/') would be + # bypassed. + return any( + path.startswith(p) + for p in self.bypass_prefixes + if p.endswith("/") and p != "/" + ) + + async def __call__(self, scope, receive, send): + # Only operate on HTTP requests; pass through WebSocket / lifespan. + if scope["type"] != "http": + return await self.app(scope, receive, send) + + # No limit configured or no body expected — no-op. + if self.max_bytes is None or scope["method"] not in self.METHODS_WITH_BODY: + return await self.app(scope, receive, send) + + path = scope.get("path", "") + if self._is_bypassed(path): + return await self.app(scope, receive, send) + + # Eager check on Content-Length. If the client declared a body + # larger than the limit, reject immediately without consuming any + # bytes off the wire. + try: + content_length_hdr = None + for name, value in scope.get("headers", []): + if name == b"content-length": + content_length_hdr = value.decode("latin-1") + break + if content_length_hdr is not None: + declared = int(content_length_hdr) + if declared > self.max_bytes: + await self._log_rejection( + scope, + declared_or_observed=declared, + reason="declared_size", + ) + return await self._send_413( + send, + observed=declared, + reason="declared_size", + ) + except (ValueError, TypeError): + # Malformed Content-Length — fall through to stream counting. + pass + + total = 0 + + async def wrapped_receive(): + nonlocal total + message = await receive() + mtype = message.get("type") + if mtype == "http.request": + chunk = message.get("body", b"") + total += len(chunk) + if total > self.max_bytes: + # Signal the exception so that the outer __call__ can + # emit a 413 even if the application has already started + # producing a response. + raise HTTPBodyTooLarge(self.max_bytes, total) + return message + + try: + await self.app(scope, wrapped_receive, send) + except HTTPBodyTooLarge as exc: + await self._log_rejection( + scope, + declared_or_observed=exc.observed, + reason="streamed_size", + ) + await self._send_413( + send, + observed=exc.observed, + reason="streamed_size", + ) + + async def _send_413(self, send, observed: int, reason: str): + """Emit a JSON 413 response using the project's ErrorEnvelope shape. + + ``reason`` distinguishes eager (Content-Length) rejection from + streamed rejection; the message is worded accordingly so the + response is precise and not misleading. + """ + if reason == "declared_size": + msg = ( + f"Declared request body of {observed} bytes exceeds the " + f"maximum allowed size of {self.max_bytes} bytes." + ) + else: + msg = ( + f"Request body streamed so far ({observed} bytes) exceeds " + f"the maximum allowed size of {self.max_bytes} bytes." + ) + + envelope = ErrorEnvelope( + error=ErrorDetail( + code="PAYLOAD_TOO_LARGE", + message=msg, + ) + ).model_dump() + body = json.dumps(envelope).encode("utf-8") + + await send( + { + "type": "http.response.start", + "status": 413, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(body)).encode("ascii")), + ], + } + ) + await send({"type": "http.response.body", "body": body}) + + async def _log_rejection( + self, + scope, + declared_or_observed: int, + reason: str, + ) -> None: + """Emit a structured warning so operators can correlate DoS attempts. + + ``reason`` is either ``"declared_size"`` (Content-Length spoofing) + or ``"streamed_size"`` (chunked transfer smuggling), so logs + differentiate between attack classes. + """ + client = scope.get("client") + client_str = f"{client[0]}:{client[1]}" if client else "unknown" + logger.warning( + "request body rejected: method=%s path=%s bytes=%d limit=%d " + "client=%s reason=%s", + scope.get("method"), + scope.get("path"), + declared_or_observed, + self.max_bytes, + client_str, + reason, + ) + + limiter = Limiter(key_func=get_remote_address) log_level_name = settings.log_level.upper() if hasattr(settings, "log_level") else "INFO" @@ -93,6 +287,20 @@ async def lifespan(app: FastAPI): lifespan=lifespan, ) +# Register the body-size limit at the outermost layer so it short-circuits +# before legacy redirects, observability middleware, or any handler buffers +# the request body. +_bypass_paths = [ + p.strip() + for p in (settings.request_body_bypass_paths or "").split(",") + if p.strip() +] +app.add_middleware( + MaxRequestBodySizeMiddleware, + max_bytes=settings.max_request_body_bytes, + bypass_prefixes=_bypass_paths, +) + proof_of_life_analyzer = ProofOfLifeAnalyzer( config=ProofOfLifeConfig( confidence_threshold=settings.proof_of_life_confidence_threshold, diff --git a/app/ai-service/tests/test_request_body_limit.py b/app/ai-service/tests/test_request_body_limit.py new file mode 100644 index 0000000..5225f84 --- /dev/null +++ b/app/ai-service/tests/test_request_body_limit.py @@ -0,0 +1,380 @@ +""" +Tests for the MaxRequestBodySizeMiddleware (DoS mitigation for issue #137). + +Two angles are covered: + +1. White-box tests against a freshly-constructed, isolated ASGI app so we + can exercise the middleware with small byte caps without sending + 10 MiB blobs through the network. + +2. Black-box regression tests against the real `main.app` to confirm the + middleware is actually wired up, the default 10 MiB cap is in place, + and the 413 response uses the project's `ErrorEnvelope` shape + (matching every other error path in the service). +""" + +import asyncio +import json + +import pytest +from fastapi import FastAPI, Request +from fastapi.testclient import TestClient + +import main +from main import MaxRequestBodySizeMiddleware +from config import settings + + +# --------------------------------------------------------------------------- +# Helpers — isolated test app so we can set tight limits without sending +# massive bodies through the network. +# --------------------------------------------------------------------------- + + +def _build_isolated_app(max_bytes: int, bypass_prefixes=None): + """Create a tiny FastAPI app with only the size-limit middleware installed.""" + test_app = FastAPI() + test_app.add_middleware( + MaxRequestBodySizeMiddleware, + max_bytes=max_bytes, + bypass_prefixes=bypass_prefixes or [], + ) + + @test_app.post("/echo") + async def echo(req: Request): + body = await req.body() + return {"size": len(body)} + + @test_app.post("/big-bypass") + async def big_bypass(req: Request): + body = await req.body() + return {"size": len(body)} + + @test_app.get("/anything") + async def anything(): + return {"ok": True} + + @test_app.head("/anything") + async def anything_head(): + return {"ok": True} + + return test_app + + +# --------------------------------------------------------------------------- +# 1. Content-Length rejection (eager path) +# --------------------------------------------------------------------------- + + +class TestContentLengthRejection: + def test_payload_within_limit_succeeds(self): + app = _build_isolated_app(max_bytes=128) + client = TestClient(app) + resp = client.post("/echo", content=b"hello") + assert resp.status_code == 200 + assert resp.json() == {"size": 5} + + def test_oversized_content_length_returns_413(self): + app = _build_isolated_app(max_bytes=16) + client = TestClient(app) + # 64 bytes against a 16-byte cap → must reject before reading. + resp = client.post("/echo", content=b"x" * 64) + assert resp.status_code == 413 + body = resp.json() + assert body["error"]["code"] == "PAYLOAD_TOO_LARGE" + assert "16 bytes" in body["error"]["message"] + + def test_413_response_uses_error_envelope_shape(self): + app = _build_isolated_app(max_bytes=4) + client = TestClient(app) + resp = client.post("/echo", content=b"too-long") + assert resp.status_code == 413 + # The contract used by every other handler in the service. + assert set(resp.json().keys()) == {"error"} + assert set(resp.json()["error"].keys()) == { + "code", + "message", + "details", + } + + def test_malformed_content_length_falls_through(self): + """A bogus Content-Length header must not crash the middleware. + + Drive the middleware directly with a synthetic ASGI scope whose + headers list contains a malformed Content-Length value. The + middleware should swallow the resulting ``ValueError``, fall + through to stream counting, and successfully process a small + downstream body. + """ + middleware = MaxRequestBodySizeMiddleware(app=_PassthroughApp(), max_bytes=128) + scope = _make_scope( + headers=[(b"content-length", b"not-a-number")], + ) + chunks = [ + {"type": "http.request", "body": b"ok", "more_body": False}, + ] + sent = _run_middleware(middleware, scope, chunks) + assert sent[0]["status"] == 200 + + +# --------------------------------------------------------------------------- +# 2. Chunked streaming rejection (no/lie Content-Length) +# --------------------------------------------------------------------------- + + +class _PassthroughApp: + """No-op ASGI app used as a downstream for middleware unit tests. + + Consumes the entire body the upstream middleware lets through and + responds 200 with the captured byte count. + """ + + async def __call__(self, scope, receive, send): + chunks = [] + while True: + message = await receive() + if message["type"] == "http.request": + chunks.append(message.get("body", b"") or b"") + if not message.get("more_body", False): + break + elif message["type"] == "http.disconnect": + break + total = sum(len(c) for c in chunks) + body = json.dumps({"received": total}).encode("utf-8") + await send( + { + "type": "http.response.start", + "status": 200, + "headers": [ + (b"content-type", b"application/json"), + (b"content-length", str(len(body)).encode("ascii")), + ], + } + ) + await send({"type": "http.response.body", "body": body}) + + +def _make_scope(method="POST", path="/echo", headers=None): + return { + "type": "http", + "method": method, + "path": path, + "raw_path": path.encode("latin-1"), + "query_string": b"", + "scheme": "http", + "server": ("testserver", 80), + "client": ("testclient", 50000), + "headers": headers or [], + "asgi": {"version": "3.0", "spec_version": "2.0"}, + } + + +def _run_middleware(middleware, scope, chunks): + """Drive the middleware with a fixed queue of ASGI http.request messages + and return the messages it sent on the response channel.""" + queue = list(chunks) + sent = [] + + async def receive(): + if not queue: + return {"type": "http.disconnect"} + return queue.pop(0) + + async def send(message): + sent.append(message) + + asyncio.run(middleware(scope, receive, send)) + return sent + + +class TestStreamingRejection: + def test_chunked_stream_below_limit_is_accepted(self): + """A streaming body whose cumulative bytes stay below the cap + must reach the downstream app without rejection.""" + middleware = MaxRequestBodySizeMiddleware(app=_PassthroughApp(), max_bytes=128) + scope = _make_scope(headers=[]) # no Content-Length + chunks = [ + {"type": "http.request", "body": b"abc", "more_body": True}, + {"type": "http.request", "body": b"defg", "more_body": False}, + ] + sent = _run_middleware(middleware, scope, chunks) + assert sent[0]["status"] == 200 + body = json.loads(b"".join(m["body"] for m in sent if m["type"] == "http.response.body")) + assert body == {"received": 7} + + def test_chunked_stream_exceeding_limit_is_413(self): + """A streaming body whose cumulative bytes exceed the cap must be + rejected with 413 — even when no Content-Length header is present.""" + middleware = MaxRequestBodySizeMiddleware(app=_PassthroughApp(), max_bytes=8) + scope = _make_scope(headers=[]) + chunks = [ + {"type": "http.request", "body": b"x" * 5, "more_body": True}, + {"type": "http.request", "body": b"y" * 5, "more_body": False}, + ] + sent = _run_middleware(middleware, scope, chunks) + assert sent[0]["status"] == 413 + body = json.loads(b"".join(m["body"] for m in sent if m["type"] == "http.response.body")) + assert body["error"]["code"] == "PAYLOAD_TOO_LARGE" + assert "8 bytes" in body["error"]["message"] + + def test_oversized_chunk_alone_is_413(self): + """Even a single chunk larger than the cap must be rejected.""" + middleware = MaxRequestBodySizeMiddleware(app=_PassthroughApp(), max_bytes=4) + scope = _make_scope(headers=[]) + chunks = [ + {"type": "http.request", "body": b"z" * 100, "more_body": False}, + ] + sent = _run_middleware(middleware, scope, chunks) + assert sent[0]["status"] == 413 + + +# --------------------------------------------------------------------------- +# 3. GET / HEAD not subject to the limit +# --------------------------------------------------------------------------- + + +class TestMethodsWithoutBody: + def test_get_succeeds_regardless_of_limit(self): + app = _build_isolated_app(max_bytes=1) # ridiculously tight cap + client = TestClient(app) + resp = client.get("/anything") + assert resp.status_code == 200 + assert resp.json() == {"ok": True} + + def test_head_succeeds_regardless_of_limit(self): + app = _build_isolated_app(max_bytes=1) + client = TestClient(app) + resp = client.head("/anything") + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# 4. Bypass paths are never throttled +# --------------------------------------------------------------------------- + + +class TestBypassPaths: + def test_health_post_not_size_limited(self): + # Even if Content-Length lies about a huge body, /health is exempt. + app = _build_isolated_app(max_bytes=4) + + @app.post("/health") + async def h(): + return {"ok": True} + + client = TestClient(app) + # We can't easily force a fake Content-Length through TestClient, + # but we can override settings to disable the limit and verify + # the bypass predicate doesn't reject legitimate payloads. + resp = client.post("/health", content=b"x" * 1000) + assert resp.status_code == 200 + + def test_configured_prefix_is_bypassed(self): + app = _build_isolated_app(max_bytes=8, bypass_prefixes=["/big-bypass"]) + + @app.post("/big-bypass") + async def bp(req: Request): + body = await req.body() + return {"size": len(body)} + + client = TestClient(app) + resp = client.post("/big-bypass", content=b"x" * 1000) + assert resp.status_code == 200 + assert resp.json() == {"size": 1000} + + +# --------------------------------------------------------------------------- +# 5. Disabled limit (max_bytes=0) +# --------------------------------------------------------------------------- + + +class TestDisabledLimit: + def test_zero_limit_disables_middleware(self): + app = _build_isolated_app(max_bytes=0) + client = TestClient(app) + resp = client.post("/echo", content=b"x" * 10_000) + assert resp.status_code == 200 + + +# --------------------------------------------------------------------------- +# 6. Real `main.app` regression — confirm the middleware is wired in +# --------------------------------------------------------------------------- + + +class TestRealAppWiring: + def test_real_app_includes_size_limit_middleware(self): + # After FastAPI.add_middleware, class names are stored in + # app.user_middleware; check by class name to avoid import cycles. + names = [m.cls.__name__ for m in main.app.user_middleware] + assert "MaxRequestBodySizeMiddleware" in names + + def test_default_limit_is_ten_mib(self): + assert settings.max_request_body_bytes == 10 * 1024 * 1024 + + def test_real_app_size_limit_registered_with_expected_kwargs(self): + """Confirm the real app registered the middleware with the + expected 10 MiB cap and a list (possibly empty) of bypass + prefixes from settings. This is a wiring test \u2014 the actual + rejection behaviour is exercised via isolated test apps above. + """ + matched = [ + m + for m in main.app.user_middleware + if m.cls is MaxRequestBodySizeMiddleware + ] + assert matched, "MaxRequestBodySizeMiddleware not registered" + assert matched[0].kwargs["max_bytes"] == 10 * 1024 * 1024 + # bypass_prefixes is always a list passed by the registration code. + assert isinstance(matched[0].kwargs["bypass_prefixes"], list) + + def test_lowered_cap_rejects_oversized_payload(self, monkeypatch): + """Lower the configured cap and confirm oversized POSTs are + rejected with 413 using the existing reflection envelope shape.""" + monkeypatch.setattr(main.settings, "max_request_body_bytes", 32) + + tmp_app = FastAPI() + tmp_app.add_middleware( + MaxRequestBodySizeMiddleware, + max_bytes=main.settings.max_request_body_bytes, + bypass_prefixes=[], + ) + + @tmp_app.post("/v1/ai/anonymize") + async def echo(req: Request): + body = await req.body() + return {"size": len(body)} + + client = TestClient(tmp_app) + resp = client.post("/v1/ai/anonymize", content=b"x" * 200) + + assert resp.status_code == 413 + assert resp.json()["error"]["code"] == "PAYLOAD_TOO_LARGE" + + def test_header_fraud_with_lowered_cap_returns_413(self, monkeypatch): + """A request that lies about its Content-Length must be rejected + before any body bytes are consumed. This is the primary DoS + vector that issue #137 calls out.""" + monkeypatch.setattr(main.settings, "max_request_body_bytes", 64) + + tmp_app = FastAPI() + tmp_app.add_middleware( + MaxRequestBodySizeMiddleware, + max_bytes=main.settings.max_request_body_bytes, + bypass_prefixes=[], + ) + + @tmp_app.post("/v1/ai/upload") + async def upload(req: Request): + body = await req.body() + return {"size": len(body)} + + client = TestClient(tmp_app) + # Sent body is small, but the (lying) Content-Length exceeds the cap. + resp = client.post( + "/v1/ai/upload", + content=b"y" * 8, + headers={"Content-Length": "999999999"}, + ) + + assert resp.status_code == 413 + assert resp.json()["error"]["code"] == "PAYLOAD_TOO_LARGE"