Files
2025-11-11 20:27:55 +01:00

662 lines
24 KiB
Python

from typing import Optional
from db.db import get_db
from db.models import FundTable, InvestorTable, ProjectTable, SectorTable
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from schemas.router_schemas import (
CompanyMinimal,
InvestmentResponse,
InvestmentStage,
InvestorData,
PaginatedResponse,
SectorMinimal,
)
from services.compatibility_score import (
_calculate_project_fund_compatibility,
_calculate_project_investor_direct_compatibility,
)
from sqlalchemy.orm import Session, selectinload
router = APIRouter(tags=["Investor Routes"])
# Request schemas for creating/updating
class InvestorCreate(BaseModel):
name: str
description: Optional[str] = None
website: Optional[str] = None
headquarters: Optional[str] = None
aum: int
check_size_lower: int
check_size_upper: int
geographic_focus: str
number_of_investments: int = 0
class InvestorUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
website: Optional[str] = None
headquarters: Optional[str] = None
aum: Optional[int] = None
check_size_lower: Optional[int] = None
check_size_upper: Optional[int] = None
geographic_focus: Optional[str] = None
number_of_investments: Optional[int] = None
@router.get("/investors", response_model=PaginatedResponse[InvestmentResponse])
def read_investors(
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
project_id: Optional[int] = Query(
None, description="Optional project ID for compatibility scoring"
),
db: Session = Depends(get_db),
):
"""Get all investors with their funds as separate entries (paginated)
Each investor-fund combination is returned as a separate row.
An investor with 3 funds will appear as 3 entries.
If project_id is provided, calculates compatibility scores for each investor.
"""
# Calculate offset
offset = (page - 1) * page_size
# Get total count
total_count = db.query(InvestorTable).count()
# Load project if project_id provided
project = None
if project_id is not None:
project = (
db.query(ProjectTable)
.options(selectinload(ProjectTable.sector))
.filter(ProjectTable.id == project_id)
.first()
)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# When project_id is provided, we need to get all investors first to sort by compatibility score
# Otherwise, we can paginate at the database level
if project is not None:
# Get all investors (we'll sort by compatibility score, then paginate)
all_investors = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds).selectinload(
FundTable.investment_stages
),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
)
.all()
)
# We'll paginate after sorting by compatibility score
investors = all_investors
else:
# Get paginated results (no sorting needed)
investors = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds).selectinload(
FundTable.investment_stages
),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
)
.offset(offset)
.limit(page_size)
.all()
)
# Transform to InvestmentResponse format (one row per investor-fund combination)
investment_responses = []
for investor in investors:
# Get top 3 portfolio companies (id and name only)
portfolio_companies = [
CompanyMinimal(id=company.id, name=company.name)
for company in investor.portfolio_companies[:3]
]
# If investor has funds, create one entry per fund
if investor.funds:
for fund in investor.funds:
# Calculate compatibility score for this specific fund
compatibility_score = 1.0
if project is not None:
compatibility_score = _calculate_project_fund_compatibility(
project=project, fund=fund
)
# Get stage focus as comma-separated string
stage_focus = (
", ".join([stage.name for stage in fund.investment_stages])
if fund.investment_stages
else None
)
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in sorted(
fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name
)
]
investment_response = InvestmentResponse(
id=investor.id,
name=f"{investor.name} - {fund.fund_name}"
if fund.fund_name
else investor.name,
aum=investor.aum,
check_size_lower=fund.check_size_lower,
check_size_upper=fund.check_size_upper,
geographic_focus=fund.geographic_focus,
stage_focus=stage_focus,
portfolio_companies=portfolio_companies,
sectors=fund_sectors,
compatibility_score=compatibility_score,
)
investment_responses.append(investment_response)
else:
# If no funds, create one entry with null fund fields
# Calculate compatibility using investor-level data
compatibility_score = 1.0
if project is not None:
compatibility_score = _calculate_project_investor_direct_compatibility(
project=project, investor=investor
)
investment_response = InvestmentResponse(
id=investor.id,
name=investor.name,
aum=investor.aum,
check_size_lower=None,
check_size_upper=None,
geographic_focus=None,
stage_focus=None,
portfolio_companies=portfolio_companies,
sectors=[],
compatibility_score=compatibility_score,
)
investment_responses.append(investment_response)
# Sort by compatibility score (descending) when project_id is provided
if project is not None:
investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True)
# Apply pagination after sorting
investment_responses = investment_responses[offset : offset + page_size]
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=investment_responses,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
@router.get("/investors/filter", response_model=PaginatedResponse[InvestmentResponse])
def filter_investors(
stage: Optional[InvestmentStage] = Query(
None, description="Filter by investment stage"
),
min_check_size: Optional[int] = Query(None, description="Minimum check size"),
max_check_size: Optional[int] = Query(None, description="Maximum check size"),
geography: Optional[str] = Query(
None, description="Geographic focus (partial match)"
),
sector: Optional[str] = Query(None, description="Sector name (partial match)"),
min_aum: Optional[int] = Query(None, description="Minimum AUM"),
max_aum: Optional[int] = Query(None, description="Maximum AUM"),
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
project_id: Optional[int] = Query(
None, description="Optional project ID for compatibility scoring"
),
db: Session = Depends(get_db),
):
"""Filter investors based on various criteria (paginated)
Returns investor-fund combinations as separate rows.
Queries the funds table to find matching funds.
If project_id is provided, calculates compatibility scores for each investor.
"""
# Load project if project_id provided
project = None
if project_id is not None:
project = (
db.query(ProjectTable)
.options(selectinload(ProjectTable.sector))
.filter(ProjectTable.id == project_id)
.first()
)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Start with base query on funds table
query = db.query(FundTable).options(
selectinload(FundTable.investor).selectinload(
InvestorTable.portfolio_companies
),
selectinload(FundTable.investor).selectinload(InvestorTable.team_members),
selectinload(FundTable.investor).selectinload(InvestorTable.sectors),
selectinload(FundTable.investment_stages),
selectinload(FundTable.sectors),
)
# Apply filters at fund level
if min_check_size is not None:
query = query.filter(FundTable.check_size_lower >= min_check_size)
if max_check_size is not None:
query = query.filter(FundTable.check_size_upper <= max_check_size)
if geography:
query = query.filter(FundTable.geographic_focus.ilike(f"%{geography}%"))
# Apply filters at investor level (through relationship)
if min_aum is not None:
query = query.join(FundTable.investor).filter(InvestorTable.aum >= min_aum)
if max_aum is not None:
if min_aum is None: # Only join if not already joined
query = query.join(FundTable.investor)
query = query.filter(InvestorTable.aum <= max_aum)
# Filter by sector if provided (at fund level)
if sector:
query = query.join(FundTable.sectors).filter(
SectorTable.name.ilike(f"%{sector}%")
)
# Get total count before pagination
total_count = query.count()
# When project_id is provided, we need to get all funds first to sort by compatibility score
# Otherwise, we can paginate at the database level
if project is not None:
# Get all funds (we'll sort by compatibility score, then paginate)
all_funds = query.all()
funds = all_funds
else:
# Calculate offset and apply pagination (no sorting needed)
offset = (page - 1) * page_size
funds = query.offset(offset).limit(page_size).all()
# Transform to InvestmentResponse format (one row per fund)
investment_responses = []
for fund in funds:
investor = fund.investor
# Calculate compatibility score for this specific fund
compatibility_score = 1.0
if project is not None:
compatibility_score = _calculate_project_fund_compatibility(
project=project, fund=fund
)
# Get top 3 portfolio companies (id and name only)
portfolio_companies = [
CompanyMinimal(id=company.id, name=company.name)
for company in investor.portfolio_companies[:3]
]
# Get stage focus as comma-separated string
stage_focus = (
", ".join([stage.name for stage in fund.investment_stages])
if fund.investment_stages
else None
)
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in sorted(
fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name
)
]
investment_response = InvestmentResponse(
id=investor.id,
name=f"{investor.name} - {fund.fund_name}"
if fund.fund_name
else investor.name,
aum=investor.aum,
check_size_lower=fund.check_size_lower,
check_size_upper=fund.check_size_upper,
geographic_focus=fund.geographic_focus,
stage_focus=stage_focus,
portfolio_companies=portfolio_companies,
sectors=fund_sectors,
compatibility_score=compatibility_score,
)
investment_responses.append(investment_response)
# Sort by compatibility score (descending) when project_id is provided
if project is not None:
investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True)
# Apply pagination after sorting
offset = (page - 1) * page_size
investment_responses = investment_responses[offset : offset + page_size]
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=investment_responses,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
@router.get("/investors/{investor_id}", response_model=InvestorData)
def read_investor(investor_id: int, db: Session = Depends(get_db)):
"""Get a specific investor by ID with all their funds"""
investor = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds),
)
.filter(InvestorTable.id == investor_id)
.first()
)
if not investor:
raise HTTPException(status_code=404, detail="Investor not found")
# Transform to InvestorData format (includes funds array)
return InvestorData(
investor=investor,
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members,
sectors=investor.sectors,
funds=investor.funds,
)
@router.post("/investors", response_model=InvestorData)
def create_investor(investor: InvestorCreate, db: Session = Depends(get_db)):
"""Create a new investor"""
db_investor = InvestorTable(**investor.dict())
db.add(db_investor)
db.commit()
db.refresh(db_investor)
# Reload with relationships
investor_with_relations = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds),
)
.filter(InvestorTable.id == db_investor.id)
.first()
)
# Transform to InvestorData format
return InvestorData(
investor=investor_with_relations,
portfolio_companies=investor_with_relations.portfolio_companies,
team_members=investor_with_relations.team_members,
sectors=investor_with_relations.sectors,
funds=investor_with_relations.funds,
)
@router.put("/investors/{investor_id}", response_model=InvestorData)
def update_investor(
investor_id: int, investor: InvestorUpdate, db: Session = Depends(get_db)
):
"""Update an existing investor"""
db_investor = (
db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
)
if not db_investor:
raise HTTPException(status_code=404, detail="Investor not found")
update_data = investor.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(db_investor, field, value)
db.commit()
db.refresh(db_investor)
# Reload with relationships
investor_with_relations = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds),
)
.filter(InvestorTable.id == investor_id)
.first()
)
# Transform to InvestorData format
return InvestorData(
investor=investor_with_relations,
portfolio_companies=investor_with_relations.portfolio_companies,
team_members=investor_with_relations.team_members,
sectors=investor_with_relations.sectors,
funds=investor_with_relations.funds,
)
@router.delete("/investors/{investor_id}")
def delete_investor(investor_id: int, db: Session = Depends(get_db)):
"""Delete an investor"""
db_investor = (
db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
)
if not db_investor:
raise HTTPException(status_code=404, detail="Investor not found")
db.delete(db_investor)
db.commit()
return {"message": "Investor deleted successfully"}
@router.get(
"/investors/{investor_id}/similar",
response_model=PaginatedResponse[InvestmentResponse],
)
def find_similar_investors(
investor_id: int,
limit: int = Query(10, description="Maximum number of similar investors to return"),
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
db: Session = Depends(get_db),
):
"""Find investors similar to a given investor based on characteristics (paginated)
Returns investor-fund combinations as separate rows.
Queries the funds table to find matching funds.
"""
# Get the target investor to get their funds for comparison
target_investor = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
)
.filter(InvestorTable.id == investor_id)
.first()
)
if not target_investor:
raise HTTPException(status_code=404, detail="Investor not found")
# Get target investor's sector IDs for comparison (from their funds)
target_sector_ids = set()
target_stage_ids = set()
target_check_ranges = []
target_geographies = []
for fund in target_investor.funds:
if fund.sectors:
target_sector_ids.update({sector.id for sector in fund.sectors})
if fund.investment_stages:
target_stage_ids.update({stage.id for stage in fund.investment_stages})
if fund.check_size_lower and fund.check_size_upper:
target_check_ranges.append((fund.check_size_lower, fund.check_size_upper))
if fund.geographic_focus:
target_geographies.append(fund.geographic_focus.lower())
# Query all funds from other investors
candidate_funds = (
db.query(FundTable)
.options(
selectinload(FundTable.investor).selectinload(
InvestorTable.portfolio_companies
),
selectinload(FundTable.investor).selectinload(InvestorTable.team_members),
selectinload(FundTable.investor).selectinload(InvestorTable.sectors),
selectinload(FundTable.investment_stages),
selectinload(FundTable.sectors),
)
.join(FundTable.investor)
.filter(InvestorTable.id != investor_id)
.all()
)
# Calculate similarity scores for each fund
scored_funds = []
for fund in candidate_funds:
score = 0
# Geographic focus match (20 points for exact, 10 for partial)
if fund.geographic_focus and target_geographies:
fund_geo_lower = fund.geographic_focus.lower()
for target_geo in target_geographies:
if fund_geo_lower == target_geo:
score += 20
break
elif fund_geo_lower in target_geo or target_geo in fund_geo_lower:
score += 10
break
# Check size overlap (20 points max)
if fund.check_size_lower and fund.check_size_upper and target_check_ranges:
max_overlap_score = 0
for target_lower, target_upper in target_check_ranges:
overlap_start = max(fund.check_size_lower, target_lower)
overlap_end = min(fund.check_size_upper, target_upper)
if overlap_end > overlap_start:
overlap = overlap_end - overlap_start
target_range = target_upper - target_lower
overlap_ratio = overlap / target_range if target_range > 0 else 0
max_overlap_score = max(max_overlap_score, int(20 * overlap_ratio))
score += max_overlap_score
# AUM similarity (15 points max)
if fund.investor.aum and target_investor.aum:
aum_diff = abs(fund.investor.aum - target_investor.aum)
max_aum = max(fund.investor.aum, target_investor.aum)
similarity_ratio = 1 - (aum_diff / max_aum) if max_aum > 0 else 0
score += int(15 * similarity_ratio)
# Sector overlap (30 points max)
if fund.sectors and target_sector_ids:
fund_sector_ids = {sector.id for sector in fund.sectors}
common_sectors = target_sector_ids.intersection(fund_sector_ids)
overlap_ratio = len(common_sectors) / len(target_sector_ids)
score += int(30 * overlap_ratio)
# Investment stage match (15 points max)
if fund.investment_stages and target_stage_ids:
fund_stage_ids = {stage.id for stage in fund.investment_stages}
common_stages = target_stage_ids.intersection(fund_stage_ids)
overlap_ratio = len(common_stages) / len(target_stage_ids)
score += int(15 * overlap_ratio)
if score > 0: # Only include funds with some similarity
scored_funds.append((score, fund))
# Sort by score (descending) and take top N based on limit
scored_funds.sort(key=lambda x: x[0], reverse=True)
top_similar = scored_funds[:limit]
# Apply pagination to the top similar funds
total_count = len(top_similar)
offset = (page - 1) * page_size
paginated_similar = top_similar[offset : offset + page_size]
similar_funds = [fund for score, fund in paginated_similar]
# Transform to InvestmentResponse format (one row per fund)
investment_responses = []
for fund in similar_funds:
investor = fund.investor
# Get top 3 portfolio companies (id and name only)
portfolio_companies = [
CompanyMinimal(id=company.id, name=company.name)
for company in investor.portfolio_companies[:3]
]
# Get stage focus as comma-separated string
stage_focus = (
", ".join([stage.name for stage in fund.investment_stages])
if fund.investment_stages
else None
)
# Get top 3 sectors from fund (id and name only)
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in (fund.sectors[:3] if fund.sectors else [])
]
investment_response = InvestmentResponse(
id=investor.id,
name=f"{investor.name} - {fund.fund_name}"
if fund.fund_name
else investor.name,
aum=investor.aum,
check_size_lower=fund.check_size_lower,
check_size_upper=fund.check_size_upper,
geographic_focus=fund.geographic_focus,
stage_focus=stage_focus,
portfolio_companies=portfolio_companies,
sectors=fund_sectors,
compatibility_score=1.0,
)
investment_responses.append(investment_response)
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=investment_responses,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)