import json import logging import os from typing import List, Optional from db.db import DATABASE_URL, get_db from db.models import FundTable, InvestorTable, ProjectTable from langchain import hub from langchain_community.agent_toolkits import SQLDatabaseToolkit from langchain_community.utilities import SQLDatabase from langchain_openai import ChatOpenAI from langgraph.prebuilt import create_react_agent from schemas.router_schemas import ( CompanyMinimal, InvestmentResponse, PaginatedResponse, SectorMinimal, ) from sqlalchemy.orm import selectinload from services.compatibility_score import calculate_project_investor_compatibility logger = logging.getLogger(__name__) # Connect to SQLite prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt") db = SQLDatabase.from_uri(DATABASE_URL) class QueryProcessor: def __init__(self): self.llm = ChatOpenAI( api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1", model="openai/gpt-4o-mini", temperature=0, ) self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) # Update system message to specifically request only fund IDs system_message_updated = ( prompt_template.format(dialect="SQLite", top_k=5) + "\n\n=== IMPORTANT TERMINOLOGY ===" + "\n- When users say 'investors' or 'find me investors', they mean FUNDS" + "\n- Always query the 'funds' table for investment opportunities" + "\n- The 'investors' table is for parent company information only" + "\n- Relationship: investors (1) -> (many) funds" + "\n\n=== YOUR TASK ===" + "\nReturn ONLY fund IDs (funds.id) that match the user's criteria." + "\nFormat: comma-separated numbers only (e.g., 1, 5, 12, 23)" + "\nNo explanations, no other data." + "\n\n=== QUERY GUIDELINES ===" + "\n1. For geographic searches: use funds.geographic_focus" + "\n2. For sector searches: JOIN with fund_sectors table" + "\n3. For stage searches: JOIN with fund_investment_stages table" + "\n4. If no results: respond with 'NO_RESULTS'" + "\n5. Never repeat the same failed query" ) self.agent = create_react_agent( model=self.llm, tools=self.toolkit.get_tools(), prompt=system_message_updated, ) def process_query( self, question: str, project_id: Optional[int] = None ) -> PaginatedResponse[InvestmentResponse]: """Process a query using the LLM and return investment response data. Args: question: The natural language query to process project_id: Optional project ID for compatibility scoring """ # Let the LLM handle all database interactions and filtering to get fund IDs response = self.agent.invoke( {"messages": [("user", question)]}, config={"recursion_limit": 50}, ) # Extract the actual message content logger.info(f"{response}") final_message_content = response["messages"][-1].content logger.info(f"AI Response: \n{final_message_content}") # Extract fund IDs from the AI response fund_ids = self._extract_fund_ids_from_response(final_message_content) # Fetch full fund data with investor relationships using the IDs return self._fetch_funds_by_ids(fund_ids, project_id) def _extract_fund_ids_from_response(self, ai_response: str) -> List[int]: """Extract fund IDs from AI response.""" import re fund_ids = [] try: # Try multiple patterns to extract IDs from the response # Pattern 1: Simple numbers (assuming they are IDs) numbers = re.findall(r"\b\d+\b", ai_response) fund_ids = [int(num) for num in numbers] # Pattern 2: If response contains explicit ID references id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower()) if id_matches: fund_ids = [int(id_str) for id_str in id_matches] except Exception as e: print(f"Error extracting IDs from response: {e}") return [] return fund_ids def _fetch_funds_by_ids( self, fund_ids: List[int], project_id: Optional[int] = None ) -> PaginatedResponse[InvestmentResponse]: """Fetch funds with all their relationships from the database using fund IDs. Constructs response similar to read_investors but starting from funds. Args: fund_ids: List of fund IDs to fetch project_id: Optional project ID for compatibility scoring """ if not fund_ids: return PaginatedResponse( items=[], total=0, page=1, page_size=len(fund_ids) if fund_ids else 10, total_pages=0, ) # Get database session db_session = next(get_db()) try: # Load project if project_id provided project = None if project_id is not None: project = ( db_session.query(ProjectTable) .options(selectinload(ProjectTable.sector)) .filter(ProjectTable.id == project_id) .first() ) # Query funds with all necessary relationships loaded funds = ( db_session.query(FundTable) .options( selectinload(FundTable.investor).selectinload( InvestorTable.portfolio_companies ), selectinload(FundTable.investor).selectinload( InvestorTable.team_members ), selectinload(FundTable.investor).selectinload( InvestorTable.sectors ), selectinload(FundTable.investment_stages), selectinload(FundTable.sectors), ) .filter(FundTable.id.in_(fund_ids)) .all() ) # Transform to InvestmentResponse format (one row per fund) investment_responses = [] for fund in funds: investor = fund.investor # Calculate compatibility score if project provided compatibility_score = 1.0 if project is not None: compatibility_score = calculate_project_investor_compatibility( project=project, investor=investor, use_funds=True ) # Get top 3 portfolio companies (id and name only) portfolio_companies = [ CompanyMinimal(id=company.id, name=company.name) for company in investor.portfolio_companies[:3] ] # Get stage focus as comma-separated string stage_focus = ( ", ".join([stage.name for stage in fund.investment_stages]) if fund.investment_stages else None ) # Get top 3 sectors from fund (id and name only) fund_sectors = [ SectorMinimal(id=sector.id, name=sector.name) for sector in (fund.sectors[:3] if fund.sectors else []) ] investment_response = InvestmentResponse( id=investor.id, name=f"{investor.name} - {fund.fund_name}" if fund.fund_name else investor.name, aum=investor.aum, check_size_lower=fund.check_size_lower, check_size_upper=fund.check_size_upper, geographic_focus=fund.geographic_focus, stage_focus=stage_focus, portfolio_companies=portfolio_companies, sectors=fund_sectors, compatibility_score=compatibility_score, ) investment_responses.append(investment_response) total_count = len(investment_responses) total_pages = 1 if total_count > 0 else 0 return PaginatedResponse( items=investment_responses, total=total_count, page=1, page_size=total_count, total_pages=total_pages, ) finally: db_session.close()