diff --git a/data_access_layer/topics_db_functions.py b/data_access_layer/topics_db_functions.py index 4021538..d05fafa 100644 --- a/data_access_layer/topics_db_functions.py +++ b/data_access_layer/topics_db_functions.py @@ -33,3 +33,21 @@ def get_topics(): raise f"Connection error occurred: {c}" except Exception as e: raise f"An unexpected error occurred: {e}" + + +def add_new_topic(topic: str): + try: + collection = setup_mongodb(client=None, collection_name='topics') + collection.insert_one({"name": topic}) + except errors.ConnectionFailure as c: + raise f"Connection error occurred: {c}" + except Exception as e: + raise f"An unexpected error occurred: {e}" + +def check_topic_and_delete(topic:str): + collection = setup_mongodb(client=None, collection_name='topics') + topic_exists = collection.find_one({"name": topic}) + if topic_exists: + collection.delete_one({"name": topic}) + return True + return False diff --git a/routes/openai_route.py b/routes/openai_route.py index 2786039..316b745 100644 --- a/routes/openai_route.py +++ b/routes/openai_route.py @@ -1,4 +1,5 @@ from fastapi import APIRouter, Response, HTTPException +from data_access_layer.topics_db_functions import get_topics from globals import globals from data_access_layer import users_db_functions from services.openai_service import get_question_and_answer, evaluate_answer @@ -16,10 +17,14 @@ async def gen_question(body: GenBody, response: Response): max_attempts = 5 attempts = 0 + allowed_topics = get_topics() + if body.topic not in allowed_topics: + raise HTTPException(status_code=400, detail=f"Invalid topic: {body.topic}. Must be one of {allowed_topics}.") + try: while attempts < max_attempts: answer = get_question_and_answer(topic, difficulty, answers_num) - if answers_num == len(answer['Answer']): + if (answers_num and answers_num == len(answer['Answer'])) or (not answers_num and len(answer['Answer']) == 1): return answer attempts += 1 @@ -30,7 +35,6 @@ async def gen_question(body: GenBody, response: Response): response.status_code = 400 return {"error": str(e)} - @router.post('/evaluate') async def evaluate_question(body: QARequest,ai_answer:str, response: Response): try: @@ -41,9 +45,10 @@ async def evaluate_question(body: QARequest,ai_answer:str, response: Response): answer = body.answer evaluation_score = evaluate_answer(question=question_text, answer=answer,ai_answer=ai_answer) users_db_functions.add_user_stats(user_id=user_id, question_text=question_text, answer=answer, topic=topic, - difficulty=difficulty, - score=evaluation_score["Score"], answer_correct=(evaluation_score["Score"] >= 5), - client=globals.mongo_client) + difficulty=difficulty, + score=evaluation_score["Score"], + answer_correct=(evaluation_score["Score"] >= 5), + client=globals.mongo_client) evaluation_score["question"] = question_text evaluation_score["user_answer"] = answer return evaluation_score diff --git a/routes/questions_route.py b/routes/questions_route.py index 00944ca..8d4a3fe 100644 --- a/routes/questions_route.py +++ b/routes/questions_route.py @@ -1,6 +1,6 @@ -from fastapi import APIRouter, Form +from fastapi import APIRouter, Form, HTTPException from data_access_layer import questions_db_functions -from data_access_layer.topics_db_functions import get_topics +from data_access_layer.topics_db_functions import get_topics, add_new_topic from pymongo import errors router = APIRouter() @@ -11,7 +11,7 @@ async def store_data(question: str = Form(...), answer: str = Form(...), topic: explanation: str = Form(...), difficulty: str = Form(...), user_name: str = Form(...), user_id: str = Form(...)): response = questions_db_functions.store_data(question=question, topic=topic, answer=answer, explanation=explanation, - difficulty=difficulty, user_name=user_name, user_id=user_id) + difficulty=difficulty, user_name=user_name, user_id=user_id) return response @@ -25,3 +25,22 @@ async def load_topics(): raise f"Connection error occurred: {c}" except Exception as e: raise f"An unexpected error occurred: {e}" + + +@router.post('/topics') +async def add_topic(topic: str): + try: + topics = get_topics() + if topic not in topics: + add_new_topic(topic) + return f"{topic} added successfully" + else: + raise HTTPException(status_code=409, detail=f"{topic} is already in the list topics") + except HTTPException as h_e: + print(h_e) + raise h_e + except errors.ConnectionFailure as c: + raise f"Connection error occurred: {c}" + except Exception as e: + raise f"An unexpected error occurred: {e}" + diff --git a/tests/test_openai_service.py b/tests/test_openai_service.py index 732a820..e845992 100644 --- a/tests/test_openai_service.py +++ b/tests/test_openai_service.py @@ -70,6 +70,13 @@ def test_question_generation_wrong_difficulty(): response = requests.post(url, json={"topic": "python", "difficulty": "HARD"}) assert response.status_code == 422 +def test_question_generation_invalid_topic(): + server_url = os.getenv("SERVER_URL") + assert server_url is not None + url = f"{server_url}/question/generate" + response = requests.post(url, json={"topic": "invalid_topic"}) + assert response.status_code == 400 + assert "Invalid topic: invalid_topic. Must be one of" in response.json()['detail'] def test_gen_question_answers_num(): server_url = os.getenv("SERVER_URL") diff --git a/tests/test_questions.py b/tests/test_questions.py index d1901de..f5c145e 100644 --- a/tests/test_questions.py +++ b/tests/test_questions.py @@ -2,7 +2,7 @@ import requests from data_access_layer.questions_db_functions import check_question_existence_and_delete, store_data from globals import globals - +from data_access_layer.topics_db_functions import * def test_store_data(): data = { @@ -25,6 +25,7 @@ def test_store_data(): # Check if the question was added to the database and then delete it assert check_question_existence_and_delete(data=data, client=globals.mongo_client) + def test_load_topics_success(): server_url = os.getenv("SERVER_URL") assert server_url is not None @@ -33,3 +34,22 @@ def test_load_topics_success(): assert response.status_code == 200 assert isinstance(response.json(), list) assert all(isinstance(topic, str) for topic in response.json()) + + +def test_add_topic(): + server_url = os.getenv("SERVER_URL") + assert server_url is not None + url = f"{server_url}/question/topics?topic=Java" + response = requests.post(url) + assert response.status_code == 200 + assert response.json() == "Java added successfully" + assert check_topic_and_delete("Java") + +def test_add_exists_topic(): + server_url = os.getenv("SERVER_URL") + assert server_url is not None + url = f"{server_url}/question/topics?topic=python" + response = requests.post(url) + assert response.status_code == 409 + assert response.json()['detail'] == 'python is already in the list topics' +