import asyncio import hashlib import logging import os from typing import List from db.db import get_db from db.models import CompanyTable from langchain_core.prompts import ChatPromptTemplate from langchain_openai import ChatOpenAI from schemas.router_schemas import CompanyData, PaginatedResponse from sqlalchemy import text from sqlalchemy.orm import selectinload logger = logging.getLogger(__name__) class CompanyQueryProcessor: def __init__(self): self.llm = ChatOpenAI( api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1", model="openai/gpt-4o-mini", temperature=0, ) # Query cache for performance self.query_cache = {} # SQL generation prompt self.sql_prompt = ChatPromptTemplate.from_messages( [ ( "system", """You are a SQL expert. Generate a SQLite query to find company IDs based on user requirements. Database Schema: - companies: id, name, industry, location, description, founded_year, website - company_sector: company_id, sector_id - sectors: id, name - investor_companies: investor_id, company_id - investors: id, name, aum - team_members: id, company_id, name, title IMPORTANT RULES: 1. ALWAYS return ONLY company IDs (companies.id) - use SELECT DISTINCT c.id 2. For industry: Check BOTH industry field AND sectors table with synonyms - Use LEFT JOIN for sectors so companies without sector tags still match - Include related terms: 'Fintech' → c.industry LIKE '%Fintech%' OR c.industry LIKE '%Finance%' OR sec.name LIKE '%Fintech%' OR sec.name LIKE '%Financial%' - 'AI' → c.industry LIKE '%AI%' OR c.industry LIKE '%Artificial Intelligence%' OR c.industry LIKE '%Machine Learning%' OR sec.name LIKE '%AI%' OR sec.name LIKE '%ML%' 3. For location: Be FLEXIBLE with variations and abbreviations - 'San Francisco' → c.location LIKE '%San Francisco%' OR c.location LIKE '%SF%' OR c.location LIKE '%Bay Area%' - 'New York' → c.location LIKE '%New York%' OR c.location LIKE '%NYC%' OR c.location LIKE '%NY%' - 'Europe' → c.location LIKE '%Europe%' OR c.location LIKE '%UK%' OR c.location LIKE '%London%' OR c.location LIKE '%Berlin%' OR c.location LIKE '%Paris%' 4. For sectors: Use LEFT JOIN and include multiple synonyms - 'Healthcare' → sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR c.industry LIKE '%Health%' 5. For founding year filters (include NULL to be inclusive): - "founded after 2020" → WHERE (founded_year >= 2020 OR founded_year IS NULL) - "founded before 2018" → WHERE (founded_year <= 2018 OR founded_year IS NULL) - "founded in 2020" → WHERE founded_year = 2020 6. For investor-related queries: Use JOIN investor_companies 7. Use LEFT JOIN for sectors so companies without tags still match 8. Use DISTINCT to avoid duplicates from joins 9. Be INCLUSIVE - use OR conditions with synonyms and variations 10. Return a single, complete SELECT query Example Queries: Q: "Fintech companies founded in 2020" A: SELECT DISTINCT c.id FROM companies c LEFT JOIN company_sector cs ON c.id = cs.company_id LEFT JOIN sectors sec ON cs.sector_id = sec.id WHERE (c.industry LIKE '%Fintech%' OR c.industry LIKE '%Finance%' OR c.industry LIKE '%Financial%' OR sec.name LIKE '%Fintech%' OR sec.name LIKE '%Financial Services%') AND c.founded_year = 2020 Q: "AI companies in San Francisco" A: SELECT DISTINCT c.id FROM companies c LEFT JOIN company_sector cs ON c.id = cs.company_id LEFT JOIN sectors sec ON cs.sector_id = sec.id WHERE (c.industry LIKE '%AI%' OR c.industry LIKE '%Artificial Intelligence%' OR c.industry LIKE '%Machine Learning%' OR sec.name LIKE '%AI%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%') AND (c.location LIKE '%San Francisco%' OR c.location LIKE '%SF%' OR c.location LIKE '%Bay Area%') Q: "Healthcare companies" A: SELECT DISTINCT c.id FROM companies c LEFT JOIN company_sector cs ON c.id = cs.company_id LEFT JOIN sectors sec ON cs.sector_id = sec.id WHERE c.industry LIKE '%Healthcare%' OR c.industry LIKE '%Health%' OR c.industry LIKE '%Medical%' OR sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR sec.name LIKE '%Pharma%' Q: "Companies funded by Sequoia" A: SELECT DISTINCT c.id FROM companies c JOIN investor_companies ic ON c.id = ic.company_id JOIN investors i ON ic.investor_id = i.id WHERE i.name LIKE '%Sequoia%' Q: "European startups founded after 2019" A: SELECT DISTINCT c.id FROM companies c WHERE (c.location LIKE '%Europe%' OR c.location LIKE '%UK%' OR c.location LIKE '%London%' OR c.location LIKE '%Germany%' OR c.location LIKE '%Berlin%' OR c.location LIKE '%France%' OR c.location LIKE '%Paris%') AND (c.founded_year > 2019 OR c.founded_year IS NULL) Q: "SaaS companies" A: SELECT DISTINCT c.id FROM companies c LEFT JOIN company_sector cs ON c.id = cs.company_id LEFT JOIN sectors sec ON cs.sector_id = sec.id WHERE c.industry LIKE '%SaaS%' OR c.industry LIKE '%Software%' OR c.industry LIKE '%Cloud%' OR sec.name LIKE '%SaaS%' OR sec.name LIKE '%Software%' IMPORTANT: - Use LEFT JOIN so companies without sector tags still match via industry field - Use OR conditions with related keywords/synonyms to cast a wider net - Include NULL checks for optional filters to avoid excluding companies with missing data Return ONLY the SQL query, no explanations or markdown.""", ), ("user", "{question}"), ] ) def _get_cache_key(self, question: str) -> str: """Generate cache key from normalized question.""" return hashlib.md5(question.lower().strip().encode()).hexdigest() # synchronous helper is provided below as `_process_query_sync` and an # async wrapper `process_query` runs it in a thread. This keeps the # FastAPI event loop non-blocking while reusing the existing sync code. async def process_query(self, question: str) -> PaginatedResponse[CompanyData]: """Async wrapper for process_query. Runs blocking work in a thread to avoid blocking the event loop. """ return await asyncio.to_thread(self._process_query_sync, question) def _process_query_sync(self, question: str) -> PaginatedResponse[CompanyData]: """Synchronous implementation of process_query. This is run in a thread by the async wrapper above. """ cache_key = self._get_cache_key(question) # Check cache first if cache_key in self.query_cache: sql_query = self.query_cache[cache_key] logger.info(f"Using cached SQL: {sql_query}") else: # Generate SQL query messages = self.sql_prompt.format_messages(question=question) response = self.llm.invoke(messages) sql_query = response.content.strip() # Clean up SQL (remove markdown code blocks if present) sql_query = sql_query.replace("```sql", "").replace("```", "").strip() # Cache the query self.query_cache[cache_key] = sql_query logger.info(f"Generated SQL: {sql_query}") # Execute query to get company IDs db_session = next(get_db()) try: result = db_session.execute(text(sql_query)) company_ids = [row[0] for row in result.fetchall()] logger.info( f"Found {len(company_ids)} company IDs: {company_ids[:10]}{'...' if len(company_ids) > 10 else ''}" ) return self._fetch_companies_by_ids(company_ids) except Exception as e: logger.error(f"SQL execution error: {e}") logger.error(f"Failed SQL: {sql_query}") # Return empty result return PaginatedResponse( items=[], total=0, page=1, page_size=10, total_pages=0 ) finally: db_session.close() def _fetch_companies_by_ids( self, company_ids: List[int] ) -> PaginatedResponse[CompanyData]: """Fetch companies with all their relationships from the database using company IDs. Args: company_ids: List of company IDs to fetch """ if not company_ids: return PaginatedResponse( items=[], total=0, page=1, page_size=10, total_pages=0, ) # Get database session db_session = next(get_db()) try: # Query companies with all necessary relationships loaded companies = ( db_session.query(CompanyTable) .options( selectinload(CompanyTable.investors), selectinload(CompanyTable.members), selectinload(CompanyTable.sectors), ) .filter(CompanyTable.id.in_(company_ids)) .all() ) # Transform to CompanyData format company_data_list = [] for company in companies: company_data = CompanyData( company=company, investors=company.investors, members=company.members, sectors=company.sectors, ) company_data_list.append(company_data) total_count = len(company_data_list) total_pages = 1 if total_count > 0 else 0 return PaginatedResponse( items=company_data_list, total=total_count, page=1, page_size=total_count, total_pages=total_pages, ) finally: db_session.close()