import asyncio import hashlib import logging import os from typing import List, Optional from db.db import get_db from db.models import FundTable, InvestorTable, ProjectTable from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from schemas.router_schemas import ( CompanyMinimal, InvestmentResponse, PaginatedResponse, SectorMinimal, ) from sqlalchemy import text from sqlalchemy.orm import selectinload from services.compatibility_score import calculate_project_investor_compatibility logger = logging.getLogger(__name__) class QueryProcessor: def __init__(self): self.llm = ChatOpenAI( api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1", model="openai/gpt-4o-mini", temperature=0, ) # Query cache for performance self.query_cache = {} # SQL generation prompt self.sql_prompt = ChatPromptTemplate.from_messages( [ ( "system", """You are a SQL expert. Generate a SQLite query to find fund IDs based on user requirements. Database Schema: - funds: id, fund_name, investor_id, check_size_lower, check_size_upper, geographic_focus - fund_sectors: fund_id, sector_id - fund_investment_stages: fund_id, stage_id - sectors: id, name - investment_stages: id, name - investors: id, name, aum IMPORTANT RULES: 1. ALWAYS return ONLY fund IDs (funds.id) - use SELECT DISTINCT f.id 2. For geography: Be FLEXIBLE - use OR with variations and partial matches - 'Europe' → WHERE geographic_focus LIKE '%Europe%' OR geographic_focus LIKE '%European%' - 'America' → WHERE geographic_focus LIKE '%America%' OR geographic_focus LIKE '%US%' OR geographic_focus LIKE '%United States%' - 'Asia' → WHERE geographic_focus LIKE '%Asia%' OR geographic_focus LIKE '%Asian%' - If no geography specified, DON'T filter by geography 3. For stages: Use LEFT JOIN and LIKE for flexible matching with synonyms - 'Seed' → s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%' - 'Series A' → s.name LIKE '%Series A%' OR s.name LIKE '%A%' - 'Growth' → s.name LIKE '%Growth%' OR s.name LIKE '%Late%' OR s.name LIKE '%Expansion%' - If stage not specified, include ALL funds 4. For sectors: Use LEFT JOIN and include related terms with OR - 'Fintech' → sec.name LIKE '%Fintech%' OR sec.name LIKE '%Finance%' OR sec.name LIKE '%Financial%' - 'AI' → sec.name LIKE '%AI%' OR sec.name LIKE '%Artificial Intelligence%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%' - 'Healthcare' → sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' 5. For check size filters (be flexible with ranges): - "under X" → WHERE (check_size_upper <= X OR check_size_upper IS NULL) - "over X" → WHERE (check_size_lower >= X OR check_size_lower IS NULL) - "between X and Y" → WHERE check_size_lower >= X AND check_size_upper <= Y 6. Use LEFT JOIN for stages and sectors so funds without tags still match 7. Use DISTINCT to avoid duplicates from joins 8. Be INCLUSIVE - use OR conditions to cast a wider net 9. If query is very simple (e.g., just "seed stage"), don't add unnecessary filters 10. Return a single, complete SELECT query Example Queries: Q: "Seed stage investors in Europe" A: SELECT DISTINCT f.id FROM funds f LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id LEFT JOIN investment_stages s ON fis.stage_id = s.id WHERE (s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%' OR s.id IS NULL) AND (f.geographic_focus LIKE '%Europe%' OR f.geographic_focus LIKE '%European%') Q: "Fintech investors with check size under 5 million" A: SELECT DISTINCT f.id FROM funds f LEFT JOIN fund_sectors fs ON f.id = fs.fund_id LEFT JOIN sectors sec ON fs.sector_id = sec.id WHERE (sec.name LIKE '%Fintech%' OR sec.name LIKE '%Finance%' OR sec.name LIKE '%Financial%' OR sec.id IS NULL) AND (f.check_size_upper <= 5000000 OR f.check_size_upper IS NULL) Q: "Seed stage investors" A: SELECT DISTINCT f.id FROM funds f LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id LEFT JOIN investment_stages s ON fis.stage_id = s.id WHERE s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%' Q: "Growth stage investors" A: SELECT DISTINCT f.id FROM funds f LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id LEFT JOIN investment_stages s ON fis.stage_id = s.id WHERE s.name LIKE '%Growth%' OR s.name LIKE '%Late%' OR s.name LIKE '%Expansion%' OR s.name LIKE '%Series C%' OR s.name LIKE '%Series D%' Q: "AI investors in America" A: SELECT DISTINCT f.id FROM funds f LEFT JOIN fund_sectors fs ON f.id = fs.fund_id LEFT JOIN sectors sec ON fs.sector_id = sec.id WHERE (sec.name LIKE '%AI%' OR sec.name LIKE '%Artificial Intelligence%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%') AND (f.geographic_focus LIKE '%America%' OR f.geographic_focus LIKE '%US%' OR f.geographic_focus LIKE '%United States%' OR f.geographic_focus LIKE '%USA%') Q: "Healthcare investors" A: SELECT DISTINCT f.id FROM funds f LEFT JOIN fund_sectors fs ON f.id = fs.fund_id LEFT JOIN sectors sec ON fs.sector_id = sec.id WHERE sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR sec.name LIKE '%Pharma%' IMPORTANT: Use LEFT JOIN so funds without sector/stage tags can still match. Include synonym terms with OR for better recall. Return ONLY the SQL query, no explanations or markdown.""", ), ("user", "{question}"), ] ) def _get_cache_key(self, question: str) -> str: """Generate cache key from normalized question.""" return hashlib.md5(question.lower().strip().encode()).hexdigest() async def process_query( self, question: str, project_id: Optional[int] = None ) -> PaginatedResponse[InvestmentResponse]: """Async wrapper for process_query. Runs blocking work in a thread to avoid blocking the event loop. """ return await asyncio.to_thread(self._process_query_sync, question, project_id) def _process_query_sync( self, question: str, project_id: Optional[int] = None ) -> PaginatedResponse[InvestmentResponse]: """Synchronous implementation of process_query. This is run in a thread by the async wrapper above. """ cache_key = self._get_cache_key(question) # Check cache first if cache_key in self.query_cache: sql_query = self.query_cache[cache_key] logger.info(f"Using cached SQL: {sql_query}") else: # Generate SQL query messages = self.sql_prompt.format_messages(question=question) response = self.llm.invoke(messages) sql_query = response.content.strip() # Clean up SQL (remove markdown code blocks if present) sql_query = sql_query.replace("```sql", "").replace("```", "").strip() # Cache the query self.query_cache[cache_key] = sql_query logger.info(f"Generated SQL: {sql_query}") # Execute query to get fund IDs db_session = next(get_db()) try: result = db_session.execute(text(sql_query)) fund_ids = [row[0] for row in result.fetchall()] logger.info( f"Found {len(fund_ids)} fund IDs: {fund_ids[:10]}{'...' if len(fund_ids) > 10 else ''}" ) return self._fetch_funds_by_ids(fund_ids, project_id) except Exception as e: logger.error(f"SQL execution error: {e}") logger.error(f"Failed SQL: {sql_query}") # Return empty result return PaginatedResponse( items=[], total=0, page=1, page_size=10, total_pages=0 ) finally: db_session.close() def _fetch_funds_by_ids( self, fund_ids: List[int], project_id: Optional[int] = None ) -> PaginatedResponse[InvestmentResponse]: """Fetch funds with all their relationships from the database using fund IDs. Constructs response similar to read_investors but starting from funds. Args: fund_ids: List of fund IDs to fetch project_id: Optional project ID for compatibility scoring """ if not fund_ids: return PaginatedResponse( items=[], total=0, page=1, page_size=len(fund_ids) if fund_ids else 10, total_pages=0, ) # Get database session db_session = next(get_db()) try: # Load project if project_id provided project = None if project_id is not None: project = ( db_session.query(ProjectTable) .options(selectinload(ProjectTable.sector)) .filter(ProjectTable.id == project_id) .first() ) # Query funds with all necessary relationships loaded funds = ( db_session.query(FundTable) .options( selectinload(FundTable.investor).selectinload( InvestorTable.portfolio_companies ), selectinload(FundTable.investor).selectinload( InvestorTable.team_members ), selectinload(FundTable.investor).selectinload( InvestorTable.sectors ), selectinload(FundTable.investment_stages), selectinload(FundTable.sectors), ) .filter(FundTable.id.in_(fund_ids)) .all() ) # Transform to InvestmentResponse format (one row per fund) investment_responses = [] for fund in funds: investor = fund.investor # Calculate compatibility score if project provided compatibility_score = 1.0 if project is not None: compatibility_score = calculate_project_investor_compatibility( project=project, investor=investor, use_funds=True ) # Get top 3 portfolio companies (id and name only) portfolio_companies = [ CompanyMinimal(id=company.id, name=company.name) for company in investor.portfolio_companies[:3] ] # Get stage focus as comma-separated string stage_focus = ( ", ".join([stage.name for stage in fund.investment_stages]) if fund.investment_stages else None ) # Get top 3 sectors from fund (id and name only) - sorted alphabetically fund_sectors = [ SectorMinimal(id=sector.id, name=sector.name) for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name) ] investment_response = InvestmentResponse( id=investor.id, name=f"{investor.name} - {fund.fund_name}" if fund.fund_name else investor.name, aum=investor.aum, check_size_lower=fund.check_size_lower, check_size_upper=fund.check_size_upper, geographic_focus=fund.geographic_focus, stage_focus=stage_focus, portfolio_companies=portfolio_companies, sectors=fund_sectors, compatibility_score=compatibility_score, ) investment_responses.append(investment_response) total_count = len(investment_responses) total_pages = 1 if total_count > 0 else 0 return PaginatedResponse( items=investment_responses, total=total_count, page=1, page_size=total_count, total_pages=total_pages, ) finally: db_session.close()