-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserver.py
More file actions
151 lines (116 loc) · 5.22 KB
/
server.py
File metadata and controls
151 lines (116 loc) · 5.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import os
import time
import logging
from contextlib import asynccontextmanager
from pathlib import Path
import uvicorn
from fastapi import Depends, FastAPI, Request
from fastapi.exceptions import RequestValidationError
import json
import typing
from fastapi.responses import JSONResponse
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from slowapi import Limiter
from slowapi.util import get_remote_address
from slowapi.errors import RateLimitExceeded
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from redactor import PIIRedactor
from schemas import ErrorResponse, RedactRequest, RedactResponse
env_file = Path(__file__).parent / ".env"
if env_file.exists():
for line in env_file.read_text().splitlines():
line = line.strip()
if line and not line.startswith("#") and "=" in line:
key, _, value = line.partition("=")
os.environ.setdefault(key.strip(), value.strip())
DEV_MODE = os.environ.get("DEV_MODE", "false").lower() == "true"
AUTH_KEYS = [k.strip() for k in os.environ.get("AUTH_KEYS", "").split(",") if k.strip()]
RATE_LIMIT = os.environ.get("RATE_LIMIT", "60/minute")
CORS_ORIGINS = [o.strip() for o in os.environ.get("CORS_ORIGINS", "").split(",") if o.strip()]
PRETTY_JSON = os.environ.get("PRETTY_JSON", "false").lower() == "true"
logging.basicConfig(
level=logging.DEBUG if DEV_MODE else logging.INFO,
format="%(asctime)s | %(levelname)-8s | %(message)s",
datefmt="%H:%M:%S",
)
logger = logging.getLogger("privacy-filter")
logging.getLogger("httpx").setLevel(logging.WARNING)
logging.getLogger("httpcore").setLevel(logging.WARNING)
redactor: PIIRedactor
security = HTTPBearer(auto_error=False)
limiter = Limiter(key_func=get_remote_address, default_limits=[RATE_LIMIT])
class RequestLoggerMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
start = time.perf_counter()
response = await call_next(request)
elapsed_ms = (time.perf_counter() - start) * 1000
logger.info(
"%s %s -> %d (%.1fms)",
request.method,
request.url.path,
response.status_code,
elapsed_ms,
)
return response
class AuthError(Exception):
def __init__(self, status_code: int, detail: str):
self.status_code = status_code
self.detail = detail
async def verify_auth(credentials: HTTPAuthorizationCredentials | None = Depends(security)):
if not AUTH_KEYS:
return
if credentials is None or credentials.credentials not in AUTH_KEYS:
raise AuthError(status_code=401, detail="Invalid or missing API key")
async def auth_error_handler(_request: Request, exc: AuthError):
return JSONResponse(status_code=exc.status_code, content=ErrorResponse(error="auth_error", detail=exc.detail))
async def rate_limit_handler(_request: Request, _exc: RateLimitExceeded):
return JSONResponse(status_code=429, content=ErrorResponse(error="rate_limit", detail="Too many requests"))
async def validation_error_handler(_request: Request, exc: RequestValidationError):
return JSONResponse(status_code=422, content=ErrorResponse(error="validation_error", detail=str(exc)))
async def unhandled_error_handler(request: Request, exc: Exception):
logger.error("Unhandled error on %s %s: %s", request.method, request.url.path, exc)
return JSONResponse(status_code=500, content=ErrorResponse(error="internal_error"))
@asynccontextmanager
async def lifespan(_app: FastAPI):
global redactor
device = os.environ.get("DEVICE", "cpu")
redactor = PIIRedactor(device=device)
logger.info("Model loaded on %s", device)
yield
class PrettyJSONResponse(JSONResponse):
def render(self, content: typing.Any) -> bytes:
return json.dumps(content, ensure_ascii=False, allow_nan=False, indent=2).encode("utf-8")
app = FastAPI(title="Privacy Filter Proxy", lifespan=lifespan, default_response_class=PrettyJSONResponse if PRETTY_JSON else JSONResponse)
app.state.limiter = limiter
app.add_exception_handler(AuthError, auth_error_handler)
app.add_exception_handler(RateLimitExceeded, rate_limit_handler)
app.add_exception_handler(RequestValidationError, validation_error_handler)
app.add_exception_handler(Exception, unhandled_error_handler)
app.add_middleware(RequestLoggerMiddleware)
if CORS_ORIGINS:
app.add_middleware(
CORSMiddleware,
allow_origins=CORS_ORIGINS,
allow_credentials=True,
allow_methods=["GET", "POST"],
allow_headers=["Authorization", "Content-Type"],
)
@app.get("/health")
async def health():
return {"status": "ok"}
@app.post("/redact", response_model=RedactResponse)
@limiter.limit(RATE_LIMIT)
async def redact(request: Request, req: RedactRequest, _=Depends(verify_auth)):
logger.debug("Input text length: %d chars", len(req.text))
result = redactor.redact(req.text)
logger.info(
"Detected %d spans: %s",
result.summary.total_spans,
result.summary.by_label,
)
return result
if __name__ == "__main__":
host = os.environ.get("HOST", "0.0.0.0")
port = int(os.environ.get("PORT", "8000"))
uvicorn.run(app, host=host, port=port, log_level="debug" if DEV_MODE else "info")