-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
93 lines (77 loc) · 3.12 KB
/
main.py
File metadata and controls
93 lines (77 loc) · 3.12 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
from utils.logging import setup_logger
from utils.analyze_query import analyze_query
from schemas import ResearchRequest, ResearchResponse
from tools.web_search_tool import get_google_search_tool
from tools.content_analyzer_tool import run_content_analyzer_tool
from tools.web_scraper_tool import run_web_scraper_tool
from tools.result_aggregator_tool import run_result_aggregator_tool
from fastapi import FastAPI, Request, Body
from fastapi.responses import JSONResponse
from starlette.middleware.cors import CORSMiddleware
import uvicorn
from utils.get_embeddings import get_embeddings
from utils.get_relevant_urls import get_relevant_urls
import os
from dotenv import load_dotenv
load_dotenv()
logger = setup_logger("web-research-agent")
app = FastAPI(title="Web Research Agent")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
@app.exception_handler(Exception)
async def global_exception_handler(request: Request, exc: Exception):
logger.error(f"Unhandled error occurred during request: {request.url} - {exc}", exc_info=True)
return JSONResponse(
status_code=500,
content={"detail": "An unexpected error occurred. Please try again later."}
)
@app.get("/")
async def health_check():
return {"status": "ok"}
@app.post("/execute-research", response_model=ResearchResponse)
async def execute_research(payload: ResearchRequest = Body(...)):
query = payload.query
# query analysis
query_analysis = analyze_query(query)
if query_analysis.get("intent") == "invalid":
return {
"query": query,
"result": {
"content": query_analysis.get("reason", "Invalid query"),
"sources": []
}
}
logger.info(f"Query analysis: {query_analysis}")
subqueries = query_analysis.get("subqueries", [query])
# web search
google_search_tool = get_google_search_tool()
search_results = []
for sq in subqueries:
results = google_search_tool.func(sq, num_results=10)
for r in results:
snippet = r.get("snippet", r.get("title", ""))
search_results.append({"url": r["link"], "snippet": snippet, "subquery": sq})
# logger.info(f"Google search results: {search_results}")
logger.info(f"Collected {len(search_results)} search results")
embeddings = get_embeddings()
# get relevant URLs
M = 10
selected_urls = get_relevant_urls(query, embeddings, search_results, M)
logger.info(f"Selected top {M} URLs: {selected_urls}")
# scraping
docs = await run_web_scraper_tool(selected_urls)
logger.info(f"Scraped {len(docs)} pages")
# relevant content from content analysis
relevant_chunks = run_content_analyzer_tool(docs, embeddings, query_analysis, query)
logger.info(f'relevant chunks {len(relevant_chunks)}')
answer, sources = run_result_aggregator_tool(relevant_chunks, query, query_analysis)
return {
"query": query,
"result": {"content": answer, "sources": sources }
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=os.getenv('PORT', 8000))