import os from typing import List from db.db import DATABASE_URL, get_db from db.models import InvestorTable from langchain import hub from langchain_community.agent_toolkits import SQLDatabaseToolkit from langchain_community.utilities import SQLDatabase from langchain_openai import ChatOpenAI from langgraph.prebuilt import create_react_agent from schemas.py_schemas import InvestorData, InvestorList from sqlalchemy.orm import selectinload # Connect to SQLite prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt") db = SQLDatabase.from_uri(DATABASE_URL) class QueryProcessor: def __init__(self): self.llm = ChatOpenAI( api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1", model="openai/gpt-4o-mini", temperature=0, ) self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) # Update system message to specifically request only investor IDs system_message_updated = ( prompt_template.format(dialect="SQLite", top_k=5) + "\n\nIMPORTANT: You must ONLY return the investor IDs (id field) that match the user's criteria. " + "Do NOT return any other information, explanations, or data. " + "Your response should be ONLY a comma-separated list of numbers representing the investor IDs. " + "Example format: 1, 5, 12, 23" ) self.agent = create_react_agent( model=self.llm, tools=self.toolkit.get_tools(), prompt=system_message_updated, ) def process_query(self, question: str) -> InvestorList: """Process a query using the LLM and return investor data.""" # Let the LLM handle all database interactions and filtering to get IDs response = self.agent.invoke( {"messages": [("user", question)]}, ) # Extract the actual message content ai_response = ( response["messages"][-1].content if response.get("messages") else "" ) # Extract investor IDs from the AI response investor_ids = self._extract_investor_ids_from_response(ai_response) # Fetch full investor data using the IDs return self._fetch_investors_by_ids(investor_ids) def _extract_investor_ids_from_response(self, ai_response: str) -> List[int]: """Extract investor IDs from AI response.""" import re investor_ids = [] try: # Try multiple patterns to extract IDs from the response # Pattern 1: Simple numbers (assuming they are IDs) numbers = re.findall(r"\b\d+\b", ai_response) investor_ids = [int(num) for num in numbers] # Pattern 2: If response contains explicit ID references id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower()) if id_matches: investor_ids = [int(id_str) for id_str in id_matches] except Exception as e: print(f"Error extracting IDs from response: {e}") return [] return investor_ids def _fetch_investors_by_ids(self, investor_ids: List[int]) -> InvestorList: """Fetch investors with all their relationships from the database using IDs.""" if not investor_ids: return InvestorList(investors=[]) # Get database session db_session = next(get_db()) try: # Build query with all relationships loaded query = ( db_session.query(InvestorTable) .options( selectinload(InvestorTable.portfolio_companies), selectinload(InvestorTable.team_members), selectinload(InvestorTable.sectors), ) .filter(InvestorTable.id.in_(investor_ids)) ) investors = query.all() # Transform to InvestorData format investor_data_list = [] for investor in investors: investor_data = InvestorData( investor=investor, portfolio_companies=investor.portfolio_companies, team_members=investor.team_members, sectors=investor.sectors, ) investor_data_list.append(investor_data) return InvestorList(investors=investor_data_list) finally: db_session.close()