import os from typing import List, Optional from db.db import DATABASE_URL, get_db from db.models import FundTable, InvestorTable, ProjectTable from langchain import hub from langchain_community.agent_toolkits import SQLDatabaseToolkit from langchain_community.utilities import SQLDatabase from langchain_openai import ChatOpenAI from langgraph.prebuilt import create_react_agent from schemas.router_schemas import ( CompanyMinimal, InvestmentResponse, PaginatedResponse, SectorMinimal, ) from sqlalchemy.orm import selectinload from services.compatibility_score import calculate_project_investor_compatibility # Connect to SQLite prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt") db = SQLDatabase.from_uri(DATABASE_URL) class QueryProcessor: def __init__(self): self.llm = ChatOpenAI( api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1", model="x-ai/grok-4-fast", temperature=0, ) self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) # Update system message to specifically request only fund IDs system_message_updated = ( prompt_template.format(dialect="SQLite", top_k=5) + "\n\nIMPORTANT: You must ONLY return the fund IDs (id field from the funds table) that match the user's criteria. " + "Do NOT return any other information, explanations, or data. " + "Your response should be ONLY a comma-separated list of numbers representing the fund IDs. " + "Example format: 1, 5, 12, 23" ) self.agent = create_react_agent( model=self.llm, tools=self.toolkit.get_tools(), prompt=system_message_updated, ) def process_query( self, question: str, project_id: Optional[int] = None ) -> PaginatedResponse[InvestmentResponse]: """Process a query using the LLM and return investment response data. Args: question: The natural language query to process project_id: Optional project ID for compatibility scoring """ # Let the LLM handle all database interactions and filtering to get fund IDs response = self.agent.invoke( {"messages": [("user", question)]}, ) # Extract the actual message content ai_response = ( response["messages"][-1].content if response.get("messages") else "" ) # Extract fund IDs from the AI response fund_ids = self._extract_fund_ids_from_response(ai_response) # Fetch full fund data with investor relationships using the IDs return self._fetch_funds_by_ids(fund_ids, project_id) def _extract_fund_ids_from_response(self, ai_response: str) -> List[int]: """Extract fund IDs from AI response.""" import re fund_ids = [] try: # Try multiple patterns to extract IDs from the response # Pattern 1: Simple numbers (assuming they are IDs) numbers = re.findall(r"\b\d+\b", ai_response) fund_ids = [int(num) for num in numbers] # Pattern 2: If response contains explicit ID references id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower()) if id_matches: fund_ids = [int(id_str) for id_str in id_matches] except Exception as e: print(f"Error extracting IDs from response: {e}") return [] return fund_ids def _fetch_funds_by_ids( self, fund_ids: List[int], project_id: Optional[int] = None ) -> PaginatedResponse[InvestmentResponse]: """Fetch funds with all their relationships from the database using fund IDs. Constructs response similar to read_investors but starting from funds. Args: fund_ids: List of fund IDs to fetch project_id: Optional project ID for compatibility scoring """ if not fund_ids: return PaginatedResponse( items=[], total=0, page=1, page_size=len(fund_ids) if fund_ids else 10, total_pages=0, ) # Get database session db_session = next(get_db()) try: # Load project if project_id provided project = None if project_id is not None: project = ( db_session.query(ProjectTable) .options(selectinload(ProjectTable.sector)) .filter(ProjectTable.id == project_id) .first() ) # Query funds with all necessary relationships loaded funds = ( db_session.query(FundTable) .options( selectinload(FundTable.investor).selectinload( InvestorTable.portfolio_companies ), selectinload(FundTable.investor).selectinload( InvestorTable.team_members ), selectinload(FundTable.investor).selectinload( InvestorTable.sectors ), selectinload(FundTable.investment_stages), selectinload(FundTable.sectors), ) .filter(FundTable.id.in_(fund_ids)) .all() ) # Transform to InvestmentResponse format (one row per fund) investment_responses = [] for fund in funds: investor = fund.investor # Calculate compatibility score if project provided compatibility_score = 1.0 if project is not None: compatibility_score = calculate_project_investor_compatibility( project=project, investor=investor, use_funds=True ) # Get top 3 portfolio companies (id and name only) portfolio_companies = [ CompanyMinimal(id=company.id, name=company.name) for company in investor.portfolio_companies[:3] ] # Get stage focus as comma-separated string stage_focus = ( ", ".join([stage.name for stage in fund.investment_stages]) if fund.investment_stages else None ) # Get top 3 sectors from fund (id and name only) fund_sectors = [ SectorMinimal(id=sector.id, name=sector.name) for sector in (fund.sectors[:3] if fund.sectors else []) ] investment_response = InvestmentResponse( id=investor.id, name=f"{investor.name} - {fund.fund_name}" if fund.fund_name else investor.name, aum=investor.aum, check_size_lower=fund.check_size_lower, check_size_upper=fund.check_size_upper, geographic_focus=fund.geographic_focus, stage_focus=stage_focus, portfolio_companies=portfolio_companies, sectors=fund_sectors, compatibility_score=compatibility_score, ) investment_responses.append(investment_response) total_count = len(investment_responses) total_pages = 1 if total_count > 0 else 0 return PaginatedResponse( items=investment_responses, total=total_count, page=1, page_size=total_count, total_pages=total_pages, ) finally: db_session.close()