From 29c05a7774faef837a2ff4a919e6e9e26fe51b0e Mon Sep 17 00:00:00 2001 From: Jeremy Zilar Date: Sat, 13 Jun 2026 14:13:46 -0400 Subject: [PATCH] Add well counts and project filters for the Projects view. Groups list responses now include well counts, and the wells list can filter and sort by linked projects through the virtual groups field. --- api/group.py | 14 ++- .../refine-json-filters-and-virtual-fields.md | 4 + schemas/group.py | 1 + services/group_helper.py | 65 ++++++++++ services/query_helper.py | 111 ++++++++++++++++++ services/thing_helper.py | 1 + tests/test_group.py | 53 ++++++++- tests/test_thing.py | 61 ++++++++++ 8 files changed, 304 insertions(+), 6 deletions(-) create mode 100644 services/group_helper.py diff --git a/api/group.py b/api/group.py index 5399ce103..962870c1e 100644 --- a/api/group.py +++ b/api/group.py @@ -27,10 +27,12 @@ from db.group import Group from schemas.group import UpdateGroup, CreateGroup, GroupResponse from services.crud_helper import model_patcher, model_deleter, model_adder -from services.query_helper import ( - simple_get_by_id, - paginated_all_getter, +from services.group_helper import ( + get_well_counts_by_group_id, + group_to_response, + paginated_groups_getter, ) +from services.query_helper import simple_get_by_id router = APIRouter(prefix="/group", tags=["group"]) @@ -74,7 +76,7 @@ def get_groups( """ Retrieve all groups from the database. """ - return paginated_all_getter(session, Group, filter_=filter_) + return paginated_groups_getter(session, filter_=filter_) @router.get("/{group_id}", summary="Get group by ID") @@ -84,7 +86,9 @@ def get_group_by_id( """ Retrieve a group by ID from the database. """ - return simple_get_by_id(session, Group, group_id) + group = simple_get_by_id(session, Group, group_id) + counts = get_well_counts_by_group_id(session, [group.id]) + return group_to_response(group, counts.get(group.id, 0)) # @router.get( diff --git a/docs/refine-json-filters-and-virtual-fields.md b/docs/refine-json-filters-and-virtual-fields.md index e00a7c209..cd393cdcf 100644 --- a/docs/refine-json-filters-and-virtual-fields.md +++ b/docs/refine-json-filters-and-virtual-fields.md @@ -38,6 +38,7 @@ Associations are stored in **`ThingContactAssociation`** (`thing_id`, `contact_i | List resource | Virtual `field` | Meaning | Implementation sketch | |---------------|------------------|---------|------------------------| | Thing (wells) | `contacts` | “Does **any** linked contact’s **name** match?” | EXISTS over `ThingContactAssociation` joining `Contact`, predicate on **`Contact.name`** | +| Thing (wells) | `groups` | “Does **any** linked project (**Group**) match?” | EXISTS over `GroupThingAssociation` joining `Group`, predicate on **`Group.id`** or **`Group.name`** | | Contact | `things` | “Does **any** linked monitoring site (**thing**) **name** match?” | EXISTS over **`ThingContactAssociation`** joining **`Thing`**, predicate on **`Thing.name`** | We keep naming aligned with ORM accessors (`Thing.contacts`-style summaries in API responses use **contacts**, and **`Contact`** side uses **`things`** for parity with the association proxy). @@ -77,6 +78,7 @@ Those paths previously raised **500**. Virtual sorts are implemented in **`_appl | `monitoring_status`, `well_status`, `datalogger_suitability_status` | Same “latest open” **`StatusHistory.status_value`** subquery as filters; **`lower(...)`**, **`nulls_last`** | | `site_name` | **`ThingIdLink.alternate_id`** where **`alternate_organization = 'NMBGMR'`**, smallest link **`id`** (matches **`Thing.site_name`**) | | `contacts` | **`min(lower(Contact.name))`** over **`ThingContactAssociation`** (first name alphabetically among linked contacts) | +| `groups` | **`min(lower(Group.name))`** over **`GroupThingAssociation`** (first project name alphabetically among linked groups) | | `aquifers` | **`min(lower(AquiferSystem.name))`** over **`ThingAquiferAssociation`** | | `open_status` | Latest open **“Open Status”** row; rank **Open** before **Closed**, then unknown strings, then no row | | `measuring_point_height` | Latest **`MeasuringPointHistory`** row with non-null height (**`start_date` desc**, limit 1) | @@ -104,6 +106,7 @@ Each filter **must** include **`field`**, **`operator`**, and **`value`** keys ( | Merge **`filter_`** + **`filters`**, sorting, pagination hook | **`order_sort_filter`** in **`services/query_helper.py`** | | Dispatch virtual fields | **`_apply_json_filter_clause`** in **`services/query_helper.py`** | | **`Thing` + contacts** | **`_apply_thing_contacts_filter`** | +| **`Thing` + groups** | **`_apply_thing_groups_filter`** | | **`Contact` + things** | **`_apply_contact_things_filter`** | | Contact list accepts repeated **`filter`** | **`GET`** **`/contact`** in **`api/contact.py`**, **`get_db_contacts`** in **`services/contact_helper.py`** | | Wells list pattern (reference) | **`GET`** **`/thing/water-well`** in **`api/thing.py`**, **`get_db_things`** in **`services/thing_helper.py`** | @@ -113,6 +116,7 @@ Each filter **must** include **`field`**, **`operator`**, and **`value`** keys ( - **`tests/test_contact_filters.py`**: **`things`** filters, **`things`** sort, multiple **`filter`** params on **`GET /contact`**. - **`tests/test_thing.py`** (contacts on wells): **`contacts`** **`contains`**, **`ncontains`**, **`nnull`**, and **`sort`** on **`monitoring_status`**, **`site_name`**, **`contacts`**, **`aquifers`**, etc. +- **`tests/test_thing.py`** (groups on wells): **`groups`** **`eq`** by project id or name when filtering wells by project. ## When you change this diff --git a/schemas/group.py b/schemas/group.py index e3cc7488c..2472dc0fa 100644 --- a/schemas/group.py +++ b/schemas/group.py @@ -58,6 +58,7 @@ class GroupResponse(BaseResponseModel): project_area: str | None group_type: GroupType | None parent_group_id: int | None + well_count: int = 0 @model_validator(mode="before") def project_area_to_wkt(self: Self) -> Self: diff --git a/services/group_helper.py b/services/group_helper.py new file mode 100644 index 000000000..b81dd81c2 --- /dev/null +++ b/services/group_helper.py @@ -0,0 +1,65 @@ +# =============================================================================== +# Copyright 2025 ross +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# =============================================================================== +from typing import Any + +from fastapi_pagination.ext.sqlalchemy import paginate +from sqlalchemy import func, select +from sqlalchemy.orm import Session + +from db.group import Group, GroupThingAssociation +from db.thing import Thing +from schemas.group import GroupResponse +from services.query_helper import order_sort_filter + + +def get_well_counts_by_group_id( + session: Session, group_ids: list[int] +) -> dict[int, int]: + if not group_ids: + return {} + + stmt = ( + select( + GroupThingAssociation.group_id, + func.count(Thing.id), + ) + .join(Thing, GroupThingAssociation.thing_id == Thing.id) + .where(GroupThingAssociation.group_id.in_(group_ids)) + .where(Thing.thing_type == "water well") + .group_by(GroupThingAssociation.group_id) + ) + return {row[0]: int(row[1]) for row in session.execute(stmt).all()} + + +def group_to_response(group: Group, well_count: int = 0) -> GroupResponse: + response = GroupResponse.model_validate(group) + return response.model_copy(update={"well_count": well_count}) + + +def paginated_groups_getter( + session: Session, + filter_: str | None = None, + *, + filters: list[str] | None = None, +) -> Any: + sql = select(Group) + sql = order_sort_filter(sql, Group, None, None, filter_, filters=filters) + + def transformer(groups: list[Group]) -> list[GroupResponse]: + counts = get_well_counts_by_group_id(session, [group.id for group in groups]) + return [group_to_response(group, counts.get(group.id, 0)) for group in groups] + + return paginate(query=sql, conn=session, transformer=transformer) diff --git a/services/query_helper.py b/services/query_helper.py index 538e93a0e..aeb142273 100644 --- a/services/query_helper.py +++ b/services/query_helper.py @@ -343,6 +343,25 @@ def _thing_contacts_min_name_sort_scalar(thing_table: type): ) +def _thing_groups_min_name_sort_scalar(thing_table: type): + """Minimum ``lower(Group.name)`` across linked projects (stable proxy for display order).""" + from db.group import Group, GroupThingAssociation + + gta = GroupThingAssociation + g = Group + return ( + select(func.min(func.lower(g.name))) + .select_from(gta) + .join(g, gta.group_id == g.id) + .where( + gta.thing_id == thing_table.id, + g.name.isnot(None), + ) + .correlate(thing_table) + .scalar_subquery() + ) + + def _thing_aquifers_min_name_sort_scalar(thing_table: type): """Minimum ``lower(AquiferSystem.name)`` across linked aquifers.""" from db.aquifer_system import AquiferSystem @@ -417,6 +436,7 @@ def _contact_things_min_name_sort_scalar(contact_table: type): "datalogger_suitability_status", "site_name", "contacts", + "groups", "aquifers", "open_status", "measuring_point_height", @@ -486,6 +506,9 @@ def num_order(expr): if sort == "contacts": return str_order(_thing_contacts_min_name_sort_scalar(thing_table)) + if sort == "groups": + return str_order(_thing_groups_min_name_sort_scalar(thing_table)) + if sort == "aquifers": return str_order(_thing_aquifers_min_name_sort_scalar(thing_table)) @@ -610,6 +633,91 @@ def _linked_contact_select(predicate): return sql.where(exists(_linked_contact_select(pred))) +def _apply_thing_groups_filter( + sql: Select[Any], + thing_table: type, + operator: str, + value: Any, +) -> Select[Any]: + """Filter ``Thing`` rows using linked groups / projects (many-to-many). + + Refine sends ``field=groups`` from the wells list when filtering by project. + Match **any** linked ``Group`` by id (numeric ``eq``) or by ``Group.name``. + """ + from db.group import Group, GroupThingAssociation + + gta = GroupThingAssociation + g = Group + + def _linked_group_select(predicate): + return ( + select(1) + .select_from(gta) + .join(g, gta.group_id == g.id) + .where( + gta.thing_id == thing_table.id, + predicate, + ) + ) + + any_linked_group = ( + select(1) + .select_from(gta) + .join(g, gta.group_id == g.id) + .where(gta.thing_id == thing_table.id) + ) + + if operator == "nnull": + return sql.where(exists(any_linked_group)) + + if operator == "null": + return sql.where(~exists(any_linked_group)) + + if operator == "eq": + + def _eq_predicate(): + try: + group_id = int(value) + return g.id == group_id + except (TypeError, ValueError): + return g.name == str(value) + + return sql.where(exists(_linked_group_select(_eq_predicate()))) + + if operator == "ne": + + def _ne_predicate(): + try: + group_id = int(value) + return g.id == group_id + except (TypeError, ValueError): + return g.name == str(value) + + return sql.where(~exists(_linked_group_select(_ne_predicate()))) + + if operator == "ncontains": + nlg = _linked_group_select(g.name.ilike(f"%{value}%")) + return sql.where(~exists(nlg)) + + if operator == "contains": + pred = g.name.ilike(f"%{value}%") + elif operator == "startswith": + pred = g.name.ilike(f"{value}%") + elif operator == "endswith": + pred = g.name.ilike(f"%{value}") + else: + raise HTTPException( + status_code=400, + detail=( + f"Operator {operator!r} is not supported for groups " + "filters (contains, ncontains, eq, ne, startswith, endswith, " + "null, nnull)" + ), + ) + + return sql.where(exists(_linked_group_select(pred))) + + def _apply_contact_things_filter( sql: Select[Any], contact_table: type, @@ -739,6 +847,9 @@ def _apply_json_filter_clause( if getattr(table, "__name__", None) == "Thing" and field == "contacts": return _apply_thing_contacts_filter(sql, table, operator, value) + if getattr(table, "__name__", None) == "Thing" and field == "groups": + return _apply_thing_groups_filter(sql, table, operator, value) + try: column = getattr(table, field) except AttributeError as exc: diff --git a/services/thing_helper.py b/services/thing_helper.py index 16fdd9a6a..c75000450 100644 --- a/services/thing_helper.py +++ b/services/thing_helper.py @@ -71,6 +71,7 @@ def is_debug_timing_enabled() -> bool: selectinload(Thing.contact_associations).selectinload( ThingContactAssociation.contact ), + selectinload(Thing.group_associations).selectinload(GroupThingAssociation.group), selectinload(Thing.well_purposes), selectinload(Thing.well_casing_materials), selectinload(Thing.links), diff --git a/tests/test_group.py b/tests/test_group.py index d703b0bd5..de4c6672a 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -5,7 +5,8 @@ from pydantic import ValidationError from core.dependencies import admin_function, viewer_function, editor_function -from db import Group +from db import Group, GroupThingAssociation, Thing +from db.engine import session_ctx from main import app from schemas import DT_FMT from schemas.group import ValidateGroup @@ -103,6 +104,55 @@ def test_get_groups(group): assert data["items"][0]["project_area"] == to_shape(group.project_area).wkt assert data["items"][0]["description"] == group.description assert data["items"][0]["parent_group_id"] == group.parent_group_id + assert data["items"][0]["well_count"] == 1 + + +def test_get_groups_well_count_excludes_non_water_wells( + group, water_well_thing, location, spring_thing +): + with session_ctx() as session: + second_well = Thing( + name="Second Test Well", + first_visit_date="2023-03-03", + thing_type="water well", + release_status="draft", + well_depth=10, + hole_depth=10, + well_casing_diameter=5.0, + well_casing_depth=10.0, + ) + session.add(second_well) + session.commit() + session.refresh(second_well) + + for thing_id in (second_well.id, spring_thing.id): + session.add(GroupThingAssociation(group_id=group.id, thing_id=thing_id)) + session.commit() + + response = client.get("/group") + assert response.status_code == 200 + data = response.json() + item = next(item for item in data["items"] if item["id"] == group.id) + assert item["well_count"] == 2 + + +def test_get_groups_well_count_zero_without_associations(): + payload = { + "release_status": "private", + "name": "Empty Project Group", + "description": "No associated wells.", + } + create_response = client.post("/group", json=payload) + assert create_response.status_code == 201 + group_id = create_response.json()["id"] + + response = client.get("/group") + assert response.status_code == 200 + data = response.json() + item = next(item for item in data["items"] if item["id"] == group_id) + assert item["well_count"] == 0 + + cleanup_post_test(Group, group_id) def test_get_group_by_id(group): @@ -118,6 +168,7 @@ def test_get_group_by_id(group): assert data["description"] == group.description assert data["parent_group_id"] == group.parent_group_id assert data["release_status"] == group.release_status + assert data["well_count"] == 1 def test_get_group_by_id_404_not_found(group): diff --git a/tests/test_thing.py b/tests/test_thing.py index d3444a7c2..8c8859c44 100644 --- a/tests/test_thing.py +++ b/tests/test_thing.py @@ -985,6 +985,67 @@ def test_get_water_wells_filter_contacts_nnull(water_well_thing, contact): assert water_well_thing.id in ids +def test_get_water_wells_filter_groups_eq_by_id(group, water_well_thing): + fl = json.dumps( + {"field": "groups", "operator": "eq", "value": str(group.id)}, + ) + response = client.get("/thing/water-well", params=[("filter", fl)]) + assert response.status_code == 200 + data = response.json() + ids = [item["id"] for item in data["items"]] + assert water_well_thing.id in ids + + +def test_get_water_wells_filter_groups_eq_by_id_no_match(group, water_well_thing): + fl = json.dumps( + {"field": "groups", "operator": "eq", "value": "999999"}, + ) + response = client.get("/thing/water-well", params=[("filter", fl)]) + assert response.status_code == 200 + data = response.json() + ids = [item["id"] for item in data["items"]] + assert water_well_thing.id not in ids + + +def test_get_water_wells_filter_groups_eq_by_name(group, water_well_thing): + fl = json.dumps( + {"field": "groups", "operator": "eq", "value": group.name}, + ) + response = client.get("/thing/water-well", params=[("filter", fl)]) + assert response.status_code == 200 + data = response.json() + ids = [item["id"] for item in data["items"]] + assert water_well_thing.id in ids + + +def test_get_water_wells_filter_groups_unsupported_operator(group): + fl = json.dumps( + {"field": "groups", "operator": "gt", "value": group.id}, + ) + response = client.get("/thing/water-well", params=[("filter", fl)]) + assert response.status_code == 400 + + +def test_get_water_wells_list_includes_groups(group, water_well_thing): + response = client.get("/thing/water-well", params={"page": 1, "size": 50}) + assert response.status_code == 200 + data = response.json() + well = next(item for item in data["items"] if item["id"] == water_well_thing.id) + assert len(well["groups"]) >= 1 + assert well["groups"][0]["name"] == group.name + + +def test_get_water_wells_sort_groups_asc(group, water_well_thing): + response = client.get( + "/thing/water-well", + params={ + "sort": "groups", + "order": "asc", + }, + ) + assert response.status_code == 200 + + def test_get_water_wells_sort_monitoring_status_desc(water_well_thing): """Derived status columns are Python properties; sort uses StatusHistory SQL.""" response = client.get(