from typing import List, Optional from db.db import get_db from db.models import InvestorTable, SectorTable from fastapi import APIRouter, Depends, HTTPException, Query from schemas.router_schemas import InvestmentStage, InvestorData from pydantic import BaseModel from sqlalchemy.orm import Session, selectinload router = APIRouter(tags=["Investor Routes"]) # Request schemas for creating/updating class InvestorCreate(BaseModel): name: str description: Optional[str] = None aum: int check_size_lower: int check_size_upper: int geographic_focus: str stage_focus: InvestmentStage number_of_investments: int = 0 class InvestorUpdate(BaseModel): name: Optional[str] = None description: Optional[str] = None aum: Optional[int] = None check_size_lower: Optional[int] = None check_size_upper: Optional[int] = None geographic_focus: Optional[str] = None stage_focus: Optional[InvestmentStage] = None number_of_investments: Optional[int] = None @router.get("/investors", response_model=List[InvestorData]) def read_investors(db: Session = Depends(get_db)): """Get all investors with their related data""" investors = ( db.query(InvestorTable) .options( selectinload(InvestorTable.portfolio_companies), selectinload(InvestorTable.team_members), selectinload(InvestorTable.sectors), ) .all() ) # Transform InvestorTable objects to InvestorData format investor_data_list = [] for investor in investors: investor_data = InvestorData( investor=investor, # This maps to InvestorSchema portfolio_companies=investor.portfolio_companies, team_members=investor.team_members, sectors=investor.sectors, ) investor_data_list.append(investor_data) return investor_data_list @router.get("/investors/filter", response_model=List[InvestorData]) def filter_investors( stage: Optional[InvestmentStage] = Query( None, description="Filter by investment stage" ), min_check_size: Optional[int] = Query(None, description="Minimum check size"), max_check_size: Optional[int] = Query(None, description="Maximum check size"), geography: Optional[str] = Query( None, description="Geographic focus (partial match)" ), sector: Optional[str] = Query(None, description="Sector name (partial match)"), min_aum: Optional[int] = Query(None, description="Minimum AUM"), max_aum: Optional[int] = Query(None, description="Maximum AUM"), db: Session = Depends(get_db), ): """Filter investors based on various criteria""" # Start with base query query = db.query(InvestorTable).options( selectinload(InvestorTable.portfolio_companies), selectinload(InvestorTable.team_members), selectinload(InvestorTable.sectors), ) # Apply filters if stage: query = query.filter(InvestorTable.stage_focus == stage) if min_check_size is not None: query = query.filter(InvestorTable.check_size_lower >= min_check_size) if max_check_size is not None: query = query.filter(InvestorTable.check_size_upper <= max_check_size) if geography: query = query.filter(InvestorTable.geographic_focus.ilike(f"%{geography}%")) if min_aum is not None: query = query.filter(InvestorTable.aum >= min_aum) if max_aum is not None: query = query.filter(InvestorTable.aum <= max_aum) # Filter by sector if provided if sector: query = query.join(InvestorTable.sectors).filter( SectorTable.name.ilike(f"%{sector}%") ) investors = query.all() # Transform to InvestorData format investor_data_list = [] for investor in investors: investor_data = InvestorData( investor=investor, portfolio_companies=investor.portfolio_companies, team_members=investor.team_members, sectors=investor.sectors, ) investor_data_list.append(investor_data) return investor_data_list @router.get("/investors/{investor_id}", response_model=InvestorData) def read_investor(investor_id: int, db: Session = Depends(get_db)): """Get a specific investor by ID""" investor = ( db.query(InvestorTable) .options( selectinload(InvestorTable.portfolio_companies), selectinload(InvestorTable.team_members), selectinload(InvestorTable.sectors), ) .filter(InvestorTable.id == investor_id) .first() ) if not investor: raise HTTPException(status_code=404, detail="Investor not found") # Transform to InvestorData format return InvestorData( investor=investor, portfolio_companies=investor.portfolio_companies, team_members=investor.team_members, sectors=investor.sectors, ) @router.post("/investors", response_model=InvestorData) def create_investor(investor: InvestorCreate, db: Session = Depends(get_db)): """Create a new investor""" db_investor = InvestorTable(**investor.dict()) db.add(db_investor) db.commit() db.refresh(db_investor) # Reload with relationships investor_with_relations = ( db.query(InvestorTable) .options( selectinload(InvestorTable.portfolio_companies), selectinload(InvestorTable.team_members), selectinload(InvestorTable.sectors), ) .filter(InvestorTable.id == db_investor.id) .first() ) # Transform to InvestorData format return InvestorData( investor=investor_with_relations, portfolio_companies=investor_with_relations.portfolio_companies, team_members=investor_with_relations.team_members, sectors=investor_with_relations.sectors, ) @router.put("/investors/{investor_id}", response_model=InvestorData) def update_investor( investor_id: int, investor: InvestorUpdate, db: Session = Depends(get_db) ): """Update an existing investor""" db_investor = ( db.query(InvestorTable).filter(InvestorTable.id == investor_id).first() ) if not db_investor: raise HTTPException(status_code=404, detail="Investor not found") update_data = investor.dict(exclude_unset=True) for field, value in update_data.items(): setattr(db_investor, field, value) db.commit() db.refresh(db_investor) # Reload with relationships investor_with_relations = ( db.query(InvestorTable) .options( selectinload(InvestorTable.portfolio_companies), selectinload(InvestorTable.team_members), selectinload(InvestorTable.sectors), ) .filter(InvestorTable.id == investor_id) .first() ) # Transform to InvestorData format return InvestorData( investor=investor_with_relations, portfolio_companies=investor_with_relations.portfolio_companies, team_members=investor_with_relations.team_members, sectors=investor_with_relations.sectors, ) @router.delete("/investors/{investor_id}") def delete_investor(investor_id: int, db: Session = Depends(get_db)): """Delete an investor""" db_investor = ( db.query(InvestorTable).filter(InvestorTable.id == investor_id).first() ) if not db_investor: raise HTTPException(status_code=404, detail="Investor not found") db.delete(db_investor) db.commit() return {"message": "Investor deleted successfully"}