229 lines
9.7 KiB
Python
229 lines
9.7 KiB
Python
import asyncio
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
from typing import List
|
|
|
|
from db.db import get_db
|
|
from db.models import CompanyTable
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_openai import ChatOpenAI
|
|
from schemas.router_schemas import CompanyData, PaginatedResponse
|
|
from sqlalchemy import text
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CompanyQueryProcessor:
|
|
def __init__(self):
|
|
self.llm = ChatOpenAI(
|
|
api_key=os.getenv("OPENROUTER_API_KEY"),
|
|
base_url="https://openrouter.ai/api/v1",
|
|
model="openai/gpt-4o-mini",
|
|
temperature=0,
|
|
)
|
|
|
|
# Query cache for performance
|
|
self.query_cache = {}
|
|
|
|
# SQL generation prompt
|
|
self.sql_prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
(
|
|
"system",
|
|
"""You are a SQL expert. Generate a SQLite query to find company IDs based on user requirements.
|
|
|
|
Database Schema:
|
|
- companies: id, name, industry, location, description, founded_year, website
|
|
- company_sector: company_id, sector_id
|
|
- sectors: id, name
|
|
- investor_companies: investor_id, company_id
|
|
- investors: id, name, aum
|
|
- team_members: id, company_id, name, title
|
|
|
|
IMPORTANT RULES:
|
|
1. ALWAYS return ONLY company IDs (companies.id) - use SELECT DISTINCT c.id
|
|
2. For industry: Check BOTH industry field AND sectors table with synonyms
|
|
- Use LEFT JOIN for sectors so companies without sector tags still match
|
|
- Include related terms: 'Fintech' → c.industry LIKE '%Fintech%' OR c.industry LIKE '%Finance%' OR sec.name LIKE '%Fintech%' OR sec.name LIKE '%Financial%'
|
|
- 'AI' → c.industry LIKE '%AI%' OR c.industry LIKE '%Artificial Intelligence%' OR c.industry LIKE '%Machine Learning%' OR sec.name LIKE '%AI%' OR sec.name LIKE '%ML%'
|
|
3. For location: Be FLEXIBLE with variations and abbreviations
|
|
- 'San Francisco' → c.location LIKE '%San Francisco%' OR c.location LIKE '%SF%' OR c.location LIKE '%Bay Area%'
|
|
- 'New York' → c.location LIKE '%New York%' OR c.location LIKE '%NYC%' OR c.location LIKE '%NY%'
|
|
- 'Europe' → c.location LIKE '%Europe%' OR c.location LIKE '%UK%' OR c.location LIKE '%London%' OR c.location LIKE '%Berlin%' OR c.location LIKE '%Paris%'
|
|
4. For sectors: Use LEFT JOIN and include multiple synonyms
|
|
- 'Healthcare' → sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR c.industry LIKE '%Health%'
|
|
5. For founding year filters (include NULL to be inclusive):
|
|
- "founded after 2020" → WHERE (founded_year >= 2020 OR founded_year IS NULL)
|
|
- "founded before 2018" → WHERE (founded_year <= 2018 OR founded_year IS NULL)
|
|
- "founded in 2020" → WHERE founded_year = 2020
|
|
6. For investor-related queries: Use JOIN investor_companies
|
|
7. Use LEFT JOIN for sectors so companies without tags still match
|
|
8. Use DISTINCT to avoid duplicates from joins
|
|
9. Be INCLUSIVE - use OR conditions with synonyms and variations
|
|
10. Return a single, complete SELECT query
|
|
|
|
Example Queries:
|
|
Q: "Fintech companies founded in 2020"
|
|
A: SELECT DISTINCT c.id FROM companies c
|
|
LEFT JOIN company_sector cs ON c.id = cs.company_id
|
|
LEFT JOIN sectors sec ON cs.sector_id = sec.id
|
|
WHERE (c.industry LIKE '%Fintech%' OR c.industry LIKE '%Finance%' OR c.industry LIKE '%Financial%' OR sec.name LIKE '%Fintech%' OR sec.name LIKE '%Financial Services%')
|
|
AND c.founded_year = 2020
|
|
|
|
Q: "AI companies in San Francisco"
|
|
A: SELECT DISTINCT c.id FROM companies c
|
|
LEFT JOIN company_sector cs ON c.id = cs.company_id
|
|
LEFT JOIN sectors sec ON cs.sector_id = sec.id
|
|
WHERE (c.industry LIKE '%AI%' OR c.industry LIKE '%Artificial Intelligence%' OR c.industry LIKE '%Machine Learning%' OR sec.name LIKE '%AI%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%')
|
|
AND (c.location LIKE '%San Francisco%' OR c.location LIKE '%SF%' OR c.location LIKE '%Bay Area%')
|
|
|
|
Q: "Healthcare companies"
|
|
A: SELECT DISTINCT c.id FROM companies c
|
|
LEFT JOIN company_sector cs ON c.id = cs.company_id
|
|
LEFT JOIN sectors sec ON cs.sector_id = sec.id
|
|
WHERE c.industry LIKE '%Healthcare%' OR c.industry LIKE '%Health%' OR c.industry LIKE '%Medical%' OR sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR sec.name LIKE '%Pharma%'
|
|
|
|
Q: "Companies funded by Sequoia"
|
|
A: SELECT DISTINCT c.id FROM companies c
|
|
JOIN investor_companies ic ON c.id = ic.company_id
|
|
JOIN investors i ON ic.investor_id = i.id
|
|
WHERE i.name LIKE '%Sequoia%'
|
|
|
|
Q: "European startups founded after 2019"
|
|
A: SELECT DISTINCT c.id FROM companies c
|
|
WHERE (c.location LIKE '%Europe%' OR c.location LIKE '%UK%' OR c.location LIKE '%London%' OR c.location LIKE '%Germany%' OR c.location LIKE '%Berlin%' OR c.location LIKE '%France%' OR c.location LIKE '%Paris%')
|
|
AND (c.founded_year > 2019 OR c.founded_year IS NULL)
|
|
|
|
Q: "SaaS companies"
|
|
A: SELECT DISTINCT c.id FROM companies c
|
|
LEFT JOIN company_sector cs ON c.id = cs.company_id
|
|
LEFT JOIN sectors sec ON cs.sector_id = sec.id
|
|
WHERE c.industry LIKE '%SaaS%' OR c.industry LIKE '%Software%' OR c.industry LIKE '%Cloud%' OR sec.name LIKE '%SaaS%' OR sec.name LIKE '%Software%'
|
|
|
|
IMPORTANT:
|
|
- Use LEFT JOIN so companies without sector tags still match via industry field
|
|
- Use OR conditions with related keywords/synonyms to cast a wider net
|
|
- Include NULL checks for optional filters to avoid excluding companies with missing data
|
|
|
|
Return ONLY the SQL query, no explanations or markdown.""",
|
|
),
|
|
("user", "{question}"),
|
|
]
|
|
)
|
|
|
|
def _get_cache_key(self, question: str) -> str:
|
|
"""Generate cache key from normalized question."""
|
|
return hashlib.md5(question.lower().strip().encode()).hexdigest()
|
|
|
|
# synchronous helper is provided below as `_process_query_sync` and an
|
|
# async wrapper `process_query` runs it in a thread. This keeps the
|
|
# FastAPI event loop non-blocking while reusing the existing sync code.
|
|
async def process_query(self, question: str) -> PaginatedResponse[CompanyData]:
|
|
"""Async wrapper for process_query. Runs blocking work in a thread to avoid
|
|
blocking the event loop.
|
|
"""
|
|
return await asyncio.to_thread(self._process_query_sync, question)
|
|
|
|
def _process_query_sync(self, question: str) -> PaginatedResponse[CompanyData]:
|
|
"""Synchronous implementation of process_query. This is run in a thread by
|
|
the async wrapper above.
|
|
"""
|
|
cache_key = self._get_cache_key(question)
|
|
|
|
# Check cache first
|
|
if cache_key in self.query_cache:
|
|
sql_query = self.query_cache[cache_key]
|
|
logger.info(f"Using cached SQL: {sql_query}")
|
|
else:
|
|
# Generate SQL query
|
|
messages = self.sql_prompt.format_messages(question=question)
|
|
response = self.llm.invoke(messages)
|
|
sql_query = response.content.strip()
|
|
|
|
# Clean up SQL (remove markdown code blocks if present)
|
|
sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
|
|
|
|
# Cache the query
|
|
self.query_cache[cache_key] = sql_query
|
|
logger.info(f"Generated SQL: {sql_query}")
|
|
|
|
# Execute query to get company IDs
|
|
db_session = next(get_db())
|
|
try:
|
|
result = db_session.execute(text(sql_query))
|
|
company_ids = [row[0] for row in result.fetchall()]
|
|
logger.info(
|
|
f"Found {len(company_ids)} company IDs: {company_ids[:10]}{'...' if len(company_ids) > 10 else ''}"
|
|
)
|
|
|
|
return self._fetch_companies_by_ids(company_ids)
|
|
except Exception as e:
|
|
logger.error(f"SQL execution error: {e}")
|
|
logger.error(f"Failed SQL: {sql_query}")
|
|
# Return empty result
|
|
return PaginatedResponse(
|
|
items=[], total=0, page=1, page_size=10, total_pages=0
|
|
)
|
|
finally:
|
|
db_session.close()
|
|
|
|
def _fetch_companies_by_ids(
|
|
self, company_ids: List[int]
|
|
) -> PaginatedResponse[CompanyData]:
|
|
"""Fetch companies with all their relationships from the database using company IDs.
|
|
|
|
Args:
|
|
company_ids: List of company IDs to fetch
|
|
"""
|
|
if not company_ids:
|
|
return PaginatedResponse(
|
|
items=[],
|
|
total=0,
|
|
page=1,
|
|
page_size=10,
|
|
total_pages=0,
|
|
)
|
|
|
|
# Get database session
|
|
db_session = next(get_db())
|
|
|
|
try:
|
|
# Query companies with all necessary relationships loaded
|
|
companies = (
|
|
db_session.query(CompanyTable)
|
|
.options(
|
|
selectinload(CompanyTable.investors),
|
|
selectinload(CompanyTable.members),
|
|
selectinload(CompanyTable.sectors),
|
|
)
|
|
.filter(CompanyTable.id.in_(company_ids))
|
|
.all()
|
|
)
|
|
|
|
# Transform to CompanyData format
|
|
company_data_list = []
|
|
for company in companies:
|
|
company_data = CompanyData(
|
|
company=company,
|
|
investors=company.investors,
|
|
members=company.members,
|
|
sectors=company.sectors,
|
|
)
|
|
company_data_list.append(company_data)
|
|
|
|
total_count = len(company_data_list)
|
|
total_pages = 1 if total_count > 0 else 0
|
|
|
|
return PaginatedResponse(
|
|
items=company_data_list,
|
|
total=total_count,
|
|
page=1,
|
|
page_size=total_count,
|
|
total_pages=total_pages,
|
|
)
|
|
|
|
finally:
|
|
db_session.close()
|