diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index de1f0749c..608a59e9b 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -14,6 +14,7 @@ repos: '--show-source', '--statistics' ] + exclude: ^db/__init__.py$ # all models need to be imported for Alembic, but are not used directly # - repo: https://github.com/pre-commit/mirrors-mypy # rev: v1.10.0 # Use the latest stable version or pin to your preference diff --git a/alembic/env.py b/alembic/env.py index d94484ccd..d02a07101 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -46,7 +46,7 @@ def include_object(object, name, type_, reflected, compare_to): # only include tables in sql alchemy model, not auto-generated tables from PostGIS or TIGER - if type_ == "table": + if type_ == "table" or name.endswith("_version") or name == "transaction": return name in model_tables return True diff --git a/alembic/versions/66ac1af4ba69_initial_migration.py b/alembic/versions/66ac1af4ba69_initial_migration.py index 7bd24acdf..97a31f5d5 100644 --- a/alembic/versions/66ac1af4ba69_initial_migration.py +++ b/alembic/versions/66ac1af4ba69_initial_migration.py @@ -8,10 +8,7 @@ from typing import Sequence, Union -from alembic import op -import geoalchemy2 -import sqlalchemy as sa -import sqlalchemy_utils +# from alembic import op from sqlalchemy.orm import configure_mappers # revision identifiers, used by Alembic. @@ -39,6 +36,23 @@ def upgrade() -> None: # It is here as a record of the initial database state. # Actual initial database creation should be done through the Base.metadata.create_all(engine) call above. + """ + TODO + The following code will need to be regenerated by Alembic since configure_mappers() is now called + in db/__init__.py to ensure all models are loaded before creating the database schema. This is + require for SQL Alchemy continuum. + + The following code will also need to be added: + + - op.drop_index("idx_location_version_point", table_name="location_version", if_exists=True) + - before calling op.create_index("idx_location_version_point", "location_version", ["point"], unique=False, postgresql_using="gist",) + - op.drop_index("idx_location_point", table_name="location", if_exists=True) + - before calling op.create_index("idx_location_point", "location", ["point"], unique=False, postgresql_using="gist",) + + We will also need to figure out how to handle the SQL Alchemy searchable columns in the models, as they are not currently handled by Alembic. + There is some documentation about sync_triggers, but that has not yet been tested. + """ + # ### commands auto generated by Alembic - please adjust! ### # op.create_table('asset', # sa.Column('name', sa.String(), nullable=False), diff --git a/api/sample.py b/api/sample.py index c7a4f2928..e36daafd6 100644 --- a/api/sample.py +++ b/api/sample.py @@ -14,20 +14,19 @@ # limitations under the License. # =============================================================================== -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException from sqlalchemy.orm import Session -from starlette.status import HTTP_201_CREATED +from starlette.status import HTTP_201_CREATED, HTTP_404_NOT_FOUND from api.pagination import CustomPage from core.dependencies import session_dependency from db import adder from db.engine import get_db_session from db.sample import Sample -from schemas_v2.sample import ( - SampleResponse, - CreateSample, -) +from schemas_v2 import ResourceNotFoundResponse +from schemas_v2.sample import SampleResponse, CreateSample, UpdateSample from services.query_helper import paginated_all_getter +from services.crud_helper import model_patcher router = APIRouter( prefix="/sample", @@ -36,7 +35,7 @@ # ============= Post ============================================= @router.post("", status_code=HTTP_201_CREATED) -def add_sample(sample_data: CreateSample, session: Session = Depends(get_db_session)): +def add_sample(sample_data: CreateSample, session: session_dependency): """ Endpoint to add a sample. """ @@ -69,6 +68,33 @@ def add_sample(sample_data: CreateSample, session: Session = Depends(get_db_sess # return adder(session, GeothermalSample, sample_data) +# ============= Update ============================================= +@router.patch("/{sample_id}", summary="Update Sample") +def update_sample( + sample_id: int, + sample_data: UpdateSample, + session: Session = Depends(get_db_session), +) -> SampleResponse | ResourceNotFoundResponse: + """ + Endpoint to update a sample. + """ + + """ + Development notes: + + What do we do if the field is nullable and the schema defaults to None? + If that occurs, then we update the field to None, which may not have + been the intension of the user. We could set some string to indicate + DO NOT UPDATE. Perhaps coordination between the front and backends? + """ + if session.get(Sample, sample_id) is None: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, + detail=f"Sample with ID {sample_id} not found.", + ) + return model_patcher(session, Sample, sample_id, sample_data) + + # ============= Get ============================================= @router.get("", summary="Get Samples") def get_samples(session: session_dependency) -> CustomPage[SampleResponse]: @@ -102,11 +128,20 @@ def get_samples(session: session_dependency) -> CustomPage[SampleResponse]: # ============= Get by ID ============================================= @router.get("/{sample_id}", summary="Get Sample by ID") -def get_sample_by_id(sample_id: int, session: session_dependency) -> SampleResponse: +def get_sample_by_id( + sample_id: int, session: session_dependency +) -> SampleResponse | ResourceNotFoundResponse: """ Endpoint to retrieve a sample by its ID. """ - return session.get(Sample, sample_id) + sample = session.get(Sample, sample_id) + if sample is None: + raise HTTPException( + status_code=HTTP_404_NOT_FOUND, + detail=f"Sample with ID {sample_id} not found.", + ) + else: + return sample # @router.get("/{sample_id}", summary="Get Geochemical Sample by ID") diff --git a/db/__init__.py b/db/__init__.py index 2aed05165..40af68323 100644 --- a/db/__init__.py +++ b/db/__init__.py @@ -14,20 +14,25 @@ # limitations under the License. # =============================================================================== -from db.asset import * +# import all models from db package so that Alembic can discover them + +from db.base import * from db.base import Base + +from db.asset import * from db.collabnet import * +from db.contact import * from db.geochronology import * +from db.geothermal import * +from db.group import * from db.lexicon import * from db.location import * from db.observation import * from db.publication import * from db.sample import * -from db.sensor import * +from db.sensor.groundwaterlevel import * +from db.sensor.sensor import * from db.thing import * -from db.contact import * -from db.group import * - from sqlalchemy import ( func, @@ -40,6 +45,9 @@ inspect_search_vectors, search_manager, ) +from sqlalchemy.orm import configure_mappers + +configure_mappers() def adder(session, table, model, **kwargs): diff --git a/schemas_v2/__init__.py b/schemas_v2/__init__.py index 8e546ddc2..a117e7097 100644 --- a/schemas_v2/__init__.py +++ b/schemas_v2/__init__.py @@ -13,5 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== +from pydantic import BaseModel + + +class ResourceNotFoundResponse(BaseModel): + detail: str + # ============= EOF ============================================= diff --git a/schemas_v2/sample.py b/schemas_v2/sample.py index 47c45db03..f9d431a1b 100644 --- a/schemas_v2/sample.py +++ b/schemas_v2/sample.py @@ -13,9 +13,11 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== -from datetime import datetime +from datetime import datetime, timezone +from pydantic import BaseModel, field_validator -from pydantic import BaseModel +from db.engine import get_db_session +from db import Thing # -------- CREATE ---------- @@ -47,8 +49,39 @@ class CreateGeothermalSample(BaseModel): # -------- RESPONSE ---------- class SampleResponse(BaseModel): id: int + collection_timestamp: datetime + collection_method: str + thing_id: int # -------- UPDATE ---------- +class UpdateSample(BaseModel): + collection_timestamp: datetime | None = None + collection_method: str | None = None + thing_id: int | None = None + + @field_validator("thing_id") + def validate_thing_id_exists(cls, thing_id: int) -> int: + """ + Validate that the thing_id exists in the database. + """ + with next(get_db_session()) as session: + thing = session.get(Thing, thing_id) + if not thing: + raise ValueError(f"Thing with ID {thing_id} does not exist.") + return thing_id + + @field_validator("collection_timestamp") + def validate_collection_timestamp(cls, collection_timestamp: datetime) -> datetime: + """ + Validate that the collection_timestamp is not in the future. + """ + if collection_timestamp: + if collection_timestamp > datetime.now(tz=timezone.utc): + raise ValueError( + f"Collection timestamp {collection_timestamp} cannot be in the future." + ) + return collection_timestamp + # ============= EOF ============================================= diff --git a/services/query_helper.py b/services/query_helper.py index 350276fbc..373c543df 100644 --- a/services/query_helper.py +++ b/services/query_helper.py @@ -18,7 +18,7 @@ from fastapi import HTTPException from fastapi_pagination.ext.sqlalchemy import paginate -from sqlalchemy import select, Float, Integer, Column, Select, func +from sqlalchemy import select, Float, Integer, Column, Select from sqlalchemy.orm import DeclarativeBase from sqlalchemy.sql.elements import OperatorExpression @@ -100,6 +100,12 @@ def simple_get_by_id(session, table, item_id) -> object | None: """ Helper function to get a record by ID from the database. """ + """ + REFACTOR NOTE/TODO: this function replicates the functionality of + session.get(table, item_id), which is a SQL Alchemy method to retrieve + a record by its primary key. This function can be replaced with + session.get(table, item_id). + """ sql = select(table).where(table.id == item_id) result = session.execute(sql) return result.scalar_one_or_none() diff --git a/tests/__init__.py b/tests/__init__.py index bdca88b10..87d334000 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -13,22 +13,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== -import uuid - -import pytest from fastapi.testclient import TestClient -from sqlalchemy.orm import configure_mappers -from core.app import init_lexicon, init_hypertables -from db.location import Location -from db.base import Base -from db.sample import Sample -from db.sensor import Sensor +from core.app import init_lexicon +from db import Base +from db.engine import engine from main import app -from db.engine import engine, session_ctx -from services.thing_helper import add_thing -configure_mappers() Base.metadata.drop_all(engine) Base.metadata.create_all(engine) @@ -38,66 +29,4 @@ client = TestClient(app) -@pytest.fixture(scope="session") -def location(): - with session_ctx() as session: - loc = Location(point="SRID=4326;POINT(0 0)") - session.add(loc) - session.commit() - session.refresh(loc) - yield loc - - session.close() - - -@pytest.fixture(scope="session") -def thing(location): - with session_ctx() as session: - # loc = Location(point='SRID=4326;POINT(0 0)') - # session.add(loc) - # session.commit() - # session.refresh(loc) - - wt = add_thing( - session, - { - "location_id": location.id, - "name": "Test Well", - }, - "water well", - ) - - yield wt - - session.close() - - -@pytest.fixture(scope="session") -def sample(thing): - with session_ctx() as session: - sample = Sample( - collection_timestamp="2025-01-01T00:00:00Z", - collection_method="manual", - thing_id=thing.id, - sample_type="groundwater", - sampler="Test Sampler", - ) - session.add(sample) - session.commit() - yield sample - - session.close() - - -@pytest.fixture(scope="session") -def sensor(): - with session_ctx() as session: - sensor = Sensor(name=f"Test Sensor {uuid.uuid4()}") - session.add(sensor) - session.commit() - yield sensor - - session.close() - - # ============= EOF ============================================= diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..7c14541df --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,62 @@ +import pytest +import uuid + +from db import * +from db.engine import session_ctx +from services.thing_helper import add_thing + + +@pytest.fixture(scope="session") +def location(): + with session_ctx() as session: + loc = Location(point="SRID=4326;POINT(0 0)") + session.add(loc) + session.commit() + session.refresh(loc) + yield loc + + session.close() + + +@pytest.fixture(scope="session") +def thing(location): + with session_ctx() as session: + wt = add_thing( + session, + { + "location_id": location.id, + "name": "Test Well", + }, + "water well", + ) + + yield wt + + session.close() + + +@pytest.fixture(scope="session") +def sample(thing): + with session_ctx() as session: + sample = Sample( + collection_timestamp="2025-01-01T00:00:00", + collection_method="manual", + thing_id=thing.id, + sample_type="groundwater", + sampler="Test Sampler", + ) + session.add(sample) + session.commit() + yield sample + + session.close() + + +@pytest.fixture(scope="session") +def sensor(): + with session_ctx() as session: + sensor = Sensor(name=f"Test Sensor {uuid.uuid4()}") + session.add(sensor) + session.commit() + yield sensor + session.close() diff --git a/tests/test_asset.py b/tests/test_asset.py index b5dfcf9e4..90803ea72 100644 --- a/tests/test_asset.py +++ b/tests/test_asset.py @@ -15,7 +15,7 @@ # =============================================================================== from api.asset import get_storage_bucket from core.app import app -from tests import client, thing, location +from tests import client class MockBlob: diff --git a/tests/test_contact.py b/tests/test_contact.py index 33bfcbdd7..7657fbb27 100644 --- a/tests/test_contact.py +++ b/tests/test_contact.py @@ -1,11 +1,6 @@ # from fastapi.testclient import TestClient # from main import app # from models import Base, engine -import pytest - -from db import Thing -from db.engine import get_db_session, session_ctx - # Base.metadata.drop_all(engine) # Base.metadata.create_all(engine) @@ -17,24 +12,13 @@ # ADD tests ====================================================== -@pytest.fixture(scope="function") -def thing(): - with session_ctx() as session: - thing = Thing(name="Test Thing", thing_type="water well") - session.add(thing) - session.commit() - yield - - session.close() - - def test_add_contact(thing): response = client.post( "/contact", json={ "name": "Test Contact", "role": "Owner", - "thing_id": 1, + "thing_id": thing.id, "emails": [{"email": "fasdfasdf@gmail.com", "email_type": "Primary"}], "phones": [{"phone_number": "+12345678901", "phone_type": "Primary"}], "addresses": [ @@ -83,7 +67,7 @@ def test_add_contact(thing): # assert data["phone"] == f"+1234567890{i}" -def test_phone_validation_fail(): +def test_phone_validation_fail(thing): for phone in [ "definitely not a phone", # "1234567890", @@ -100,7 +84,7 @@ def test_phone_validation_fail(): "/contact", json={ "name": "Test Contact 2", - "thing_id": 1, + "thing_id": thing.id, "role": "Primary", "emails": [{"email": "fasdfasdf@gmail.com", "email_type": "Primary"}], "phones": [{"phone_number": phone, "phone_type": "Primary"}], @@ -124,7 +108,7 @@ def test_phone_validation_fail(): assert detail["msg"] == f"Value error, Invalid phone number. {phone}" -def test_email_validation_fail(): +def test_email_validation_fail(thing): for email in [ "", @@ -137,7 +121,7 @@ def test_email_validation_fail(): "/contact", json={ "name": "Test ContactX", - "thing_id": 1, + "thing_id": thing.id, "role": "Primary", "emails": [{"email": email, "email_type": "Primary"}], "phones": [{"phone_number": "+12345678901", "phone_type": "Primary"}], diff --git a/tests/test_geospatial.py b/tests/test_geospatial.py index ba1c5bec0..9f54bedd3 100644 --- a/tests/test_geospatial.py +++ b/tests/test_geospatial.py @@ -13,6 +13,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== +from pathlib import Path import pytest from db import Thing, Location, LocationThingAssociation @@ -76,6 +77,9 @@ def test_get_shapefile(): 'attachment; filename="things.zip"' == response.headers["Content-Disposition"] ) + for shapefile_ending in [".shp", ".shx", ".dbf", ".prj", ".zip"]: + Path(f"things{shapefile_ending}").unlink(missing_ok=True) + @pytest.mark.skip def test_get_locations_expand(): diff --git a/tests/test_group.py b/tests/test_group.py index 800fa78b1..6b7d88e71 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -1,5 +1,5 @@ import pytest -from tests import client, thing, location +from tests import client # ADD tests ====================================================== diff --git a/tests/test_observation.py b/tests/test_observation.py index 8dc31499e..cd43afba2 100644 --- a/tests/test_observation.py +++ b/tests/test_observation.py @@ -14,7 +14,7 @@ # limitations under the License. # =============================================================================== -from tests import client, sample, sensor, thing, location +from tests import client import pytest diff --git a/tests/test_sample.py b/tests/test_sample.py index 69fd2f276..76c7ab81b 100644 --- a/tests/test_sample.py +++ b/tests/test_sample.py @@ -14,18 +14,22 @@ # limitations under the License. # =============================================================================== import pytest +from datetime import datetime +from db.engine import session_ctx +from db.sample import Sample from tests import client -def test_add_sample(): +# ============= Post tests for samples ============================================= +def test_add_sample(thing): """ Test adding a sample to the collaborative network. """ response = client.post( "/sample", json={ - "thing_id": 1, + "thing_id": thing.id, "collection_timestamp": "2025-01-01T00:00:00Z", "collection_method": "manual", "release_status": "draft", @@ -36,7 +40,13 @@ def test_add_sample(): data = response.json() assert response.status_code == 201 assert data["id"] is not None - assert data["thing_id"] == 1 + assert data["thing_id"] == thing.id + + # cleanup after adding the sample + sample_id = data["id"] + with session_ctx() as session: + session.query(Sample).where(Sample.id == sample_id).delete() + session.commit() @pytest.mark.skip(reason="Geochemical sample endpoint not implemented yet") @@ -73,16 +83,123 @@ def test_add_geothermal_sample(): assert data["sample_id"] == 1 +# ============= Patch tests for samples ============================================= +def test_patch_sample(sample): + """ + Test updating a sample in the collaborative network. + """ + original_method_patch = sample.collection_method + original_timestamp_patch = sample.collection_timestamp + + collection_method_patch = "continuous" + collection_timestamp_patch = "2025-01-02T00:00:00+00:00" + response = client.patch( + f"/sample/{sample.id}", + json={ + "collection_method": collection_method_patch, + "collection_timestamp": collection_timestamp_patch, + }, + ) + assert response.status_code == 200 + data = response.json() + assert data == { + "id": sample.id, + "collection_timestamp": collection_timestamp_patch.split("+")[0], + "collection_method": collection_method_patch, + "thing_id": sample.thing_id, + } + + # cleanup after patching the sample + with session_ctx() as session: + updated_sample = session.query(Sample).filter(Sample.id == sample.id).one() + updated_sample.collection_method = original_method_patch + updated_sample.collection_timestamp = original_timestamp_patch + session.commit() + + +def test_patch_sample_404_not_found(sample): + """ + Test updating a sample that does not exist in the collaborative network. + """ + collection_method_patch = "continuous" + response = client.patch( + "/sample/999", + json={ + "collection_method": collection_method_patch, + }, + ) + assert response.status_code == 404 + data = response.json() + assert data["detail"] == "Sample with ID 999 not found." + + +def test_patch_sample_422_thing_id_not_found(sample): + """ + Test updating a sample with a thing_id that does not exist + """ + bad_thing_id = 999 + response = client.patch( + f"/sample/{sample.id}", + json={ + "thing_id": bad_thing_id, + }, + ) + assert response.status_code == 422 + data = response.json() + assert data["detail"] == [ + { + "type": "value_error", + "loc": ["body", "thing_id"], + "msg": f"Value error, Thing with ID {bad_thing_id} does not exist.", + "input": bad_thing_id, + "ctx": {"error": {}}, + } + ] + + +def test_patch_sample_422_invalid_timestamp(sample): + """ + Test updating a sample with an invalid collection timestamp. + """ + bad_collection_timestamp = "3500-01-01T00:00:00Z" + bad_collection_timestamp_dt = datetime.fromisoformat( + bad_collection_timestamp.replace("Z", "+00:00") + ) + response = client.patch( + f"/sample/{sample.id}", + json={ + "collection_timestamp": bad_collection_timestamp, # Invalid date + }, + ) + assert response.status_code == 422 + data = response.json() + assert data["detail"] == [ + { + "type": "value_error", + "loc": ["body", "collection_timestamp"], + "msg": f"Value error, Collection timestamp {bad_collection_timestamp_dt} cannot be in the future.", + "input": bad_collection_timestamp, + "ctx": {"error": {}}, + } + ] + + # ============= Get tests for samples ============================================= -def test_get_samples(): +def test_get_samples(sample): """ Test retrieving samples from the collaborative network. """ response = client.get("/sample") assert response.status_code == 200 data = response.json() - assert "items" in data - assert len(data["items"]) > 0 + assert data["items"] == [ + { + "id": sample.id, + "collection_timestamp": sample.collection_timestamp, + "collection_method": sample.collection_method, + "thing_id": sample.thing_id, + } + ] @pytest.mark.skip(reason="Geochemical samples endpoint not implemented yet") @@ -109,14 +226,29 @@ def test_get_geothermal_samples(): assert len(data["items"]) > 0 -def test_get_sample_by_id(): +def test_get_sample_by_id(sample): """ Test retrieving a sample from the collaborative network. """ - response = client.get("/sample/1") + response = client.get(f"/sample/{sample.id}") assert response.status_code == 200 data = response.json() - assert data["id"] == 1 + assert data == { + "id": sample.id, + "collection_timestamp": sample.collection_timestamp, + "collection_method": sample.collection_method, + "thing_id": sample.thing_id, + } + + +def test_get_sample_by_id_404_not_found(sample): + """ + Test retrieving a sample from the collaborative network. + """ + response = client.get("/sample/999") + assert response.status_code == 404 + data = response.json() + assert data["detail"] == "Sample with ID 999 not found." # ============= EOF ============================================= diff --git a/tests/test_search.py b/tests/test_search.py index 9f1927de7..cb4364d70 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -13,8 +13,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # =============================================================================== -import pprint - import pytest from sqlalchemy import select @@ -24,13 +22,13 @@ from tests import client -def test_search_api(): +def test_search_api(thing, sample): response = client.get("/search", params={"q": "Test"}) assert response.status_code == 200 data = response.json() assert isinstance(data, list) - assert len(data) == 5 + assert len(data) == 2 @pytest.mark.skip(reason="This test is not working .")