Files
Anton_wireframe/app/services/company_querying.py
T

177 lines
6.9 KiB
Python
Raw Normal View History

import logging
import os
from typing import List
from db.db import DATABASE_URL, get_db
from db.models import CompanyTable
from langchain import hub
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
from schemas.router_schemas import CompanyData, PaginatedResponse
from sqlalchemy.orm import selectinload
logger = logging.getLogger(__name__)
# Connect to SQLite
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
db = SQLDatabase.from_uri(DATABASE_URL)
class CompanyQueryProcessor:
def __init__(self):
self.llm = ChatOpenAI(
api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1",
model="openai/gpt-4o-mini",
temperature=0,
)
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
# Update system message to specifically request only company IDs
system_message_updated = (
prompt_template.format(dialect="SQLite", top_k=5)
+ "\n\n=== CRITICAL INSTRUCTIONS ==="
+ "\n- Your ONLY task is to run SQL queries and extract company IDs"
+ "\n- When you get SQL results with company IDs, return them EXACTLY as shown"
+ "\n- If the SQL query returns rows with company IDs like [(1,), (5,), (9,)], return all those IDs"
+ "\n- Do NOT add any explanations, just list the IDs"
+ "\n- If a query returns NO ROWS (empty result), then respond with 'NO_RESULTS'"
+ "\n\n=== QUERY GUIDELINES ==="
+ "\n1. For sector searches: SELECT companies.id FROM companies JOIN company_sector ON companies.id = company_sector.company_id JOIN sectors ON company_sector.sector_id = sectors.id WHERE sectors.name LIKE '%sector_name%'"
+ "\n2. For industry searches: WHERE companies.industry LIKE '%search_term%'"
+ "\n3. For location searches: WHERE companies.location LIKE '%location%'"
+ "\n4. For founding year searches: WHERE companies.founded_year >= year"
+ "\n5. For investor-related: JOIN investor_companies table"
)
self.agent = create_react_agent(
model=self.llm,
tools=self.toolkit.get_tools(),
prompt=system_message_updated,
)
def process_query(self, question: str) -> PaginatedResponse[CompanyData]:
"""Process a query using the LLM and return company response data.
Args:
question: The natural language query to process
"""
# Let the LLM handle all database interactions and filtering to get company IDs
response = self.agent.invoke(
{"messages": [("user", question)]},
config={"recursion_limit": 50},
)
# Extract the actual message content
logger.info(f"{response}")
# Look through all messages to find the SQL query results (ToolMessage with actual data)
company_ids = []
for message in response["messages"]:
if hasattr(message, "content") and message.content:
# Check if this looks like SQL results (contains tuples with numbers)
if "(" in str(message.content) and "," in str(message.content):
company_ids = self._extract_company_ids_from_response(
str(message.content)
)
if company_ids:
logger.info(
f"Extracted {len(company_ids)} company IDs from results"
)
break
# If no IDs found from ToolMessage, check the final AI message
if not company_ids:
final_message_content = response["messages"][-1].content
logger.info(f"AI Response: \n{final_message_content}")
company_ids = self._extract_company_ids_from_response(final_message_content)
# Fetch full company data with relationships using the IDs
return self._fetch_companies_by_ids(company_ids)
def _extract_company_ids_from_response(self, ai_response: str) -> List[int]:
"""Extract company IDs from AI response."""
import re
company_ids = []
# Check if response is NO_RESULTS
if "NO_RESULTS" in ai_response.upper():
return []
try:
# The response contains tuples like (1,), (5,), etc.
# Extract numbers between parentheses
pattern = r"\((\d+),?\)"
matches = re.findall(pattern, ai_response)
if matches:
company_ids = [int(match) for match in matches]
else:
# Fallback: extract all numbers
numbers = re.findall(r"\b\d+\b", ai_response)
# Filter out very large numbers that might be tokens or timestamps
company_ids = [int(num) for num in numbers if int(num) < 100000]
except Exception as e:
logger.error(f"Error extracting IDs from response: {e}")
return []
return company_ids
def _fetch_companies_by_ids(
self, company_ids: List[int]
) -> PaginatedResponse[CompanyData]:
"""Fetch companies with all their relationships from the database using company IDs.
Args:
company_ids: List of company IDs to fetch
"""
if not company_ids:
return PaginatedResponse(
items=[],
total=0,
page=1,
page_size=len(company_ids) if company_ids else 10,
total_pages=0,
)
# Get database session
db_session = next(get_db())
try:
# Query companies with all necessary relationships loaded
companies = (
db_session.query(CompanyTable)
.options(
selectinload(CompanyTable.investors),
selectinload(CompanyTable.members),
selectinload(CompanyTable.sectors),
)
.filter(CompanyTable.id.in_(company_ids))
.all()
)
# Transform to CompanyData format
company_data_list = []
for company in companies:
company_data = CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=company.sectors,
)
company_data_list.append(company_data)
total_count = len(company_data_list)
total_pages = 1 if total_count > 0 else 0
return PaginatedResponse(
items=company_data_list,
total=total_count,
page=1,
page_size=total_count,
total_pages=total_pages,
)
finally:
db_session.close()