Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 30 additions & 147 deletions src/app/api/api_v1/endpoints/tutor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@
)

from src.app.api.dependencies import get_settings
from src.app.core.config import Settings
from src.app.models.search import EnhancedSearchQuery
from src.app.services.abst_chat import get_chat_service
from src.app.services.data_collection import get_data_collection_service
from src.app.services.exceptions import NoResultsError
from src.app.services.helpers import extract_json_from_response
from src.app.services.search import SearchService, get_search_service
from src.app.services.search_helpers import search_multi_inputs
from src.app.services.tutor.agents import TEMPLATES
Expand All @@ -43,8 +43,6 @@

router = APIRouter()

settings = get_settings()


def backoff_hdlr(details):
logger.info(
Expand All @@ -54,17 +52,21 @@ def backoff_hdlr(details):
)


def with_backoff():
return backoff.on_exception(
wait_gen=backoff.expo,
exception=Exception,
logger=logger,
max_tries=3,
max_time=180,
jitter=backoff.random_jitter,
on_backoff=backoff_hdlr,
factor=2,
)


@with_backoff()
@router.post("/files/content")
@backoff.on_exception(
wait_gen=backoff.expo,
exception=Exception,
logger=logger,
max_tries=3,
max_time=180,
jitter=backoff.random_jitter,
on_backoff=backoff_hdlr,
factor=2,
)
async def extract_files_content(
files: Annotated[list[UploadFile], File()],
response: Response,
Expand All @@ -88,21 +90,13 @@ async def extract_files_content(
]

try:
summaries = await chatfactory.chat_client.completion(messages=messages)

assert isinstance(summaries, str)
json_summaries = extract_json_from_response(summaries)
assert isinstance(json_summaries, list)

try:
summaries_output = ExtractorOutputList(extracts=json_summaries)
return summaries_output
except Exception:
formatted_output = await chatfactory.json_formatter_agent(
summaries,
"{extracts: [ 'summary': 'Summary', 'themes': [{'theme': 'Theme 1', 'reason': 'Reason for Theme 1'}, {'theme': 'Theme 2', 'reason': 'Reason for Theme 2'}, ...]}, { 'summary': 'Sumamry', 'themes': [{'theme': 'Theme 1', 'reason': 'Reason for Theme 1'}, {'theme': 'Theme 2', 'reason': 'Reason for Theme 2'}, ...]}] }",
)
return formatted_output
summaries_output = await chatfactory.run_llm_with_json_parsing(
messages,
ExtractorOutputList,
fallback_formatter="{extracts: [ 'summary': 'Summary', 'themes': [{'theme': 'Theme 1', 'reason': 'Reason for Theme 1'}, {'theme': 'Theme 2', 'reason': 'Reason for Theme 2'}, ...]}, { 'summary': 'Sumamry', 'themes': [{'theme': 'Theme 1', 'reason': 'Reason for Theme 1'}, {'theme': 'Theme 2', 'reason': 'Reason for Theme 2'}, ...]}] }",
)

return summaries_output

except Exception as e:
logger.error(f"Error in extractor schema: {e}")
Expand All @@ -111,6 +105,7 @@ async def extract_files_content(


@router.post("/search_extracts")
@with_backoff()
async def tutor_search_extract(
summaries: SummariesList,
background_tasks: BackgroundTasks,
Expand Down Expand Up @@ -157,121 +152,17 @@ async def tutor_search_extract(
return resp


@router.post("/search", deprecated=True)
async def tutor_search(
files: Annotated[list[UploadFile], File()],
response: Response,
sp: SearchService = Depends(get_search_service),
chatfactory=Depends(get_chat_service),
):
files_content = await get_files_content(files)

doc_list_to_string = "Document {doc_nb}: {content}"

file_content_str = [
doc_list_to_string.format(
doc_nb=index + 1,
content=content,
)
for index, content in enumerate(files_content)
]
file_content_str = "\n\n".join(file_content_str)

messages = [
{"role": "system", "content": extractor_system_prompt},
{"role": "user", "content": file_content_str},
]

try:
themes_extracted = await chatfactory.chat_client.completion(
messages=messages, response_format=ExtractorOutputList
)

jsn = {}
if isinstance(themes_extracted, str):
jsn = extract_json_from_response(themes_extracted)
elif isinstance(themes_extracted, dict):
jsn = themes_extracted
else:
raise ValueError("Unexpected response format")

print(jsn)
themes_extracted = ExtractorOutputList(**jsn)

except Exception as e:
logger.error(f"Error in chat schema: {e}")
response.status_code = 204
return TutorSearchResponse(
extracts=[],
nb_results=0,
documents=[],
)

if not themes_extracted or not themes_extracted.extracts:
return TutorSearchResponse(
extracts=[],
nb_results=0,
documents=[],
)

inputs = [doc.summary for doc in themes_extracted.extracts] # type: ignore

try:
qp = EnhancedSearchQuery(
query=inputs,
nb_results=10,
sdg_filter=None,
corpora=None,
)

search_results = await search_multi_inputs(
qp=qp,
callback_function=sp.search_handler,
)
except NoResultsError as e:
response.status_code = 404
logger.error(f"No results found: {e}")
return TutorSearchResponse(
extracts=themes_extracted.extracts,
nb_results=0,
documents=[],
)

if not search_results:
return TutorSearchResponse(
extracts=themes_extracted.extracts,
nb_results=0,
documents=[],
)

resp = TutorSearchResponse(
extracts=themes_extracted.extracts,
nb_results=len(search_results),
documents=search_results,
)

return resp


@backoff.on_exception(
wait_gen=backoff.expo,
exception=Exception,
logger=logger,
max_tries=3,
max_time=180,
jitter=backoff.random_jitter,
on_backoff=backoff_hdlr,
factor=2,
)
@with_backoff()
@router.post("/syllabus")
async def create_syllabus(
request: Request,
body: TutorSyllabusRequest,
lang: str = "en",
data_collection=Depends(get_data_collection_service),
settings: Settings = Depends(get_settings),
) -> SyllabusResponse:
session_id = request.headers.get("X-Session-ID")
results = await tutor_manager(body, lang)
results = await tutor_manager(body, lang, settings)

# TODO: handle errors

Expand Down Expand Up @@ -329,16 +220,7 @@ async def create_syllabus(
"""


@backoff.on_exception(
wait_gen=backoff.expo,
exception=Exception,
logger=logger,
max_tries=3,
max_time=180,
jitter=backoff.random_jitter,
on_backoff=backoff_hdlr,
factor=2,
)
@with_backoff()
@router.post("/syllabus/feedback")
async def handle_syllabus_feedback(
request: Request,
Expand All @@ -359,7 +241,7 @@ async def handle_syllabus_feedback(
syllabus=body.syllabus[0],
feedback=body.feedback,
documents=body.documents,
extracts=("/n").join([extract.summary for extract in body.extracts]),
extracts="\n".join([extract.summary for extract in body.extracts]),
themes=(", ").join(
[
(", ").join([theme["theme"] for theme in extract.themes])
Expand All @@ -373,7 +255,8 @@ async def handle_syllabus_feedback(
try:
syllabus = await chatfactory.chat_client.completion(messages=messages)

assert isinstance(syllabus, str)
if not isinstance(syllabus, str):
raise ValueError("Syllabus feedback response is not a string")

await data_collection.register_syllabus_data(
session_id=session_id,
Expand Down
21 changes: 21 additions & 0 deletions src/app/services/abst_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -463,6 +463,27 @@ async def agent_message(

return res

async def run_llm_with_json_parsing(
self,
messages: list[dict],
model_class,
fallback_formatter: str | None = None,
):
raw = await self.chat_client.completion(messages=messages)

if not isinstance(raw, str):
raise ValueError("LLM response must be string")

try:
json_data = extract_json_from_response(raw)
if not isinstance(json_data, dict):
raise ValueError("Extracted JSON data is not a dictionary")
return model_class(**json_data)
except Exception:
if fallback_formatter:
return await self.json_formatter_agent(raw, fallback_formatter)
raise


async def get_llm_client(request: Request):
return request.app.state.llm
Expand Down
11 changes: 4 additions & 7 deletions src/app/services/tutor/tutor.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from langchain_azure_ai.chat_models import AzureAIChatCompletionsModel # type: ignore

from src.app.api.dependencies import get_settings
from src.app.core.config import Settings
from src.app.services.tutor.agents import (
PedagogicalEngineerAgent,
SDGExpertAgent,
Expand All @@ -13,9 +13,6 @@
)
from src.app.services.tutor.utils import extract_doc_info

settings = get_settings()


GREENCOMP_COMPETENCIES = (
"Here are the GreenComp competencies: "
"url: https://joint-research-centre.ec.europa.eu/greencomp-european-sustainability-competence-framework_en "
Expand Down Expand Up @@ -46,7 +43,7 @@
)


def _build_chat_model() -> AzureAIChatCompletionsModel:
def _build_chat_model(settings: Settings) -> AzureAIChatCompletionsModel:
return AzureAIChatCompletionsModel(
endpoint=settings.AZURE_APIM_API_BASE,
credential=settings.AZURE_APIM_API_KEY,
Expand All @@ -56,7 +53,7 @@ def _build_chat_model() -> AzureAIChatCompletionsModel:


async def tutor_manager(
content: TutorSyllabusRequest, lang: str
content: TutorSyllabusRequest, lang: str, settings: Settings
) -> list[SyllabusResponseAgent]:
formatted_content = MessageWithResources(
lang=lang,
Expand All @@ -70,7 +67,7 @@ async def tutor_manager(
description=content.description,
)

chat_model = _build_chat_model()
chat_model = _build_chat_model(settings=settings)
teacher_agent = UniversityTeacherAgent(chat_model, lang)
sdg_agent = SDGExpertAgent(chat_model, GREENCOMP_COMPETENCIES, lang)
pedagogical_agent = PedagogicalEngineerAgent(
Expand Down
45 changes: 0 additions & 45 deletions src/app/tests/api/api_v1/test_tutor.py

This file was deleted.