Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 23 additions & 10 deletions api/thing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ===============================================================================
from typing import Optional
from typing import Annotated, Optional
from fastapi import APIRouter, Query, Request
from fastapi_pagination.ext.sqlalchemy import paginate
from sqlalchemy import select
Expand Down Expand Up @@ -151,24 +151,27 @@ def get_water_wells(
request: Request,
sort: Optional[str] = None,
order: Optional[str] = None,
filter_: str = Query(alias="filter", default=None),
filter_params: Annotated[list[str] | None, Query(alias="filter")] = None,
query: Optional[str] = None,
name: Optional[str] = None,
name_contains: Optional[str] = None,
include_contacts: bool = False,
) -> CustomPage[WellResponse]:
"""
Retrieve all wells from the database.
"""
thing_type = request.url.path.split("/")[2].replace("-", " ")
return get_db_things(
filter_,
None,
order,
query,
session,
sort,
name=name,
thing_type=thing_type,
include_contacts=include_contacts,
filters=filter_params,
name_contains=name_contains,
)


Expand Down Expand Up @@ -293,14 +296,24 @@ def get_springs(
request: Request,
sort: str = None,
order: str = None,
filter_: str = Query(alias="filter", default=None),
filter_params: Annotated[list[str] | None, Query(alias="filter")] = None,
query: str = None,
name_contains: Optional[str] = None,
) -> CustomPage[SpringResponse]:
"""
Retrieve all springs from the database.
"""
thing_type = request.url.path.split("/")[2].replace("-", " ")
return get_db_things(filter_, order, query, session, sort, thing_type=thing_type)
return get_db_things(
None,
order,
query,
session,
sort,
thing_type=thing_type,
filters=filter_params,
name_contains=name_contains,
)


@router.get("/spring/{thing_id}", summary="Get spring by ID", status_code=HTTP_200_OK)
Expand Down Expand Up @@ -359,23 +372,23 @@ def get_things(
sort: Optional[str] = None,
order: Optional[str] = None,
include_contacts: bool = False,
filter_: str = Query(
default=None,
alias="filter",
),
filter_params: Annotated[list[str] | None, Query(alias="filter")] = None,
name_contains: Optional[str] = None,
) -> CustomPage[ThingResponse]:
"""
Retrieve all things or filter by type.
"""

return get_db_things(
filter_,
None,
order,
query,
session,
sort,
within=within,
include_contacts=include_contacts,
filters=filter_params,
name_contains=name_contains,
)


Expand Down
160 changes: 138 additions & 22 deletions services/query_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from fastapi import HTTPException
from fastapi_pagination.ext.sqlalchemy import paginate
from sqlalchemy import select, Float, Integer, Column, Select, func, String
from sqlalchemy import Column, Float, Integer, Select, String, Text, func, not_, select
from sqlalchemy.orm import DeclarativeBase, Session
from sqlalchemy.sql.elements import OperatorExpression
from starlette.status import HTTP_404_NOT_FOUND
Expand Down Expand Up @@ -110,8 +110,130 @@ def simple_all_getter(session, table) -> list[object]:
return session.scalars(sql).all()


def _python_type(column: Any):
try:
return column.type.python_type
except Exception:
return None


def _apply_json_filter_clause(
sql: Select[Any], table: DeclarativeBase, f: dict
) -> Select[Any]:
"""Apply one Refine logical filter dict (field / operator / value) to a SELECT."""
required_keys = {"field", "value", "operator"}
missing = required_keys - f.keys()
if missing:
raise HTTPException(
status_code=422,
detail=f"Missing required filter keys: {', '.join(sorted(missing))}",
)

field = f["field"]
value = f["value"]
operator = f["operator"]

try:
column = getattr(table, field)
except AttributeError as exc:
raise HTTPException(
status_code=400,
detail=f"Unknown filter field {field!r} for {table.__name__}",
) from exc

py_t = _python_type(column)
is_string = py_t is str or isinstance(column.type, (String, Text))

if operator == "contains":
if not is_string:
raise HTTPException(
status_code=400,
detail=f"Operator contains is not supported for field {field!r}",
)
return sql.where(column.ilike(f"%{value}%"))

if operator == "ncontains":
if not is_string:
raise HTTPException(
status_code=400,
detail=f"Operator ncontains is not supported for field {field!r}",
)
return sql.where(not_(column.ilike(f"%{value}%")))

if operator == "startswith":
if not is_string:
raise HTTPException(
status_code=400,
detail=f"Operator startswith is not supported for field {field!r}",
)
return sql.where(column.ilike(f"{value}%"))

if operator == "endswith":
if not is_string:
raise HTTPException(
status_code=400,
detail=f"Operator endswith is not supported for field {field!r}",
)
return sql.where(column.ilike(f"%{value}"))

if operator == "eq":
if py_t is float:
return sql.where(column == float(value))
if py_t is int:
return sql.where(column == int(value))
if is_string:
return sql.where(column == str(value))
return sql.where(column == value)

if operator == "ne":
if py_t is float:
return sql.where(column != float(value))
if py_t is int:
return sql.where(column != int(value))
if is_string:
return sql.where(column != str(value))
return sql.where(column != value)

if operator == "gt":
return sql.where(column > float(value) if py_t is float else column > value)

if operator == "gte":
return sql.where(column >= float(value) if py_t is float else column >= value)

if operator == "lt":
return sql.where(column < float(value) if py_t is float else column < value)

if operator == "lte":
return sql.where(column <= float(value) if py_t is float else column <= value)

if operator == "null":
return sql.where(column.is_(None))

if operator == "nnull":
return sql.where(column.is_not(None))

if operator == "in":
if not isinstance(value, (list, tuple)):
raise HTTPException(
status_code=400,
detail="Operator in requires an array value",
)
return sql.where(column.in_(list(value)))

raise HTTPException(
status_code=400,
detail=f"Unsupported filter operator {operator!r}",
)


def order_sort_filter(
sql: Select[Any], table: DeclarativeBase, sort: str, order: str, filter_: str
sql: Select[Any],
table: DeclarativeBase,
sort: str | None,
order: str | None,
filter_: str | None = None,
*,
filters: list[str] | None = None,
) -> Select[Any]:
if order:
if not sort:
Expand All @@ -132,27 +254,21 @@ def order_sort_filter(
else:
raise ValueError("Invalid order parameter. Use 'asc' or 'desc'.")

filter_jsons: list[str] = []
if filters:
filter_jsons.extend([x for x in filters if x])
if filter_:
required_keys = {"field", "value", "operator"}
if filter_ is not None:
try:
f = json.loads(filter_)
except Exception:
raise HTTPException(status_code=400, detail="Invalid JSON in filter")

missing = required_keys - f.keys()
if missing:
raise HTTPException(
status_code=422,
detail=f"Missing required filter keys: {', '.join(missing)}",
)

field = f["field"]
value = f["value"]
operator = f["operator"]
column = getattr(table, field)
if operator == "contains":
sql = sql.where(column.ilike(f"%{value}%"))
filter_jsons.append(filter_)

for raw in filter_jsons:
try:
f = json.loads(raw)
except Exception as exc:
raise HTTPException(
status_code=400, detail="Invalid JSON in filter"
) from exc

sql = _apply_json_filter_clause(sql, table, f)

return sql

Expand Down
13 changes: 12 additions & 1 deletion services/thing_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def get_db_things(
within: Optional[str] = None,
name: Optional[str] = None,
include_contacts: bool = False,
filters: Optional[list[str]] = None,
name_contains: Optional[str] = None,
) -> list:

if query:
Expand Down Expand Up @@ -151,6 +153,9 @@ def get_db_things(
if name:
sql = sql.where(Thing.name == name)

if name_contains and name_contains.strip():
sql = sql.where(Thing.name.ilike(f"%{name_contains.strip()}%"))

if within:
latest_assoc = (
select(
Expand All @@ -173,7 +178,13 @@ def get_db_things(
)
sql = make_within_wkt(sql, within)

sql = order_sort_filter(sql, Thing, sort, order, filter_)
merged_filters: list[str] | None = None
if filters:
merged_filters = list(filters)
elif filter_:
merged_filters = [filter_]

sql = order_sort_filter(sql, Thing, sort, order, filters=merged_filters)

return paginate(query=sql, conn=session)

Expand Down
Loading