Skip to content

Commit 1760cb5

Browse files
committed
refactor: addressed feedback
1 parent 2c7430d commit 1760cb5

5 files changed

Lines changed: 42 additions & 17 deletions

File tree

api/observation.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,10 @@ def get_groundwater_level_observations(
110110
sql = sql.where(Observation.observation_datetime <= end_time)
111111

112112
sql = order_sort_filter(sql, Observation, sort, order, filter_)
113-
sql = sql.order_by(Observation.observation_datetime.desc())
113+
114+
if not order:
115+
sql = sql.order_by(Observation.observation_datetime.desc())
116+
114117
return paginate(query=sql, conn=session)
115118

116119

api/search.py

Lines changed: 23 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,11 @@
1616
from fastapi import APIRouter, Depends
1717
from sqlalchemy import select
1818
from sqlalchemy.orm import Session
19+
from api.pagination import CustomPage
20+
from fastapi_pagination import paginate
21+
from fastapi_pagination.utils import disable_installed_extensions_check
1922

23+
from core.dependencies import session_dependency
2024
from db import (
2125
Contact,
2226
Email,
@@ -27,12 +31,13 @@
2731
AssetThingAssociation,
2832
search,
2933
)
30-
from db.engine import get_db_session
3134

35+
36+
disable_installed_extensions_check()
3237
router = APIRouter(prefix="/search", tags=["search"])
3338

3439

35-
def _get_contact_results(session: Session, q: str) -> list[dict]:
40+
def _get_contact_results(session: Session, q: str, limit: int) -> list[dict]:
3641
vector = (
3742
Contact.search_vector
3843
| Email.search_vector
@@ -41,7 +46,8 @@ def _get_contact_results(session: Session, q: str) -> list[dict]:
4146
)
4247

4348
query = search(
44-
select(Contact).join(Email).join(Phone).join(Address), q, vector=vector
49+
select(Contact).join(Email).join(Phone).join(Address), q, vector=vector,
50+
limit= limit
4551
)
4652
contacts = session.scalars(query).all()
4753
results = [
@@ -62,13 +68,15 @@ def _get_contact_results(session: Session, q: str) -> list[dict]:
6268
return results
6369

6470

65-
def _get_thing_results(session: Session, q: str) -> list[dict]:
71+
def _get_thing_results(session: Session, q: str, limit: int) -> list[dict]:
6672
vector = Thing.search_vector
6773
water_well_query = search(
68-
select(Thing).where(Thing.thing_type == "water well"), q, vector=vector
74+
select(Thing).where(Thing.thing_type == "water well"), q, vector=vector,
75+
limit=limit
6976
)
7077
spring_well_query = search(
71-
select(Thing).where(Thing.thing_type == "spring"), q, vector=vector
78+
select(Thing).where(Thing.thing_type == "spring"), q, vector=vector,
79+
limit=limit
7280
)
7381

7482
wells = session.scalars(water_well_query).all()
@@ -117,10 +125,11 @@ def make_spring_response(thing: Thing) -> dict:
117125
]
118126

119127

120-
def _get_asset_results(session: Session, q: str) -> list[dict]:
128+
def _get_asset_results(session: Session, q: str, limit: int) -> list[dict]:
121129
vector = Asset.search_vector
122130
query = search(
123-
select(Asset).join(AssetThingAssociation).join(Thing), q, vector=vector
131+
select(Asset).join(AssetThingAssociation).join(Thing), q, vector=vector,
132+
limit=limit
124133
)
125134

126135
assets = session.scalars(query).all()
@@ -143,16 +152,17 @@ def _get_asset_results(session: Session, q: str) -> list[dict]:
143152

144153

145154
@router.get("")
146-
def search_api(q: str, session: Session = Depends(get_db_session)):
155+
def search_api(session: session_dependency, q: str, limit: int=25, ) -> CustomPage[dict]:
147156
"""
148157
Search endpoint for the collaborative network.
149158
"""
150159

151-
results = _get_contact_results(session, q)
152-
results.extend(_get_thing_results(session, q))
153-
results.extend(_get_asset_results(session, q))
160+
results = _get_contact_results(session, q, limit)
161+
results.extend(_get_thing_results(session, q, limit))
162+
results.extend(_get_asset_results(session, q, limit))
154163

155-
return {"items": results, "total": len(results)}
164+
return paginate(results)
165+
# return {"items": results, "total": len(results)}
156166

157167

158168
# ============= EOF =============================================

core/dependencies.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,15 +24,24 @@
2424
session_dependency = Annotated[Session, Depends(get_db_session)]
2525

2626
# authentication functions
27-
# well_user_function = authenticated(permissions=["", "well:write"])
27+
28+
# Admin, can do everything Editor and Viewer can do
29+
# + create new objects
2830
admin_function = authenticated(permissions=["Admin"])
31+
32+
# Editor can do everything Viewer can do
33+
# + edit existing objects
2934
editor_function = authenticated(permissions=["Editor"])
35+
36+
# Viewer can view all "global" entities Location, Sample, Group, Lexicon, etc
3037
viewer_function = authenticated(permissions=["Viewer"])
3138

39+
# AMP specific permissions
3240
amp_admin_function = authenticated(permissions=["AMPAdmin"])
3341
amp_editor_function = authenticated(permissions=["AMPEditor"])
3442
amp_viewer_function = authenticated(permissions=["AMPViewer"])
3543

44+
# for testing
3645
no_permission_function = authenticated(permissions=["NoPermission"])
3746

3847

db/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ def adder(session, table, model, user=None, **kwargs):
7171
return obj
7272

7373

74-
def search(query, search_query, vector=None, regconfig=None, sort=True):
74+
def search(query, search_query, vector=None, regconfig=None, sort=True, limit=None):
7575
if not search_query.strip():
7676
return query
7777

@@ -95,6 +95,9 @@ def search(query, search_query, vector=None, regconfig=None, sort=True):
9595
)
9696
)
9797

98+
if limit:
99+
query = query.limit(limit)
100+
98101
return query.params(term=search_query)
99102

100103

services/query_helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ def simple_all_getter(session, table) -> list[object]:
121121
return session.scalars(sql).all()
122122

123123

124-
def order_sort_filter(sql, table, sort, order, filter_) -> Select[Any]:
124+
def order_sort_filter(sql: Select[Any], table: DeclarativeBase, sort:str, order:str, filter_:str) -> Select[Any]:
125125
if order:
126126
if not sort:
127127
raise ValueError(

0 commit comments

Comments
 (0)