feat: Update query endpoint to return paginated investment responses with fund details

This commit is contained in:
bolade
2025-10-08 14:19:36 +01:00
parent 58722f1102
commit cefe89bb67
8 changed files with 107 additions and 45 deletions
Binary file not shown.
+7 -2
View File
@@ -6,7 +6,7 @@ from dotenv import load_dotenv
from fastapi import FastAPI, File, Form, UploadFile from fastapi import FastAPI, File, Form, UploadFile
from pydantic import BaseModel from pydantic import BaseModel
from routers import companies, investors, projects from routers import companies, investors, projects
from schemas.router_schemas import InvestorList from schemas.router_schemas import InvestmentResponse, PaginatedResponse
from services.llm_parser import InvestorProcessor from services.llm_parser import InvestorProcessor
from services.querying import QueryProcessor from services.querying import QueryProcessor
@@ -84,11 +84,16 @@ async def parse_csv(
return results return results
@app.post("/query", response_model=InvestorList, tags=["Querying"]) @app.post(
"/query", response_model=PaginatedResponse[InvestmentResponse], tags=["Querying"]
)
async def query_investors(request: QueryRequest): async def query_investors(request: QueryRequest):
""" """
Query investors using natural language. Query investors using natural language.
Returns fund-level matches (one row per fund) with investor details.
This ensures only relevant funds are included in the response.
Supports queries like: Supports queries like:
- "Show me seed stage investors" - "Show me seed stage investors"
- "Find fintech investors in Silicon Valley" - "Find fintech investors in Silicon Valley"
Binary file not shown.
View File
View File
View File
+100 -43
View File
@@ -2,13 +2,18 @@ import os
from typing import List from typing import List
from db.db import DATABASE_URL, get_db from db.db import DATABASE_URL, get_db
from db.models import InvestorTable from db.models import FundTable, InvestorTable
from langchain import hub from langchain import hub
from langchain_community.agent_toolkits import SQLDatabaseToolkit from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent from langgraph.prebuilt import create_react_agent
from schemas.py_schemas import InvestorData, InvestorList from schemas.router_schemas import (
CompanyMinimal,
InvestmentResponse,
PaginatedResponse,
SectorMinimal,
)
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
# Connect to SQLite # Connect to SQLite
@@ -21,16 +26,16 @@ class QueryProcessor:
self.llm = ChatOpenAI( self.llm = ChatOpenAI(
api_key=os.getenv("OPENROUTER_API_KEY"), api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1", base_url="https://openrouter.ai/api/v1",
model="openai/gpt-4o-mini", model="x-ai/grok-4-fast",
temperature=0, temperature=0,
) )
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
# Update system message to specifically request only investor IDs # Update system message to specifically request only fund IDs
system_message_updated = ( system_message_updated = (
prompt_template.format(dialect="SQLite", top_k=5) prompt_template.format(dialect="SQLite", top_k=5)
+ "\n\nIMPORTANT: You must ONLY return the investor IDs (id field) that match the user's criteria. " + "\n\nIMPORTANT: You must ONLY return the fund IDs (id field from the funds table) that match the user's criteria. "
+ "Do NOT return any other information, explanations, or data. " + "Do NOT return any other information, explanations, or data. "
+ "Your response should be ONLY a comma-separated list of numbers representing the investor IDs. " + "Your response should be ONLY a comma-separated list of numbers representing the fund IDs. "
+ "Example format: 1, 5, 12, 23" + "Example format: 1, 5, 12, 23"
) )
self.agent = create_react_agent( self.agent = create_react_agent(
@@ -39,9 +44,9 @@ class QueryProcessor:
prompt=system_message_updated, prompt=system_message_updated,
) )
def process_query(self, question: str) -> InvestorList: def process_query(self, question: str) -> PaginatedResponse[InvestmentResponse]:
"""Process a query using the LLM and return investor data.""" """Process a query using the LLM and return investment response data."""
# Let the LLM handle all database interactions and filtering to get IDs # Let the LLM handle all database interactions and filtering to get fund IDs
response = self.agent.invoke( response = self.agent.invoke(
{"messages": [("user", question)]}, {"messages": [("user", question)]},
) )
@@ -51,70 +56,122 @@ class QueryProcessor:
response["messages"][-1].content if response.get("messages") else "" response["messages"][-1].content if response.get("messages") else ""
) )
# Extract investor IDs from the AI response # Extract fund IDs from the AI response
investor_ids = self._extract_investor_ids_from_response(ai_response) fund_ids = self._extract_fund_ids_from_response(ai_response)
# Fetch full investor data using the IDs # Fetch full fund data with investor relationships using the IDs
return self._fetch_investors_by_ids(investor_ids) return self._fetch_funds_by_ids(fund_ids)
def _extract_investor_ids_from_response(self, ai_response: str) -> List[int]: def _extract_fund_ids_from_response(self, ai_response: str) -> List[int]:
"""Extract investor IDs from AI response.""" """Extract fund IDs from AI response."""
import re import re
investor_ids = [] fund_ids = []
try: try:
# Try multiple patterns to extract IDs from the response # Try multiple patterns to extract IDs from the response
# Pattern 1: Simple numbers (assuming they are IDs) # Pattern 1: Simple numbers (assuming they are IDs)
numbers = re.findall(r"\b\d+\b", ai_response) numbers = re.findall(r"\b\d+\b", ai_response)
investor_ids = [int(num) for num in numbers] fund_ids = [int(num) for num in numbers]
# Pattern 2: If response contains explicit ID references # Pattern 2: If response contains explicit ID references
id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower()) id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower())
if id_matches: if id_matches:
investor_ids = [int(id_str) for id_str in id_matches] fund_ids = [int(id_str) for id_str in id_matches]
except Exception as e: except Exception as e:
print(f"Error extracting IDs from response: {e}") print(f"Error extracting IDs from response: {e}")
return [] return []
return investor_ids return fund_ids
def _fetch_investors_by_ids(self, investor_ids: List[int]) -> InvestorList: def _fetch_funds_by_ids(
"""Fetch investors with all their relationships from the database using IDs.""" self, fund_ids: List[int]
if not investor_ids: ) -> PaginatedResponse[InvestmentResponse]:
return InvestorList(investors=[]) """Fetch funds with all their relationships from the database using fund IDs.
Constructs response similar to read_investors but starting from funds."""
if not fund_ids:
return PaginatedResponse(
items=[],
total=0,
page=1,
page_size=len(fund_ids) if fund_ids else 10,
total_pages=0,
)
# Get database session # Get database session
db_session = next(get_db()) db_session = next(get_db())
try: try:
# Build query with all relationships loaded # Query funds with all necessary relationships loaded
query = ( funds = (
db_session.query(InvestorTable) db_session.query(FundTable)
.options( .options(
selectinload(InvestorTable.portfolio_companies), selectinload(FundTable.investor).selectinload(
selectinload(InvestorTable.team_members), InvestorTable.portfolio_companies
selectinload(InvestorTable.sectors), ),
selectinload(InvestorTable.funds), selectinload(FundTable.investor).selectinload(
InvestorTable.team_members
),
selectinload(FundTable.investor).selectinload(
InvestorTable.sectors
),
selectinload(FundTable.investment_stages),
selectinload(FundTable.sectors),
) )
.filter(InvestorTable.id.in_(investor_ids)) .filter(FundTable.id.in_(fund_ids))
.all()
) )
investors = query.all() # Transform to InvestmentResponse format (one row per fund)
investment_responses = []
for fund in funds:
investor = fund.investor
# Transform to InvestorData format # Get top 3 portfolio companies (id and name only)
investor_data_list = [] portfolio_companies = [
for investor in investors: CompanyMinimal(id=company.id, name=company.name)
investor_data = InvestorData( for company in investor.portfolio_companies[:3]
investor=investor, ]
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members, # Get stage focus as comma-separated string
sectors=investor.sectors, stage_focus = (
funds=investor.funds, ", ".join([stage.name for stage in fund.investment_stages])
if fund.investment_stages
else None
) )
investor_data_list.append(investor_data)
return InvestorList(investors=investor_data_list) # Get top 3 sectors from fund (id and name only)
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in (fund.sectors[:3] if fund.sectors else [])
]
investment_response = InvestmentResponse(
id=investor.id,
name=f"{investor.name} - {fund.fund_name}"
if fund.fund_name
else investor.name,
aum=investor.aum,
check_size_lower=fund.check_size_lower,
check_size_upper=fund.check_size_upper,
geographic_focus=fund.geographic_focus,
stage_focus=stage_focus,
portfolio_companies=portfolio_companies,
sectors=fund_sectors,
compatibility_score=1.0,
)
investment_responses.append(investment_response)
total_count = len(investment_responses)
total_pages = 1 if total_count > 0 else 0
return PaginatedResponse(
items=investment_responses,
total=total_count,
page=1,
page_size=total_count,
total_pages=total_pages,
)
finally: finally:
db_session.close() db_session.close()
View File