119 lines
4.4 KiB
Python
119 lines
4.4 KiB
Python
import os
|
|
from typing import List
|
|
|
|
from db.db import DATABASE_URL, get_db
|
|
from db.models import InvestorTable
|
|
from langchain import hub
|
|
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
|
from langchain_community.utilities import SQLDatabase
|
|
from langchain_openai import ChatOpenAI
|
|
from langgraph.prebuilt import create_react_agent
|
|
from schemas.py_schemas import InvestorData, InvestorList
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
# Connect to SQLite
|
|
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
|
|
db = SQLDatabase.from_uri(DATABASE_URL)
|
|
|
|
|
|
class QueryProcessor:
|
|
def __init__(self):
|
|
self.llm = ChatOpenAI(
|
|
api_key=os.getenv("OPENROUTER_API_KEY"),
|
|
base_url="https://openrouter.ai/api/v1",
|
|
model="openai/gpt-4o-mini",
|
|
temperature=0,
|
|
)
|
|
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
|
# Update system message to specifically request only investor IDs
|
|
system_message_updated = (
|
|
prompt_template.format(dialect="SQLite", top_k=5)
|
|
+ "\n\nIMPORTANT: You must ONLY return the investor IDs (id field) that match the user's criteria. "
|
|
+ "Do NOT return any other information, explanations, or data. "
|
|
+ "Your response should be ONLY a comma-separated list of numbers representing the investor IDs. "
|
|
+ "Example format: 1, 5, 12, 23"
|
|
)
|
|
self.agent = create_react_agent(
|
|
model=self.llm,
|
|
tools=self.toolkit.get_tools(),
|
|
prompt=system_message_updated,
|
|
)
|
|
|
|
def process_query(self, question: str) -> InvestorList:
|
|
"""Process a query using the LLM and return investor data."""
|
|
# Let the LLM handle all database interactions and filtering to get IDs
|
|
response = self.agent.invoke(
|
|
{"messages": [("user", question)]},
|
|
)
|
|
|
|
# Extract the actual message content
|
|
ai_response = (
|
|
response["messages"][-1].content if response.get("messages") else ""
|
|
)
|
|
|
|
# Extract investor IDs from the AI response
|
|
investor_ids = self._extract_investor_ids_from_response(ai_response)
|
|
|
|
# Fetch full investor data using the IDs
|
|
return self._fetch_investors_by_ids(investor_ids)
|
|
|
|
def _extract_investor_ids_from_response(self, ai_response: str) -> List[int]:
|
|
"""Extract investor IDs from AI response."""
|
|
import re
|
|
|
|
investor_ids = []
|
|
try:
|
|
# Try multiple patterns to extract IDs from the response
|
|
# Pattern 1: Simple numbers (assuming they are IDs)
|
|
numbers = re.findall(r"\b\d+\b", ai_response)
|
|
investor_ids = [int(num) for num in numbers]
|
|
|
|
# Pattern 2: If response contains explicit ID references
|
|
id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower())
|
|
if id_matches:
|
|
investor_ids = [int(id_str) for id_str in id_matches]
|
|
|
|
except Exception as e:
|
|
print(f"Error extracting IDs from response: {e}")
|
|
return []
|
|
|
|
return investor_ids
|
|
|
|
def _fetch_investors_by_ids(self, investor_ids: List[int]) -> InvestorList:
|
|
"""Fetch investors with all their relationships from the database using IDs."""
|
|
if not investor_ids:
|
|
return InvestorList(investors=[])
|
|
|
|
# Get database session
|
|
db_session = next(get_db())
|
|
|
|
try:
|
|
# Build query with all relationships loaded
|
|
query = (
|
|
db_session.query(InvestorTable)
|
|
.options(
|
|
selectinload(InvestorTable.portfolio_companies),
|
|
selectinload(InvestorTable.team_members),
|
|
selectinload(InvestorTable.sectors),
|
|
)
|
|
.filter(InvestorTable.id.in_(investor_ids))
|
|
)
|
|
|
|
investors = query.all()
|
|
|
|
# Transform to InvestorData format
|
|
investor_data_list = []
|
|
for investor in investors:
|
|
investor_data = InvestorData(
|
|
investor=investor,
|
|
portfolio_companies=investor.portfolio_companies,
|
|
team_members=investor.team_members,
|
|
sectors=investor.sectors,
|
|
)
|
|
investor_data_list.append(investor_data)
|
|
|
|
return InvestorList(investors=investor_data_list)
|
|
|
|
finally:
|
|
db_session.close()
|