From fbe33fe0765ed2eb206c8d3e9443bbf8cbfc4f48 Mon Sep 17 00:00:00 2001 From: Sandra Guerreiro Date: Tue, 3 Mar 2026 11:32:50 +0100 Subject: [PATCH] feat(tutor): clean endpoint --- src/app/api/api_v1/endpoints/tutor.py | 177 +++++-------------------- src/app/services/abst_chat.py | 21 +++ src/app/services/tutor/tutor.py | 11 +- src/app/tests/api/api_v1/test_tutor.py | 45 ------- 4 files changed, 55 insertions(+), 199 deletions(-) delete mode 100644 src/app/tests/api/api_v1/test_tutor.py diff --git a/src/app/api/api_v1/endpoints/tutor.py b/src/app/api/api_v1/endpoints/tutor.py index 48d1209..a8a6325 100644 --- a/src/app/api/api_v1/endpoints/tutor.py +++ b/src/app/api/api_v1/endpoints/tutor.py @@ -12,11 +12,11 @@ ) from src.app.api.dependencies import get_settings +from src.app.core.config import Settings from src.app.models.search import EnhancedSearchQuery from src.app.services.abst_chat import get_chat_service from src.app.services.data_collection import get_data_collection_service from src.app.services.exceptions import NoResultsError -from src.app.services.helpers import extract_json_from_response from src.app.services.search import SearchService, get_search_service from src.app.services.search_helpers import search_multi_inputs from src.app.services.tutor.agents import TEMPLATES @@ -43,8 +43,6 @@ router = APIRouter() -settings = get_settings() - def backoff_hdlr(details): logger.info( @@ -54,17 +52,21 @@ def backoff_hdlr(details): ) +def with_backoff(): + return backoff.on_exception( + wait_gen=backoff.expo, + exception=Exception, + logger=logger, + max_tries=3, + max_time=180, + jitter=backoff.random_jitter, + on_backoff=backoff_hdlr, + factor=2, + ) + + +@with_backoff() @router.post("/files/content") -@backoff.on_exception( - wait_gen=backoff.expo, - exception=Exception, - logger=logger, - max_tries=3, - max_time=180, - jitter=backoff.random_jitter, - on_backoff=backoff_hdlr, - factor=2, -) async def extract_files_content( files: Annotated[list[UploadFile], File()], response: Response, @@ -88,21 +90,13 @@ async def extract_files_content( ] try: - summaries = await chatfactory.chat_client.completion(messages=messages) - - assert isinstance(summaries, str) - json_summaries = extract_json_from_response(summaries) - assert isinstance(json_summaries, list) - - try: - summaries_output = ExtractorOutputList(extracts=json_summaries) - return summaries_output - except Exception: - formatted_output = await chatfactory.json_formatter_agent( - summaries, - "{extracts: [ 'summary': 'Summary', 'themes': [{'theme': 'Theme 1', 'reason': 'Reason for Theme 1'}, {'theme': 'Theme 2', 'reason': 'Reason for Theme 2'}, ...]}, { 'summary': 'Sumamry', 'themes': [{'theme': 'Theme 1', 'reason': 'Reason for Theme 1'}, {'theme': 'Theme 2', 'reason': 'Reason for Theme 2'}, ...]}] }", - ) - return formatted_output + summaries_output = await chatfactory.run_llm_with_json_parsing( + messages, + ExtractorOutputList, + fallback_formatter="{extracts: [ 'summary': 'Summary', 'themes': [{'theme': 'Theme 1', 'reason': 'Reason for Theme 1'}, {'theme': 'Theme 2', 'reason': 'Reason for Theme 2'}, ...]}, { 'summary': 'Sumamry', 'themes': [{'theme': 'Theme 1', 'reason': 'Reason for Theme 1'}, {'theme': 'Theme 2', 'reason': 'Reason for Theme 2'}, ...]}] }", + ) + + return summaries_output except Exception as e: logger.error(f"Error in extractor schema: {e}") @@ -111,6 +105,7 @@ async def extract_files_content( @router.post("/search_extracts") +@with_backoff() async def tutor_search_extract( summaries: SummariesList, background_tasks: BackgroundTasks, @@ -157,121 +152,17 @@ async def tutor_search_extract( return resp -@router.post("/search", deprecated=True) -async def tutor_search( - files: Annotated[list[UploadFile], File()], - response: Response, - sp: SearchService = Depends(get_search_service), - chatfactory=Depends(get_chat_service), -): - files_content = await get_files_content(files) - - doc_list_to_string = "Document {doc_nb}: {content}" - - file_content_str = [ - doc_list_to_string.format( - doc_nb=index + 1, - content=content, - ) - for index, content in enumerate(files_content) - ] - file_content_str = "\n\n".join(file_content_str) - - messages = [ - {"role": "system", "content": extractor_system_prompt}, - {"role": "user", "content": file_content_str}, - ] - - try: - themes_extracted = await chatfactory.chat_client.completion( - messages=messages, response_format=ExtractorOutputList - ) - - jsn = {} - if isinstance(themes_extracted, str): - jsn = extract_json_from_response(themes_extracted) - elif isinstance(themes_extracted, dict): - jsn = themes_extracted - else: - raise ValueError("Unexpected response format") - - print(jsn) - themes_extracted = ExtractorOutputList(**jsn) - - except Exception as e: - logger.error(f"Error in chat schema: {e}") - response.status_code = 204 - return TutorSearchResponse( - extracts=[], - nb_results=0, - documents=[], - ) - - if not themes_extracted or not themes_extracted.extracts: - return TutorSearchResponse( - extracts=[], - nb_results=0, - documents=[], - ) - - inputs = [doc.summary for doc in themes_extracted.extracts] # type: ignore - - try: - qp = EnhancedSearchQuery( - query=inputs, - nb_results=10, - sdg_filter=None, - corpora=None, - ) - - search_results = await search_multi_inputs( - qp=qp, - callback_function=sp.search_handler, - ) - except NoResultsError as e: - response.status_code = 404 - logger.error(f"No results found: {e}") - return TutorSearchResponse( - extracts=themes_extracted.extracts, - nb_results=0, - documents=[], - ) - - if not search_results: - return TutorSearchResponse( - extracts=themes_extracted.extracts, - nb_results=0, - documents=[], - ) - - resp = TutorSearchResponse( - extracts=themes_extracted.extracts, - nb_results=len(search_results), - documents=search_results, - ) - - return resp - - -@backoff.on_exception( - wait_gen=backoff.expo, - exception=Exception, - logger=logger, - max_tries=3, - max_time=180, - jitter=backoff.random_jitter, - on_backoff=backoff_hdlr, - factor=2, -) +@with_backoff() @router.post("/syllabus") async def create_syllabus( request: Request, body: TutorSyllabusRequest, lang: str = "en", data_collection=Depends(get_data_collection_service), + settings: Settings = Depends(get_settings), ) -> SyllabusResponse: session_id = request.headers.get("X-Session-ID") - results = await tutor_manager(body, lang) + results = await tutor_manager(body, lang, settings) # TODO: handle errors @@ -329,16 +220,7 @@ async def create_syllabus( """ -@backoff.on_exception( - wait_gen=backoff.expo, - exception=Exception, - logger=logger, - max_tries=3, - max_time=180, - jitter=backoff.random_jitter, - on_backoff=backoff_hdlr, - factor=2, -) +@with_backoff() @router.post("/syllabus/feedback") async def handle_syllabus_feedback( request: Request, @@ -359,7 +241,7 @@ async def handle_syllabus_feedback( syllabus=body.syllabus[0], feedback=body.feedback, documents=body.documents, - extracts=("/n").join([extract.summary for extract in body.extracts]), + extracts="\n".join([extract.summary for extract in body.extracts]), themes=(", ").join( [ (", ").join([theme["theme"] for theme in extract.themes]) @@ -373,7 +255,8 @@ async def handle_syllabus_feedback( try: syllabus = await chatfactory.chat_client.completion(messages=messages) - assert isinstance(syllabus, str) + if not isinstance(syllabus, str): + raise ValueError("Syllabus feedback response is not a string") await data_collection.register_syllabus_data( session_id=session_id, diff --git a/src/app/services/abst_chat.py b/src/app/services/abst_chat.py index a913dcb..dbaa38d 100644 --- a/src/app/services/abst_chat.py +++ b/src/app/services/abst_chat.py @@ -463,6 +463,27 @@ async def agent_message( return res + async def run_llm_with_json_parsing( + self, + messages: list[dict], + model_class, + fallback_formatter: str | None = None, + ): + raw = await self.chat_client.completion(messages=messages) + + if not isinstance(raw, str): + raise ValueError("LLM response must be string") + + try: + json_data = extract_json_from_response(raw) + if not isinstance(json_data, dict): + raise ValueError("Extracted JSON data is not a dictionary") + return model_class(**json_data) + except Exception: + if fallback_formatter: + return await self.json_formatter_agent(raw, fallback_formatter) + raise + async def get_llm_client(request: Request): return request.app.state.llm diff --git a/src/app/services/tutor/tutor.py b/src/app/services/tutor/tutor.py index 8981026..f6dfba7 100644 --- a/src/app/services/tutor/tutor.py +++ b/src/app/services/tutor/tutor.py @@ -1,6 +1,6 @@ from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel # type: ignore -from src.app.api.dependencies import get_settings +from src.app.core.config import Settings from src.app.services.tutor.agents import ( PedagogicalEngineerAgent, SDGExpertAgent, @@ -13,9 +13,6 @@ ) from src.app.services.tutor.utils import extract_doc_info -settings = get_settings() - - GREENCOMP_COMPETENCIES = ( "Here are the GreenComp competencies: " "url: https://joint-research-centre.ec.europa.eu/greencomp-european-sustainability-competence-framework_en " @@ -46,7 +43,7 @@ ) -def _build_chat_model() -> AzureAIChatCompletionsModel: +def _build_chat_model(settings: Settings) -> AzureAIChatCompletionsModel: return AzureAIChatCompletionsModel( endpoint=settings.AZURE_APIM_API_BASE, credential=settings.AZURE_APIM_API_KEY, @@ -56,7 +53,7 @@ def _build_chat_model() -> AzureAIChatCompletionsModel: async def tutor_manager( - content: TutorSyllabusRequest, lang: str + content: TutorSyllabusRequest, lang: str, settings: Settings ) -> list[SyllabusResponseAgent]: formatted_content = MessageWithResources( lang=lang, @@ -70,7 +67,7 @@ async def tutor_manager( description=content.description, ) - chat_model = _build_chat_model() + chat_model = _build_chat_model(settings=settings) teacher_agent = UniversityTeacherAgent(chat_model, lang) sdg_agent = SDGExpertAgent(chat_model, GREENCOMP_COMPETENCIES, lang) pedagogical_agent = PedagogicalEngineerAgent( diff --git a/src/app/tests/api/api_v1/test_tutor.py b/src/app/tests/api/api_v1/test_tutor.py deleted file mode 100644 index a304740..0000000 --- a/src/app/tests/api/api_v1/test_tutor.py +++ /dev/null @@ -1,45 +0,0 @@ -import io -from unittest import IsolatedAsyncioTestCase, mock - -from fastapi.testclient import TestClient - -from src.app.core.config import settings -from src.main import app - -# client = TestClient(app) - - -@mock.patch("src.app.services.sql_db.sql_service.session_maker") -@mock.patch( - "src.app.services.security.check_api_key_sync", - new=mock.MagicMock(return_value=True), -) -class TutorTests(IsolatedAsyncioTestCase): - def test_tutor_no_files(self, *mocks): - with TestClient(app) as client: - response = client.post( - f"{settings.API_V1_STR}/tutor/search", - files={}, - headers={"x-API-Key": "test"}, - ) - assert response.status_code == 422 - - def test_tutor_empty_file(self, *mocks): - file = io.BytesIO(b"") - with TestClient(app) as client: - response = client.post( - f"{settings.API_V1_STR}/tutor/search", - files={"files": ("test.txt", file)}, - headers={"x-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 400) - - def test_tutor_file(self, *mocks): - file = io.BytesIO(b"this is a test file") - with TestClient(app) as client: - response = client.post( - f"{settings.API_V1_STR}/tutor/search", - files={"files": ("test.txt", file)}, - headers={"x-API-Key": "test"}, - ) - self.assertEqual(response.status_code, 204)