feat: Refactor Fund schema to use many-to-many relationships for investment stages and sectors
- Updated FundTable to replace JSON fields for investment stages and sectors with relationships. - Introduced InvestmentStageTable and fund_investment_stages association table. - Created fund_sectors association table for many-to-many relationship with sectors. - Changed geographic_focus from JSON array to a simple string. - Migrated existing data to new schema, ensuring data integrity and normalization. - Updated related schemas, routers, and services to reflect new structure. - Added migration script to handle data transformation and schema updates. - Implemented tests to verify new relationships and data integrity.
This commit is contained in:
+48
-7
@@ -70,6 +70,22 @@ project_company_association = Table(
|
||||
Column("company_id", Integer, ForeignKey("companies.id")),
|
||||
)
|
||||
|
||||
# Association table for fund-stage many-to-many
|
||||
fund_investment_stages_association = Table(
|
||||
"fund_investment_stages",
|
||||
Base.metadata,
|
||||
Column("fund_id", Integer, ForeignKey("funds.id")),
|
||||
Column("stage_id", Integer, ForeignKey("investment_stages.id")),
|
||||
)
|
||||
|
||||
# Association table for fund-sector many-to-many
|
||||
fund_sectors_association = Table(
|
||||
"fund_sectors",
|
||||
Base.metadata,
|
||||
Column("fund_id", Integer, ForeignKey("funds.id")),
|
||||
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
||||
)
|
||||
|
||||
|
||||
class InvestorTable(Base, TimestampMixin):
|
||||
__tablename__ = "investors"
|
||||
@@ -172,13 +188,21 @@ class FundTable(Base, TimestampMixin):
|
||||
source_url = Column(String, nullable=True)
|
||||
source_provider = Column(String, nullable=True) # e.g., "Perplexity"
|
||||
|
||||
# JSON array fields
|
||||
geographic_focus = Column(JSON, nullable=True) # Array of regions/countries
|
||||
investment_stage_focus = Column(JSON, nullable=True) # Array of stages
|
||||
sector_focus = Column(JSON, nullable=True) # Array of sectors
|
||||
# Geographic focus as simple string
|
||||
geographic_focus = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
investor = relationship("InvestorTable", back_populates="funds")
|
||||
investment_stages = relationship(
|
||||
"InvestmentStageTable",
|
||||
secondary=fund_investment_stages_association,
|
||||
back_populates="funds",
|
||||
)
|
||||
sectors = relationship(
|
||||
"SectorTable",
|
||||
secondary=fund_sectors_association,
|
||||
back_populates="funds",
|
||||
)
|
||||
|
||||
|
||||
class CompanyTable(Base, TimestampMixin):
|
||||
@@ -224,26 +248,43 @@ class CompanyMember(Base, TimestampMixin):
|
||||
company = relationship("CompanyTable", back_populates="members")
|
||||
|
||||
|
||||
class InvestmentStageTable(Base, TimestampMixin):
|
||||
__tablename__ = "investment_stages"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
|
||||
# Relationships
|
||||
funds = relationship(
|
||||
"FundTable",
|
||||
secondary=fund_investment_stages_association,
|
||||
back_populates="investment_stages",
|
||||
)
|
||||
|
||||
|
||||
class SectorTable(Base, TimestampMixin):
|
||||
__tablename__ = "sectors"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
|
||||
# Add relationship back to investors
|
||||
# Relationships
|
||||
investors = relationship(
|
||||
"InvestorTable",
|
||||
secondary=investor_sector_association,
|
||||
back_populates="sectors",
|
||||
)
|
||||
|
||||
companies = relationship(
|
||||
"CompanyTable", secondary=company_sector_association, back_populates="sectors"
|
||||
)
|
||||
|
||||
projects = relationship(
|
||||
"ProjectTable", secondary=project_sector_association, back_populates="sector"
|
||||
)
|
||||
funds = relationship(
|
||||
"FundTable",
|
||||
secondary=fund_sectors_association,
|
||||
back_populates="sectors",
|
||||
)
|
||||
|
||||
|
||||
class ProjectTable(Base, TimestampMixin):
|
||||
|
||||
+12
-12
@@ -82,8 +82,8 @@ def read_investors(db: Session = Depends(get_db)):
|
||||
check_size_lower=fund.check_size_lower,
|
||||
check_size_upper=fund.check_size_upper,
|
||||
geographic_focus=fund.geographic_focus,
|
||||
investment_stage_focus=fund.investment_stage_focus,
|
||||
sector_focus=fund.sector_focus,
|
||||
fund_investment_stages=fund.investment_stages, # Now a relationship
|
||||
fund_sectors=fund.sectors, # Now a relationship
|
||||
# Related data (same for all funds of this investor)
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
@@ -113,8 +113,8 @@ def read_investors(db: Session = Depends(get_db)):
|
||||
check_size_lower=None,
|
||||
check_size_upper=None,
|
||||
geographic_focus=None,
|
||||
investment_stage_focus=None,
|
||||
sector_focus=None,
|
||||
fund_investment_stages=None,
|
||||
fund_sectors=None,
|
||||
# Related data
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
@@ -208,8 +208,8 @@ def filter_investors(
|
||||
check_size_lower=fund.check_size_lower,
|
||||
check_size_upper=fund.check_size_upper,
|
||||
geographic_focus=fund.geographic_focus,
|
||||
investment_stage_focus=fund.investment_stage_focus,
|
||||
sector_focus=fund.sector_focus,
|
||||
fund_investment_stages=fund.investment_stages, # Now a relationship
|
||||
fund_sectors=fund.sectors, # Now a relationship
|
||||
# Related data
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
@@ -239,8 +239,8 @@ def filter_investors(
|
||||
check_size_lower=None,
|
||||
check_size_upper=None,
|
||||
geographic_focus=None,
|
||||
investment_stage_focus=None,
|
||||
sector_focus=None,
|
||||
fund_investment_stages=None,
|
||||
fund_sectors=None,
|
||||
# Related data
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
@@ -502,8 +502,8 @@ def find_similar_investors(
|
||||
check_size_lower=fund.check_size_lower,
|
||||
check_size_upper=fund.check_size_upper,
|
||||
geographic_focus=fund.geographic_focus,
|
||||
investment_stage_focus=fund.investment_stage_focus,
|
||||
sector_focus=fund.sector_focus,
|
||||
fund_investment_stages=fund.investment_stages, # Now a relationship
|
||||
fund_sectors=fund.sectors, # Now a relationship
|
||||
# Related data
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
@@ -533,8 +533,8 @@ def find_similar_investors(
|
||||
check_size_lower=None,
|
||||
check_size_upper=None,
|
||||
geographic_focus=None,
|
||||
investment_stage_focus=None,
|
||||
sector_focus=None,
|
||||
fund_investment_stages=None,
|
||||
fund_sectors=None,
|
||||
# Related data
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
|
||||
@@ -22,6 +22,14 @@ class SectorSchema(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestmentStageSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorMemberSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
@@ -41,9 +49,9 @@ class FundSchema(BaseModel):
|
||||
check_size_upper: int | None # NEW: Upper bound of check size range
|
||||
source_url: str | None
|
||||
source_provider: str | None
|
||||
geographic_focus: List[str] | None
|
||||
investment_stage_focus: List[str] | None
|
||||
sector_focus: List[str] | None
|
||||
geographic_focus: str | None # Changed from List[str] to string
|
||||
investment_stages: List[InvestmentStageSchema] | None # Changed to relationship
|
||||
sectors: List[SectorSchema] | None # Changed to relationship
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
@@ -134,9 +142,11 @@ class InvestorFundData(BaseModel):
|
||||
fund_size_source_url: str | None
|
||||
check_size_lower: int | None # NEW: Lower bound of check size range
|
||||
check_size_upper: int | None # NEW: Upper bound of check size range
|
||||
geographic_focus: List[str] | None
|
||||
investment_stage_focus: List[str] | None
|
||||
sector_focus: List[str] | None
|
||||
geographic_focus: str | None # Changed from List[str] to string
|
||||
fund_investment_stages: (
|
||||
List[InvestmentStageSchema] | None
|
||||
) # Changed to relationship
|
||||
fund_sectors: List[SectorSchema] | None # Changed to relationship
|
||||
|
||||
# Related data
|
||||
portfolio_companies: List[CompanySchema]
|
||||
|
||||
@@ -9,6 +9,7 @@ from db.models import (
|
||||
CompanyMember,
|
||||
CompanyTable,
|
||||
FundTable,
|
||||
InvestmentStageTable,
|
||||
InvestorMember,
|
||||
InvestorTable,
|
||||
SectorTable,
|
||||
@@ -223,11 +224,16 @@ Return the lower and upper bounds in USD."""
|
||||
"check_size_upper": None,
|
||||
"source_url": fund.get("sourceUrl"),
|
||||
"source_provider": fund.get("sourceProvider"),
|
||||
"geographic_focus": fund.get("geographicFocus", []),
|
||||
"investment_stage_focus": fund.get("investmentStageFocus", []),
|
||||
"sector_focus": fund.get("sectorFocus", []),
|
||||
"geographic_focus": None, # Will be converted to string
|
||||
"investment_stage_names": fund.get("investmentStageFocus", []),
|
||||
"sector_names": fund.get("sectorFocus", []),
|
||||
}
|
||||
|
||||
# Convert geographic focus from array to comma-separated string
|
||||
geo_focus = fund.get("geographicFocus", [])
|
||||
if geo_focus and isinstance(geo_focus, list):
|
||||
fund_data["geographic_focus"] = ", ".join(geo_focus)
|
||||
|
||||
# Convert fund size to USD integer
|
||||
fund_size_str = fund.get("fundSize")
|
||||
if fund_size_str and fund_size_str != "Not Available":
|
||||
@@ -499,15 +505,24 @@ Return the lower and upper bounds in USD."""
|
||||
fund_name=fund_data.get("fund_name"),
|
||||
fund_size=fund_data.get("fund_size"), # Now an integer
|
||||
fund_size_source_url=fund_data.get("fund_size_source_url"),
|
||||
check_size_lower=fund_data.get("check_size_lower"), # NEW
|
||||
check_size_upper=fund_data.get("check_size_upper"), # NEW
|
||||
check_size_lower=fund_data.get("check_size_lower"),
|
||||
check_size_upper=fund_data.get("check_size_upper"),
|
||||
source_url=fund_data.get("source_url"),
|
||||
source_provider=fund_data.get("source_provider"),
|
||||
geographic_focus=fund_data.get("geographic_focus"),
|
||||
investment_stage_focus=fund_data.get("investment_stage_focus"),
|
||||
sector_focus=fund_data.get("sector_focus"),
|
||||
geographic_focus=fund_data.get("geographic_focus"), # Now a string
|
||||
)
|
||||
db.add(fund)
|
||||
db.flush() # Get the fund ID
|
||||
|
||||
# Add investment stages (many-to-many)
|
||||
for stage_name in fund_data.get("investment_stage_names", []):
|
||||
stage = self._get_or_create_investment_stage(db, stage_name)
|
||||
fund.investment_stages.append(stage)
|
||||
|
||||
# Add sectors (many-to-many)
|
||||
for sector_name in fund_data.get("sector_names", []):
|
||||
sector = self._get_or_create_sector(db, sector_name)
|
||||
fund.sectors.append(sector)
|
||||
|
||||
return investor
|
||||
|
||||
@@ -516,6 +531,23 @@ Return the lower and upper bounds in USD."""
|
||||
db.rollback()
|
||||
return None
|
||||
|
||||
def _get_or_create_investment_stage(
|
||||
self, db: Session, stage_name: str
|
||||
) -> InvestmentStageTable:
|
||||
"""Get existing investment stage or create new one"""
|
||||
from db.models import InvestmentStageTable
|
||||
|
||||
stage = (
|
||||
db.query(InvestmentStageTable)
|
||||
.filter(InvestmentStageTable.name == stage_name)
|
||||
.first()
|
||||
)
|
||||
if not stage:
|
||||
stage = InvestmentStageTable(name=stage_name)
|
||||
db.add(stage)
|
||||
db.flush() # Get the ID without committing
|
||||
return stage
|
||||
|
||||
def _get_or_create_sector(self, db: Session, sector_name: str) -> SectorTable:
|
||||
"""Get existing sector or create new one"""
|
||||
sector = db.query(SectorTable).filter(SectorTable.name == sector_name).first()
|
||||
|
||||
Reference in New Issue
Block a user