diff --git a/api/location.py b/api/location.py index 1d63b23c5..666e09fe3 100644 --- a/api/location.py +++ b/api/location.py @@ -13,9 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== -from typing import Union - -from fastapi import Depends, Query +from fastapi import Depends, Query, Response from fastapi_pagination.ext.sqlalchemy import paginate from sqlalchemy import select, func from sqlalchemy.orm import Session @@ -27,10 +25,9 @@ from db.location import Location from db.engine import get_db_session from schemas.location import CreateLocation, LocationResponse, UpdateLocation -from schemas.thing import LocationWellResponse from services.geospatial_helper import make_within_wkt -from services.query_helper import make_query, order_sort_filter -from services.crud_helper import model_patcher +from services.query_helper import make_query, order_sort_filter, simple_get_by_id +from services.crud_helper import model_patcher, model_deleter from fastapi import APIRouter @@ -132,11 +129,10 @@ async def get_location( nearby_distance_km: float = 1, within: str = None, query: str = None, - expand: str = None, sort: str = None, order: str = None, filter_: str = Query(alias="filter", default=None), -) -> CustomPage[Union[LocationResponse, LocationWellResponse]]: +) -> CustomPage[LocationResponse]: """ Retrieve all wells from the database. """ @@ -154,17 +150,9 @@ async def get_location( elif within: sql = make_within_wkt(sql, within) - if expand == "well": - pass - - def transformer(items): - if expand == "well": - return [LocationWellResponse.model_validate(item) for item in items] - return [LocationResponse.model_validate(item) for item in items] - sql = order_sort_filter(sql, Location, sort, order, filter_) - return paginate(query=sql, conn=session, transformer=transformer) + return paginate(query=sql, conn=session) @router.get( @@ -172,24 +160,23 @@ def transformer(items): summary="Get location by ID", ) async def get_location_by_id( - location_id: int, expand: str = None, session: Session = Depends(get_db_session) -) -> LocationResponse | LocationWellResponse: + location_id: int, session: Session = Depends(get_db_session) +) -> LocationResponse: """ Retrieve a sample location by ID from the database. """ - sql = select(Location).where(Location.id == location_id) + location = simple_get_by_id(session, Location, location_id) + return location - result = session.execute(sql) - location = result.scalar_one_or_none() - if not location: - return {"message": "Location not found"} - - response_klass = LocationResponse - if expand == "well": - response_klass = LocationWellResponse - - return response_klass.model_validate(location) +@router.delete("/{location_id}", summary="Delete location by ID") +async def delete_location( + location_id: int, session: Session = Depends(get_db_session) +) -> Response: + """ + Delete a sample location by ID from the database. + """ + return model_deleter(session, Location, location_id) # ============= EOF ============================================= diff --git a/schemas/location.py b/schemas/location.py index af48e8f9d..b81795cec 100644 --- a/schemas/location.py +++ b/schemas/location.py @@ -21,7 +21,7 @@ from schemas import ORMBaseModel """ -REFACTOR TODO +TODO Create common validator classes to be shared amongst create and update schemas. Since many fields are optional in the update schemas set check_fields=False in the field_validator. diff --git a/services/crud_helper.py b/services/crud_helper.py index ac9f37a95..623abfce1 100644 --- a/services/crud_helper.py +++ b/services/crud_helper.py @@ -13,10 +13,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== -from fastapi import HTTPException +from fastapi import Response from pydantic import BaseModel from sqlalchemy.orm import Session, DeclarativeBase -from starlette.status import HTTP_404_NOT_FOUND +from starlette.status import HTTP_204_NO_CONTENT from services.query_helper import simple_get_by_id @@ -35,4 +35,12 @@ def model_patcher( return item +def model_deleter(session: Session, model: DeclarativeBase, item_id: int): + # simple_get_by_id raises HTTP_404_NOT_FOUND if the item is not found + item = simple_get_by_id(session, model, item_id) + session.delete(item) + session.commit() + return Response(status_code=HTTP_204_NO_CONTENT) + + # ============= EOF ============================================= diff --git a/tests/test_geospatial.py b/tests/test_geospatial.py index e9cff5d34..11d5f9060 100644 --- a/tests/test_geospatial.py +++ b/tests/test_geospatial.py @@ -57,6 +57,13 @@ def populate(): session.add(LocationThingAssociation(location=loc2, thing=thing2)) session.commit() + yield + + # Cleanup + session.delete(loc1) + session.delete(loc2) + session.commit() + def test_get_geojson(): response = client.get("/geospatial", params={"format": "geojson"}) diff --git a/tests/test_location.py b/tests/test_location.py index a73910077..ee98770d1 100644 --- a/tests/test_location.py +++ b/tests/test_location.py @@ -13,40 +13,62 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== +from geoalchemy2.shape import to_shape import pytest +from db import Location +from db.engine import session_ctx from tests import client +# ============= module & function fixtures ======================================= + + +@pytest.fixture(scope="function") +def second_location(): + with session_ctx() as session: + location = Location( + name="second location", + point="POINT (10.2 10.2)", + release_status="draft", + ) + session.add(location) + session.commit() + yield location + session.delete(location) + session.commit() + + +# ============= Post tests for locations ====================================== + def test_add_location(): - response = client.post( - "/location", - json={ - "name": "Test Location 3", - "point": "POINT(10.1 10.1)", - # "visible": True, - }, - ) - assert response.status_code == 201 - data = response.json() - assert "id" in data + payload = { + "name": "test location", + "point": "POINT (10.1 10.1)", + "release_status": "draft", + } + response = client.post("/location", json=payload) - response = client.post( - "/location", - json={ - "name": "Test Location 4", - "point": "POINT(50.0 50.0)", - # "visible": False, - }, - ) assert response.status_code == 201 data = response.json() assert "id" in data + assert data["name"] == payload["name"] + assert data["point"] == payload["point"] + assert data["release_status"] == payload["release_status"] + + # cleanup after test + with session_ctx() as session: + session.delete(session.get(Location, data["id"])) + session.commit() -def test_update_location(): +# ============= Patch tests for locations ===================================== + + +def test_update_location(location): + location_id = location.id response = client.patch( - "/location/1", + f"/location/{location_id}", json={ "point": "POINT (10.1 20.2)", "release_status": "draft", @@ -54,35 +76,90 @@ def test_update_location(): ) assert response.status_code == 200 data = response.json() - assert "id" in data + assert data["id"] == location_id assert data["point"] == "POINT (10.1 20.2)" assert data["release_status"] == "draft" + # cleanup after test + with session_ctx() as session: + updated_location = session.get(Location, location_id) + updated_location.point = location.point + updated_location.release_status = location.release_status + session.commit() -@pytest.mark.skip -def test_get_locations_expand(): - response = client.get("/base/location?expand=well") + +def test_patch_location_404_not_found(location): + """ + Testing updating a location that does not exist + """ + bad_location_id = 99999 + location_name_patch = "another test name" + response = client.patch( + f"/location/{bad_location_id}", json={"name": location_name_patch} + ) + data = response.json() + assert response.status_code == 404 + assert data["detail"] == f"Location with ID {bad_location_id} not found." + + +# ============= GET tests for locations ======================================= + + +def test_get_locations(location): + """ + Test retrieving locations + """ + response = client.get("/location") assert response.status_code == 200 data = response.json() - assert "items" in data - assert len(data["items"]) > 0 - for item in data["items"]: - assert "id" in item - assert "point" in item - assert "well" in item + assert data["total"] == 1 + assert data["items"][0]["id"] == location.id + assert data["items"][0]["name"] == location.name + assert data["items"][0]["point"] == to_shape(location.point).wkt + assert data["items"][0]["release_status"] == location.release_status -@pytest.mark.skip -def test_get_location_expand(): - response = client.get("/base/location/1", params={"expand": "well"}) +def test_get_location_by_id(location): + response = client.get(f"/location/{location.id}") assert response.status_code == 200 data = response.json() - assert "id" in data - assert data["id"] == 1 - assert "point" in data - assert data["point"] == "POINT (10.1 10.1)" - assert "well" in data - assert len(data["well"]) == 1 + assert data["id"] == location.id + assert data["name"] == location.name + assert data["point"] == to_shape(location.point).wkt + assert data["release_status"] == location.release_status + + +def test_get_sample_by_id_404_not_found(location): + bad_location_id = 999999999 + response = client.get(f"/location/{bad_location_id}") + data = response.json() + assert response.status_code == 404 + assert data["detail"] == f"Location with ID {bad_location_id} not found." + + +# ============= DELETE tests for locations ==================================== + + +def test_delete_location(second_location): + response = client.delete(f"/location/{second_location.id}") + assert response.status_code == 204 + + # Verify the location is deleted + response = client.get(f"/location/{second_location.id}") + assert response.status_code == 404 + data = response.json() + assert data["detail"] == f"Location with ID {second_location.id} not found." + + +def test_delete_location_404_not_found(second_location): + """ + Testing deleting a location that does not exist + """ + bad_location_id = 99999 + response = client.delete(f"/location/{bad_location_id}") + data = response.json() + assert response.status_code == 404 + assert data["detail"] == f"Location with ID {bad_location_id} not found." # ============= EOF ============================================= diff --git a/tests/test_thing.py b/tests/test_thing.py index f8e79132b..181dcf481 100644 --- a/tests/test_thing.py +++ b/tests/test_thing.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== -import pytest - from tests import client from main import app from core.dependencies import well_user_function @@ -46,7 +44,7 @@ def test_add_group(): assert data["name"] == "collabnet" -def test_add_well(): +def test_add_well(location): # response = client.post( # "/lexicon/add", json={"term": "Monitoring", "definition": "Monitoring Well"} # ) @@ -60,7 +58,7 @@ def test_add_well(): "/thing", json={ "thing_type": "water well", - "location_id": 1, + "location_id": location.id, "name": "Test Well", "api_id": "1001-0001", "ose_pod_id": "RA-0001", @@ -79,7 +77,7 @@ def test_add_well(): "/thing", json={ "thing_type": "water well", - "location_id": 2, + "location_id": location.id, "name": "Test Well 2", "api_id": "1001-0002", "ose_pod_id": "RA-0002",