Refactor investor-related schemas and models; implement investor CRUD operations and update stage_focus values to uppercase
This commit is contained in:
Binary file not shown.
Binary file not shown.
+230
-5
@@ -1,8 +1,233 @@
|
||||
from fastapi import APIRouter
|
||||
from typing import List, Optional
|
||||
|
||||
router = APIRouter()
|
||||
from db.db import get_db
|
||||
from db.models import InvestorTable, SectorTable
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from py_schemas import InvestmentStage, InvestorData
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
@router.get("/investors")
|
||||
def read_investors():
|
||||
return {"message": "list of investors"}
|
||||
router = APIRouter(tags=["Investor Routes"])
|
||||
|
||||
|
||||
# Request schemas for creating/updating
|
||||
class InvestorCreate(BaseModel):
|
||||
name: str
|
||||
description: str = None
|
||||
aum: int
|
||||
check_size_lower: int
|
||||
check_size_upper: int
|
||||
geographic_focus: str
|
||||
stage_focus: InvestmentStage
|
||||
number_of_investments: int = 0
|
||||
|
||||
|
||||
class InvestorUpdate(BaseModel):
|
||||
name: str = None
|
||||
description: str = None
|
||||
aum: int = None
|
||||
check_size_lower: int = None
|
||||
check_size_upper: int = None
|
||||
geographic_focus: str = None
|
||||
stage_focus: InvestmentStage = None
|
||||
number_of_investments: int = None
|
||||
|
||||
|
||||
@router.get("/investors", response_model=List[InvestorData])
|
||||
def read_investors(db: Session = Depends(get_db)):
|
||||
"""Get all investors with their related data"""
|
||||
investors = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Transform InvestorTable objects to InvestorData format
|
||||
investor_data_list = []
|
||||
for investor in investors:
|
||||
investor_data = InvestorData(
|
||||
investor=investor, # This maps to InvestorSchema
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
)
|
||||
investor_data_list.append(investor_data)
|
||||
|
||||
return investor_data_list
|
||||
|
||||
|
||||
@router.get("/investors/filter", response_model=List[InvestorData])
|
||||
def filter_investors(
|
||||
stage: Optional[InvestmentStage] = Query(
|
||||
None, description="Filter by investment stage"
|
||||
),
|
||||
min_check_size: Optional[int] = Query(None, description="Minimum check size"),
|
||||
max_check_size: Optional[int] = Query(None, description="Maximum check size"),
|
||||
geography: Optional[str] = Query(
|
||||
None, description="Geographic focus (partial match)"
|
||||
),
|
||||
sector: Optional[str] = Query(None, description="Sector name (partial match)"),
|
||||
min_aum: Optional[int] = Query(None, description="Minimum AUM"),
|
||||
max_aum: Optional[int] = Query(None, description="Maximum AUM"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Filter investors based on various criteria"""
|
||||
|
||||
# Start with base query
|
||||
query = db.query(InvestorTable).options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if stage:
|
||||
query = query.filter(InvestorTable.stage_focus == stage)
|
||||
|
||||
if min_check_size is not None:
|
||||
query = query.filter(InvestorTable.check_size_lower >= min_check_size)
|
||||
|
||||
if max_check_size is not None:
|
||||
query = query.filter(InvestorTable.check_size_upper <= max_check_size)
|
||||
|
||||
if geography:
|
||||
query = query.filter(InvestorTable.geographic_focus.ilike(f"%{geography}%"))
|
||||
|
||||
if min_aum is not None:
|
||||
query = query.filter(InvestorTable.aum >= min_aum)
|
||||
|
||||
if max_aum is not None:
|
||||
query = query.filter(InvestorTable.aum <= max_aum)
|
||||
|
||||
# Filter by sector if provided
|
||||
if sector:
|
||||
query = query.join(InvestorTable.sectors).filter(
|
||||
SectorTable.name.ilike(f"%{sector}%")
|
||||
)
|
||||
|
||||
investors = query.all()
|
||||
|
||||
# Transform to InvestorData format
|
||||
investor_data_list = []
|
||||
for investor in investors:
|
||||
investor_data = InvestorData(
|
||||
investor=investor,
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
)
|
||||
investor_data_list.append(investor_data)
|
||||
|
||||
return investor_data_list
|
||||
|
||||
|
||||
@router.get("/investors/{investor_id}", response_model=InvestorData)
|
||||
def read_investor(investor_id: int, db: Session = Depends(get_db)):
|
||||
"""Get a specific investor by ID"""
|
||||
investor = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
)
|
||||
.filter(InvestorTable.id == investor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not investor:
|
||||
raise HTTPException(status_code=404, detail="Investor not found")
|
||||
|
||||
# Transform to InvestorData format
|
||||
return InvestorData(
|
||||
investor=investor,
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/investors", response_model=InvestorData)
|
||||
def create_investor(investor: InvestorCreate, db: Session = Depends(get_db)):
|
||||
"""Create a new investor"""
|
||||
db_investor = InvestorTable(**investor.dict())
|
||||
db.add(db_investor)
|
||||
db.commit()
|
||||
db.refresh(db_investor)
|
||||
|
||||
# Reload with relationships
|
||||
investor_with_relations = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
)
|
||||
.filter(InvestorTable.id == db_investor.id)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Transform to InvestorData format
|
||||
return InvestorData(
|
||||
investor=investor_with_relations,
|
||||
portfolio_companies=investor_with_relations.portfolio_companies,
|
||||
team_members=investor_with_relations.team_members,
|
||||
sectors=investor_with_relations.sectors,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/investors/{investor_id}", response_model=InvestorData)
|
||||
def update_investor(
|
||||
investor_id: int, investor: InvestorUpdate, db: Session = Depends(get_db)
|
||||
):
|
||||
"""Update an existing investor"""
|
||||
db_investor = (
|
||||
db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
|
||||
)
|
||||
if not db_investor:
|
||||
raise HTTPException(status_code=404, detail="Investor not found")
|
||||
|
||||
update_data = investor.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_investor, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_investor)
|
||||
|
||||
# Reload with relationships
|
||||
investor_with_relations = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
)
|
||||
.filter(InvestorTable.id == investor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Transform to InvestorData format
|
||||
return InvestorData(
|
||||
investor=investor_with_relations,
|
||||
portfolio_companies=investor_with_relations.portfolio_companies,
|
||||
team_members=investor_with_relations.team_members,
|
||||
sectors=investor_with_relations.sectors,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/investors/{investor_id}")
|
||||
def delete_investor(investor_id: int, db: Session = Depends(get_db)):
|
||||
"""Delete an investor"""
|
||||
db_investor = (
|
||||
db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
|
||||
)
|
||||
if not db_investor:
|
||||
raise HTTPException(status_code=404, detail="Investor not found")
|
||||
|
||||
db.delete(db_investor)
|
||||
db.commit()
|
||||
return {"message": "Investor deleted successfully"}
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
from sqlalchemy.orm import Session
|
||||
from db.models import InvestorTable
|
||||
from db.db import get_db
|
||||
|
||||
def update_stage_focus_values():
|
||||
"""Update existing stage_focus values from lowercase to uppercase"""
|
||||
db = next(get_db())
|
||||
|
||||
try:
|
||||
# Mapping of old lowercase values to new uppercase values
|
||||
stage_mappings = {
|
||||
'seed': 'SEED',
|
||||
'series_a': 'SERIES_A',
|
||||
'series_b': 'SERIES_B',
|
||||
'series_c': 'SERIES_C',
|
||||
'growth': 'GROWTH',
|
||||
'late_stage': 'LATE_STAGE'
|
||||
}
|
||||
|
||||
updated_count = 0
|
||||
|
||||
for old_value, new_value in stage_mappings.items():
|
||||
# Update records with the old value
|
||||
result = db.query(InvestorTable).filter(
|
||||
InvestorTable.stage_focus == old_value
|
||||
).update(
|
||||
{InvestorTable.stage_focus: new_value},
|
||||
synchronize_session=False
|
||||
)
|
||||
|
||||
updated_count += result
|
||||
print(f"Updated {result} records from '{old_value}' to '{new_value}'")
|
||||
|
||||
db.commit()
|
||||
print(f"Successfully updated {updated_count} total records")
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f"Error updating stage_focus values: {e}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Run the update
|
||||
if __name__ == "__main__":
|
||||
update_stage_focus_values()
|
||||
Binary file not shown.
+6
-6
@@ -9,12 +9,12 @@ from db.db import Base
|
||||
|
||||
|
||||
class InvestmentStage(enum.Enum):
|
||||
SEED = "seed"
|
||||
SERIES_A = "series_a"
|
||||
SERIES_B = "series_b"
|
||||
SERIES_C = "series_c"
|
||||
GROWTH = "growth"
|
||||
LATE_STAGE = "late_stage"
|
||||
SEED = "SEED"
|
||||
SERIES_A = "SERIES_A"
|
||||
SERIES_B = "SERIES_B"
|
||||
SERIES_C = "SERIES_C"
|
||||
GROWTH = "GROWTH"
|
||||
LATE_STAGE = "LATE_STAGE"
|
||||
|
||||
|
||||
# Association table for many-to-many relationship between investors and companies
|
||||
|
||||
+12
-10
@@ -1,16 +1,17 @@
|
||||
from pydantic import BaseModel
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class InvestmentStage(str, Enum):
|
||||
SEED = "seed"
|
||||
SERIES_A = "series_a"
|
||||
SERIES_B = "series_b"
|
||||
SERIES_C = "series_c"
|
||||
GROWTH = "growth"
|
||||
LATE_STAGE = "late_stage"
|
||||
SEED = "SEED"
|
||||
SERIES_A = "SERIES_A"
|
||||
SERIES_B = "SERIES_B"
|
||||
SERIES_C = "SERIES_C"
|
||||
GROWTH = "GROWTH"
|
||||
LATE_STAGE = "LATE_STAGE"
|
||||
|
||||
|
||||
class SectorSchema(BaseModel):
|
||||
@@ -64,6 +65,7 @@ class InvestorSchema(BaseModel):
|
||||
|
||||
class InvestorData(BaseModel):
|
||||
"""Comprehensive investor data schema for LLM processing"""
|
||||
|
||||
investor: InvestorSchema
|
||||
portfolio_companies: List[CompanySchema] = []
|
||||
team_members: List[InvestorTeamMemberSchema] = []
|
||||
@@ -71,7 +73,7 @@ class InvestorData(BaseModel):
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
|
||||
class InvestorList(BaseModel):
|
||||
investors: List[InvestorData]
|
||||
investors: List[InvestorData]
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Reference in New Issue
Block a user