feat: Refactor Fund schema to use many-to-many relationships for investment stages and sectors
- Updated FundTable to replace JSON fields for investment stages and sectors with relationships. - Introduced InvestmentStageTable and fund_investment_stages association table. - Created fund_sectors association table for many-to-many relationship with sectors. - Changed geographic_focus from JSON array to a simple string. - Migrated existing data to new schema, ensuring data integrity and normalization. - Updated related schemas, routers, and services to reflect new structure. - Added migration script to handle data transformation and schema updates. - Implemented tests to verify new relationships and data integrity.
This commit is contained in:
@@ -9,6 +9,7 @@ from db.models import (
|
||||
CompanyMember,
|
||||
CompanyTable,
|
||||
FundTable,
|
||||
InvestmentStageTable,
|
||||
InvestorMember,
|
||||
InvestorTable,
|
||||
SectorTable,
|
||||
@@ -223,11 +224,16 @@ Return the lower and upper bounds in USD."""
|
||||
"check_size_upper": None,
|
||||
"source_url": fund.get("sourceUrl"),
|
||||
"source_provider": fund.get("sourceProvider"),
|
||||
"geographic_focus": fund.get("geographicFocus", []),
|
||||
"investment_stage_focus": fund.get("investmentStageFocus", []),
|
||||
"sector_focus": fund.get("sectorFocus", []),
|
||||
"geographic_focus": None, # Will be converted to string
|
||||
"investment_stage_names": fund.get("investmentStageFocus", []),
|
||||
"sector_names": fund.get("sectorFocus", []),
|
||||
}
|
||||
|
||||
# Convert geographic focus from array to comma-separated string
|
||||
geo_focus = fund.get("geographicFocus", [])
|
||||
if geo_focus and isinstance(geo_focus, list):
|
||||
fund_data["geographic_focus"] = ", ".join(geo_focus)
|
||||
|
||||
# Convert fund size to USD integer
|
||||
fund_size_str = fund.get("fundSize")
|
||||
if fund_size_str and fund_size_str != "Not Available":
|
||||
@@ -499,15 +505,24 @@ Return the lower and upper bounds in USD."""
|
||||
fund_name=fund_data.get("fund_name"),
|
||||
fund_size=fund_data.get("fund_size"), # Now an integer
|
||||
fund_size_source_url=fund_data.get("fund_size_source_url"),
|
||||
check_size_lower=fund_data.get("check_size_lower"), # NEW
|
||||
check_size_upper=fund_data.get("check_size_upper"), # NEW
|
||||
check_size_lower=fund_data.get("check_size_lower"),
|
||||
check_size_upper=fund_data.get("check_size_upper"),
|
||||
source_url=fund_data.get("source_url"),
|
||||
source_provider=fund_data.get("source_provider"),
|
||||
geographic_focus=fund_data.get("geographic_focus"),
|
||||
investment_stage_focus=fund_data.get("investment_stage_focus"),
|
||||
sector_focus=fund_data.get("sector_focus"),
|
||||
geographic_focus=fund_data.get("geographic_focus"), # Now a string
|
||||
)
|
||||
db.add(fund)
|
||||
db.flush() # Get the fund ID
|
||||
|
||||
# Add investment stages (many-to-many)
|
||||
for stage_name in fund_data.get("investment_stage_names", []):
|
||||
stage = self._get_or_create_investment_stage(db, stage_name)
|
||||
fund.investment_stages.append(stage)
|
||||
|
||||
# Add sectors (many-to-many)
|
||||
for sector_name in fund_data.get("sector_names", []):
|
||||
sector = self._get_or_create_sector(db, sector_name)
|
||||
fund.sectors.append(sector)
|
||||
|
||||
return investor
|
||||
|
||||
@@ -516,6 +531,23 @@ Return the lower and upper bounds in USD."""
|
||||
db.rollback()
|
||||
return None
|
||||
|
||||
def _get_or_create_investment_stage(
|
||||
self, db: Session, stage_name: str
|
||||
) -> InvestmentStageTable:
|
||||
"""Get existing investment stage or create new one"""
|
||||
from db.models import InvestmentStageTable
|
||||
|
||||
stage = (
|
||||
db.query(InvestmentStageTable)
|
||||
.filter(InvestmentStageTable.name == stage_name)
|
||||
.first()
|
||||
)
|
||||
if not stage:
|
||||
stage = InvestmentStageTable(name=stage_name)
|
||||
db.add(stage)
|
||||
db.flush() # Get the ID without committing
|
||||
return stage
|
||||
|
||||
def _get_or_create_sector(self, db: Session, sector_name: str) -> SectorTable:
|
||||
"""Get existing sector or create new one"""
|
||||
sector = db.query(SectorTable).filter(SectorTable.name == sector_name).first()
|
||||
|
||||
Reference in New Issue
Block a user