Files
Anton_wireframe/app/routers/companies.py
T

233 lines
7.0 KiB
Python
Raw Normal View History

from typing import List, Optional
from db.db import get_db
from db.models import CompanyTable, InvestorTable
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
2025-09-25 17:00:38 +01:00
from schemas.router_schemas import CompanyData
from sqlalchemy.orm import Session, selectinload
router = APIRouter(tags=["Company Routes"])
# Request schemas for creating/updating
class CompanyCreate(BaseModel):
name: str
industry: str
location: str
2025-09-25 17:00:38 +01:00
description: Optional[str] = None
founded_year: Optional[int] = None
website: Optional[str] = None
class CompanyUpdate(BaseModel):
name: Optional[str] = None
industry: Optional[str] = None
location: Optional[str] = None
2025-09-25 17:00:38 +01:00
description: Optional[str] = None
founded_year: Optional[int] = None
website: Optional[str] = None
@router.get("/companies", response_model=List[CompanyData])
def read_companies(db: Session = Depends(get_db)):
"""Get all companies with their investor relationships"""
companies = (
2025-09-25 17:00:38 +01:00
db.query(CompanyTable)
.filter(
CompanyTable.name.isnot(None),
CompanyTable.description.isnot(None)
)
2025-09-25 17:00:38 +01:00
.options(
selectinload(CompanyTable.investors),
selectinload(CompanyTable.members),
selectinload(CompanyTable.sectors),
)
.all()
)
# Transform CompanyTable objects to CompanyData format
company_data_list = []
for company in companies:
2025-09-25 17:00:38 +01:00
company_data = CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=company.sectors,
)
company_data_list.append(company_data)
return company_data_list
@router.get("/companies/filter", response_model=List[CompanyData])
def filter_companies(
industry: Optional[str] = Query(
None, description="Filter by industry (partial match)"
),
location: Optional[str] = Query(
None, description="Filter by location (partial match)"
),
founded_after: Optional[int] = Query(None, description="Founded after year"),
founded_before: Optional[int] = Query(None, description="Founded before year"),
has_website: Optional[bool] = Query(
None, description="Filter companies with/without website"
),
investor_name: Optional[str] = Query(
None, description="Filter by investor name (partial match)"
),
db: Session = Depends(get_db),
):
"""Filter companies based on various criteria"""
# Start with base query
2025-09-25 17:00:38 +01:00
query = db.query(CompanyTable).options(
selectinload(CompanyTable.investors),
selectinload(CompanyTable.members),
selectinload(CompanyTable.sectors),
)
# Apply filters
if industry:
query = query.filter(CompanyTable.industry.ilike(f"%{industry}%"))
if location:
query = query.filter(CompanyTable.location.ilike(f"%{location}%"))
if founded_after is not None:
query = query.filter(CompanyTable.founded_year >= founded_after)
if founded_before is not None:
query = query.filter(CompanyTable.founded_year <= founded_before)
if has_website is not None:
if has_website:
query = query.filter(CompanyTable.website.isnot(None))
else:
query = query.filter(CompanyTable.website.is_(None))
# Filter by investor if provided
if investor_name:
query = query.join(CompanyTable.investors).filter(
InvestorTable.name.ilike(f"%{investor_name}%")
)
companies = query.all()
# Transform to CompanyData format
company_data_list = []
for company in companies:
2025-09-25 17:00:38 +01:00
company_data = CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=company.sectors,
)
company_data_list.append(company_data)
return company_data_list
@router.get("/companies/{company_id}", response_model=CompanyData)
def read_company(company_id: int, db: Session = Depends(get_db)):
"""Get a specific company by ID with its investors"""
company = (
db.query(CompanyTable)
2025-09-25 17:00:38 +01:00
.options(
selectinload(CompanyTable.investors),
selectinload(CompanyTable.members),
selectinload(CompanyTable.sectors),
)
.filter(CompanyTable.id == company_id)
.first()
)
if not company:
raise HTTPException(status_code=404, detail="Company not found")
# Transform to CompanyData format
2025-09-25 17:00:38 +01:00
return CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=company.sectors,
)
@router.post("/companies", response_model=CompanyData)
def create_company(company: CompanyCreate, db: Session = Depends(get_db)):
"""Create a new company"""
db_company = CompanyTable(**company.dict())
db.add(db_company)
db.commit()
db.refresh(db_company)
# Reload with relationships
company_with_relations = (
db.query(CompanyTable)
2025-09-25 17:00:38 +01:00
.options(
selectinload(CompanyTable.investors),
selectinload(CompanyTable.members),
selectinload(CompanyTable.sectors),
)
.filter(CompanyTable.id == db_company.id)
.first()
)
# Transform to CompanyData format
return CompanyData(
2025-09-25 17:00:38 +01:00
company=company_with_relations,
investors=company_with_relations.investors,
members=company_with_relations.members,
sectors=company_with_relations.sectors,
)
@router.put("/companies/{company_id}", response_model=CompanyData)
def update_company(
company_id: int, company: CompanyUpdate, db: Session = Depends(get_db)
):
"""Update an existing company"""
db_company = db.query(CompanyTable).filter(CompanyTable.id == company_id).first()
if not db_company:
raise HTTPException(status_code=404, detail="Company not found")
update_data = company.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(db_company, field, value)
db.commit()
db.refresh(db_company)
# Reload with relationships
company_with_relations = (
db.query(CompanyTable)
2025-09-25 17:00:38 +01:00
.options(
selectinload(CompanyTable.investors),
selectinload(CompanyTable.members),
selectinload(CompanyTable.sectors),
)
.filter(CompanyTable.id == company_id)
.first()
)
# Transform to CompanyData format
return CompanyData(
2025-09-25 17:00:38 +01:00
company=company_with_relations,
investors=company_with_relations.investors,
members=company_with_relations.members,
sectors=company_with_relations.sectors,
)
@router.delete("/companies/{company_id}")
def delete_company(company_id: int, db: Session = Depends(get_db)):
"""Delete a company"""
db_company = db.query(CompanyTable).filter(CompanyTable.id == company_id).first()
if not db_company:
raise HTTPException(status_code=404, detail="Company not found")
db.delete(db_company)
db.commit()
return {"message": "Company deleted successfully"}