296 lines
12 KiB
Python
296 lines
12 KiB
Python
import asyncio
|
|
import hashlib
|
|
import logging
|
|
import os
|
|
from typing import List, Optional
|
|
|
|
from db.db import get_db
|
|
from db.models import FundTable, InvestorTable, ProjectTable
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_openai import ChatOpenAI
|
|
from schemas.router_schemas import (
|
|
CompanyMinimal,
|
|
InvestmentResponse,
|
|
PaginatedResponse,
|
|
SectorMinimal,
|
|
)
|
|
from sqlalchemy import text
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from services.compatibility_score import calculate_project_investor_compatibility
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class QueryProcessor:
|
|
def __init__(self):
|
|
self.llm = ChatOpenAI(
|
|
api_key=os.getenv("OPENROUTER_API_KEY"),
|
|
base_url="https://openrouter.ai/api/v1",
|
|
model="openai/gpt-4o-mini",
|
|
temperature=0,
|
|
)
|
|
|
|
# Query cache for performance
|
|
self.query_cache = {}
|
|
|
|
# SQL generation prompt
|
|
self.sql_prompt = ChatPromptTemplate.from_messages(
|
|
[
|
|
(
|
|
"system",
|
|
"""You are a SQL expert. Generate a SQLite query to find fund IDs based on user requirements.
|
|
|
|
Database Schema:
|
|
- funds: id, fund_name, investor_id, check_size_lower, check_size_upper, geographic_focus
|
|
- fund_sectors: fund_id, sector_id
|
|
- fund_investment_stages: fund_id, stage_id
|
|
- sectors: id, name
|
|
- investment_stages: id, name
|
|
- investors: id, name, aum
|
|
|
|
IMPORTANT RULES:
|
|
1. ALWAYS return ONLY fund IDs (funds.id) - use SELECT DISTINCT f.id
|
|
2. For geography: Be FLEXIBLE - use OR with variations and partial matches
|
|
- 'Europe' → WHERE geographic_focus LIKE '%Europe%' OR geographic_focus LIKE '%European%'
|
|
- 'America' → WHERE geographic_focus LIKE '%America%' OR geographic_focus LIKE '%US%' OR geographic_focus LIKE '%United States%'
|
|
- 'Asia' → WHERE geographic_focus LIKE '%Asia%' OR geographic_focus LIKE '%Asian%'
|
|
- If no geography specified, DON'T filter by geography
|
|
3. For stages: Use LEFT JOIN and LIKE for flexible matching with synonyms
|
|
- 'Seed' → s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%'
|
|
- 'Series A' → s.name LIKE '%Series A%' OR s.name LIKE '%A%'
|
|
- 'Growth' → s.name LIKE '%Growth%' OR s.name LIKE '%Late%' OR s.name LIKE '%Expansion%'
|
|
- If stage not specified, include ALL funds
|
|
4. For sectors: Use LEFT JOIN and include related terms with OR
|
|
- 'Fintech' → sec.name LIKE '%Fintech%' OR sec.name LIKE '%Finance%' OR sec.name LIKE '%Financial%'
|
|
- 'AI' → sec.name LIKE '%AI%' OR sec.name LIKE '%Artificial Intelligence%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%'
|
|
- 'Healthcare' → sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%'
|
|
5. For check size filters (be flexible with ranges):
|
|
- "under X" → WHERE (check_size_upper <= X OR check_size_upper IS NULL)
|
|
- "over X" → WHERE (check_size_lower >= X OR check_size_lower IS NULL)
|
|
- "between X and Y" → WHERE check_size_lower >= X AND check_size_upper <= Y
|
|
6. Use LEFT JOIN for stages and sectors so funds without tags still match
|
|
7. Use DISTINCT to avoid duplicates from joins
|
|
8. Be INCLUSIVE - use OR conditions to cast a wider net
|
|
9. If query is very simple (e.g., just "seed stage"), don't add unnecessary filters
|
|
10. Return a single, complete SELECT query
|
|
|
|
Example Queries:
|
|
Q: "Seed stage investors in Europe"
|
|
A: SELECT DISTINCT f.id FROM funds f
|
|
LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id
|
|
LEFT JOIN investment_stages s ON fis.stage_id = s.id
|
|
WHERE (s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%' OR s.id IS NULL)
|
|
AND (f.geographic_focus LIKE '%Europe%' OR f.geographic_focus LIKE '%European%')
|
|
|
|
Q: "Fintech investors with check size under 5 million"
|
|
A: SELECT DISTINCT f.id FROM funds f
|
|
LEFT JOIN fund_sectors fs ON f.id = fs.fund_id
|
|
LEFT JOIN sectors sec ON fs.sector_id = sec.id
|
|
WHERE (sec.name LIKE '%Fintech%' OR sec.name LIKE '%Finance%' OR sec.name LIKE '%Financial%' OR sec.id IS NULL)
|
|
AND (f.check_size_upper <= 5000000 OR f.check_size_upper IS NULL)
|
|
|
|
Q: "Seed stage investors"
|
|
A: SELECT DISTINCT f.id FROM funds f
|
|
LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id
|
|
LEFT JOIN investment_stages s ON fis.stage_id = s.id
|
|
WHERE s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%'
|
|
|
|
Q: "Growth stage investors"
|
|
A: SELECT DISTINCT f.id FROM funds f
|
|
LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id
|
|
LEFT JOIN investment_stages s ON fis.stage_id = s.id
|
|
WHERE s.name LIKE '%Growth%' OR s.name LIKE '%Late%' OR s.name LIKE '%Expansion%' OR s.name LIKE '%Series C%' OR s.name LIKE '%Series D%'
|
|
|
|
Q: "AI investors in America"
|
|
A: SELECT DISTINCT f.id FROM funds f
|
|
LEFT JOIN fund_sectors fs ON f.id = fs.fund_id
|
|
LEFT JOIN sectors sec ON fs.sector_id = sec.id
|
|
WHERE (sec.name LIKE '%AI%' OR sec.name LIKE '%Artificial Intelligence%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%')
|
|
AND (f.geographic_focus LIKE '%America%' OR f.geographic_focus LIKE '%US%' OR f.geographic_focus LIKE '%United States%' OR f.geographic_focus LIKE '%USA%')
|
|
|
|
Q: "Healthcare investors"
|
|
A: SELECT DISTINCT f.id FROM funds f
|
|
LEFT JOIN fund_sectors fs ON f.id = fs.fund_id
|
|
LEFT JOIN sectors sec ON fs.sector_id = sec.id
|
|
WHERE sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR sec.name LIKE '%Pharma%'
|
|
|
|
IMPORTANT: Use LEFT JOIN so funds without sector/stage tags can still match. Include synonym terms with OR for better recall.
|
|
|
|
Return ONLY the SQL query, no explanations or markdown.""",
|
|
),
|
|
("user", "{question}"),
|
|
]
|
|
)
|
|
|
|
def _get_cache_key(self, question: str) -> str:
|
|
"""Generate cache key from normalized question."""
|
|
return hashlib.md5(question.lower().strip().encode()).hexdigest()
|
|
|
|
async def process_query(
|
|
self, question: str, project_id: Optional[int] = None
|
|
) -> PaginatedResponse[InvestmentResponse]:
|
|
"""Async wrapper for process_query. Runs blocking work in a thread to avoid
|
|
blocking the event loop.
|
|
"""
|
|
return await asyncio.to_thread(self._process_query_sync, question, project_id)
|
|
|
|
def _process_query_sync(
|
|
self, question: str, project_id: Optional[int] = None
|
|
) -> PaginatedResponse[InvestmentResponse]:
|
|
"""Synchronous implementation of process_query. This is run in a thread by
|
|
the async wrapper above.
|
|
"""
|
|
cache_key = self._get_cache_key(question)
|
|
|
|
# Check cache first
|
|
if cache_key in self.query_cache:
|
|
sql_query = self.query_cache[cache_key]
|
|
logger.info(f"Using cached SQL: {sql_query}")
|
|
else:
|
|
# Generate SQL query
|
|
messages = self.sql_prompt.format_messages(question=question)
|
|
response = self.llm.invoke(messages)
|
|
sql_query = response.content.strip()
|
|
|
|
# Clean up SQL (remove markdown code blocks if present)
|
|
sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
|
|
|
|
# Cache the query
|
|
self.query_cache[cache_key] = sql_query
|
|
logger.info(f"Generated SQL: {sql_query}")
|
|
|
|
# Execute query to get fund IDs
|
|
db_session = next(get_db())
|
|
try:
|
|
result = db_session.execute(text(sql_query))
|
|
fund_ids = [row[0] for row in result.fetchall()]
|
|
logger.info(
|
|
f"Found {len(fund_ids)} fund IDs: {fund_ids[:10]}{'...' if len(fund_ids) > 10 else ''}"
|
|
)
|
|
|
|
return self._fetch_funds_by_ids(fund_ids, project_id)
|
|
except Exception as e:
|
|
logger.error(f"SQL execution error: {e}")
|
|
logger.error(f"Failed SQL: {sql_query}")
|
|
# Return empty result
|
|
return PaginatedResponse(
|
|
items=[], total=0, page=1, page_size=10, total_pages=0
|
|
)
|
|
finally:
|
|
db_session.close()
|
|
|
|
def _fetch_funds_by_ids(
|
|
self, fund_ids: List[int], project_id: Optional[int] = None
|
|
) -> PaginatedResponse[InvestmentResponse]:
|
|
"""Fetch funds with all their relationships from the database using fund IDs.
|
|
Constructs response similar to read_investors but starting from funds.
|
|
|
|
Args:
|
|
fund_ids: List of fund IDs to fetch
|
|
project_id: Optional project ID for compatibility scoring
|
|
"""
|
|
if not fund_ids:
|
|
return PaginatedResponse(
|
|
items=[],
|
|
total=0,
|
|
page=1,
|
|
page_size=len(fund_ids) if fund_ids else 10,
|
|
total_pages=0,
|
|
)
|
|
|
|
# Get database session
|
|
db_session = next(get_db())
|
|
|
|
try:
|
|
# Load project if project_id provided
|
|
project = None
|
|
if project_id is not None:
|
|
project = (
|
|
db_session.query(ProjectTable)
|
|
.options(selectinload(ProjectTable.sector))
|
|
.filter(ProjectTable.id == project_id)
|
|
.first()
|
|
)
|
|
|
|
# Query funds with all necessary relationships loaded
|
|
funds = (
|
|
db_session.query(FundTable)
|
|
.options(
|
|
selectinload(FundTable.investor).selectinload(
|
|
InvestorTable.portfolio_companies
|
|
),
|
|
selectinload(FundTable.investor).selectinload(
|
|
InvestorTable.team_members
|
|
),
|
|
selectinload(FundTable.investor).selectinload(
|
|
InvestorTable.sectors
|
|
),
|
|
selectinload(FundTable.investment_stages),
|
|
selectinload(FundTable.sectors),
|
|
)
|
|
.filter(FundTable.id.in_(fund_ids))
|
|
.all()
|
|
)
|
|
|
|
# Transform to InvestmentResponse format (one row per fund)
|
|
investment_responses = []
|
|
for fund in funds:
|
|
investor = fund.investor
|
|
|
|
# Calculate compatibility score if project provided
|
|
compatibility_score = 1.0
|
|
if project is not None:
|
|
compatibility_score = calculate_project_investor_compatibility(
|
|
project=project, investor=investor, use_funds=True
|
|
)
|
|
|
|
# Get top 3 portfolio companies (id and name only)
|
|
portfolio_companies = [
|
|
CompanyMinimal(id=company.id, name=company.name)
|
|
for company in investor.portfolio_companies[:3]
|
|
]
|
|
|
|
# Get stage focus as comma-separated string
|
|
stage_focus = (
|
|
", ".join([stage.name for stage in fund.investment_stages])
|
|
if fund.investment_stages
|
|
else None
|
|
)
|
|
|
|
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
|
|
fund_sectors = [
|
|
SectorMinimal(id=sector.id, name=sector.name)
|
|
for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name)
|
|
]
|
|
|
|
investment_response = InvestmentResponse(
|
|
id=investor.id,
|
|
name=f"{investor.name} - {fund.fund_name}"
|
|
if fund.fund_name
|
|
else investor.name,
|
|
aum=investor.aum,
|
|
check_size_lower=fund.check_size_lower,
|
|
check_size_upper=fund.check_size_upper,
|
|
geographic_focus=fund.geographic_focus,
|
|
stage_focus=stage_focus,
|
|
portfolio_companies=portfolio_companies,
|
|
sectors=fund_sectors,
|
|
compatibility_score=compatibility_score,
|
|
)
|
|
investment_responses.append(investment_response)
|
|
|
|
total_count = len(investment_responses)
|
|
total_pages = 1 if total_count > 0 else 0
|
|
|
|
return PaginatedResponse(
|
|
items=investment_responses,
|
|
total=total_count,
|
|
page=1,
|
|
page_size=total_count,
|
|
total_pages=total_pages,
|
|
)
|
|
|
|
finally:
|
|
db_session.close()
|