Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -153,5 +153,6 @@ dmypy.json
cython_debug/

# Misc
db.sqlite3:
db.sqlite3
db.test.sqlite3
.DS_STORE
5 changes: 4 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,8 @@
"source.fixAll": "explicit",
"source.organizeImports": "explicit"
}
}
},
"python.testing.pytestArgs": ["tests"],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
}
6 changes: 3 additions & 3 deletions app/applicants/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from starlette import status

import app.applicants.services as service
from app.applicants.schemas import Applicant, ApplicantPayload
from app.applicants.schemas import Applicant, ApplicantBase, ApplicantPayload
from app.db import get_db

router = APIRouter(
Expand Down Expand Up @@ -35,15 +35,15 @@ async def get_applicant(id: int, db: db_session):
@router.put(
"/applicants/{id}", status_code=status.HTTP_200_OK, response_model=Applicant
)
async def update_applicant(id: int, applicant: Applicant, db: db_session):
async def update_applicant(id: int, applicant: ApplicantBase, db: db_session):
db_applicant = service.update_item(db, id, applicant)
return db_applicant


@router.post(
"/applicants", status_code=status.HTTP_201_CREATED, response_model=Applicant
)
async def create_applicant(applicant: Applicant, db: db_session):
async def create_applicant(applicant: ApplicantBase, db: db_session):
db_applicant = service.create_item(db, applicant)
return db_applicant

Expand Down
9 changes: 6 additions & 3 deletions app/applicants/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@


# Pydantic Models
class Applicant(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int | None = None
class ApplicantBase(BaseModel):
first_name: str
last_name: str
middle_name: str | None = None
Expand All @@ -21,6 +19,11 @@ class Applicant(BaseModel):
state: str | None = None
zip: str | None = None
country: str = "USA"


class Applicant(ApplicantBase):
model_config = ConfigDict(from_attributes=True)
id: int | None = None
created_at: datetime
updated_at: datetime

Expand Down
8 changes: 5 additions & 3 deletions app/applicants/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sqlalchemy.orm import Session

from app.applicants.models import DBApplicant
from app.applicants.schemas import Applicant
from app.applicants.schemas import ApplicantBase
from app.utils import get_next_page, get_page_count, get_prev_page


Expand All @@ -25,7 +25,7 @@ def get_item(db: Session, applicant_id: int):
return db.query(DBApplicant).where(DBApplicant.id == applicant_id).first()


def update_item(db: Session, id: int, applicant: Applicant):
def update_item(db: Session, id: int, applicant: ApplicantBase):
db_applicant = db.query(DBApplicant).filter(DBApplicant.id == id).first()
if db_applicant is None:
raise HTTPException(status_code=404, detail="Applicant not founds")
Expand Down Expand Up @@ -53,8 +53,10 @@ def update_item(db: Session, id: int, applicant: Applicant):
return db_applicant


def create_item(db: Session, applicant: Applicant):
def create_item(db: Session, applicant: ApplicantBase):
db_applicant = DBApplicant(**applicant.model_dump())
db_applicant.created_at = datetime.now()
db_applicant.updated_at = datetime.now()
db.add(db_applicant)
db.commit()
db.refresh(db_applicant)
Expand Down
6 changes: 3 additions & 3 deletions app/cases/router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from starlette import status

import app.cases.services as service
from app.cases.schemas import Case, CasePayload, CaseWithApplicant
from app.cases.schemas import Case, CaseBase, CasePayload, CaseWithApplicant
from app.db import get_db

router = APIRouter(
Expand All @@ -31,13 +31,13 @@ async def get_case(id: int, db: db_session):


@router.put("/cases/{id}", status_code=status.HTTP_200_OK, response_model=Case)
async def update_case(id: int, case: Case, db: db_session):
async def update_case(id: int, case: CaseBase, db: db_session):
db_case = service.update_item(db, id, case)
return db_case


@router.post("/cases", status_code=status.HTTP_201_CREATED, response_model=Case)
async def create_case(case: Case, db: db_session):
async def create_case(case: CaseBase, db: db_session):
db_case = service.create_item(db, case)
return db_case

Expand Down
19 changes: 9 additions & 10 deletions app/cases/schemas.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,26 @@
from datetime import datetime
from typing import Literal

from pydantic import BaseModel, ConfigDict

from app.applicants.schemas import Applicant


# Pydantic Models
class Case(BaseModel):
model_config = ConfigDict(from_attributes=True)
id: int | None = None
status: str
class CaseBase(BaseModel):
status: Literal["Not Started", "In Progress", "Approved", "Denied"]
assigned_to: str | None = None
created_at: datetime
updated_at: datetime
applicant_id: int | None = None


class CaseWithApplicant(BaseModel):
id: int
status: str
assigned_to: str | None = None
class Case(CaseBase):
model_config = ConfigDict(from_attributes=True)
id: int | None = None
created_at: datetime
updated_at: datetime


class CaseWithApplicant(Case):
applicant: Applicant | None = None


Expand Down
8 changes: 5 additions & 3 deletions app/cases/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from sqlalchemy.orm import Session, joinedload

from app.cases.models import DBCase
from app.cases.schemas import Case
from app.cases.schemas import CaseBase
from app.utils import get_next_page, get_page_count, get_prev_page


Expand Down Expand Up @@ -65,7 +65,7 @@ def get_item(db: Session, case_id: int):
}


def update_item(db: Session, id: int, case: Case):
def update_item(db: Session, id: int, case: CaseBase):
db_case = db.query(DBCase).filter(DBCase.id == id).first()
if db_case is None:
raise HTTPException(status_code=404, detail="Case not founds")
Expand All @@ -80,8 +80,10 @@ def update_item(db: Session, id: int, case: Case):
return db_case


def create_item(db: Session, case: Case):
def create_item(db: Session, case: CaseBase):
db_case = DBCase(**case.model_dump())
db_case.created_at = datetime.now()
db_case.updated_at = datetime.now()
db.add(db_case)
db.commit()
db.refresh(db_case)
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ dependencies = [
]

[project.optional-dependencies]
dev = ["coverage", "pytest", "ruff"]
dev = ["coverage", "pytest", "pytest-asyncio", "ruff"]

[tool.setuptools]
packages = ["app"]
Expand Down
92 changes: 87 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,92 @@
import os
import sys
from collections.abc import Generator
from typing import Any

import pytest
from fastapi import FastAPI
from fastapi.testclient import TestClient
from sqlalchemy import create_engine
from sqlalchemy.orm import sessionmaker

sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# this is to include backend dir in sys.path so that we can import from db,main.py

from app.admin.router import router as admin_router
from app.applicants.router import router as applicants_router
from app.cases.router import router as cases_router
from app.db import Base, get_db
from app.health.router import router as health_router
from app.users.router import router as users_router


def start_application():
app = FastAPI()
app.include_router(cases_router)
app.include_router(applicants_router)
app.include_router(users_router)
app.include_router(admin_router)
app.include_router(health_router)
return app


SQLALCHEMY_DATABASE_URL = "sqlite:///./db.test.sqlite3"
engine = create_engine(
SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False}
)
# Use connect_args parameter only with sqlite
SessionTesting = sessionmaker(autocommit=False, autoflush=False, bind=engine)


@pytest.fixture(scope="function")
def app() -> Generator[FastAPI, Any, None]:
"""
Create a fresh database on each test case.
"""
Base.metadata.create_all(engine) # Create the tables.
_app = start_application()
yield _app
Base.metadata.drop_all(engine)


@pytest.fixture(scope="function")
def db_session(app: FastAPI) -> Generator[SessionTesting, Any, None]: # type: ignore
connection = engine.connect()
transaction = connection.begin()
session = SessionTesting(bind=connection)
yield session # use the session in tests.
session.close()
transaction.rollback()
connection.close()


@pytest.fixture(scope="function")
def client(
app: FastAPI, db_session: SessionTesting # type: ignore
) -> Generator[TestClient, Any, None]:
"""
Create a new FastAPI TestClient that uses the `db_session` fixture to override
the `get_db` dependency that is injected into routes.
"""

def _get_test_db():
try:
yield db_session
finally:
pass

from app.main import app
app.dependency_overrides[get_db] = _get_test_db
with TestClient(app) as client:
yield client


@pytest.fixture(scope="module")
def client():
client = TestClient(app)
yield client
def generalize_json_data(data):
"""
Helper function to generalize JSON data for testing by removing
auto-generated fields like id, created_at, updated_at.
"""
new_data = data.copy()
new_data.pop("id", None)
new_data.pop("created_at", None)
new_data.pop("updated_at", None)
return new_data
57 changes: 39 additions & 18 deletions tests/test_applicants.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import pytest

from tests.conftest import generalize_json_data

base_applicant = {
"id": 0,
"first_name": "John",
"last_name": "Doe",
"middle_name": "A",
Expand All @@ -14,56 +17,74 @@
"state": "CA",
"zip": "12345",
"country": "USA",
"created_at": "2023-10-01T00:00:00",
"updated_at": "2023-10-01T00:00:00",
}


def test_create_applicant(client):
async def seed_data(client):
client.post("/api/applicants/", json=base_applicant)


@pytest.mark.asyncio
async def test_create_applicant(client):
response = client.post("/api/applicants/", json=base_applicant)
response_json = generalize_json_data(response.json())
assert response.status_code == 201
assert response.json() == base_applicant
assert response_json == base_applicant


def test_get_all_applicants(client):
@pytest.mark.asyncio
async def test_get_all_applicants(client):
await seed_data(client)
response = client.get("/api/applicants")
assert response.status_code == 200
assert len(response.json()) > 0


def test_get_applicants_paged(client):
@pytest.mark.asyncio
async def test_get_applicants_paged(client):
await seed_data(client)
response = client.get("/api/applicants?page_number=0&page_size=10")
assert response.status_code == 200
assert len(response.json()) > 0


def test_get_applicant(client):
response = client.get("/api/applicants/0")
@pytest.mark.asyncio
async def test_get_applicant(client):
await seed_data(client)
response = client.get("/api/applicants/1")
response_json = generalize_json_data(response.json())
assert response.status_code == 200
assert response.json() == base_applicant
assert response_json == base_applicant


def test_update_applicant(client):
@pytest.mark.asyncio
async def test_update_applicant(client):
await seed_data(client)
updated_applicant = base_applicant.copy()
updated_applicant["middle_name"] = "test"

response = client.put("/api/applicants/0", json=updated_applicant)
response_json = response.json()
response_json["updated_at"] = updated_applicant["updated_at"]
response = client.put("/api/applicants/1", json=updated_applicant)
response_json = generalize_json_data(response.json())
assert response.status_code == 200
assert response_json == updated_applicant


def test_update_applicant_invalid_id(client):
@pytest.mark.asyncio
async def test_update_applicant_invalid_id(client):
await seed_data(client)
response = client.put("/api/applicants/-1", json=base_applicant)
assert response.status_code == 404


def test_delete_applicant(client):
response = client.delete("/api/applicants/0")
@pytest.mark.asyncio
async def test_delete_applicant(client):
await seed_data(client)
response = client.delete("/api/applicants/1")
assert response.status_code == 204


def test_delete_applicant_invalid_id(client):
@pytest.mark.asyncio
async def test_delete_applicant_invalid_id(client):
await seed_data(client)
response = client.delete("/api/applicants/-1")
assert response.status_code == 404
Loading