from typing import List, Optional from db.db import get_db from db.models import ( CompanyTable, InvestorTable, ProjectTable, SectorTable, ) from fastapi import APIRouter, Depends, HTTPException, Query from schemas.project_schemas import ( InvestmentStage, ProjectCreate, ProjectData, ProjectUpdate, ) from schemas.router_schemas import PaginatedResponse from sqlalchemy.orm import Session, selectinload router = APIRouter(tags=["Project Routes"]) @router.get("/projects", response_model=PaginatedResponse[ProjectData]) def read_projects( page: int = Query(1, ge=1, description="Page number (starts at 1)"), page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"), include_archived: bool = Query(False, description="Include archived projects"), db: Session = Depends(get_db), ): """Get all projects with their related data (paginated) By default, archived projects are excluded. Set include_archived=True to include them. """ # Calculate offset offset = (page - 1) * page_size # Start with base query query = db.query(ProjectTable) # Filter out archived projects by default if not include_archived: query = query.filter(ProjectTable.is_archived == 0) # Get total count total_count = query.count() # Get paginated results projects = ( query.options( selectinload(ProjectTable.sector), selectinload(ProjectTable.investors), selectinload(ProjectTable.companies), ) .offset(offset) .limit(page_size) .all() ) # Transform ProjectTable objects to ProjectData format project_data_list = [] for project in projects: project_data = ProjectData( project=project, sector=project.sector, investors=project.investors, companies=project.companies, ) project_data_list.append(project_data) # Calculate total pages total_pages = (total_count + page_size - 1) // page_size return PaginatedResponse( items=project_data_list, total=total_count, page=page, page_size=page_size, total_pages=total_pages, ) @router.get("/projects/{project_id}", response_model=ProjectData) def read_project(project_id: int, db: Session = Depends(get_db)): """Get a specific project by ID""" project = ( db.query(ProjectTable) .options( selectinload(ProjectTable.sector), selectinload(ProjectTable.investors), selectinload(ProjectTable.companies), ) .filter(ProjectTable.id == project_id) .first() ) if not project: raise HTTPException(status_code=404, detail="Project not found") return ProjectData( project=project, sector=project.sector, investors=project.investors, companies=project.companies, ) @router.post("/projects", response_model=ProjectData) def create_project(project: ProjectCreate, db: Session = Depends(get_db)): """Create a new project""" db_project = ProjectTable(**project.dict()) db.add(db_project) db.commit() db.refresh(db_project) # Reload with relationships db_project = ( db.query(ProjectTable) .options( selectinload(ProjectTable.sector), selectinload(ProjectTable.investors), selectinload(ProjectTable.companies), ) .filter(ProjectTable.id == db_project.id) .first() ) return ProjectData( project=db_project, sector=db_project.sector, investors=db_project.investors, companies=db_project.companies, ) @router.put("/projects/{project_id}", response_model=ProjectData) def update_project( project_id: int, project: ProjectUpdate, db: Session = Depends(get_db) ): """Update an existing project""" db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first() if not db_project: raise HTTPException(status_code=404, detail="Project not found") # Update only provided fields update_data = project.dict(exclude_unset=True) for key, value in update_data.items(): setattr(db_project, key, value) db.commit() db.refresh(db_project) # Reload with relationships db_project = ( db.query(ProjectTable) .options( selectinload(ProjectTable.sector), selectinload(ProjectTable.investors), selectinload(ProjectTable.companies), ) .filter(ProjectTable.id == project_id) .first() ) return ProjectData( project=db_project, sector=db_project.sector, investors=db_project.investors, companies=db_project.companies, ) @router.delete("/projects/{project_id}") def delete_project(project_id: int, db: Session = Depends(get_db)): """Delete a project permanently""" db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first() if not db_project: raise HTTPException(status_code=404, detail="Project not found") db.delete(db_project) db.commit() return {"message": "Project deleted successfully"} @router.post("/projects/{project_id}/archive") def archive_project(project_id: int, db: Session = Depends(get_db)): """Archive a project (soft delete)""" db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first() if not db_project: raise HTTPException(status_code=404, detail="Project not found") db_project.is_archived = 1 db.commit() db.refresh(db_project) return {"message": "Project archived successfully", "project_id": project_id} @router.post("/projects/{project_id}/unarchive") def unarchive_project(project_id: int, db: Session = Depends(get_db)): """Unarchive a project (restore from archive)""" db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first() if not db_project: raise HTTPException(status_code=404, detail="Project not found") db_project.is_archived = 0 db.commit() db.refresh(db_project) return {"message": "Project unarchived successfully", "project_id": project_id} @router.get("/projects/archived", response_model=PaginatedResponse[ProjectData]) def read_archived_projects( page: int = Query(1, ge=1, description="Page number (starts at 1)"), page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"), db: Session = Depends(get_db), ): """Get all archived projects (paginated)""" # Calculate offset offset = (page - 1) * page_size # Query only archived projects query = db.query(ProjectTable).filter(ProjectTable.is_archived == 1) # Get total count total_count = query.count() # Get paginated results projects = ( query.options( selectinload(ProjectTable.sector), selectinload(ProjectTable.investors), selectinload(ProjectTable.companies), ) .offset(offset) .limit(page_size) .all() ) # Transform ProjectTable objects to ProjectData format project_data_list = [] for project in projects: project_data = ProjectData( project=project, sector=project.sector, investors=project.investors, companies=project.companies, ) project_data_list.append(project_data) # Calculate total pages total_pages = (total_count + page_size - 1) // page_size return PaginatedResponse( items=project_data_list, total=total_count, page=page, page_size=page_size, total_pages=total_pages, ) @router.get("/projects/filter", response_model=PaginatedResponse[ProjectData]) def filter_projects( stage: Optional[InvestmentStage] = Query( None, description="Filter by project stage" ), min_valuation: Optional[int] = Query(None, description="Minimum valuation"), max_valuation: Optional[int] = Query(None, description="Maximum valuation"), location: Optional[str] = Query(None, description="Location (partial match)"), industry: Optional[str] = Query(None, description="Industry (partial match)"), sector: Optional[str] = Query(None, description="Sector name (partial match)"), investor_name: Optional[str] = Query( None, description="Investor name (partial match)" ), company_name: Optional[str] = Query( None, description="Company name (partial match)" ), page: int = Query(1, ge=1, description="Page number (starts at 1)"), page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"), db: Session = Depends(get_db), ): """Filter projects based on various criteria (paginated)""" # Start with base query query = db.query(ProjectTable).options( selectinload(ProjectTable.sector), selectinload(ProjectTable.investors), selectinload(ProjectTable.companies), ) # Apply filters if stage: query = query.filter(ProjectTable.stage == stage) if min_valuation is not None: query = query.filter(ProjectTable.valuation >= min_valuation) if max_valuation is not None: query = query.filter(ProjectTable.valuation <= max_valuation) if location: query = query.filter(ProjectTable.location.ilike(f"%{location}%")) if industry: query = query.filter(ProjectTable.industry.ilike(f"%{industry}%")) if sector: query = query.join(ProjectTable.sector).filter( SectorTable.name.ilike(f"%{sector}%") ) if investor_name: query = query.join(ProjectTable.investors).filter( InvestorTable.name.ilike(f"%{investor_name}%") ) if company_name: query = query.join(ProjectTable.companies).filter( CompanyTable.name.ilike(f"%{company_name}%") ) # Get total count before pagination total_count = query.count() # Calculate offset and apply pagination offset = (page - 1) * page_size projects = query.offset(offset).limit(page_size).all() # Transform to ProjectData format project_data_list = [] for project in projects: project_data = ProjectData( project=project, sector=project.sector, investors=project.investors, companies=project.companies, ) project_data_list.append(project_data) # Calculate total pages total_pages = (total_count + page_size - 1) // page_size return PaginatedResponse( items=project_data_list, total=total_count, page=page, page_size=page_size, total_pages=total_pages, ) # Association management routes @router.post("/projects/{project_id}/investors/{investor_id}") def add_investor_to_project( project_id: int, investor_id: int, db: Session = Depends(get_db) ): """Add an investor to a project""" # Check if project exists project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") # Check if investor exists investor = db.query(InvestorTable).filter(InvestorTable.id == investor_id).first() if not investor: raise HTTPException(status_code=404, detail="Investor not found") # Check if association already exists if investor in project.investors: raise HTTPException( status_code=400, detail="Investor already associated with project" ) # Add association project.investors.append(investor) db.commit() return {"message": "Investor added to project successfully"} @router.delete("/projects/{project_id}/investors/{investor_id}") def remove_investor_from_project( project_id: int, investor_id: int, db: Session = Depends(get_db) ): """Remove an investor from a project""" # Check if project exists project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") # Check if investor exists investor = db.query(InvestorTable).filter(InvestorTable.id == investor_id).first() if not investor: raise HTTPException(status_code=404, detail="Investor not found") # Check if association exists if investor not in project.investors: raise HTTPException( status_code=400, detail="Investor not associated with project" ) # Remove association project.investors.remove(investor) db.commit() return {"message": "Investor removed from project successfully"} @router.post("/projects/{project_id}/companies/{company_id}") def add_company_to_project( project_id: int, company_id: int, db: Session = Depends(get_db) ): """Add a company to a project""" # Check if project exists project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") # Check if company exists company = db.query(CompanyTable).filter(CompanyTable.id == company_id).first() if not company: raise HTTPException(status_code=404, detail="Company not found") # Check if association already exists if company in project.companies: raise HTTPException( status_code=400, detail="Company already associated with project" ) # Add association project.companies.append(company) db.commit() return {"message": "Company added to project successfully"} @router.delete("/projects/{project_id}/companies/{company_id}") def remove_company_from_project( project_id: int, company_id: int, db: Session = Depends(get_db) ): """Remove a company from a project""" # Check if project exists project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") # Check if company exists company = db.query(CompanyTable).filter(CompanyTable.id == company_id).first() if not company: raise HTTPException(status_code=404, detail="Company not found") # Check if association exists if company not in project.companies: raise HTTPException( status_code=400, detail="Company not associated with project" ) # Remove association project.companies.remove(company) db.commit() return {"message": "Company removed from project successfully"} @router.post("/projects/{project_id}/sectors/{sector_id}") def add_sector_to_project( project_id: int, sector_id: int, db: Session = Depends(get_db) ): """Add a sector to a project""" # Check if project exists project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") # Check if sector exists sector = db.query(SectorTable).filter(SectorTable.id == sector_id).first() if not sector: raise HTTPException(status_code=404, detail="Sector not found") # Check if association already exists if sector in project.sector: raise HTTPException( status_code=400, detail="Sector already associated with project" ) # Add association project.sector.append(sector) db.commit() return {"message": "Sector added to project successfully"} @router.delete("/projects/{project_id}/sectors/{sector_id}") def remove_sector_from_project( project_id: int, sector_id: int, db: Session = Depends(get_db) ): """Remove a sector from a project""" # Check if project exists project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") # Check if sector exists sector = db.query(SectorTable).filter(SectorTable.id == sector_id).first() if not sector: raise HTTPException(status_code=404, detail="Sector not found") # Check if association exists if sector not in project.sector: raise HTTPException( status_code=400, detail="Sector not associated with project" ) # Remove association project.sector.remove(sector) db.commit() return {"message": "Sector removed from project successfully"} # Bulk association management @router.post("/projects/{project_id}/investors") def add_multiple_investors_to_project( project_id: int, investor_ids: List[int], db: Session = Depends(get_db) ): """Add multiple investors to a project""" # Check if project exists project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") # Get all investors investors = db.query(InvestorTable).filter(InvestorTable.id.in_(investor_ids)).all() if len(investors) != len(investor_ids): raise HTTPException(status_code=404, detail="One or more investors not found") # Add associations (only if not already associated) added_count = 0 for investor in investors: if investor not in project.investors: project.investors.append(investor) added_count += 1 db.commit() return {"message": f"Added {added_count} investors to project successfully"} @router.post("/projects/{project_id}/companies") def add_multiple_companies_to_project( project_id: int, company_ids: List[int], db: Session = Depends(get_db) ): """Add multiple companies to a project""" # Check if project exists project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first() if not project: raise HTTPException(status_code=404, detail="Project not found") # Get all companies companies = db.query(CompanyTable).filter(CompanyTable.id.in_(company_ids)).all() if len(companies) != len(company_ids): raise HTTPException(status_code=404, detail="One or more companies not found") # Add associations (only if not already associated) added_count = 0 for company in companies: if company not in project.companies: project.companies.append(company) added_count += 1 db.commit() return {"message": f"Added {added_count} companies to project successfully"}