from typing import Optional from db.db import get_db from db.models import CompanyTable, InvestorTable from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel from schemas.router_schemas import CompanyData, PaginatedResponse from sqlalchemy.orm import Session, selectinload router = APIRouter(tags=["Company Routes"]) # Request schemas for creating/updating class CompanyCreate(BaseModel): name: str industry: str location: str description: Optional[str] = None founded_year: Optional[int] = None website: Optional[str] = None class CompanyUpdate(BaseModel): name: Optional[str] = None industry: Optional[str] = None location: Optional[str] = None description: Optional[str] = None founded_year: Optional[int] = None website: Optional[str] = None @router.get("/companies", response_model=PaginatedResponse[CompanyData]) def read_companies( page: int = Query(1, ge=1, description="Page number (starts at 1)"), page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"), db: Session = Depends(get_db), ): """Get all companies with their investor relationships (paginated)""" # Calculate offset offset = (page - 1) * page_size # Get total count total_count = ( db.query(CompanyTable) .filter(CompanyTable.name.isnot(None), CompanyTable.description.isnot(None)) .count() ) # Get paginated results companies = ( db.query(CompanyTable) .filter(CompanyTable.name.isnot(None), CompanyTable.description.isnot(None)) .options( selectinload(CompanyTable.investors), selectinload(CompanyTable.members), selectinload(CompanyTable.sectors), ) .offset(offset) .limit(page_size) .all() ) # Transform CompanyTable objects to CompanyData format company_data_list = [] for company in companies: # Sort sectors alphabetically sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else [] company_data = CompanyData( company=company, investors=company.investors, members=company.members, sectors=sorted_sectors, ) company_data_list.append(company_data) # Calculate total pages total_pages = (total_count + page_size - 1) // page_size return PaginatedResponse( items=company_data_list, total=total_count, page=page, page_size=page_size, total_pages=total_pages, ) @router.get("/companies/filter", response_model=PaginatedResponse[CompanyData]) def filter_companies( industry: Optional[str] = Query( None, description="Filter by industry (partial match)" ), location: Optional[str] = Query( None, description="Filter by location (partial match)" ), founded_after: Optional[int] = Query(None, description="Founded after year"), founded_before: Optional[int] = Query(None, description="Founded before year"), has_website: Optional[bool] = Query( None, description="Filter companies with/without website" ), investor_name: Optional[str] = Query( None, description="Filter by investor name (partial match)" ), page: int = Query(1, ge=1, description="Page number (starts at 1)"), page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"), db: Session = Depends(get_db), ): """Filter companies based on various criteria (paginated)""" # Start with base query query = db.query(CompanyTable).options( selectinload(CompanyTable.investors), selectinload(CompanyTable.members), selectinload(CompanyTable.sectors), ) # Apply filters if industry: query = query.filter(CompanyTable.industry.ilike(f"%{industry}%")) if location: query = query.filter(CompanyTable.location.ilike(f"%{location}%")) if founded_after is not None: query = query.filter(CompanyTable.founded_year >= founded_after) if founded_before is not None: query = query.filter(CompanyTable.founded_year <= founded_before) if has_website is not None: if has_website: query = query.filter(CompanyTable.website.isnot(None)) else: query = query.filter(CompanyTable.website.is_(None)) # Filter by investor if provided if investor_name: query = query.join(CompanyTable.investors).filter( InvestorTable.name.ilike(f"%{investor_name}%") ) # Get total count before pagination total_count = query.count() # Calculate offset and apply pagination offset = (page - 1) * page_size companies = query.offset(offset).limit(page_size).all() # Transform to CompanyData format company_data_list = [] for company in companies: # Sort sectors alphabetically sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else [] company_data = CompanyData( company=company, investors=company.investors, members=company.members, sectors=sorted_sectors, ) company_data_list.append(company_data) # Calculate total pages total_pages = (total_count + page_size - 1) // page_size return PaginatedResponse( items=company_data_list, total=total_count, page=page, page_size=page_size, total_pages=total_pages, ) @router.get("/companies/{company_id}", response_model=CompanyData) def read_company(company_id: int, db: Session = Depends(get_db)): """Get a specific company by ID with its investors""" company = ( db.query(CompanyTable) .options( selectinload(CompanyTable.investors), selectinload(CompanyTable.members), selectinload(CompanyTable.sectors), ) .filter(CompanyTable.id == company_id) .first() ) if not company: raise HTTPException(status_code=404, detail="Company not found") # Sort sectors alphabetically sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else [] # Transform to CompanyData format return CompanyData( company=company, investors=company.investors, members=company.members, sectors=sorted_sectors, ) @router.post("/companies", response_model=CompanyData) def create_company(company: CompanyCreate, db: Session = Depends(get_db)): """Create a new company""" db_company = CompanyTable(**company.dict()) db.add(db_company) db.commit() db.refresh(db_company) # Reload with relationships company_with_relations = ( db.query(CompanyTable) .options( selectinload(CompanyTable.investors), selectinload(CompanyTable.members), selectinload(CompanyTable.sectors), ) .filter(CompanyTable.id == db_company.id) .first() ) # Transform to CompanyData format return CompanyData( company=company_with_relations, investors=company_with_relations.investors, members=company_with_relations.members, sectors=company_with_relations.sectors, ) @router.put("/companies/{company_id}", response_model=CompanyData) def update_company( company_id: int, company: CompanyUpdate, db: Session = Depends(get_db) ): """Update an existing company""" db_company = db.query(CompanyTable).filter(CompanyTable.id == company_id).first() if not db_company: raise HTTPException(status_code=404, detail="Company not found") update_data = company.dict(exclude_unset=True) for field, value in update_data.items(): setattr(db_company, field, value) db.commit() db.refresh(db_company) # Reload with relationships company_with_relations = ( db.query(CompanyTable) .options( selectinload(CompanyTable.investors), selectinload(CompanyTable.members), selectinload(CompanyTable.sectors), ) .filter(CompanyTable.id == company_id) .first() ) # Sort sectors alphabetically sorted_sectors = sorted(company_with_relations.sectors, key=lambda s: s.name) if company_with_relations.sectors else [] # Transform to CompanyData format return CompanyData( company=company_with_relations, investors=company_with_relations.investors, members=company_with_relations.members, sectors=sorted_sectors, ) @router.delete("/companies/{company_id}") def delete_company(company_id: int, db: Session = Depends(get_db)): """Delete a company""" db_company = db.query(CompanyTable).filter(CompanyTable.id == company_id).first() if not db_company: raise HTTPException(status_code=404, detail="Company not found") db.delete(db_company) db.commit() return {"message": "Company deleted successfully"}