177 lines
6.9 KiB
Python
177 lines
6.9 KiB
Python
|
|
import logging
|
||
|
|
import os
|
||
|
|
from typing import List
|
||
|
|
|
||
|
|
from db.db import DATABASE_URL, get_db
|
||
|
|
from db.models import CompanyTable
|
||
|
|
from langchain import hub
|
||
|
|
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||
|
|
from langchain_community.utilities import SQLDatabase
|
||
|
|
from langchain_openai import ChatOpenAI
|
||
|
|
from langgraph.prebuilt import create_react_agent
|
||
|
|
from schemas.router_schemas import CompanyData, PaginatedResponse
|
||
|
|
from sqlalchemy.orm import selectinload
|
||
|
|
|
||
|
|
logger = logging.getLogger(__name__)
|
||
|
|
# Connect to SQLite
|
||
|
|
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
|
||
|
|
db = SQLDatabase.from_uri(DATABASE_URL)
|
||
|
|
|
||
|
|
|
||
|
|
class CompanyQueryProcessor:
|
||
|
|
def __init__(self):
|
||
|
|
self.llm = ChatOpenAI(
|
||
|
|
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||
|
|
base_url="https://openrouter.ai/api/v1",
|
||
|
|
model="openai/gpt-4o-mini",
|
||
|
|
temperature=0,
|
||
|
|
)
|
||
|
|
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
||
|
|
# Update system message to specifically request only company IDs
|
||
|
|
system_message_updated = (
|
||
|
|
prompt_template.format(dialect="SQLite", top_k=5)
|
||
|
|
+ "\n\n=== CRITICAL INSTRUCTIONS ==="
|
||
|
|
+ "\n- Your ONLY task is to run SQL queries and extract company IDs"
|
||
|
|
+ "\n- When you get SQL results with company IDs, return them EXACTLY as shown"
|
||
|
|
+ "\n- If the SQL query returns rows with company IDs like [(1,), (5,), (9,)], return all those IDs"
|
||
|
|
+ "\n- Do NOT add any explanations, just list the IDs"
|
||
|
|
+ "\n- If a query returns NO ROWS (empty result), then respond with 'NO_RESULTS'"
|
||
|
|
+ "\n\n=== QUERY GUIDELINES ==="
|
||
|
|
+ "\n1. For sector searches: SELECT companies.id FROM companies JOIN company_sector ON companies.id = company_sector.company_id JOIN sectors ON company_sector.sector_id = sectors.id WHERE sectors.name LIKE '%sector_name%'"
|
||
|
|
+ "\n2. For industry searches: WHERE companies.industry LIKE '%search_term%'"
|
||
|
|
+ "\n3. For location searches: WHERE companies.location LIKE '%location%'"
|
||
|
|
+ "\n4. For founding year searches: WHERE companies.founded_year >= year"
|
||
|
|
+ "\n5. For investor-related: JOIN investor_companies table"
|
||
|
|
)
|
||
|
|
self.agent = create_react_agent(
|
||
|
|
model=self.llm,
|
||
|
|
tools=self.toolkit.get_tools(),
|
||
|
|
prompt=system_message_updated,
|
||
|
|
)
|
||
|
|
|
||
|
|
def process_query(self, question: str) -> PaginatedResponse[CompanyData]:
|
||
|
|
"""Process a query using the LLM and return company response data.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
question: The natural language query to process
|
||
|
|
"""
|
||
|
|
# Let the LLM handle all database interactions and filtering to get company IDs
|
||
|
|
response = self.agent.invoke(
|
||
|
|
{"messages": [("user", question)]},
|
||
|
|
config={"recursion_limit": 50},
|
||
|
|
)
|
||
|
|
|
||
|
|
# Extract the actual message content
|
||
|
|
logger.info(f"{response}")
|
||
|
|
|
||
|
|
# Look through all messages to find the SQL query results (ToolMessage with actual data)
|
||
|
|
company_ids = []
|
||
|
|
for message in response["messages"]:
|
||
|
|
if hasattr(message, "content") and message.content:
|
||
|
|
# Check if this looks like SQL results (contains tuples with numbers)
|
||
|
|
if "(" in str(message.content) and "," in str(message.content):
|
||
|
|
company_ids = self._extract_company_ids_from_response(
|
||
|
|
str(message.content)
|
||
|
|
)
|
||
|
|
if company_ids:
|
||
|
|
logger.info(
|
||
|
|
f"Extracted {len(company_ids)} company IDs from results"
|
||
|
|
)
|
||
|
|
break
|
||
|
|
|
||
|
|
# If no IDs found from ToolMessage, check the final AI message
|
||
|
|
if not company_ids:
|
||
|
|
final_message_content = response["messages"][-1].content
|
||
|
|
logger.info(f"AI Response: \n{final_message_content}")
|
||
|
|
company_ids = self._extract_company_ids_from_response(final_message_content)
|
||
|
|
|
||
|
|
# Fetch full company data with relationships using the IDs
|
||
|
|
return self._fetch_companies_by_ids(company_ids)
|
||
|
|
|
||
|
|
def _extract_company_ids_from_response(self, ai_response: str) -> List[int]:
|
||
|
|
"""Extract company IDs from AI response."""
|
||
|
|
import re
|
||
|
|
|
||
|
|
company_ids = []
|
||
|
|
|
||
|
|
# Check if response is NO_RESULTS
|
||
|
|
if "NO_RESULTS" in ai_response.upper():
|
||
|
|
return []
|
||
|
|
|
||
|
|
try:
|
||
|
|
# The response contains tuples like (1,), (5,), etc.
|
||
|
|
# Extract numbers between parentheses
|
||
|
|
pattern = r"\((\d+),?\)"
|
||
|
|
matches = re.findall(pattern, ai_response)
|
||
|
|
if matches:
|
||
|
|
company_ids = [int(match) for match in matches]
|
||
|
|
else:
|
||
|
|
# Fallback: extract all numbers
|
||
|
|
numbers = re.findall(r"\b\d+\b", ai_response)
|
||
|
|
# Filter out very large numbers that might be tokens or timestamps
|
||
|
|
company_ids = [int(num) for num in numbers if int(num) < 100000]
|
||
|
|
|
||
|
|
except Exception as e:
|
||
|
|
logger.error(f"Error extracting IDs from response: {e}")
|
||
|
|
return []
|
||
|
|
|
||
|
|
return company_ids
|
||
|
|
|
||
|
|
def _fetch_companies_by_ids(
|
||
|
|
self, company_ids: List[int]
|
||
|
|
) -> PaginatedResponse[CompanyData]:
|
||
|
|
"""Fetch companies with all their relationships from the database using company IDs.
|
||
|
|
|
||
|
|
Args:
|
||
|
|
company_ids: List of company IDs to fetch
|
||
|
|
"""
|
||
|
|
if not company_ids:
|
||
|
|
return PaginatedResponse(
|
||
|
|
items=[],
|
||
|
|
total=0,
|
||
|
|
page=1,
|
||
|
|
page_size=len(company_ids) if company_ids else 10,
|
||
|
|
total_pages=0,
|
||
|
|
)
|
||
|
|
|
||
|
|
# Get database session
|
||
|
|
db_session = next(get_db())
|
||
|
|
|
||
|
|
try:
|
||
|
|
# Query companies with all necessary relationships loaded
|
||
|
|
companies = (
|
||
|
|
db_session.query(CompanyTable)
|
||
|
|
.options(
|
||
|
|
selectinload(CompanyTable.investors),
|
||
|
|
selectinload(CompanyTable.members),
|
||
|
|
selectinload(CompanyTable.sectors),
|
||
|
|
)
|
||
|
|
.filter(CompanyTable.id.in_(company_ids))
|
||
|
|
.all()
|
||
|
|
)
|
||
|
|
|
||
|
|
# Transform to CompanyData format
|
||
|
|
company_data_list = []
|
||
|
|
for company in companies:
|
||
|
|
company_data = CompanyData(
|
||
|
|
company=company,
|
||
|
|
investors=company.investors,
|
||
|
|
members=company.members,
|
||
|
|
sectors=company.sectors,
|
||
|
|
)
|
||
|
|
company_data_list.append(company_data)
|
||
|
|
|
||
|
|
total_count = len(company_data_list)
|
||
|
|
total_pages = 1 if total_count > 0 else 0
|
||
|
|
|
||
|
|
return PaginatedResponse(
|
||
|
|
items=company_data_list,
|
||
|
|
total=total_count,
|
||
|
|
page=1,
|
||
|
|
page_size=total_count,
|
||
|
|
total_pages=total_pages,
|
||
|
|
)
|
||
|
|
|
||
|
|
finally:
|
||
|
|
db_session.close()
|