-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapi_server.py
More file actions
153 lines (121 loc) · 4.87 KB
/
api_server.py
File metadata and controls
153 lines (121 loc) · 4.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
from __future__ import annotations
"""Network Anomaly Detection REST API Service
Provides model inference, health check, and batch prediction API endpoints.
"""
import os
from pathlib import Path
import pandas as pd
from flask import Flask, jsonify, request
from src.config import DEFAULT_API_PORT, DEFAULT_MODEL_PATH, FLOW_FEATURE_COLUMNS
from src.feature_extractor import FeatureExtractor
from src.model import AnomalyDetector
class ModelUnavailableError(RuntimeError):
"""Raised when the prediction model is unavailable."""
def create_app(model_path: str | Path | None = None) -> Flask:
"""Create and configure the Flask API application."""
app = Flask(__name__)
resolved_model_path = Path(
model_path or os.environ.get("MODEL_PATH", str(DEFAULT_MODEL_PATH))
).resolve()
detector = AnomalyDetector()
extractor = FeatureExtractor()
state = {"model_loaded": False}
app.config["MODEL_PATH"] = str(resolved_model_path)
def ensure_model():
"""Load the model if needed and surface a structured readiness error."""
if not state["model_loaded"]:
try:
detector.load_model(str(resolved_model_path))
except Exception as exc:
raise ModelUnavailableError(str(exc)) from exc
state["model_loaded"] = True
def model_unavailable_response(error_message: str):
"""Return a consistent response when the model is not ready."""
return (
jsonify(
{
"status": "unhealthy",
"ready": False,
"error": "Model unavailable",
"details": error_message,
"model_path": str(resolved_model_path),
}
),
503,
)
@app.route("/api/v1/health", methods=["GET"])
def health():
"""Readiness check endpoint."""
try:
ensure_model()
except ModelUnavailableError as exc:
return model_unavailable_response(str(exc))
return jsonify(
{
"status": "healthy",
"ready": True,
"model_loaded": state["model_loaded"],
"model_path": str(resolved_model_path),
}
)
@app.route("/api/v1/predict", methods=["POST"])
def predict():
"""Perform anomaly detection on flow records."""
try:
ensure_model()
except ModelUnavailableError as exc:
return model_unavailable_response(str(exc))
data = request.get_json(force=True)
if not data or "records" not in data:
return jsonify({"error": "Request body must contain a 'records' array"}), 400
records = data["records"]
if not isinstance(records, list) or len(records) == 0:
return jsonify({"error": "'records' must be a non-empty array"}), 400
if len(records) > 1000:
return jsonify({"error": "Maximum 1000 records per request"}), 400
df = pd.DataFrame(records)
features = extractor.extract_flow_features(df)
missing_cols = [column for column in FLOW_FEATURE_COLUMNS if column not in features.columns]
if missing_cols:
return jsonify({"error": f"Missing required fields: {missing_cols}"}), 400
X = features[FLOW_FEATURE_COLUMNS].fillna(0)
predictions = detector.predict(X).tolist()
results = []
has_proba = hasattr(detector.model, "predict_proba")
probabilities = detector.predict_proba(X)[:, 1].tolist() if has_proba else None
for index, prediction in enumerate(predictions):
result = {
"index": index,
"prediction": "anomaly" if prediction == 1 else "normal",
}
if probabilities is not None:
result["anomaly_probability"] = round(probabilities[index], 4)
results.append(result)
return jsonify({"results": results, "total": len(results)})
@app.route("/api/v1/model/info", methods=["GET"])
def model_info():
"""Return metadata of the currently loaded model."""
try:
ensure_model()
except ModelUnavailableError as exc:
return model_unavailable_response(str(exc))
return jsonify(
{
"model_type": detector.model_type,
"feature_columns": detector._feature_columns,
"metadata": {
key: value
for key, value in detector._metadata.items()
if key != "confusion_matrix"
},
}
)
return app
app = create_app()
def main():
"""Run the API server."""
port = int(os.environ.get("PORT", DEFAULT_API_PORT))
debug = os.environ.get("FLASK_DEBUG", "0") == "1"
app.run(host="0.0.0.0", port=port, debug=debug)
if __name__ == "__main__":
main()