209 lines
7.8 KiB
Python
209 lines
7.8 KiB
Python
import os
|
|
from typing import List, Optional
|
|
|
|
from db.db import DATABASE_URL, get_db
|
|
from db.models import FundTable, InvestorTable, ProjectTable
|
|
from langchain import hub
|
|
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
|
from langchain_community.utilities import SQLDatabase
|
|
from langchain_openai import ChatOpenAI
|
|
from langgraph.prebuilt import create_react_agent
|
|
from schemas.router_schemas import (
|
|
CompanyMinimal,
|
|
InvestmentResponse,
|
|
PaginatedResponse,
|
|
SectorMinimal,
|
|
)
|
|
from sqlalchemy.orm import selectinload
|
|
|
|
from services.compatibility_score import calculate_project_investor_compatibility
|
|
|
|
# Connect to SQLite
|
|
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
|
|
db = SQLDatabase.from_uri(DATABASE_URL)
|
|
|
|
|
|
class QueryProcessor:
|
|
def __init__(self):
|
|
self.llm = ChatOpenAI(
|
|
api_key=os.getenv("OPENROUTER_API_KEY"),
|
|
base_url="https://openrouter.ai/api/v1",
|
|
model="x-ai/grok-4-fast",
|
|
temperature=0,
|
|
)
|
|
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
|
# Update system message to specifically request only fund IDs
|
|
system_message_updated = (
|
|
prompt_template.format(dialect="SQLite", top_k=5)
|
|
+ "\n\nIMPORTANT: You must ONLY return the fund IDs (id field from the funds table) that match the user's criteria. "
|
|
+ "Do NOT return any other information, explanations, or data. "
|
|
+ "Your response should be ONLY a comma-separated list of numbers representing the fund IDs. "
|
|
+ "Example format: 1, 5, 12, 23"
|
|
)
|
|
self.agent = create_react_agent(
|
|
model=self.llm,
|
|
tools=self.toolkit.get_tools(),
|
|
prompt=system_message_updated,
|
|
)
|
|
|
|
def process_query(
|
|
self, question: str, project_id: Optional[int] = None
|
|
) -> PaginatedResponse[InvestmentResponse]:
|
|
"""Process a query using the LLM and return investment response data.
|
|
|
|
Args:
|
|
question: The natural language query to process
|
|
project_id: Optional project ID for compatibility scoring
|
|
"""
|
|
# Let the LLM handle all database interactions and filtering to get fund IDs
|
|
response = self.agent.invoke(
|
|
{"messages": [("user", question)]},
|
|
)
|
|
|
|
# Extract the actual message content
|
|
ai_response = (
|
|
response["messages"][-1].content if response.get("messages") else ""
|
|
)
|
|
|
|
# Extract fund IDs from the AI response
|
|
fund_ids = self._extract_fund_ids_from_response(ai_response)
|
|
|
|
# Fetch full fund data with investor relationships using the IDs
|
|
return self._fetch_funds_by_ids(fund_ids, project_id)
|
|
|
|
def _extract_fund_ids_from_response(self, ai_response: str) -> List[int]:
|
|
"""Extract fund IDs from AI response."""
|
|
import re
|
|
|
|
fund_ids = []
|
|
try:
|
|
# Try multiple patterns to extract IDs from the response
|
|
# Pattern 1: Simple numbers (assuming they are IDs)
|
|
numbers = re.findall(r"\b\d+\b", ai_response)
|
|
fund_ids = [int(num) for num in numbers]
|
|
|
|
# Pattern 2: If response contains explicit ID references
|
|
id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower())
|
|
if id_matches:
|
|
fund_ids = [int(id_str) for id_str in id_matches]
|
|
|
|
except Exception as e:
|
|
print(f"Error extracting IDs from response: {e}")
|
|
return []
|
|
|
|
return fund_ids
|
|
|
|
def _fetch_funds_by_ids(
|
|
self, fund_ids: List[int], project_id: Optional[int] = None
|
|
) -> PaginatedResponse[InvestmentResponse]:
|
|
"""Fetch funds with all their relationships from the database using fund IDs.
|
|
Constructs response similar to read_investors but starting from funds.
|
|
|
|
Args:
|
|
fund_ids: List of fund IDs to fetch
|
|
project_id: Optional project ID for compatibility scoring
|
|
"""
|
|
if not fund_ids:
|
|
return PaginatedResponse(
|
|
items=[],
|
|
total=0,
|
|
page=1,
|
|
page_size=len(fund_ids) if fund_ids else 10,
|
|
total_pages=0,
|
|
)
|
|
|
|
# Get database session
|
|
db_session = next(get_db())
|
|
|
|
try:
|
|
# Load project if project_id provided
|
|
project = None
|
|
if project_id is not None:
|
|
project = (
|
|
db_session.query(ProjectTable)
|
|
.options(selectinload(ProjectTable.sector))
|
|
.filter(ProjectTable.id == project_id)
|
|
.first()
|
|
)
|
|
|
|
# Query funds with all necessary relationships loaded
|
|
funds = (
|
|
db_session.query(FundTable)
|
|
.options(
|
|
selectinload(FundTable.investor).selectinload(
|
|
InvestorTable.portfolio_companies
|
|
),
|
|
selectinload(FundTable.investor).selectinload(
|
|
InvestorTable.team_members
|
|
),
|
|
selectinload(FundTable.investor).selectinload(
|
|
InvestorTable.sectors
|
|
),
|
|
selectinload(FundTable.investment_stages),
|
|
selectinload(FundTable.sectors),
|
|
)
|
|
.filter(FundTable.id.in_(fund_ids))
|
|
.all()
|
|
)
|
|
|
|
# Transform to InvestmentResponse format (one row per fund)
|
|
investment_responses = []
|
|
for fund in funds:
|
|
investor = fund.investor
|
|
|
|
# Calculate compatibility score if project provided
|
|
compatibility_score = 1.0
|
|
if project is not None:
|
|
compatibility_score = calculate_project_investor_compatibility(
|
|
project=project, investor=investor, use_funds=True
|
|
)
|
|
|
|
# Get top 3 portfolio companies (id and name only)
|
|
portfolio_companies = [
|
|
CompanyMinimal(id=company.id, name=company.name)
|
|
for company in investor.portfolio_companies[:3]
|
|
]
|
|
|
|
# Get stage focus as comma-separated string
|
|
stage_focus = (
|
|
", ".join([stage.name for stage in fund.investment_stages])
|
|
if fund.investment_stages
|
|
else None
|
|
)
|
|
|
|
# Get top 3 sectors from fund (id and name only)
|
|
fund_sectors = [
|
|
SectorMinimal(id=sector.id, name=sector.name)
|
|
for sector in (fund.sectors[:3] if fund.sectors else [])
|
|
]
|
|
|
|
investment_response = InvestmentResponse(
|
|
id=investor.id,
|
|
name=f"{investor.name} - {fund.fund_name}"
|
|
if fund.fund_name
|
|
else investor.name,
|
|
aum=investor.aum,
|
|
check_size_lower=fund.check_size_lower,
|
|
check_size_upper=fund.check_size_upper,
|
|
geographic_focus=fund.geographic_focus,
|
|
stage_focus=stage_focus,
|
|
portfolio_companies=portfolio_companies,
|
|
sectors=fund_sectors,
|
|
compatibility_score=compatibility_score,
|
|
)
|
|
investment_responses.append(investment_response)
|
|
|
|
total_count = len(investment_responses)
|
|
total_pages = 1 if total_count > 0 else 0
|
|
|
|
return PaginatedResponse(
|
|
items=investment_responses,
|
|
total=total_count,
|
|
page=1,
|
|
page_size=total_count,
|
|
total_pages=total_pages,
|
|
)
|
|
|
|
finally:
|
|
db_session.close()
|