diff --git a/api/group.py b/api/group.py index 5399ce10..962870c1 100644 --- a/api/group.py +++ b/api/group.py @@ -27,10 +27,12 @@ from db.group import Group from schemas.group import UpdateGroup, CreateGroup, GroupResponse from services.crud_helper import model_patcher, model_deleter, model_adder -from services.query_helper import ( - simple_get_by_id, - paginated_all_getter, +from services.group_helper import ( + get_well_counts_by_group_id, + group_to_response, + paginated_groups_getter, ) +from services.query_helper import simple_get_by_id router = APIRouter(prefix="/group", tags=["group"]) @@ -74,7 +76,7 @@ def get_groups( """ Retrieve all groups from the database. """ - return paginated_all_getter(session, Group, filter_=filter_) + return paginated_groups_getter(session, filter_=filter_) @router.get("/{group_id}", summary="Get group by ID") @@ -84,7 +86,9 @@ def get_group_by_id( """ Retrieve a group by ID from the database. """ - return simple_get_by_id(session, Group, group_id) + group = simple_get_by_id(session, Group, group_id) + counts = get_well_counts_by_group_id(session, [group.id]) + return group_to_response(group, counts.get(group.id, 0)) # @router.get( diff --git a/docs/refine-json-filters-and-virtual-fields.md b/docs/refine-json-filters-and-virtual-fields.md index e00a7c20..cd393cdc 100644 --- a/docs/refine-json-filters-and-virtual-fields.md +++ b/docs/refine-json-filters-and-virtual-fields.md @@ -38,6 +38,7 @@ Associations are stored in **`ThingContactAssociation`** (`thing_id`, `contact_i | List resource | Virtual `field` | Meaning | Implementation sketch | |---------------|------------------|---------|------------------------| | Thing (wells) | `contacts` | “Does **any** linked contact’s **name** match?” | EXISTS over `ThingContactAssociation` joining `Contact`, predicate on **`Contact.name`** | +| Thing (wells) | `groups` | “Does **any** linked project (**Group**) match?” | EXISTS over `GroupThingAssociation` joining `Group`, predicate on **`Group.id`** or **`Group.name`** | | Contact | `things` | “Does **any** linked monitoring site (**thing**) **name** match?” | EXISTS over **`ThingContactAssociation`** joining **`Thing`**, predicate on **`Thing.name`** | We keep naming aligned with ORM accessors (`Thing.contacts`-style summaries in API responses use **contacts**, and **`Contact`** side uses **`things`** for parity with the association proxy). @@ -77,6 +78,7 @@ Those paths previously raised **500**. Virtual sorts are implemented in **`_appl | `monitoring_status`, `well_status`, `datalogger_suitability_status` | Same “latest open” **`StatusHistory.status_value`** subquery as filters; **`lower(...)`**, **`nulls_last`** | | `site_name` | **`ThingIdLink.alternate_id`** where **`alternate_organization = 'NMBGMR'`**, smallest link **`id`** (matches **`Thing.site_name`**) | | `contacts` | **`min(lower(Contact.name))`** over **`ThingContactAssociation`** (first name alphabetically among linked contacts) | +| `groups` | **`min(lower(Group.name))`** over **`GroupThingAssociation`** (first project name alphabetically among linked groups) | | `aquifers` | **`min(lower(AquiferSystem.name))`** over **`ThingAquiferAssociation`** | | `open_status` | Latest open **“Open Status”** row; rank **Open** before **Closed**, then unknown strings, then no row | | `measuring_point_height` | Latest **`MeasuringPointHistory`** row with non-null height (**`start_date` desc**, limit 1) | @@ -104,6 +106,7 @@ Each filter **must** include **`field`**, **`operator`**, and **`value`** keys ( | Merge **`filter_`** + **`filters`**, sorting, pagination hook | **`order_sort_filter`** in **`services/query_helper.py`** | | Dispatch virtual fields | **`_apply_json_filter_clause`** in **`services/query_helper.py`** | | **`Thing` + contacts** | **`_apply_thing_contacts_filter`** | +| **`Thing` + groups** | **`_apply_thing_groups_filter`** | | **`Contact` + things** | **`_apply_contact_things_filter`** | | Contact list accepts repeated **`filter`** | **`GET`** **`/contact`** in **`api/contact.py`**, **`get_db_contacts`** in **`services/contact_helper.py`** | | Wells list pattern (reference) | **`GET`** **`/thing/water-well`** in **`api/thing.py`**, **`get_db_things`** in **`services/thing_helper.py`** | @@ -113,6 +116,7 @@ Each filter **must** include **`field`**, **`operator`**, and **`value`** keys ( - **`tests/test_contact_filters.py`**: **`things`** filters, **`things`** sort, multiple **`filter`** params on **`GET /contact`**. - **`tests/test_thing.py`** (contacts on wells): **`contacts`** **`contains`**, **`ncontains`**, **`nnull`**, and **`sort`** on **`monitoring_status`**, **`site_name`**, **`contacts`**, **`aquifers`**, etc. +- **`tests/test_thing.py`** (groups on wells): **`groups`** **`eq`** by project id or name when filtering wells by project. ## When you change this diff --git a/schemas/group.py b/schemas/group.py index e3cc7488..2472dc0f 100644 --- a/schemas/group.py +++ b/schemas/group.py @@ -58,6 +58,7 @@ class GroupResponse(BaseResponseModel): project_area: str | None group_type: GroupType | None parent_group_id: int | None + well_count: int = 0 @model_validator(mode="before") def project_area_to_wkt(self: Self) -> Self: diff --git a/services/group_helper.py b/services/group_helper.py new file mode 100644 index 00000000..b81dd81c --- /dev/null +++ b/services/group_helper.py @@ -0,0 +1,65 @@ +# =============================================================================== +# Copyright 2025 ross +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== +from typing import Any + +from fastapi_pagination.ext.sqlalchemy import paginate +from sqlalchemy import func, select +from sqlalchemy.orm import Session + +from db.group import Group, GroupThingAssociation +from db.thing import Thing +from schemas.group import GroupResponse +from services.query_helper import order_sort_filter + + +def get_well_counts_by_group_id( + session: Session, group_ids: list[int] +) -> dict[int, int]: + if not group_ids: + return {} + + stmt = ( + select( + GroupThingAssociation.group_id, + func.count(Thing.id), + ) + .join(Thing, GroupThingAssociation.thing_id == Thing.id) + .where(GroupThingAssociation.group_id.in_(group_ids)) + .where(Thing.thing_type == "water well") + .group_by(GroupThingAssociation.group_id) + ) + return {row[0]: int(row[1]) for row in session.execute(stmt).all()} + + +def group_to_response(group: Group, well_count: int = 0) -> GroupResponse: + response = GroupResponse.model_validate(group) + return response.model_copy(update={"well_count": well_count}) + + +def paginated_groups_getter( + session: Session, + filter_: str | None = None, + *, + filters: list[str] | None = None, +) -> Any: + sql = select(Group) + sql = order_sort_filter(sql, Group, None, None, filter_, filters=filters) + + def transformer(groups: list[Group]) -> list[GroupResponse]: + counts = get_well_counts_by_group_id(session, [group.id for group in groups]) + return [group_to_response(group, counts.get(group.id, 0)) for group in groups] + + return paginate(query=sql, conn=session, transformer=transformer) diff --git a/services/query_helper.py b/services/query_helper.py index 538e93a0..aeb14227 100644 --- a/services/query_helper.py +++ b/services/query_helper.py @@ -343,6 +343,25 @@ def _thing_contacts_min_name_sort_scalar(thing_table: type): ) +def _thing_groups_min_name_sort_scalar(thing_table: type): + """Minimum ``lower(Group.name)`` across linked projects (stable proxy for display order).""" + from db.group import Group, GroupThingAssociation + + gta = GroupThingAssociation + g = Group + return ( + select(func.min(func.lower(g.name))) + .select_from(gta) + .join(g, gta.group_id == g.id) + .where( + gta.thing_id == thing_table.id, + g.name.isnot(None), + ) + .correlate(thing_table) + .scalar_subquery() + ) + + def _thing_aquifers_min_name_sort_scalar(thing_table: type): """Minimum ``lower(AquiferSystem.name)`` across linked aquifers.""" from db.aquifer_system import AquiferSystem @@ -417,6 +436,7 @@ def _contact_things_min_name_sort_scalar(contact_table: type): "datalogger_suitability_status", "site_name", "contacts", + "groups", "aquifers", "open_status", "measuring_point_height", @@ -486,6 +506,9 @@ def num_order(expr): if sort == "contacts": return str_order(_thing_contacts_min_name_sort_scalar(thing_table)) + if sort == "groups": + return str_order(_thing_groups_min_name_sort_scalar(thing_table)) + if sort == "aquifers": return str_order(_thing_aquifers_min_name_sort_scalar(thing_table)) @@ -610,6 +633,91 @@ def _linked_contact_select(predicate): return sql.where(exists(_linked_contact_select(pred))) +def _apply_thing_groups_filter( + sql: Select[Any], + thing_table: type, + operator: str, + value: Any, +) -> Select[Any]: + """Filter ``Thing`` rows using linked groups / projects (many-to-many). + + Refine sends ``field=groups`` from the wells list when filtering by project. + Match **any** linked ``Group`` by id (numeric ``eq``) or by ``Group.name``. + """ + from db.group import Group, GroupThingAssociation + + gta = GroupThingAssociation + g = Group + + def _linked_group_select(predicate): + return ( + select(1) + .select_from(gta) + .join(g, gta.group_id == g.id) + .where( + gta.thing_id == thing_table.id, + predicate, + ) + ) + + any_linked_group = ( + select(1) + .select_from(gta) + .join(g, gta.group_id == g.id) + .where(gta.thing_id == thing_table.id) + ) + + if operator == "nnull": + return sql.where(exists(any_linked_group)) + + if operator == "null": + return sql.where(~exists(any_linked_group)) + + if operator == "eq": + + def _eq_predicate(): + try: + group_id = int(value) + return g.id == group_id + except (TypeError, ValueError): + return g.name == str(value) + + return sql.where(exists(_linked_group_select(_eq_predicate()))) + + if operator == "ne": + + def _ne_predicate(): + try: + group_id = int(value) + return g.id == group_id + except (TypeError, ValueError): + return g.name == str(value) + + return sql.where(~exists(_linked_group_select(_ne_predicate()))) + + if operator == "ncontains": + nlg = _linked_group_select(g.name.ilike(f"%{value}%")) + return sql.where(~exists(nlg)) + + if operator == "contains": + pred = g.name.ilike(f"%{value}%") + elif operator == "startswith": + pred = g.name.ilike(f"{value}%") + elif operator == "endswith": + pred = g.name.ilike(f"%{value}") + else: + raise HTTPException( + status_code=400, + detail=( + f"Operator {operator!r} is not supported for groups " + "filters (contains, ncontains, eq, ne, startswith, endswith, " + "null, nnull)" + ), + ) + + return sql.where(exists(_linked_group_select(pred))) + + def _apply_contact_things_filter( sql: Select[Any], contact_table: type, @@ -739,6 +847,9 @@ def _apply_json_filter_clause( if getattr(table, "__name__", None) == "Thing" and field == "contacts": return _apply_thing_contacts_filter(sql, table, operator, value) + if getattr(table, "__name__", None) == "Thing" and field == "groups": + return _apply_thing_groups_filter(sql, table, operator, value) + try: column = getattr(table, field) except AttributeError as exc: diff --git a/services/thing_helper.py b/services/thing_helper.py index 16fdd9a6..c7500045 100644 --- a/services/thing_helper.py +++ b/services/thing_helper.py @@ -71,6 +71,7 @@ def is_debug_timing_enabled() -> bool: selectinload(Thing.contact_associations).selectinload( ThingContactAssociation.contact ), + selectinload(Thing.group_associations).selectinload(GroupThingAssociation.group), selectinload(Thing.well_purposes), selectinload(Thing.well_casing_materials), selectinload(Thing.links), diff --git a/tests/test_group.py b/tests/test_group.py index d703b0bd..de4c6672 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -5,7 +5,8 @@ from pydantic import ValidationError from core.dependencies import admin_function, viewer_function, editor_function -from db import Group +from db import Group, GroupThingAssociation, Thing +from db.engine import session_ctx from main import app from schemas import DT_FMT from schemas.group import ValidateGroup @@ -103,6 +104,55 @@ def test_get_groups(group): assert data["items"][0]["project_area"] == to_shape(group.project_area).wkt assert data["items"][0]["description"] == group.description assert data["items"][0]["parent_group_id"] == group.parent_group_id + assert data["items"][0]["well_count"] == 1 + + +def test_get_groups_well_count_excludes_non_water_wells( + group, water_well_thing, location, spring_thing +): + with session_ctx() as session: + second_well = Thing( + name="Second Test Well", + first_visit_date="2023-03-03", + thing_type="water well", + release_status="draft", + well_depth=10, + hole_depth=10, + well_casing_diameter=5.0, + well_casing_depth=10.0, + ) + session.add(second_well) + session.commit() + session.refresh(second_well) + + for thing_id in (second_well.id, spring_thing.id): + session.add(GroupThingAssociation(group_id=group.id, thing_id=thing_id)) + session.commit() + + response = client.get("/group") + assert response.status_code == 200 + data = response.json() + item = next(item for item in data["items"] if item["id"] == group.id) + assert item["well_count"] == 2 + + +def test_get_groups_well_count_zero_without_associations(): + payload = { + "release_status": "private", + "name": "Empty Project Group", + "description": "No associated wells.", + } + create_response = client.post("/group", json=payload) + assert create_response.status_code == 201 + group_id = create_response.json()["id"] + + response = client.get("/group") + assert response.status_code == 200 + data = response.json() + item = next(item for item in data["items"] if item["id"] == group_id) + assert item["well_count"] == 0 + + cleanup_post_test(Group, group_id) def test_get_group_by_id(group): @@ -118,6 +168,7 @@ def test_get_group_by_id(group): assert data["description"] == group.description assert data["parent_group_id"] == group.parent_group_id assert data["release_status"] == group.release_status + assert data["well_count"] == 1 def test_get_group_by_id_404_not_found(group): diff --git a/tests/test_thing.py b/tests/test_thing.py index d3444a7c..8c8859c4 100644 --- a/tests/test_thing.py +++ b/tests/test_thing.py @@ -985,6 +985,67 @@ def test_get_water_wells_filter_contacts_nnull(water_well_thing, contact): assert water_well_thing.id in ids +def test_get_water_wells_filter_groups_eq_by_id(group, water_well_thing): + fl = json.dumps( + {"field": "groups", "operator": "eq", "value": str(group.id)}, + ) + response = client.get("/thing/water-well", params=[("filter", fl)]) + assert response.status_code == 200 + data = response.json() + ids = [item["id"] for item in data["items"]] + assert water_well_thing.id in ids + + +def test_get_water_wells_filter_groups_eq_by_id_no_match(group, water_well_thing): + fl = json.dumps( + {"field": "groups", "operator": "eq", "value": "999999"}, + ) + response = client.get("/thing/water-well", params=[("filter", fl)]) + assert response.status_code == 200 + data = response.json() + ids = [item["id"] for item in data["items"]] + assert water_well_thing.id not in ids + + +def test_get_water_wells_filter_groups_eq_by_name(group, water_well_thing): + fl = json.dumps( + {"field": "groups", "operator": "eq", "value": group.name}, + ) + response = client.get("/thing/water-well", params=[("filter", fl)]) + assert response.status_code == 200 + data = response.json() + ids = [item["id"] for item in data["items"]] + assert water_well_thing.id in ids + + +def test_get_water_wells_filter_groups_unsupported_operator(group): + fl = json.dumps( + {"field": "groups", "operator": "gt", "value": group.id}, + ) + response = client.get("/thing/water-well", params=[("filter", fl)]) + assert response.status_code == 400 + + +def test_get_water_wells_list_includes_groups(group, water_well_thing): + response = client.get("/thing/water-well", params={"page": 1, "size": 50}) + assert response.status_code == 200 + data = response.json() + well = next(item for item in data["items"] if item["id"] == water_well_thing.id) + assert len(well["groups"]) >= 1 + assert well["groups"][0]["name"] == group.name + + +def test_get_water_wells_sort_groups_asc(group, water_well_thing): + response = client.get( + "/thing/water-well", + params={ + "sort": "groups", + "order": "asc", + }, + ) + assert response.status_code == 200 + + def test_get_water_wells_sort_monitoring_status_desc(water_well_thing): """Derived status columns are Python properties; sort uses StatusHistory SQL.""" response = client.get(