feat: Refactor Fund schema to use many-to-many relationships for investment stages and sectors

- Updated FundTable to replace JSON fields for investment stages and sectors with relationships.
- Introduced InvestmentStageTable and fund_investment_stages association table.
- Created fund_sectors association table for many-to-many relationship with sectors.
- Changed geographic_focus from JSON array to a simple string.
- Migrated existing data to new schema, ensuring data integrity and normalization.
- Updated related schemas, routers, and services to reflect new structure.
- Added migration script to handle data transformation and schema updates.
- Implemented tests to verify new relationships and data integrity.
This commit is contained in:
bolade
2025-10-07 15:57:29 +01:00
parent d341cacb9a
commit a9589e54f3
10 changed files with 1134 additions and 42 deletions
+48 -7
View File
@@ -70,6 +70,22 @@ project_company_association = Table(
Column("company_id", Integer, ForeignKey("companies.id")),
)
# Association table for fund-stage many-to-many
fund_investment_stages_association = Table(
"fund_investment_stages",
Base.metadata,
Column("fund_id", Integer, ForeignKey("funds.id")),
Column("stage_id", Integer, ForeignKey("investment_stages.id")),
)
# Association table for fund-sector many-to-many
fund_sectors_association = Table(
"fund_sectors",
Base.metadata,
Column("fund_id", Integer, ForeignKey("funds.id")),
Column("sector_id", Integer, ForeignKey("sectors.id")),
)
class InvestorTable(Base, TimestampMixin):
__tablename__ = "investors"
@@ -172,13 +188,21 @@ class FundTable(Base, TimestampMixin):
source_url = Column(String, nullable=True)
source_provider = Column(String, nullable=True) # e.g., "Perplexity"
# JSON array fields
geographic_focus = Column(JSON, nullable=True) # Array of regions/countries
investment_stage_focus = Column(JSON, nullable=True) # Array of stages
sector_focus = Column(JSON, nullable=True) # Array of sectors
# Geographic focus as simple string
geographic_focus = Column(String, nullable=True)
# Relationships
investor = relationship("InvestorTable", back_populates="funds")
investment_stages = relationship(
"InvestmentStageTable",
secondary=fund_investment_stages_association,
back_populates="funds",
)
sectors = relationship(
"SectorTable",
secondary=fund_sectors_association,
back_populates="funds",
)
class CompanyTable(Base, TimestampMixin):
@@ -224,26 +248,43 @@ class CompanyMember(Base, TimestampMixin):
company = relationship("CompanyTable", back_populates="members")
class InvestmentStageTable(Base, TimestampMixin):
__tablename__ = "investment_stages"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False, unique=True)
# Relationships
funds = relationship(
"FundTable",
secondary=fund_investment_stages_association,
back_populates="investment_stages",
)
class SectorTable(Base, TimestampMixin):
__tablename__ = "sectors"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
# Add relationship back to investors
# Relationships
investors = relationship(
"InvestorTable",
secondary=investor_sector_association,
back_populates="sectors",
)
companies = relationship(
"CompanyTable", secondary=company_sector_association, back_populates="sectors"
)
projects = relationship(
"ProjectTable", secondary=project_sector_association, back_populates="sector"
)
funds = relationship(
"FundTable",
secondary=fund_sectors_association,
back_populates="sectors",
)
class ProjectTable(Base, TimestampMixin):
+12 -12
View File
@@ -82,8 +82,8 @@ def read_investors(db: Session = Depends(get_db)):
check_size_lower=fund.check_size_lower,
check_size_upper=fund.check_size_upper,
geographic_focus=fund.geographic_focus,
investment_stage_focus=fund.investment_stage_focus,
sector_focus=fund.sector_focus,
fund_investment_stages=fund.investment_stages, # Now a relationship
fund_sectors=fund.sectors, # Now a relationship
# Related data (same for all funds of this investor)
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members,
@@ -113,8 +113,8 @@ def read_investors(db: Session = Depends(get_db)):
check_size_lower=None,
check_size_upper=None,
geographic_focus=None,
investment_stage_focus=None,
sector_focus=None,
fund_investment_stages=None,
fund_sectors=None,
# Related data
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members,
@@ -208,8 +208,8 @@ def filter_investors(
check_size_lower=fund.check_size_lower,
check_size_upper=fund.check_size_upper,
geographic_focus=fund.geographic_focus,
investment_stage_focus=fund.investment_stage_focus,
sector_focus=fund.sector_focus,
fund_investment_stages=fund.investment_stages, # Now a relationship
fund_sectors=fund.sectors, # Now a relationship
# Related data
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members,
@@ -239,8 +239,8 @@ def filter_investors(
check_size_lower=None,
check_size_upper=None,
geographic_focus=None,
investment_stage_focus=None,
sector_focus=None,
fund_investment_stages=None,
fund_sectors=None,
# Related data
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members,
@@ -502,8 +502,8 @@ def find_similar_investors(
check_size_lower=fund.check_size_lower,
check_size_upper=fund.check_size_upper,
geographic_focus=fund.geographic_focus,
investment_stage_focus=fund.investment_stage_focus,
sector_focus=fund.sector_focus,
fund_investment_stages=fund.investment_stages, # Now a relationship
fund_sectors=fund.sectors, # Now a relationship
# Related data
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members,
@@ -533,8 +533,8 @@ def find_similar_investors(
check_size_lower=None,
check_size_upper=None,
geographic_focus=None,
investment_stage_focus=None,
sector_focus=None,
fund_investment_stages=None,
fund_sectors=None,
# Related data
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members,
+16 -6
View File
@@ -22,6 +22,14 @@ class SectorSchema(BaseModel):
from_attributes = True
class InvestmentStageSchema(BaseModel):
id: int
name: str
class Config:
from_attributes = True
class InvestorMemberSchema(BaseModel):
id: int
name: str
@@ -41,9 +49,9 @@ class FundSchema(BaseModel):
check_size_upper: int | None # NEW: Upper bound of check size range
source_url: str | None
source_provider: str | None
geographic_focus: List[str] | None
investment_stage_focus: List[str] | None
sector_focus: List[str] | None
geographic_focus: str | None # Changed from List[str] to string
investment_stages: List[InvestmentStageSchema] | None # Changed to relationship
sectors: List[SectorSchema] | None # Changed to relationship
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
@@ -134,9 +142,11 @@ class InvestorFundData(BaseModel):
fund_size_source_url: str | None
check_size_lower: int | None # NEW: Lower bound of check size range
check_size_upper: int | None # NEW: Upper bound of check size range
geographic_focus: List[str] | None
investment_stage_focus: List[str] | None
sector_focus: List[str] | None
geographic_focus: str | None # Changed from List[str] to string
fund_investment_stages: (
List[InvestmentStageSchema] | None
) # Changed to relationship
fund_sectors: List[SectorSchema] | None # Changed to relationship
# Related data
portfolio_companies: List[CompanySchema]
+40 -8
View File
@@ -9,6 +9,7 @@ from db.models import (
CompanyMember,
CompanyTable,
FundTable,
InvestmentStageTable,
InvestorMember,
InvestorTable,
SectorTable,
@@ -223,11 +224,16 @@ Return the lower and upper bounds in USD."""
"check_size_upper": None,
"source_url": fund.get("sourceUrl"),
"source_provider": fund.get("sourceProvider"),
"geographic_focus": fund.get("geographicFocus", []),
"investment_stage_focus": fund.get("investmentStageFocus", []),
"sector_focus": fund.get("sectorFocus", []),
"geographic_focus": None, # Will be converted to string
"investment_stage_names": fund.get("investmentStageFocus", []),
"sector_names": fund.get("sectorFocus", []),
}
# Convert geographic focus from array to comma-separated string
geo_focus = fund.get("geographicFocus", [])
if geo_focus and isinstance(geo_focus, list):
fund_data["geographic_focus"] = ", ".join(geo_focus)
# Convert fund size to USD integer
fund_size_str = fund.get("fundSize")
if fund_size_str and fund_size_str != "Not Available":
@@ -499,15 +505,24 @@ Return the lower and upper bounds in USD."""
fund_name=fund_data.get("fund_name"),
fund_size=fund_data.get("fund_size"), # Now an integer
fund_size_source_url=fund_data.get("fund_size_source_url"),
check_size_lower=fund_data.get("check_size_lower"), # NEW
check_size_upper=fund_data.get("check_size_upper"), # NEW
check_size_lower=fund_data.get("check_size_lower"),
check_size_upper=fund_data.get("check_size_upper"),
source_url=fund_data.get("source_url"),
source_provider=fund_data.get("source_provider"),
geographic_focus=fund_data.get("geographic_focus"),
investment_stage_focus=fund_data.get("investment_stage_focus"),
sector_focus=fund_data.get("sector_focus"),
geographic_focus=fund_data.get("geographic_focus"), # Now a string
)
db.add(fund)
db.flush() # Get the fund ID
# Add investment stages (many-to-many)
for stage_name in fund_data.get("investment_stage_names", []):
stage = self._get_or_create_investment_stage(db, stage_name)
fund.investment_stages.append(stage)
# Add sectors (many-to-many)
for sector_name in fund_data.get("sector_names", []):
sector = self._get_or_create_sector(db, sector_name)
fund.sectors.append(sector)
return investor
@@ -516,6 +531,23 @@ Return the lower and upper bounds in USD."""
db.rollback()
return None
def _get_or_create_investment_stage(
self, db: Session, stage_name: str
) -> InvestmentStageTable:
"""Get existing investment stage or create new one"""
from db.models import InvestmentStageTable
stage = (
db.query(InvestmentStageTable)
.filter(InvestmentStageTable.name == stage_name)
.first()
)
if not stage:
stage = InvestmentStageTable(name=stage_name)
db.add(stage)
db.flush() # Get the ID without committing
return stage
def _get_or_create_sector(self, db: Session, sector_name: str) -> SectorTable:
"""Get existing sector or create new one"""
sector = db.query(SectorTable).filter(SectorTable.name == sector_name).first()