feat: Update query endpoint to return paginated investment responses with fund details
This commit is contained in:
Binary file not shown.
+7
-2
@@ -6,7 +6,7 @@ from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, File, Form, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from routers import companies, investors, projects
|
||||
from schemas.router_schemas import InvestorList
|
||||
from schemas.router_schemas import InvestmentResponse, PaginatedResponse
|
||||
from services.llm_parser import InvestorProcessor
|
||||
from services.querying import QueryProcessor
|
||||
|
||||
@@ -84,11 +84,16 @@ async def parse_csv(
|
||||
return results
|
||||
|
||||
|
||||
@app.post("/query", response_model=InvestorList, tags=["Querying"])
|
||||
@app.post(
|
||||
"/query", response_model=PaginatedResponse[InvestmentResponse], tags=["Querying"]
|
||||
)
|
||||
async def query_investors(request: QueryRequest):
|
||||
"""
|
||||
Query investors using natural language.
|
||||
|
||||
Returns fund-level matches (one row per fund) with investor details.
|
||||
This ensures only relevant funds are included in the response.
|
||||
|
||||
Supports queries like:
|
||||
- "Show me seed stage investors"
|
||||
- "Find fintech investors in Silicon Valley"
|
||||
|
||||
Binary file not shown.
+100
-43
@@ -2,13 +2,18 @@ import os
|
||||
from typing import List
|
||||
|
||||
from db.db import DATABASE_URL, get_db
|
||||
from db.models import InvestorTable
|
||||
from db.models import FundTable, InvestorTable
|
||||
from langchain import hub
|
||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from schemas.py_schemas import InvestorData, InvestorList
|
||||
from schemas.router_schemas import (
|
||||
CompanyMinimal,
|
||||
InvestmentResponse,
|
||||
PaginatedResponse,
|
||||
SectorMinimal,
|
||||
)
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# Connect to SQLite
|
||||
@@ -21,16 +26,16 @@ class QueryProcessor:
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="openai/gpt-4o-mini",
|
||||
model="x-ai/grok-4-fast",
|
||||
temperature=0,
|
||||
)
|
||||
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
||||
# Update system message to specifically request only investor IDs
|
||||
# Update system message to specifically request only fund IDs
|
||||
system_message_updated = (
|
||||
prompt_template.format(dialect="SQLite", top_k=5)
|
||||
+ "\n\nIMPORTANT: You must ONLY return the investor IDs (id field) that match the user's criteria. "
|
||||
+ "\n\nIMPORTANT: You must ONLY return the fund IDs (id field from the funds table) that match the user's criteria. "
|
||||
+ "Do NOT return any other information, explanations, or data. "
|
||||
+ "Your response should be ONLY a comma-separated list of numbers representing the investor IDs. "
|
||||
+ "Your response should be ONLY a comma-separated list of numbers representing the fund IDs. "
|
||||
+ "Example format: 1, 5, 12, 23"
|
||||
)
|
||||
self.agent = create_react_agent(
|
||||
@@ -39,9 +44,9 @@ class QueryProcessor:
|
||||
prompt=system_message_updated,
|
||||
)
|
||||
|
||||
def process_query(self, question: str) -> InvestorList:
|
||||
"""Process a query using the LLM and return investor data."""
|
||||
# Let the LLM handle all database interactions and filtering to get IDs
|
||||
def process_query(self, question: str) -> PaginatedResponse[InvestmentResponse]:
|
||||
"""Process a query using the LLM and return investment response data."""
|
||||
# Let the LLM handle all database interactions and filtering to get fund IDs
|
||||
response = self.agent.invoke(
|
||||
{"messages": [("user", question)]},
|
||||
)
|
||||
@@ -51,70 +56,122 @@ class QueryProcessor:
|
||||
response["messages"][-1].content if response.get("messages") else ""
|
||||
)
|
||||
|
||||
# Extract investor IDs from the AI response
|
||||
investor_ids = self._extract_investor_ids_from_response(ai_response)
|
||||
# Extract fund IDs from the AI response
|
||||
fund_ids = self._extract_fund_ids_from_response(ai_response)
|
||||
|
||||
# Fetch full investor data using the IDs
|
||||
return self._fetch_investors_by_ids(investor_ids)
|
||||
# Fetch full fund data with investor relationships using the IDs
|
||||
return self._fetch_funds_by_ids(fund_ids)
|
||||
|
||||
def _extract_investor_ids_from_response(self, ai_response: str) -> List[int]:
|
||||
"""Extract investor IDs from AI response."""
|
||||
def _extract_fund_ids_from_response(self, ai_response: str) -> List[int]:
|
||||
"""Extract fund IDs from AI response."""
|
||||
import re
|
||||
|
||||
investor_ids = []
|
||||
fund_ids = []
|
||||
try:
|
||||
# Try multiple patterns to extract IDs from the response
|
||||
# Pattern 1: Simple numbers (assuming they are IDs)
|
||||
numbers = re.findall(r"\b\d+\b", ai_response)
|
||||
investor_ids = [int(num) for num in numbers]
|
||||
fund_ids = [int(num) for num in numbers]
|
||||
|
||||
# Pattern 2: If response contains explicit ID references
|
||||
id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower())
|
||||
if id_matches:
|
||||
investor_ids = [int(id_str) for id_str in id_matches]
|
||||
fund_ids = [int(id_str) for id_str in id_matches]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error extracting IDs from response: {e}")
|
||||
return []
|
||||
|
||||
return investor_ids
|
||||
return fund_ids
|
||||
|
||||
def _fetch_investors_by_ids(self, investor_ids: List[int]) -> InvestorList:
|
||||
"""Fetch investors with all their relationships from the database using IDs."""
|
||||
if not investor_ids:
|
||||
return InvestorList(investors=[])
|
||||
def _fetch_funds_by_ids(
|
||||
self, fund_ids: List[int]
|
||||
) -> PaginatedResponse[InvestmentResponse]:
|
||||
"""Fetch funds with all their relationships from the database using fund IDs.
|
||||
Constructs response similar to read_investors but starting from funds."""
|
||||
if not fund_ids:
|
||||
return PaginatedResponse(
|
||||
items=[],
|
||||
total=0,
|
||||
page=1,
|
||||
page_size=len(fund_ids) if fund_ids else 10,
|
||||
total_pages=0,
|
||||
)
|
||||
|
||||
# Get database session
|
||||
db_session = next(get_db())
|
||||
|
||||
try:
|
||||
# Build query with all relationships loaded
|
||||
query = (
|
||||
db_session.query(InvestorTable)
|
||||
# Query funds with all necessary relationships loaded
|
||||
funds = (
|
||||
db_session.query(FundTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(InvestorTable.funds),
|
||||
selectinload(FundTable.investor).selectinload(
|
||||
InvestorTable.portfolio_companies
|
||||
),
|
||||
selectinload(FundTable.investor).selectinload(
|
||||
InvestorTable.team_members
|
||||
),
|
||||
selectinload(FundTable.investor).selectinload(
|
||||
InvestorTable.sectors
|
||||
),
|
||||
selectinload(FundTable.investment_stages),
|
||||
selectinload(FundTable.sectors),
|
||||
)
|
||||
.filter(InvestorTable.id.in_(investor_ids))
|
||||
.filter(FundTable.id.in_(fund_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
investors = query.all()
|
||||
# Transform to InvestmentResponse format (one row per fund)
|
||||
investment_responses = []
|
||||
for fund in funds:
|
||||
investor = fund.investor
|
||||
|
||||
# Transform to InvestorData format
|
||||
investor_data_list = []
|
||||
for investor in investors:
|
||||
investor_data = InvestorData(
|
||||
investor=investor,
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
funds=investor.funds,
|
||||
# Get top 3 portfolio companies (id and name only)
|
||||
portfolio_companies = [
|
||||
CompanyMinimal(id=company.id, name=company.name)
|
||||
for company in investor.portfolio_companies[:3]
|
||||
]
|
||||
|
||||
# Get stage focus as comma-separated string
|
||||
stage_focus = (
|
||||
", ".join([stage.name for stage in fund.investment_stages])
|
||||
if fund.investment_stages
|
||||
else None
|
||||
)
|
||||
investor_data_list.append(investor_data)
|
||||
|
||||
return InvestorList(investors=investor_data_list)
|
||||
# Get top 3 sectors from fund (id and name only)
|
||||
fund_sectors = [
|
||||
SectorMinimal(id=sector.id, name=sector.name)
|
||||
for sector in (fund.sectors[:3] if fund.sectors else [])
|
||||
]
|
||||
|
||||
investment_response = InvestmentResponse(
|
||||
id=investor.id,
|
||||
name=f"{investor.name} - {fund.fund_name}"
|
||||
if fund.fund_name
|
||||
else investor.name,
|
||||
aum=investor.aum,
|
||||
check_size_lower=fund.check_size_lower,
|
||||
check_size_upper=fund.check_size_upper,
|
||||
geographic_focus=fund.geographic_focus,
|
||||
stage_focus=stage_focus,
|
||||
portfolio_companies=portfolio_companies,
|
||||
sectors=fund_sectors,
|
||||
compatibility_score=1.0,
|
||||
)
|
||||
investment_responses.append(investment_response)
|
||||
|
||||
total_count = len(investment_responses)
|
||||
total_pages = 1 if total_count > 0 else 0
|
||||
|
||||
return PaginatedResponse(
|
||||
items=investment_responses,
|
||||
total=total_count,
|
||||
page=1,
|
||||
page_size=total_count,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
Reference in New Issue
Block a user