import logging import os from typing import List from db.db import DATABASE_URL, get_db from db.models import CompanyTable from langchain import hub from langchain_community.agent_toolkits import SQLDatabaseToolkit from langchain_community.utilities import SQLDatabase from langchain_openai import ChatOpenAI from langgraph.prebuilt import create_react_agent from schemas.router_schemas import CompanyData, PaginatedResponse from sqlalchemy.orm import selectinload logger = logging.getLogger(__name__) # Connect to SQLite prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt") db = SQLDatabase.from_uri(DATABASE_URL) class CompanyQueryProcessor: def __init__(self): self.llm = ChatOpenAI( api_key=os.getenv("OPENROUTER_API_KEY"), base_url="https://openrouter.ai/api/v1", model="openai/gpt-4o-mini", temperature=0, ) self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) # Update system message to specifically request only company IDs system_message_updated = ( prompt_template.format(dialect="SQLite", top_k=5) + "\n\n=== CRITICAL INSTRUCTIONS ===" + "\n- Your ONLY task is to run SQL queries and extract company IDs" + "\n- When you get SQL results with company IDs, return them EXACTLY as shown" + "\n- If the SQL query returns rows with company IDs like [(1,), (5,), (9,)], return all those IDs" + "\n- Do NOT add any explanations, just list the IDs" + "\n- If a query returns NO ROWS (empty result), then respond with 'NO_RESULTS'" + "\n\n=== QUERY GUIDELINES ===" + "\n1. For sector searches: SELECT companies.id FROM companies JOIN company_sector ON companies.id = company_sector.company_id JOIN sectors ON company_sector.sector_id = sectors.id WHERE sectors.name LIKE '%sector_name%'" + "\n2. For industry searches: WHERE companies.industry LIKE '%search_term%'" + "\n3. For location searches: WHERE companies.location LIKE '%location%'" + "\n4. For founding year searches: WHERE companies.founded_year >= year" + "\n5. For investor-related: JOIN investor_companies table" ) self.agent = create_react_agent( model=self.llm, tools=self.toolkit.get_tools(), prompt=system_message_updated, ) def process_query(self, question: str) -> PaginatedResponse[CompanyData]: """Process a query using the LLM and return company response data. Args: question: The natural language query to process """ # Let the LLM handle all database interactions and filtering to get company IDs response = self.agent.invoke( {"messages": [("user", question)]}, config={"recursion_limit": 50}, ) # Extract the actual message content logger.info(f"{response}") # Look through all messages to find the SQL query results (ToolMessage with actual data) company_ids = [] for message in response["messages"]: if hasattr(message, "content") and message.content: # Check if this looks like SQL results (contains tuples with numbers) if "(" in str(message.content) and "," in str(message.content): company_ids = self._extract_company_ids_from_response( str(message.content) ) if company_ids: logger.info( f"Extracted {len(company_ids)} company IDs from results" ) break # If no IDs found from ToolMessage, check the final AI message if not company_ids: final_message_content = response["messages"][-1].content logger.info(f"AI Response: \n{final_message_content}") company_ids = self._extract_company_ids_from_response(final_message_content) # Fetch full company data with relationships using the IDs return self._fetch_companies_by_ids(company_ids) def _extract_company_ids_from_response(self, ai_response: str) -> List[int]: """Extract company IDs from AI response.""" import re company_ids = [] # Check if response is NO_RESULTS if "NO_RESULTS" in ai_response.upper(): return [] try: # The response contains tuples like (1,), (5,), etc. # Extract numbers between parentheses pattern = r"\((\d+),?\)" matches = re.findall(pattern, ai_response) if matches: company_ids = [int(match) for match in matches] else: # Fallback: extract all numbers numbers = re.findall(r"\b\d+\b", ai_response) # Filter out very large numbers that might be tokens or timestamps company_ids = [int(num) for num in numbers if int(num) < 100000] except Exception as e: logger.error(f"Error extracting IDs from response: {e}") return [] return company_ids def _fetch_companies_by_ids( self, company_ids: List[int] ) -> PaginatedResponse[CompanyData]: """Fetch companies with all their relationships from the database using company IDs. Args: company_ids: List of company IDs to fetch """ if not company_ids: return PaginatedResponse( items=[], total=0, page=1, page_size=len(company_ids) if company_ids else 10, total_pages=0, ) # Get database session db_session = next(get_db()) try: # Query companies with all necessary relationships loaded companies = ( db_session.query(CompanyTable) .options( selectinload(CompanyTable.investors), selectinload(CompanyTable.members), selectinload(CompanyTable.sectors), ) .filter(CompanyTable.id.in_(company_ids)) .all() ) # Transform to CompanyData format company_data_list = [] for company in companies: company_data = CompanyData( company=company, investors=company.investors, members=company.members, sectors=company.sectors, ) company_data_list.append(company_data) total_count = len(company_data_list) total_pages = 1 if total_count > 0 else 0 return PaginatedResponse( items=company_data_list, total=total_count, page=1, page_size=total_count, total_pages=total_pages, ) finally: db_session.close()