Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 69 additions & 9 deletions src/stac_auth_proxy/middleware/Cql2RewriteLinksFilterMiddleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,13 @@ async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
return await self.app(scope, receive, send)

request = Request(scope)
original_filter = request.query_params.get("filter")
cql2_filter: Optional[Expr] = getattr(request.state, self.state_key, None)
if cql2_filter is None:
# No filter set, just pass through
return await self.app(scope, receive, send)

user_filter, receive = await self._extract_user_filter(request, receive)

# Intercept the response
response_start = None
body_chunks = []
Expand All @@ -46,19 +47,74 @@ async def send_wrapper(message: Message):
more_body = message.get("more_body", False)
if not more_body:
await self._process_and_send_response(
response_start, body_chunks, send, original_filter
response_start, body_chunks, send, user_filter
)
else:
await send(message)

await self.app(scope, receive, send_wrapper)

async def _extract_user_filter(
self, request: Request, receive: Receive
) -> tuple[Optional[Expr], Receive]:
"""
Recover the user's original filter from either the query string or JSON body.

For methods that may carry a JSON body (POST/PUT/PATCH), the body is buffered
and a replacement ``receive`` is returned so downstream consumers still see it.
"""
query_filter = request.query_params.get("filter")
if query_filter:
try:
return Expr(query_filter), receive
except Exception:
logger.warning("Failed to parse user filter from query string")
return None, receive

if request.method not in ("POST", "PUT", "PATCH"):
return None, receive

body = b""
more_body = True
while more_body:
message = await receive()
if message["type"] == "http.request":
body += message.get("body", b"")
more_body = message.get("more_body", False)
else:
# e.g. http.disconnect - stop reading; downstream will get the same.
break

async def replay_receive() -> Message:
return {"type": "http.request", "body": body, "more_body": False}

if not body:
return None, replay_receive

try:
body_json = json.loads(body)
except json.JSONDecodeError:
return None, replay_receive

if not isinstance(body_json, dict):
return None, replay_receive

body_filter = body_json.get("filter")
if body_filter is None:
return None, replay_receive

try:
return Expr(body_filter), replay_receive
except Exception:
logger.warning("Failed to parse user filter from request body")
return None, replay_receive

async def _process_and_send_response(
self,
response_start: Message,
body_chunks: list[bytes],
send: Send,
original_filter: Optional[str],
user_filter: Optional[Expr],
):
body = b"".join(body_chunks)
try:
Expand All @@ -68,7 +124,6 @@ async def _process_and_send_response(
await send({"type": "http.response.body", "body": body, "more_body": False})
return

cql2_filter = Expr(original_filter) if original_filter else None
links = data.get("links")
if isinstance(links, list):
for link in links:
Expand All @@ -77,19 +132,24 @@ async def _process_and_send_response(
url = urlparse(link["href"])
qs = parse_qs(url.query)
if "filter" in qs:
if cql2_filter:
qs["filter"] = [cql2_filter.to_text()]
if user_filter is not None:
qs["filter"] = [user_filter.to_text()]
else:
qs.pop("filter", None)
qs.pop("filter-lang", None)
new_query = urlencode(qs, doseq=True)
link["href"] = urlunparse(url._replace(query=new_query))

# Handle filter in body (for POST links)
# Handle filter in body (for POST links). The spec only
# requires cql2-json for POST bodies, but if the link advertises
# cql2-text we preserve that lang on the way out.
if "body" in link and isinstance(link["body"], dict):
if "filter" in link["body"]:
if cql2_filter:
link["body"]["filter"] = cql2_filter.to_json()
if user_filter is not None:
if link["body"].get("filter-lang") == "cql2-text":
link["body"]["filter"] = user_filter.to_text()
else:
link["body"]["filter"] = user_filter.to_json()
else:
link["body"].pop("filter", None)
link["body"].pop("filter-lang", None)
Expand Down
139 changes: 139 additions & 0 deletions tests/test_cql2_rewrite_links_filter_middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Test Cql2RewriteLinksFilterMiddleware."""

import json
import re

import pytest
Expand All @@ -12,6 +13,24 @@
)


def _install_middlewares(app: FastAPI, system_filter: str) -> None:
"""Attach the rewrite middleware behind a mock build-filter middleware."""

class MockBuildFilterMiddleware:
def __init__(self, app, state_key="cql2_filter"):
self.app = app
self.state_key = state_key

async def __call__(self, scope, receive, send):
if scope["type"] == "http":
request = Request(scope)
setattr(request.state, self.state_key, Expr(system_filter))
await self.app(scope, receive, send)

app.add_middleware(Cql2RewriteLinksFilterMiddleware)
app.add_middleware(MockBuildFilterMiddleware)


def test_non_json_response():
"""Test middleware behavior with non-JSON responses."""
app = FastAPI()
Expand Down Expand Up @@ -335,3 +354,123 @@ async def test_endpoint(request: Request):

# Other data should be preserved
assert body["other_data"] == "preserved"


class TestPostBodyClientFilterPreservation:
"""
Client filters sent in a POST search body must be preserved in next-link bodies.

Regression: the middleware previously read the user's filter only from the
query string, silently dropping filters supplied via POST body.
"""

@pytest.mark.parametrize(
"client_filter,client_filter_lang",
[
(
{"op": "<", "args": [{"property": "cloud_coverage"}, 50]},
"cql2-json",
),
("cloud_coverage < 30", "cql2-text"),
(None, None),
],
)
def test_preserves_client_filter_from_post_body(
self, client_filter, client_filter_lang
):
"""Filter supplied in the POST body must be preserved in the next link."""
app = FastAPI()
_install_middlewares(app, system_filter="private = false")

@app.post("/search")
async def search_endpoint(request: Request):
body_json = await request.json()
system_expr = getattr(request.state, "cql2_filter", None)
user_filter = body_json.get("filter")
user_filter_lang = body_json.get("filter-lang")

combined = None
if system_expr is not None and user_filter is not None:
combined = system_expr + Expr(user_filter)
elif system_expr is not None:
combined = system_expr
elif user_filter is not None:
combined = Expr(user_filter)

next_body = {"token": "next-token"}
if combined is not None:
lang = user_filter_lang or "cql2-json"
next_body["filter-lang"] = lang
next_body["filter"] = (
combined.to_text() if lang == "cql2-text" else combined.to_json()
)

return {
"links": [
{
"rel": "next",
"method": "POST",
"href": "http://example.com/search",
"body": next_body,
}
],
}

request_body = {}
if client_filter is not None:
request_body["filter"] = client_filter
request_body["filter-lang"] = client_filter_lang

response = TestClient(app).post("/search", json=request_body)
assert response.status_code == 200, response.text
body = response.json()["links"][0]["body"]

assert body["token"] == "next-token"

if client_filter is None:
# No client filter → system filter must not leak into next link.
assert "filter" not in body
assert "filter-lang" not in body
else:
# Compare semantically: cql2-python may re-emit equivalent text
# with different formatting (e.g. added parens).
assert Expr(body["filter"]).to_json() == Expr(client_filter).to_json()
assert body["filter-lang"] == client_filter_lang

def test_request_body_is_intact_for_inner_app(self):
"""Body capture must replay the exact original bytes to the inner app."""
app = FastAPI()
_install_middlewares(app, system_filter="private = false")

@app.post("/search")
async def search_endpoint(request: Request):
return {"echo": json.loads(await request.body())}

request_body = {
"collections": ["a", "b"],
"filter": {"op": "=", "args": [{"property": "x"}, 1]},
"filter-lang": "cql2-json",
}
response = TestClient(app).post("/search", json=request_body)
assert response.status_code == 200, response.text
assert response.json()["echo"] == request_body

def test_malformed_json_body_does_not_break_middleware(self):
"""An unparseable body must pass through without the middleware crashing."""
app = FastAPI()
_install_middlewares(app, system_filter="private = false")

@app.post("/search")
async def search_endpoint(request: Request):
return Response(
content=await request.body(),
media_type="application/octet-stream",
)

response = TestClient(app).post(
"/search",
content=b"not json",
headers={"content-type": "application/json"},
)
assert response.status_code == 200
assert response.content == b"not json"
Loading