From 463417666a92df588b3389ab6558bf76f340a97e Mon Sep 17 00:00:00 2001 From: Johnny Bouder Date: Tue, 3 Jun 2025 15:56:10 -0400 Subject: [PATCH 1/2] Updates to simplify endpoint payload. Fix tests. --- .vscode/settings.json | 5 ++++- app/applicants/router.py | 6 +++--- app/applicants/schemas.py | 9 ++++++--- app/applicants/services.py | 8 +++++--- app/cases/router.py | 6 +++--- app/cases/schemas.py | 19 +++++++++---------- app/cases/services.py | 8 +++++--- tests/conftest.py | 12 ++++++++++++ tests/test_applicants.py | 20 ++++++++++---------- tests/test_cases.py | 36 ++++++++++++++---------------------- 10 files changed, 71 insertions(+), 58 deletions(-) diff --git a/.vscode/settings.json b/.vscode/settings.json index b995766..eb789a0 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -5,5 +5,8 @@ "source.fixAll": "explicit", "source.organizeImports": "explicit" } - } + }, + "python.testing.pytestArgs": ["tests"], + "python.testing.unittestEnabled": false, + "python.testing.pytestEnabled": true } diff --git a/app/applicants/router.py b/app/applicants/router.py index f9b4305..79b3971 100644 --- a/app/applicants/router.py +++ b/app/applicants/router.py @@ -5,7 +5,7 @@ from starlette import status import app.applicants.services as service -from app.applicants.schemas import Applicant, ApplicantPayload +from app.applicants.schemas import Applicant, ApplicantBase, ApplicantPayload from app.db import get_db router = APIRouter( @@ -35,7 +35,7 @@ async def get_applicant(id: int, db: db_session): @router.put( "/applicants/{id}", status_code=status.HTTP_200_OK, response_model=Applicant ) -async def update_applicant(id: int, applicant: Applicant, db: db_session): +async def update_applicant(id: int, applicant: ApplicantBase, db: db_session): db_applicant = service.update_item(db, id, applicant) return db_applicant @@ -43,7 +43,7 @@ async def update_applicant(id: int, applicant: Applicant, db: db_session): @router.post( "/applicants", status_code=status.HTTP_201_CREATED, response_model=Applicant ) -async def create_applicant(applicant: Applicant, db: db_session): +async def create_applicant(applicant: ApplicantBase, db: db_session): db_applicant = service.create_item(db, applicant) return db_applicant diff --git a/app/applicants/schemas.py b/app/applicants/schemas.py index e4e485d..0da4c30 100644 --- a/app/applicants/schemas.py +++ b/app/applicants/schemas.py @@ -4,9 +4,7 @@ # Pydantic Models -class Applicant(BaseModel): - model_config = ConfigDict(from_attributes=True) - id: int | None = None +class ApplicantBase(BaseModel): first_name: str last_name: str middle_name: str | None = None @@ -21,6 +19,11 @@ class Applicant(BaseModel): state: str | None = None zip: str | None = None country: str = "USA" + + +class Applicant(ApplicantBase): + model_config = ConfigDict(from_attributes=True) + id: int | None = None created_at: datetime updated_at: datetime diff --git a/app/applicants/services.py b/app/applicants/services.py index 7b728a8..7b78483 100644 --- a/app/applicants/services.py +++ b/app/applicants/services.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import Session from app.applicants.models import DBApplicant -from app.applicants.schemas import Applicant +from app.applicants.schemas import ApplicantBase from app.utils import get_next_page, get_page_count, get_prev_page @@ -25,7 +25,7 @@ def get_item(db: Session, applicant_id: int): return db.query(DBApplicant).where(DBApplicant.id == applicant_id).first() -def update_item(db: Session, id: int, applicant: Applicant): +def update_item(db: Session, id: int, applicant: ApplicantBase): db_applicant = db.query(DBApplicant).filter(DBApplicant.id == id).first() if db_applicant is None: raise HTTPException(status_code=404, detail="Applicant not founds") @@ -53,8 +53,10 @@ def update_item(db: Session, id: int, applicant: Applicant): return db_applicant -def create_item(db: Session, applicant: Applicant): +def create_item(db: Session, applicant: ApplicantBase): db_applicant = DBApplicant(**applicant.model_dump()) + db_applicant.created_at = datetime.now() + db_applicant.updated_at = datetime.now() db.add(db_applicant) db.commit() db.refresh(db_applicant) diff --git a/app/cases/router.py b/app/cases/router.py index 2d539cf..16581ed 100644 --- a/app/cases/router.py +++ b/app/cases/router.py @@ -5,7 +5,7 @@ from starlette import status import app.cases.services as service -from app.cases.schemas import Case, CasePayload, CaseWithApplicant +from app.cases.schemas import Case, CaseBase, CasePayload, CaseWithApplicant from app.db import get_db router = APIRouter( @@ -31,13 +31,13 @@ async def get_case(id: int, db: db_session): @router.put("/cases/{id}", status_code=status.HTTP_200_OK, response_model=Case) -async def update_case(id: int, case: Case, db: db_session): +async def update_case(id: int, case: CaseBase, db: db_session): db_case = service.update_item(db, id, case) return db_case @router.post("/cases", status_code=status.HTTP_201_CREATED, response_model=Case) -async def create_case(case: Case, db: db_session): +async def create_case(case: CaseBase, db: db_session): db_case = service.create_item(db, case) return db_case diff --git a/app/cases/schemas.py b/app/cases/schemas.py index de65b4d..302aa04 100644 --- a/app/cases/schemas.py +++ b/app/cases/schemas.py @@ -1,4 +1,5 @@ from datetime import datetime +from typing import Literal from pydantic import BaseModel, ConfigDict @@ -6,22 +7,20 @@ # Pydantic Models -class Case(BaseModel): - model_config = ConfigDict(from_attributes=True) - id: int | None = None - status: str +class CaseBase(BaseModel): + status: Literal["Not Started", "In Progress", "Approved", "Denied"] assigned_to: str | None = None - created_at: datetime - updated_at: datetime applicant_id: int | None = None -class CaseWithApplicant(BaseModel): - id: int - status: str - assigned_to: str | None = None +class Case(CaseBase): + model_config = ConfigDict(from_attributes=True) + id: int | None = None created_at: datetime updated_at: datetime + + +class CaseWithApplicant(Case): applicant: Applicant | None = None diff --git a/app/cases/services.py b/app/cases/services.py index 571a4e0..957f67e 100644 --- a/app/cases/services.py +++ b/app/cases/services.py @@ -4,7 +4,7 @@ from sqlalchemy.orm import Session, joinedload from app.cases.models import DBCase -from app.cases.schemas import Case +from app.cases.schemas import CaseBase from app.utils import get_next_page, get_page_count, get_prev_page @@ -65,7 +65,7 @@ def get_item(db: Session, case_id: int): } -def update_item(db: Session, id: int, case: Case): +def update_item(db: Session, id: int, case: CaseBase): db_case = db.query(DBCase).filter(DBCase.id == id).first() if db_case is None: raise HTTPException(status_code=404, detail="Case not founds") @@ -80,8 +80,10 @@ def update_item(db: Session, id: int, case: Case): return db_case -def create_item(db: Session, case: Case): +def create_item(db: Session, case: CaseBase): db_case = DBCase(**case.model_dump()) + db_case.created_at = datetime.now() + db_case.updated_at = datetime.now() db.add(db_case) db.commit() db.refresh(db_case) diff --git a/tests/conftest.py b/tests/conftest.py index 6f89ab7..bb35b33 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,3 +8,15 @@ def client(): client = TestClient(app) yield client + + +def generalize_json_data(data): + """ + Helper function to generalize JSON data for testing by removing + auto-generated fields like id, created_at, updated_at. + """ + new_data = data.copy() + new_data.pop("id", None) + new_data.pop("created_at", None) + new_data.pop("updated_at", None) + return new_data diff --git a/tests/test_applicants.py b/tests/test_applicants.py index f9a716a..b3c7175 100644 --- a/tests/test_applicants.py +++ b/tests/test_applicants.py @@ -1,5 +1,6 @@ +from tests.conftest import generalize_json_data + base_applicant = { - "id": 0, "first_name": "John", "last_name": "Doe", "middle_name": "A", @@ -14,15 +15,14 @@ "state": "CA", "zip": "12345", "country": "USA", - "created_at": "2023-10-01T00:00:00", - "updated_at": "2023-10-01T00:00:00", } def test_create_applicant(client): response = client.post("/api/applicants/", json=base_applicant) + response_json = generalize_json_data(response.json()) assert response.status_code == 201 - assert response.json() == base_applicant + assert response_json == base_applicant def test_get_all_applicants(client): @@ -38,18 +38,18 @@ def test_get_applicants_paged(client): def test_get_applicant(client): - response = client.get("/api/applicants/0") + response = client.get("/api/applicants/1") + response_json = generalize_json_data(response.json()) assert response.status_code == 200 - assert response.json() == base_applicant + assert response_json == base_applicant def test_update_applicant(client): updated_applicant = base_applicant.copy() updated_applicant["middle_name"] = "test" - response = client.put("/api/applicants/0", json=updated_applicant) - response_json = response.json() - response_json["updated_at"] = updated_applicant["updated_at"] + response = client.put("/api/applicants/1", json=updated_applicant) + response_json = generalize_json_data(response.json()) assert response.status_code == 200 assert response_json == updated_applicant @@ -60,7 +60,7 @@ def test_update_applicant_invalid_id(client): def test_delete_applicant(client): - response = client.delete("/api/applicants/0") + response = client.delete("/api/applicants/1") assert response.status_code == 204 diff --git a/tests/test_cases.py b/tests/test_cases.py index cd50b70..c7b0edf 100644 --- a/tests/test_cases.py +++ b/tests/test_cases.py @@ -1,26 +1,17 @@ +from tests.conftest import generalize_json_data + base_case = { - "id": 0, "status": "Not Started", "assigned_to": "Test User", - "created_at": "2023-10-01T00:00:00", - "updated_at": "2023-10-01T00:00:00", "applicant_id": 0, } -base_case_with_applicant = { - "id": 0, - "status": "Not Started", - "assigned_to": "Test User", - "created_at": "2023-10-01T00:00:00", - "updated_at": "2023-10-01T00:00:00", - "applicant": None, # This will be None if no applicant is associated -} - def test_create_case(client): response = client.post("/api/cases/", json=base_case) + response_json = generalize_json_data(response.json()) assert response.status_code == 201 - assert response.json() == base_case + assert response_json == base_case def test_get_all_cases(client): @@ -35,19 +26,20 @@ def test_get_cases_paged(client): assert len(response.json()) > 0 -def test_get_cases(client): - response = client.get("/api/cases/0") +def test_get_case(client): + response = client.get("/api/cases/1") + response_json = generalize_json_data(response.json()) + response_json["applicant_id"] = 0 + response_json.pop("applicant") assert response.status_code == 200 - assert response.json() == base_case_with_applicant + assert response_json == base_case def test_update_case(client): updated_case = base_case.copy() - updated_case["assigned_to"] = "test user" - - response = client.put("/api/cases/0", json=updated_case) - response_json = response.json() - response_json["updated_at"] = updated_case["updated_at"] + updated_case["status"] = "In Progress" + response = client.put("/api/cases/1", json=updated_case) + response_json = generalize_json_data(response.json()) assert response.status_code == 200 assert response_json == updated_case @@ -58,7 +50,7 @@ def test_update_case_invalid_id(client): def test_delete_case(client): - response = client.delete("/api/cases/0") + response = client.delete("/api/cases/1") assert response.status_code == 204 From 3fc06e2aeec8227a7f483233705285423afe17e9 Mon Sep 17 00:00:00 2001 From: Johnny Bouder Date: Wed, 4 Jun 2025 09:31:49 -0400 Subject: [PATCH 2/2] Updates to use test db for unit tests, instead of real database. --- .gitignore | 3 +- pyproject.toml | 2 +- tests/conftest.py | 80 +++++++++++++++++++++++++++++++++++++--- tests/test_applicants.py | 37 +++++++++++++++---- tests/test_cases.py | 37 +++++++++++++++---- tests/test_users.py | 58 +++++++++++++++++++---------- 6 files changed, 175 insertions(+), 42 deletions(-) diff --git a/.gitignore b/.gitignore index 0233d24..85abf38 100644 --- a/.gitignore +++ b/.gitignore @@ -153,5 +153,6 @@ dmypy.json cython_debug/ # Misc -db.sqlite3: +db.sqlite3 +db.test.sqlite3 .DS_STORE diff --git a/pyproject.toml b/pyproject.toml index e8b3cb7..3921d6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -17,7 +17,7 @@ dependencies = [ ] [project.optional-dependencies] -dev = ["coverage", "pytest", "ruff"] +dev = ["coverage", "pytest", "pytest-asyncio", "ruff"] [tool.setuptools] packages = ["app"] diff --git a/tests/conftest.py b/tests/conftest.py index bb35b33..9cec1ce 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,13 +1,83 @@ +import os +import sys +from collections.abc import Generator +from typing import Any + import pytest +from fastapi import FastAPI from fastapi.testclient import TestClient +from sqlalchemy import create_engine +from sqlalchemy.orm import sessionmaker + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +# this is to include backend dir in sys.path so that we can import from db,main.py + +from app.admin.router import router as admin_router +from app.applicants.router import router as applicants_router +from app.cases.router import router as cases_router +from app.db import Base, get_db +from app.health.router import router as health_router +from app.users.router import router as users_router + + +def start_application(): + app = FastAPI() + app.include_router(cases_router) + app.include_router(applicants_router) + app.include_router(users_router) + app.include_router(admin_router) + app.include_router(health_router) + return app + + +SQLALCHEMY_DATABASE_URL = "sqlite:///./db.test.sqlite3" +engine = create_engine( + SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} +) +# Use connect_args parameter only with sqlite +SessionTesting = sessionmaker(autocommit=False, autoflush=False, bind=engine) + -from app.main import app +@pytest.fixture(scope="function") +def app() -> Generator[FastAPI, Any, None]: + """ + Create a fresh database on each test case. + """ + Base.metadata.create_all(engine) # Create the tables. + _app = start_application() + yield _app + Base.metadata.drop_all(engine) + + +@pytest.fixture(scope="function") +def db_session(app: FastAPI) -> Generator[SessionTesting, Any, None]: # type: ignore + connection = engine.connect() + transaction = connection.begin() + session = SessionTesting(bind=connection) + yield session # use the session in tests. + session.close() + transaction.rollback() + connection.close() + + +@pytest.fixture(scope="function") +def client( + app: FastAPI, db_session: SessionTesting # type: ignore +) -> Generator[TestClient, Any, None]: + """ + Create a new FastAPI TestClient that uses the `db_session` fixture to override + the `get_db` dependency that is injected into routes. + """ + def _get_test_db(): + try: + yield db_session + finally: + pass -@pytest.fixture(scope="module") -def client(): - client = TestClient(app) - yield client + app.dependency_overrides[get_db] = _get_test_db + with TestClient(app) as client: + yield client def generalize_json_data(data): diff --git a/tests/test_applicants.py b/tests/test_applicants.py index b3c7175..057e3e6 100644 --- a/tests/test_applicants.py +++ b/tests/test_applicants.py @@ -1,3 +1,5 @@ +import pytest + from tests.conftest import generalize_json_data base_applicant = { @@ -18,33 +20,46 @@ } -def test_create_applicant(client): +async def seed_data(client): + client.post("/api/applicants/", json=base_applicant) + + +@pytest.mark.asyncio +async def test_create_applicant(client): response = client.post("/api/applicants/", json=base_applicant) response_json = generalize_json_data(response.json()) assert response.status_code == 201 assert response_json == base_applicant -def test_get_all_applicants(client): +@pytest.mark.asyncio +async def test_get_all_applicants(client): + await seed_data(client) response = client.get("/api/applicants") assert response.status_code == 200 assert len(response.json()) > 0 -def test_get_applicants_paged(client): +@pytest.mark.asyncio +async def test_get_applicants_paged(client): + await seed_data(client) response = client.get("/api/applicants?page_number=0&page_size=10") assert response.status_code == 200 assert len(response.json()) > 0 -def test_get_applicant(client): +@pytest.mark.asyncio +async def test_get_applicant(client): + await seed_data(client) response = client.get("/api/applicants/1") response_json = generalize_json_data(response.json()) assert response.status_code == 200 assert response_json == base_applicant -def test_update_applicant(client): +@pytest.mark.asyncio +async def test_update_applicant(client): + await seed_data(client) updated_applicant = base_applicant.copy() updated_applicant["middle_name"] = "test" @@ -54,16 +69,22 @@ def test_update_applicant(client): assert response_json == updated_applicant -def test_update_applicant_invalid_id(client): +@pytest.mark.asyncio +async def test_update_applicant_invalid_id(client): + await seed_data(client) response = client.put("/api/applicants/-1", json=base_applicant) assert response.status_code == 404 -def test_delete_applicant(client): +@pytest.mark.asyncio +async def test_delete_applicant(client): + await seed_data(client) response = client.delete("/api/applicants/1") assert response.status_code == 204 -def test_delete_applicant_invalid_id(client): +@pytest.mark.asyncio +async def test_delete_applicant_invalid_id(client): + await seed_data(client) response = client.delete("/api/applicants/-1") assert response.status_code == 404 diff --git a/tests/test_cases.py b/tests/test_cases.py index c7b0edf..57d25a1 100644 --- a/tests/test_cases.py +++ b/tests/test_cases.py @@ -1,3 +1,5 @@ +import pytest + from tests.conftest import generalize_json_data base_case = { @@ -7,26 +9,37 @@ } -def test_create_case(client): +async def seed_data(client): + client.post("/api/cases/", json=base_case) + + +@pytest.mark.asyncio +async def test_create_case(client): response = client.post("/api/cases/", json=base_case) response_json = generalize_json_data(response.json()) assert response.status_code == 201 assert response_json == base_case -def test_get_all_cases(client): +@pytest.mark.asyncio +async def test_get_all_cases(client): + await seed_data(client) response = client.get("/api/cases") assert response.status_code == 200 assert len(response.json()) > 0 -def test_get_cases_paged(client): +@pytest.mark.asyncio +async def test_get_cases_paged(client): + await seed_data(client) response = client.get("/api/cases?page_number=0&page_size=10") assert response.status_code == 200 assert len(response.json()) > 0 -def test_get_case(client): +@pytest.mark.asyncio +async def test_get_case(client): + await seed_data(client) response = client.get("/api/cases/1") response_json = generalize_json_data(response.json()) response_json["applicant_id"] = 0 @@ -35,7 +48,9 @@ def test_get_case(client): assert response_json == base_case -def test_update_case(client): +@pytest.mark.asyncio +async def test_update_case(client): + await seed_data(client) updated_case = base_case.copy() updated_case["status"] = "In Progress" response = client.put("/api/cases/1", json=updated_case) @@ -44,16 +59,22 @@ def test_update_case(client): assert response_json == updated_case -def test_update_case_invalid_id(client): +@pytest.mark.asyncio +async def test_update_case_invalid_id(client): + await seed_data(client) response = client.put("/api/cases/-1", json=base_case) assert response.status_code == 404 -def test_delete_case(client): +@pytest.mark.asyncio +async def test_delete_case(client): + await seed_data(client) response = client.delete("/api/cases/1") assert response.status_code == 204 -def test_delete_case_invalid_id(client): +@pytest.mark.asyncio +async def test_delete_case_invalid_id(client): + await seed_data(client) response = client.delete("/api/cases/-1") assert response.status_code == 404 diff --git a/tests/test_users.py b/tests/test_users.py index 7ce4a70..2660de9 100644 --- a/tests/test_users.py +++ b/tests/test_users.py @@ -1,21 +1,27 @@ +import pytest base_date = "2021-01-01T00:00:00.000000" base_user = { - "id": 0, - "user_id": "testuser", - "first_name": "Test", - "last_name": "User", - "display_name": "Test User", - "email": "testuser1@test.com", - "is_active": True, - "created": base_date, - "created_by": "System Account", - "modified": base_date, - "modified_by": "System Account" + "id": 0, + "user_id": "testuser", + "first_name": "Test", + "last_name": "User", + "display_name": "Test User", + "email": "testuser1@test.com", + "is_active": True, + "created": base_date, + "created_by": "System Account", + "modified": base_date, + "modified_by": "System Account", } -def test_create_user(client): +async def seed_data(client): + client.post("/api/users/", json=base_user) + + +@pytest.mark.asyncio +async def test_create_user(client): response = client.post("/api/users/", json=base_user) response_json = response.json() response_json["created"] = base_date @@ -25,19 +31,25 @@ def test_create_user(client): assert response_json == base_user -def test_get_all_users(client): +@pytest.mark.asyncio +async def test_get_all_users(client): + await seed_data(client) response = client.get("/api/users") assert response.status_code == 200 assert len(response.json()) > 0 -def test_get_users_paged(client): +@pytest.mark.asyncio +async def test_get_users_paged(client): + await seed_data(client) response = client.get("/api/users?page_number=0&page_size=10") assert response.status_code == 200 assert len(response.json()) > 0 -def test_get_user(client): +@pytest.mark.asyncio +async def test_get_user(client): + await seed_data(client) response = client.get("/api/users/0") response_json = response.json() response_json["created"] = base_date @@ -47,7 +59,9 @@ def test_get_user(client): assert response_json == base_user -def test_update_user(client): +@pytest.mark.asyncio +async def test_update_user(client): + await seed_data(client) updated_user = base_user.copy() updated_user["is_active"] = False @@ -60,16 +74,22 @@ def test_update_user(client): assert response_json == updated_user -def test_update_user_invalid_id(client): +@pytest.mark.asyncio +async def test_update_user_invalid_id(client): + await seed_data(client) response = client.put("/api/users/-1", json=base_user) assert response.status_code == 404 -def test_delete_user(client): +@pytest.mark.asyncio +async def test_delete_user(client): + await seed_data(client) response = client.delete("/api/users/0") assert response.status_code == 204 -def test_delete_user_invalid_id(client): +@pytest.mark.asyncio +async def test_delete_user_invalid_id(client): + await seed_data(client) response = client.delete("/api/users/-1") assert response.status_code == 404