from typing import Optional from db.db import get_db from db.models import FundTable, InvestorTable, ProjectTable, SectorTable from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel from schemas.router_schemas import ( CompanyMinimal, InvestmentResponse, InvestmentStage, InvestorData, PaginatedResponse, SectorMinimal, ) from services.compatibility_score import calculate_project_investor_compatibility from sqlalchemy.orm import Session, selectinload router = APIRouter(tags=["Investor Routes"]) # Request schemas for creating/updating class InvestorCreate(BaseModel): name: str description: Optional[str] = None website: Optional[str] = None headquarters: Optional[str] = None aum: int check_size_lower: int check_size_upper: int geographic_focus: str number_of_investments: int = 0 class InvestorUpdate(BaseModel): name: Optional[str] = None description: Optional[str] = None website: Optional[str] = None headquarters: Optional[str] = None aum: Optional[int] = None check_size_lower: Optional[int] = None check_size_upper: Optional[int] = None geographic_focus: Optional[str] = None number_of_investments: Optional[int] = None @router.get("/investors", response_model=PaginatedResponse[InvestmentResponse]) def read_investors( page: int = Query(1, ge=1, description="Page number (starts at 1)"), page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"), project_id: Optional[int] = Query( None, description="Optional project ID for compatibility scoring" ), db: Session = Depends(get_db), ): """Get all investors with their funds as separate entries (paginated) Each investor-fund combination is returned as a separate row. An investor with 3 funds will appear as 3 entries. If project_id is provided, calculates compatibility scores for each investor. """ # Calculate offset offset = (page - 1) * page_size # Get total count total_count = db.query(InvestorTable).count() # Load project if project_id provided project = None if project_id is not None: project = ( db.query(ProjectTable) .options(selectinload(ProjectTable.sector)) .filter(ProjectTable.id == project_id) .first() ) if not project: raise HTTPException(status_code=404, detail="Project not found") # Get paginated results investors = ( db.query(InvestorTable) .options( selectinload(InvestorTable.portfolio_companies), selectinload(InvestorTable.team_members), selectinload(InvestorTable.sectors), selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages), selectinload(InvestorTable.funds).selectinload(FundTable.sectors), ) .offset(offset) .limit(page_size) .all() ) # Transform to InvestmentResponse format (one row per investor-fund combination) investment_responses = [] for investor in investors: # Calculate compatibility score if project provided compatibility_score = 1.0 if project is not None: compatibility_score = calculate_project_investor_compatibility( project=project, investor=investor, use_funds=True ) # Get top 3 portfolio companies (id and name only) portfolio_companies = [ CompanyMinimal(id=company.id, name=company.name) for company in investor.portfolio_companies[:3] ] # If investor has funds, create one entry per fund if investor.funds: for fund in investor.funds: # Get stage focus as comma-separated string stage_focus = ( ", ".join([stage.name for stage in fund.investment_stages]) if fund.investment_stages else None ) # Get top 3 sectors from fund (id and name only) fund_sectors = [ SectorMinimal(id=sector.id, name=sector.name) for sector in (fund.sectors[:3] if fund.sectors else []) ] investment_response = InvestmentResponse( id=investor.id, name=f"{investor.name} - {fund.fund_name}" if fund.fund_name else investor.name, aum=investor.aum, check_size_lower=fund.check_size_lower, check_size_upper=fund.check_size_upper, geographic_focus=fund.geographic_focus, stage_focus=stage_focus, portfolio_companies=portfolio_companies, sectors=fund_sectors, compatibility_score=compatibility_score, ) investment_responses.append(investment_response) else: # If no funds, create one entry with null fund fields investment_response = InvestmentResponse( id=investor.id, name=investor.name, aum=investor.aum, check_size_lower=None, check_size_upper=None, geographic_focus=None, stage_focus=None, portfolio_companies=portfolio_companies, sectors=[], compatibility_score=compatibility_score, ) investment_responses.append(investment_response) # Calculate total pages total_pages = (total_count + page_size - 1) // page_size return PaginatedResponse( items=investment_responses, total=total_count, page=page, page_size=page_size, total_pages=total_pages, ) @router.get("/investors/filter", response_model=PaginatedResponse[InvestmentResponse]) def filter_investors( stage: Optional[InvestmentStage] = Query( None, description="Filter by investment stage" ), min_check_size: Optional[int] = Query(None, description="Minimum check size"), max_check_size: Optional[int] = Query(None, description="Maximum check size"), geography: Optional[str] = Query( None, description="Geographic focus (partial match)" ), sector: Optional[str] = Query(None, description="Sector name (partial match)"), min_aum: Optional[int] = Query(None, description="Minimum AUM"), max_aum: Optional[int] = Query(None, description="Maximum AUM"), page: int = Query(1, ge=1, description="Page number (starts at 1)"), page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"), project_id: Optional[int] = Query( None, description="Optional project ID for compatibility scoring" ), db: Session = Depends(get_db), ): """Filter investors based on various criteria (paginated) Returns investor-fund combinations as separate rows. Queries the funds table to find matching funds. If project_id is provided, calculates compatibility scores for each investor. """ # Load project if project_id provided project = None if project_id is not None: project = ( db.query(ProjectTable) .options(selectinload(ProjectTable.sector)) .filter(ProjectTable.id == project_id) .first() ) if not project: raise HTTPException(status_code=404, detail="Project not found") # Start with base query on funds table query = db.query(FundTable).options( selectinload(FundTable.investor).selectinload( InvestorTable.portfolio_companies ), selectinload(FundTable.investor).selectinload(InvestorTable.team_members), selectinload(FundTable.investor).selectinload(InvestorTable.sectors), selectinload(FundTable.investment_stages), selectinload(FundTable.sectors), ) # Apply filters at fund level if min_check_size is not None: query = query.filter(FundTable.check_size_lower >= min_check_size) if max_check_size is not None: query = query.filter(FundTable.check_size_upper <= max_check_size) if geography: query = query.filter(FundTable.geographic_focus.ilike(f"%{geography}%")) # Apply filters at investor level (through relationship) if min_aum is not None: query = query.join(FundTable.investor).filter(InvestorTable.aum >= min_aum) if max_aum is not None: if min_aum is None: # Only join if not already joined query = query.join(FundTable.investor) query = query.filter(InvestorTable.aum <= max_aum) # Filter by sector if provided (at fund level) if sector: query = query.join(FundTable.sectors).filter( SectorTable.name.ilike(f"%{sector}%") ) # Get total count before pagination total_count = query.count() # Calculate offset and apply pagination offset = (page - 1) * page_size funds = query.offset(offset).limit(page_size).all() # Transform to InvestmentResponse format (one row per fund) investment_responses = [] for fund in funds: investor = fund.investor # Calculate compatibility score if project provided compatibility_score = 1.0 if project is not None: compatibility_score = calculate_project_investor_compatibility( project=project, investor=investor, use_funds=True ) # Get top 3 portfolio companies (id and name only) portfolio_companies = [ CompanyMinimal(id=company.id, name=company.name) for company in investor.portfolio_companies[:3] ] # Get stage focus as comma-separated string stage_focus = ( ", ".join([stage.name for stage in fund.investment_stages]) if fund.investment_stages else None ) # Get top 3 sectors from fund (id and name only) fund_sectors = [ SectorMinimal(id=sector.id, name=sector.name) for sector in (fund.sectors[:3] if fund.sectors else []) ] investment_response = InvestmentResponse( id=investor.id, name=f"{investor.name} - {fund.fund_name}" if fund.fund_name else investor.name, aum=investor.aum, check_size_lower=fund.check_size_lower, check_size_upper=fund.check_size_upper, geographic_focus=fund.geographic_focus, stage_focus=stage_focus, portfolio_companies=portfolio_companies, sectors=fund_sectors, compatibility_score=compatibility_score, ) investment_responses.append(investment_response) # Calculate total pages total_pages = (total_count + page_size - 1) // page_size return PaginatedResponse( items=investment_responses, total=total_count, page=page, page_size=page_size, total_pages=total_pages, ) @router.get("/investors/{investor_id}", response_model=InvestorData) def read_investor(investor_id: int, db: Session = Depends(get_db)): """Get a specific investor by ID with all their funds""" investor = ( db.query(InvestorTable) .options( selectinload(InvestorTable.portfolio_companies), selectinload(InvestorTable.team_members), selectinload(InvestorTable.sectors), selectinload(InvestorTable.funds), ) .filter(InvestorTable.id == investor_id) .first() ) if not investor: raise HTTPException(status_code=404, detail="Investor not found") # Transform to InvestorData format (includes funds array) return InvestorData( investor=investor, portfolio_companies=investor.portfolio_companies, team_members=investor.team_members, sectors=investor.sectors, funds=investor.funds, ) @router.post("/investors", response_model=InvestorData) def create_investor(investor: InvestorCreate, db: Session = Depends(get_db)): """Create a new investor""" db_investor = InvestorTable(**investor.dict()) db.add(db_investor) db.commit() db.refresh(db_investor) # Reload with relationships investor_with_relations = ( db.query(InvestorTable) .options( selectinload(InvestorTable.portfolio_companies), selectinload(InvestorTable.team_members), selectinload(InvestorTable.sectors), selectinload(InvestorTable.funds), ) .filter(InvestorTable.id == db_investor.id) .first() ) # Transform to InvestorData format return InvestorData( investor=investor_with_relations, portfolio_companies=investor_with_relations.portfolio_companies, team_members=investor_with_relations.team_members, sectors=investor_with_relations.sectors, funds=investor_with_relations.funds, ) @router.put("/investors/{investor_id}", response_model=InvestorData) def update_investor( investor_id: int, investor: InvestorUpdate, db: Session = Depends(get_db) ): """Update an existing investor""" db_investor = ( db.query(InvestorTable).filter(InvestorTable.id == investor_id).first() ) if not db_investor: raise HTTPException(status_code=404, detail="Investor not found") update_data = investor.dict(exclude_unset=True) for field, value in update_data.items(): setattr(db_investor, field, value) db.commit() db.refresh(db_investor) # Reload with relationships investor_with_relations = ( db.query(InvestorTable) .options( selectinload(InvestorTable.portfolio_companies), selectinload(InvestorTable.team_members), selectinload(InvestorTable.sectors), selectinload(InvestorTable.funds), ) .filter(InvestorTable.id == investor_id) .first() ) # Transform to InvestorData format return InvestorData( investor=investor_with_relations, portfolio_companies=investor_with_relations.portfolio_companies, team_members=investor_with_relations.team_members, sectors=investor_with_relations.sectors, funds=investor_with_relations.funds, ) @router.delete("/investors/{investor_id}") def delete_investor(investor_id: int, db: Session = Depends(get_db)): """Delete an investor""" db_investor = ( db.query(InvestorTable).filter(InvestorTable.id == investor_id).first() ) if not db_investor: raise HTTPException(status_code=404, detail="Investor not found") db.delete(db_investor) db.commit() return {"message": "Investor deleted successfully"} @router.get( "/investors/{investor_id}/similar", response_model=PaginatedResponse[InvestmentResponse], ) def find_similar_investors( investor_id: int, limit: int = Query(10, description="Maximum number of similar investors to return"), page: int = Query(1, ge=1, description="Page number (starts at 1)"), page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"), db: Session = Depends(get_db), ): """Find investors similar to a given investor based on characteristics (paginated) Returns investor-fund combinations as separate rows. Queries the funds table to find matching funds. """ # Get the target investor to get their funds for comparison target_investor = ( db.query(InvestorTable) .options( selectinload(InvestorTable.portfolio_companies), selectinload(InvestorTable.team_members), selectinload(InvestorTable.sectors), selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages), selectinload(InvestorTable.funds).selectinload(FundTable.sectors), ) .filter(InvestorTable.id == investor_id) .first() ) if not target_investor: raise HTTPException(status_code=404, detail="Investor not found") # Get target investor's sector IDs for comparison (from their funds) target_sector_ids = set() target_stage_ids = set() target_check_ranges = [] target_geographies = [] for fund in target_investor.funds: if fund.sectors: target_sector_ids.update({sector.id for sector in fund.sectors}) if fund.investment_stages: target_stage_ids.update({stage.id for stage in fund.investment_stages}) if fund.check_size_lower and fund.check_size_upper: target_check_ranges.append((fund.check_size_lower, fund.check_size_upper)) if fund.geographic_focus: target_geographies.append(fund.geographic_focus.lower()) # Query all funds from other investors candidate_funds = ( db.query(FundTable) .options( selectinload(FundTable.investor).selectinload( InvestorTable.portfolio_companies ), selectinload(FundTable.investor).selectinload(InvestorTable.team_members), selectinload(FundTable.investor).selectinload(InvestorTable.sectors), selectinload(FundTable.investment_stages), selectinload(FundTable.sectors), ) .join(FundTable.investor) .filter(InvestorTable.id != investor_id) .all() ) # Calculate similarity scores for each fund scored_funds = [] for fund in candidate_funds: score = 0 # Geographic focus match (20 points for exact, 10 for partial) if fund.geographic_focus and target_geographies: fund_geo_lower = fund.geographic_focus.lower() for target_geo in target_geographies: if fund_geo_lower == target_geo: score += 20 break elif fund_geo_lower in target_geo or target_geo in fund_geo_lower: score += 10 break # Check size overlap (20 points max) if fund.check_size_lower and fund.check_size_upper and target_check_ranges: max_overlap_score = 0 for target_lower, target_upper in target_check_ranges: overlap_start = max(fund.check_size_lower, target_lower) overlap_end = min(fund.check_size_upper, target_upper) if overlap_end > overlap_start: overlap = overlap_end - overlap_start target_range = target_upper - target_lower overlap_ratio = overlap / target_range if target_range > 0 else 0 max_overlap_score = max(max_overlap_score, int(20 * overlap_ratio)) score += max_overlap_score # AUM similarity (15 points max) if fund.investor.aum and target_investor.aum: aum_diff = abs(fund.investor.aum - target_investor.aum) max_aum = max(fund.investor.aum, target_investor.aum) similarity_ratio = 1 - (aum_diff / max_aum) if max_aum > 0 else 0 score += int(15 * similarity_ratio) # Sector overlap (30 points max) if fund.sectors and target_sector_ids: fund_sector_ids = {sector.id for sector in fund.sectors} common_sectors = target_sector_ids.intersection(fund_sector_ids) overlap_ratio = len(common_sectors) / len(target_sector_ids) score += int(30 * overlap_ratio) # Investment stage match (15 points max) if fund.investment_stages and target_stage_ids: fund_stage_ids = {stage.id for stage in fund.investment_stages} common_stages = target_stage_ids.intersection(fund_stage_ids) overlap_ratio = len(common_stages) / len(target_stage_ids) score += int(15 * overlap_ratio) if score > 0: # Only include funds with some similarity scored_funds.append((score, fund)) # Sort by score (descending) and take top N based on limit scored_funds.sort(key=lambda x: x[0], reverse=True) top_similar = scored_funds[:limit] # Apply pagination to the top similar funds total_count = len(top_similar) offset = (page - 1) * page_size paginated_similar = top_similar[offset : offset + page_size] similar_funds = [fund for score, fund in paginated_similar] # Transform to InvestmentResponse format (one row per fund) investment_responses = [] for fund in similar_funds: investor = fund.investor # Get top 3 portfolio companies (id and name only) portfolio_companies = [ CompanyMinimal(id=company.id, name=company.name) for company in investor.portfolio_companies[:3] ] # Get stage focus as comma-separated string stage_focus = ( ", ".join([stage.name for stage in fund.investment_stages]) if fund.investment_stages else None ) # Get top 3 sectors from fund (id and name only) fund_sectors = [ SectorMinimal(id=sector.id, name=sector.name) for sector in (fund.sectors[:3] if fund.sectors else []) ] investment_response = InvestmentResponse( id=investor.id, name=f"{investor.name} - {fund.fund_name}" if fund.fund_name else investor.name, aum=investor.aum, check_size_lower=fund.check_size_lower, check_size_upper=fund.check_size_upper, geographic_focus=fund.geographic_focus, stage_focus=stage_focus, portfolio_companies=portfolio_companies, sectors=fund_sectors, compatibility_score=1.0, ) investment_responses.append(investment_response) # Calculate total pages total_pages = (total_count + page_size - 1) // page_size return PaginatedResponse( items=investment_responses, total=total_count, page=page, page_size=page_size, total_pages=total_pages, )