Refactor investor-related schemas and models; implement investor CRUD operations and update stage_focus values to uppercase

This commit is contained in:
bolade
2025-09-03 09:41:19 +01:00
parent 7b58834316
commit 84cbb888e6
9 changed files with 294 additions and 21 deletions
Binary file not shown.
Binary file not shown.
+230 -5
View File
@@ -1,8 +1,233 @@
from fastapi import APIRouter from typing import List, Optional
router = APIRouter() from db.db import get_db
from db.models import InvestorTable, SectorTable
from fastapi import APIRouter, Depends, HTTPException, Query
from py_schemas import InvestmentStage, InvestorData
from pydantic import BaseModel
from sqlalchemy.orm import Session, selectinload
@router.get("/investors") router = APIRouter(tags=["Investor Routes"])
def read_investors():
return {"message": "list of investors"}
# Request schemas for creating/updating
class InvestorCreate(BaseModel):
name: str
description: str = None
aum: int
check_size_lower: int
check_size_upper: int
geographic_focus: str
stage_focus: InvestmentStage
number_of_investments: int = 0
class InvestorUpdate(BaseModel):
name: str = None
description: str = None
aum: int = None
check_size_lower: int = None
check_size_upper: int = None
geographic_focus: str = None
stage_focus: InvestmentStage = None
number_of_investments: int = None
@router.get("/investors", response_model=List[InvestorData])
def read_investors(db: Session = Depends(get_db)):
"""Get all investors with their related data"""
investors = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
)
.all()
)
# Transform InvestorTable objects to InvestorData format
investor_data_list = []
for investor in investors:
investor_data = InvestorData(
investor=investor, # This maps to InvestorSchema
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members,
sectors=investor.sectors,
)
investor_data_list.append(investor_data)
return investor_data_list
@router.get("/investors/filter", response_model=List[InvestorData])
def filter_investors(
stage: Optional[InvestmentStage] = Query(
None, description="Filter by investment stage"
),
min_check_size: Optional[int] = Query(None, description="Minimum check size"),
max_check_size: Optional[int] = Query(None, description="Maximum check size"),
geography: Optional[str] = Query(
None, description="Geographic focus (partial match)"
),
sector: Optional[str] = Query(None, description="Sector name (partial match)"),
min_aum: Optional[int] = Query(None, description="Minimum AUM"),
max_aum: Optional[int] = Query(None, description="Maximum AUM"),
db: Session = Depends(get_db),
):
"""Filter investors based on various criteria"""
# Start with base query
query = db.query(InvestorTable).options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
)
# Apply filters
if stage:
query = query.filter(InvestorTable.stage_focus == stage)
if min_check_size is not None:
query = query.filter(InvestorTable.check_size_lower >= min_check_size)
if max_check_size is not None:
query = query.filter(InvestorTable.check_size_upper <= max_check_size)
if geography:
query = query.filter(InvestorTable.geographic_focus.ilike(f"%{geography}%"))
if min_aum is not None:
query = query.filter(InvestorTable.aum >= min_aum)
if max_aum is not None:
query = query.filter(InvestorTable.aum <= max_aum)
# Filter by sector if provided
if sector:
query = query.join(InvestorTable.sectors).filter(
SectorTable.name.ilike(f"%{sector}%")
)
investors = query.all()
# Transform to InvestorData format
investor_data_list = []
for investor in investors:
investor_data = InvestorData(
investor=investor,
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members,
sectors=investor.sectors,
)
investor_data_list.append(investor_data)
return investor_data_list
@router.get("/investors/{investor_id}", response_model=InvestorData)
def read_investor(investor_id: int, db: Session = Depends(get_db)):
"""Get a specific investor by ID"""
investor = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
)
.filter(InvestorTable.id == investor_id)
.first()
)
if not investor:
raise HTTPException(status_code=404, detail="Investor not found")
# Transform to InvestorData format
return InvestorData(
investor=investor,
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members,
sectors=investor.sectors,
)
@router.post("/investors", response_model=InvestorData)
def create_investor(investor: InvestorCreate, db: Session = Depends(get_db)):
"""Create a new investor"""
db_investor = InvestorTable(**investor.dict())
db.add(db_investor)
db.commit()
db.refresh(db_investor)
# Reload with relationships
investor_with_relations = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
)
.filter(InvestorTable.id == db_investor.id)
.first()
)
# Transform to InvestorData format
return InvestorData(
investor=investor_with_relations,
portfolio_companies=investor_with_relations.portfolio_companies,
team_members=investor_with_relations.team_members,
sectors=investor_with_relations.sectors,
)
@router.put("/investors/{investor_id}", response_model=InvestorData)
def update_investor(
investor_id: int, investor: InvestorUpdate, db: Session = Depends(get_db)
):
"""Update an existing investor"""
db_investor = (
db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
)
if not db_investor:
raise HTTPException(status_code=404, detail="Investor not found")
update_data = investor.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(db_investor, field, value)
db.commit()
db.refresh(db_investor)
# Reload with relationships
investor_with_relations = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
)
.filter(InvestorTable.id == investor_id)
.first()
)
# Transform to InvestorData format
return InvestorData(
investor=investor_with_relations,
portfolio_companies=investor_with_relations.portfolio_companies,
team_members=investor_with_relations.team_members,
sectors=investor_with_relations.sectors,
)
@router.delete("/investors/{investor_id}")
def delete_investor(investor_id: int, db: Session = Depends(get_db)):
"""Delete an investor"""
db_investor = (
db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
)
if not db_investor:
raise HTTPException(status_code=404, detail="Investor not found")
db.delete(db_investor)
db.commit()
return {"message": "Investor deleted successfully"}
+46
View File
@@ -0,0 +1,46 @@
from sqlalchemy.orm import Session
from db.models import InvestorTable
from db.db import get_db
def update_stage_focus_values():
"""Update existing stage_focus values from lowercase to uppercase"""
db = next(get_db())
try:
# Mapping of old lowercase values to new uppercase values
stage_mappings = {
'seed': 'SEED',
'series_a': 'SERIES_A',
'series_b': 'SERIES_B',
'series_c': 'SERIES_C',
'growth': 'GROWTH',
'late_stage': 'LATE_STAGE'
}
updated_count = 0
for old_value, new_value in stage_mappings.items():
# Update records with the old value
result = db.query(InvestorTable).filter(
InvestorTable.stage_focus == old_value
).update(
{InvestorTable.stage_focus: new_value},
synchronize_session=False
)
updated_count += result
print(f"Updated {result} records from '{old_value}' to '{new_value}'")
db.commit()
print(f"Successfully updated {updated_count} total records")
except Exception as e:
db.rollback()
print(f"Error updating stage_focus values: {e}")
raise
finally:
db.close()
# Run the update
if __name__ == "__main__":
update_stage_focus_values()
Binary file not shown.
+6 -6
View File
@@ -9,12 +9,12 @@ from db.db import Base
class InvestmentStage(enum.Enum): class InvestmentStage(enum.Enum):
SEED = "seed" SEED = "SEED"
SERIES_A = "series_a" SERIES_A = "SERIES_A"
SERIES_B = "series_b" SERIES_B = "SERIES_B"
SERIES_C = "series_c" SERIES_C = "SERIES_C"
GROWTH = "growth" GROWTH = "GROWTH"
LATE_STAGE = "late_stage" LATE_STAGE = "LATE_STAGE"
# Association table for many-to-many relationship between investors and companies # Association table for many-to-many relationship between investors and companies
+10 -8
View File
@@ -1,16 +1,17 @@
from pydantic import BaseModel
from datetime import datetime from datetime import datetime
from typing import List, Optional
from enum import Enum from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
class InvestmentStage(str, Enum): class InvestmentStage(str, Enum):
SEED = "seed" SEED = "SEED"
SERIES_A = "series_a" SERIES_A = "SERIES_A"
SERIES_B = "series_b" SERIES_B = "SERIES_B"
SERIES_C = "series_c" SERIES_C = "SERIES_C"
GROWTH = "growth" GROWTH = "GROWTH"
LATE_STAGE = "late_stage" LATE_STAGE = "LATE_STAGE"
class SectorSchema(BaseModel): class SectorSchema(BaseModel):
@@ -64,6 +65,7 @@ class InvestorSchema(BaseModel):
class InvestorData(BaseModel): class InvestorData(BaseModel):
"""Comprehensive investor data schema for LLM processing""" """Comprehensive investor data schema for LLM processing"""
investor: InvestorSchema investor: InvestorSchema
portfolio_companies: List[CompanySchema] = [] portfolio_companies: List[CompanySchema] = []
team_members: List[InvestorTeamMemberSchema] = [] team_members: List[InvestorTeamMemberSchema] = []
Binary file not shown.
Binary file not shown.