Refactor investor and company management API with FastAPI integration
- Updated README.md to reflect new features and architecture. - Implemented company management routes in app/api/companies.py. - Enhanced main FastAPI application in app/main.py to include company routes and query processing. - Improved querying capabilities in app/services/querying.py with natural language processing for investor searches. - Updated requirements.txt to include necessary dependencies for FastAPI and related libraries. - Added comprehensive error handling and response formatting for API endpoints.
This commit is contained in:
Binary file not shown.
Binary file not shown.
+205
-5
@@ -1,8 +1,208 @@
|
||||
from fastapi.routing import apirouter
|
||||
from typing import List, Optional
|
||||
|
||||
router = apirouter()
|
||||
from db.db import get_db
|
||||
from db.models import CompanyTable, InvestorTable
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from py_schemas import CompanySchema
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
@router.get("/companies")
|
||||
def read_companies():
|
||||
return {"message": "list of companies"}
|
||||
router = APIRouter(tags=["Company Routes"])
|
||||
|
||||
|
||||
# Request schemas for creating/updating
|
||||
class CompanyCreate(BaseModel):
|
||||
name: str
|
||||
industry: str
|
||||
location: str
|
||||
founded_year: Optional[int] = None
|
||||
website: Optional[str] = None
|
||||
|
||||
|
||||
class CompanyUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
industry: Optional[str] = None
|
||||
location: Optional[str] = None
|
||||
founded_year: Optional[int] = None
|
||||
website: Optional[str] = None
|
||||
|
||||
|
||||
# Response schema with relationships
|
||||
class CompanyData(BaseModel):
|
||||
"""Comprehensive company data schema"""
|
||||
|
||||
company: CompanySchema
|
||||
investors: List["InvestorBasic"] = []
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorBasic(BaseModel):
|
||||
"""Basic investor info for company responses"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
geographic_focus: str
|
||||
stage_focus: str
|
||||
check_size_lower: int
|
||||
check_size_upper: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
@router.get("/companies", response_model=List[CompanyData])
|
||||
def read_companies(db: Session = Depends(get_db)):
|
||||
"""Get all companies with their investor relationships"""
|
||||
companies = (
|
||||
db.query(CompanyTable).options(selectinload(CompanyTable.investors)).all()
|
||||
)
|
||||
|
||||
# Transform CompanyTable objects to CompanyData format
|
||||
company_data_list = []
|
||||
for company in companies:
|
||||
company_data = CompanyData(company=company, investors=company.investors)
|
||||
company_data_list.append(company_data)
|
||||
|
||||
return company_data_list
|
||||
|
||||
|
||||
@router.get("/companies/filter", response_model=List[CompanyData])
|
||||
def filter_companies(
|
||||
industry: Optional[str] = Query(
|
||||
None, description="Filter by industry (partial match)"
|
||||
),
|
||||
location: Optional[str] = Query(
|
||||
None, description="Filter by location (partial match)"
|
||||
),
|
||||
founded_after: Optional[int] = Query(None, description="Founded after year"),
|
||||
founded_before: Optional[int] = Query(None, description="Founded before year"),
|
||||
has_website: Optional[bool] = Query(
|
||||
None, description="Filter companies with/without website"
|
||||
),
|
||||
investor_name: Optional[str] = Query(
|
||||
None, description="Filter by investor name (partial match)"
|
||||
),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Filter companies based on various criteria"""
|
||||
|
||||
# Start with base query
|
||||
query = db.query(CompanyTable).options(selectinload(CompanyTable.investors))
|
||||
|
||||
# Apply filters
|
||||
if industry:
|
||||
query = query.filter(CompanyTable.industry.ilike(f"%{industry}%"))
|
||||
|
||||
if location:
|
||||
query = query.filter(CompanyTable.location.ilike(f"%{location}%"))
|
||||
|
||||
if founded_after is not None:
|
||||
query = query.filter(CompanyTable.founded_year >= founded_after)
|
||||
|
||||
if founded_before is not None:
|
||||
query = query.filter(CompanyTable.founded_year <= founded_before)
|
||||
|
||||
if has_website is not None:
|
||||
if has_website:
|
||||
query = query.filter(CompanyTable.website.isnot(None))
|
||||
else:
|
||||
query = query.filter(CompanyTable.website.is_(None))
|
||||
|
||||
# Filter by investor if provided
|
||||
if investor_name:
|
||||
query = query.join(CompanyTable.investors).filter(
|
||||
InvestorTable.name.ilike(f"%{investor_name}%")
|
||||
)
|
||||
|
||||
companies = query.all()
|
||||
|
||||
# Transform to CompanyData format
|
||||
company_data_list = []
|
||||
for company in companies:
|
||||
company_data = CompanyData(company=company, investors=company.investors)
|
||||
company_data_list.append(company_data)
|
||||
|
||||
return company_data_list
|
||||
|
||||
|
||||
@router.get("/companies/{company_id}", response_model=CompanyData)
|
||||
def read_company(company_id: int, db: Session = Depends(get_db)):
|
||||
"""Get a specific company by ID with its investors"""
|
||||
company = (
|
||||
db.query(CompanyTable)
|
||||
.options(selectinload(CompanyTable.investors))
|
||||
.filter(CompanyTable.id == company_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not company:
|
||||
raise HTTPException(status_code=404, detail="Company not found")
|
||||
|
||||
# Transform to CompanyData format
|
||||
return CompanyData(company=company, investors=company.investors)
|
||||
|
||||
|
||||
@router.post("/companies", response_model=CompanyData)
|
||||
def create_company(company: CompanyCreate, db: Session = Depends(get_db)):
|
||||
"""Create a new company"""
|
||||
db_company = CompanyTable(**company.dict())
|
||||
db.add(db_company)
|
||||
db.commit()
|
||||
db.refresh(db_company)
|
||||
|
||||
# Reload with relationships
|
||||
company_with_relations = (
|
||||
db.query(CompanyTable)
|
||||
.options(selectinload(CompanyTable.investors))
|
||||
.filter(CompanyTable.id == db_company.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Transform to CompanyData format
|
||||
return CompanyData(
|
||||
company=company_with_relations, investors=company_with_relations.investors
|
||||
)
|
||||
|
||||
|
||||
@router.put("/companies/{company_id}", response_model=CompanyData)
|
||||
def update_company(
|
||||
company_id: int, company: CompanyUpdate, db: Session = Depends(get_db)
|
||||
):
|
||||
"""Update an existing company"""
|
||||
db_company = db.query(CompanyTable).filter(CompanyTable.id == company_id).first()
|
||||
if not db_company:
|
||||
raise HTTPException(status_code=404, detail="Company not found")
|
||||
|
||||
update_data = company.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_company, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_company)
|
||||
|
||||
# Reload with relationships
|
||||
company_with_relations = (
|
||||
db.query(CompanyTable)
|
||||
.options(selectinload(CompanyTable.investors))
|
||||
.filter(CompanyTable.id == company_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Transform to CompanyData format
|
||||
return CompanyData(
|
||||
company=company_with_relations, investors=company_with_relations.investors
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/companies/{company_id}")
|
||||
def delete_company(company_id: int, db: Session = Depends(get_db)):
|
||||
"""Delete a company"""
|
||||
db_company = db.query(CompanyTable).filter(CompanyTable.id == company_id).first()
|
||||
if not db_company:
|
||||
raise HTTPException(status_code=404, detail="Company not found")
|
||||
|
||||
db.delete(db_company)
|
||||
db.commit()
|
||||
return {"message": "Company deleted successfully"}
|
||||
|
||||
+33
-9
@@ -1,23 +1,36 @@
|
||||
import io
|
||||
|
||||
import pandas as pd
|
||||
from api import investors
|
||||
from api import companies, investors
|
||||
from db.db import db_dependency, init_database
|
||||
from fastapi import FastAPI, File, UploadFile
|
||||
from py_schemas import InvestorList
|
||||
from pydantic import BaseModel
|
||||
from services.openrouter import InvestorProcessor
|
||||
from services.querying import QueryProcessor
|
||||
|
||||
app = FastAPI()
|
||||
app.include_router(investors.router)
|
||||
init_database()
|
||||
|
||||
|
||||
# Request models
|
||||
class QueryRequest(BaseModel):
|
||||
question: str
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"question": "Show me growth stage fintech investors in the US with check sizes over $1 million"
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
def health():
|
||||
return {"Hello": "World"}
|
||||
|
||||
|
||||
@app.post("/parse-csv")
|
||||
@app.post("/parse-csv", tags=["CSV Upload"], response_model=list[dict])
|
||||
async def parse_csv(db: db_dependency, file: UploadFile = File(...)):
|
||||
# Read uploaded CSV with pandas
|
||||
content = await file.read()
|
||||
@@ -28,16 +41,27 @@ async def parse_csv(db: db_dependency, file: UploadFile = File(...)):
|
||||
results = await processor.process_csv(df)
|
||||
|
||||
# Convert Pydantic objects to dictionaries
|
||||
return {"results": [r.dict() for r in results]}
|
||||
return [r.model_dump() for r in results]
|
||||
|
||||
|
||||
@app.post("/query")
|
||||
async def query_investors(db: db_dependency, question: str):
|
||||
@app.post("/query", response_model=InvestorList, tags=["Querying"])
|
||||
async def query_investors(db: db_dependency, request: QueryRequest):
|
||||
"""
|
||||
Query investors using natural language.
|
||||
|
||||
Supports queries like:
|
||||
- "Show me seed stage investors"
|
||||
- "Find fintech investors in Silicon Valley"
|
||||
- "Growth stage investors with $5M+ check sizes"
|
||||
- "Healthcare investors in Europe"
|
||||
"""
|
||||
processor = QueryProcessor(sql_session=db)
|
||||
results = processor.process_query(question)
|
||||
return {"results": results}
|
||||
results = processor.process_query(request.question)
|
||||
return results
|
||||
|
||||
|
||||
app.include_router(investors.router)
|
||||
app.include_router(companies.router)
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
|
||||
Binary file not shown.
+202
-5
@@ -1,18 +1,20 @@
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
import chromadb
|
||||
from db.models import InvestorTable
|
||||
from langchain import hub
|
||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from py_schemas import InvestorList
|
||||
from py_schemas import InvestorData, InvestorList
|
||||
from settings import settings
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# Connect to SQLite
|
||||
|
||||
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
|
||||
db = SQLDatabase.from_uri("sqlite:///investors.db")
|
||||
db = SQLDatabase.from_uri("sqlite:///investors_2.db")
|
||||
system_message = (
|
||||
prompt_template.format(dialect="SQLite", top_k=5)
|
||||
+ "\n Get answers from the Sql database and the vector database"
|
||||
@@ -25,6 +27,7 @@ class QueryProcessor:
|
||||
sql_session: Optional[object] = None,
|
||||
vector_db_client: Optional[object] = None,
|
||||
):
|
||||
self.sql_session = sql_session
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=settings.OPENROUTER_API_KEY,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
@@ -36,7 +39,6 @@ class QueryProcessor:
|
||||
model=self.llm,
|
||||
tools=self.toolkit.get_tools() + [self.query_vector_database],
|
||||
prompt=system_message,
|
||||
response_format=InvestorList,
|
||||
)
|
||||
self.vector_db_client = vector_db_client
|
||||
|
||||
@@ -77,7 +79,202 @@ class QueryProcessor:
|
||||
|
||||
def process_query(self, question: str) -> InvestorList:
|
||||
"""Process a query using the LLM and return structured investor data."""
|
||||
# Extract filters from the query first
|
||||
filters = self._extract_filters_from_query(question)
|
||||
|
||||
# Get AI response for additional context
|
||||
response = self.agent.invoke(
|
||||
{"messages": [("user", question)]},
|
||||
)
|
||||
return response
|
||||
|
||||
# Extract the actual message content
|
||||
ai_response = (
|
||||
response["messages"][-1].content if response.get("messages") else ""
|
||||
)
|
||||
|
||||
# Try to extract investor IDs or names from the AI response
|
||||
investor_ids = self._extract_investor_info_from_response(ai_response)
|
||||
|
||||
# Fetch filtered investor data with relationships from database
|
||||
return self._fetch_investors_with_relationships(investor_ids, filters)
|
||||
|
||||
def _extract_investor_info_from_response(self, ai_response: str) -> List[int]:
|
||||
"""Extract investor IDs from AI response. This is a simple implementation."""
|
||||
# This is a basic implementation - you might want to make it more sophisticated
|
||||
# based on how your AI formats responses
|
||||
investor_ids = []
|
||||
|
||||
# If the AI can't provide structured data, fall back to getting all investors
|
||||
# that match basic criteria
|
||||
try:
|
||||
# Try to extract numbers that might be IDs
|
||||
import re
|
||||
|
||||
ids = re.findall(r"\bid:\s*(\d+)", ai_response.lower())
|
||||
investor_ids = [int(id_str) for id_str in ids]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return investor_ids if investor_ids else []
|
||||
|
||||
def _extract_filters_from_query(self, question: str) -> dict:
|
||||
"""Extract filter criteria from natural language query."""
|
||||
question_lower = question.lower()
|
||||
filters = {}
|
||||
|
||||
# Extract stage filters
|
||||
if any(
|
||||
stage in question_lower
|
||||
for stage in [
|
||||
"seed",
|
||||
"series a",
|
||||
"series b",
|
||||
"series c",
|
||||
"growth",
|
||||
"late stage",
|
||||
]
|
||||
):
|
||||
if "seed" in question_lower:
|
||||
filters["stage"] = "SEED"
|
||||
elif "series a" in question_lower:
|
||||
filters["stage"] = "SERIES_A"
|
||||
elif "series b" in question_lower:
|
||||
filters["stage"] = "SERIES_B"
|
||||
elif "series c" in question_lower:
|
||||
filters["stage"] = "SERIES_C"
|
||||
elif "growth" in question_lower:
|
||||
filters["stage"] = "GROWTH"
|
||||
elif "late stage" in question_lower:
|
||||
filters["stage"] = "LATE_STAGE"
|
||||
|
||||
# Extract geographic filters
|
||||
if any(
|
||||
geo in question_lower
|
||||
for geo in [
|
||||
"us",
|
||||
"usa",
|
||||
"united states",
|
||||
"europe",
|
||||
"asia",
|
||||
"silicon valley",
|
||||
"bay area",
|
||||
]
|
||||
):
|
||||
if (
|
||||
"us" in question_lower
|
||||
or "usa" in question_lower
|
||||
or "united states" in question_lower
|
||||
):
|
||||
filters["geography"] = "US"
|
||||
elif "europe" in question_lower:
|
||||
filters["geography"] = "Europe"
|
||||
elif "asia" in question_lower:
|
||||
filters["geography"] = "Asia"
|
||||
elif "silicon valley" in question_lower or "bay area" in question_lower:
|
||||
filters["geography"] = "Silicon Valley"
|
||||
|
||||
# Extract sector filters
|
||||
sectors = [
|
||||
"fintech",
|
||||
"healthcare",
|
||||
"saas",
|
||||
"ai",
|
||||
"biotech",
|
||||
"consumer",
|
||||
"enterprise",
|
||||
"crypto",
|
||||
"blockchain",
|
||||
]
|
||||
for sector in sectors:
|
||||
if sector in question_lower:
|
||||
filters["sector"] = sector
|
||||
break
|
||||
|
||||
# Extract check size filters (simple patterns)
|
||||
import re
|
||||
|
||||
amounts = re.findall(
|
||||
r"\$?(\d+(?:,\d{3})*(?:\.\d+)?)\s*(?:million|m|k|thousand)", question_lower
|
||||
)
|
||||
if amounts:
|
||||
amount = amounts[0].replace(",", "")
|
||||
if "million" in question_lower or "m" in question_lower:
|
||||
filters["min_check_size"] = int(float(amount) * 1000000)
|
||||
elif "thousand" in question_lower or "k" in question_lower:
|
||||
filters["min_check_size"] = int(float(amount) * 1000)
|
||||
|
||||
return filters
|
||||
|
||||
def _fetch_investors_with_relationships(
|
||||
self, investor_ids: List[int] = None, filters: dict = None
|
||||
) -> InvestorList:
|
||||
"""Fetch investors with all their relationships from the database."""
|
||||
if not self.sql_session:
|
||||
return InvestorList(investors=[])
|
||||
|
||||
# Import here to avoid circular imports
|
||||
from db.models import SectorTable
|
||||
|
||||
# Build query with all relationships loaded
|
||||
query = self.sql_session.query(InvestorTable).options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
)
|
||||
|
||||
# Apply filters if provided
|
||||
if filters:
|
||||
if "stage" in filters:
|
||||
from db.models import InvestmentStage
|
||||
|
||||
stage_enum = getattr(InvestmentStage, filters["stage"])
|
||||
query = query.filter(InvestorTable.stage_focus == stage_enum)
|
||||
|
||||
if "geography" in filters:
|
||||
query = query.filter(
|
||||
InvestorTable.geographic_focus.ilike(f"%{filters['geography']}%")
|
||||
)
|
||||
|
||||
if "min_check_size" in filters:
|
||||
query = query.filter(
|
||||
InvestorTable.check_size_lower >= filters["min_check_size"]
|
||||
)
|
||||
|
||||
if "max_check_size" in filters:
|
||||
query = query.filter(
|
||||
InvestorTable.check_size_upper <= filters["max_check_size"]
|
||||
)
|
||||
|
||||
if "min_aum" in filters:
|
||||
query = query.filter(InvestorTable.aum >= filters["min_aum"])
|
||||
|
||||
if "max_aum" in filters:
|
||||
query = query.filter(InvestorTable.aum <= filters["max_aum"])
|
||||
|
||||
if "sector" in filters:
|
||||
query = query.join(InvestorTable.sectors).filter(
|
||||
SectorTable.name.ilike(f"%{filters['sector']}%")
|
||||
)
|
||||
|
||||
# Filter by IDs if provided
|
||||
if investor_ids:
|
||||
query = query.filter(InvestorTable.id.in_(investor_ids))
|
||||
else:
|
||||
# If no specific IDs and no filters, limit to prevent overwhelming response
|
||||
if not filters:
|
||||
query = query.limit(10)
|
||||
|
||||
investors = query.all()
|
||||
|
||||
# Transform to InvestorData format
|
||||
investor_data_list = []
|
||||
for investor in investors:
|
||||
investor_data = InvestorData(
|
||||
investor=investor,
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
)
|
||||
investor_data_list.append(investor_data)
|
||||
|
||||
return InvestorList(investors=investor_data_list)
|
||||
|
||||
Reference in New Issue
Block a user