diff --git a/FUND_RELATIONSHIP_UPDATE.md b/FUND_RELATIONSHIP_UPDATE.md new file mode 100644 index 0000000..f729133 --- /dev/null +++ b/FUND_RELATIONSHIP_UPDATE.md @@ -0,0 +1,604 @@ +# Fund Relationship Schema Update + +## Summary of Changes + +### Database Schema Changes + +**FundTable Updated:** + +1. `geographic_focus`: Changed from `JSON` array to `STRING` (comma-separated values) +2. `investment_stage_focus`: **REMOVED** - replaced with many-to-many relationship +3. `sector_focus`: **REMOVED** - replaced with many-to-many relationship + +**New Tables:** + +1. `investment_stages` - Stores investment stage names (replaces enum) +2. `fund_investment_stages` - Association table for fund ↔ stage many-to-many +3. `fund_sectors` - Association table for fund ↔ sector many-to-many + +### Why These Changes? + +#### 1. Geographic Focus: JSON → String + +- **Before**: `["Europe", "North America", "Asia"]` +- **After**: `"Europe, North America, Asia"` +- **Reason**: Simpler to display, easier to search with `LIKE` queries + +#### 2. Investment Stages: JSON → Many-to-Many Relationship + +- **Before**: JSON array stored in fund table +- **After**: Proper many-to-many relationship via association table +- **Benefits**: + - Can filter funds by specific stages efficiently + - Can join stages across multiple funds + - Centralized stage management + - Better data normalization + +#### 3. Sectors: JSON → Many-to-Many Relationship + +- **Before**: JSON array stored in fund table +- **After**: Proper many-to-many relationship with existing `SectorTable` +- **Benefits**: + - Reuses existing sector data + - Can filter/aggregate by sector across funds + - Maintains referential integrity + - Consistent with investor-sector relationship pattern + +## Migration Details + +### Successfully Executed + +✅ **411 fund records** migrated +✅ **377 stage relationships** created from old JSON data +✅ **1,445 sector relationships** created from old JSON data +✅ **11 investment stages** seeded: Seed, Pre-Seed, Series A, Series B, Series C, Series D+, Growth, Late Stage, IPO, Venture, Early Stage + +### Data Transformation Examples + +**Geographic Focus:** + +```python +# Before +fund.geographic_focus = ["Europe", "North America"] # JSON + +# After +fund.geographic_focus = "Europe, North America" # String +``` + +**Investment Stages:** + +```python +# Before +fund.investment_stage_focus = ["Seed", "Series A"] # JSON + +# After +fund.investment_stages = [ + InvestmentStageTable(id=1, name="Seed"), + InvestmentStageTable(id=3, name="Series A") +] # Relationship +``` + +**Sectors:** + +```python +# Before +fund.sector_focus = ["Fintech", "Healthcare"] # JSON + +# After +fund.sectors = [ + SectorTable(id=5, name="Fintech"), + SectorTable(id=12, name="Healthcare") +] # Relationship +``` + +## Database Schema + +### Investment Stages Table + +```sql +CREATE TABLE investment_stages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name VARCHAR NOT NULL UNIQUE, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME +); +``` + +### Fund Investment Stages Association + +```sql +CREATE TABLE fund_investment_stages ( + fund_id INTEGER NOT NULL, + stage_id INTEGER NOT NULL, + PRIMARY KEY (fund_id, stage_id), + FOREIGN KEY (fund_id) REFERENCES funds (id) ON DELETE CASCADE, + FOREIGN KEY (stage_id) REFERENCES investment_stages (id) ON DELETE CASCADE +); +``` + +### Fund Sectors Association + +```sql +CREATE TABLE fund_sectors ( + fund_id INTEGER NOT NULL, + sector_id INTEGER NOT NULL, + PRIMARY KEY (fund_id, sector_id), + FOREIGN KEY (fund_id) REFERENCES funds (id) ON DELETE CASCADE, + FOREIGN KEY (sector_id) REFERENCES sectors (id) ON DELETE CASCADE +); +``` + +### Updated Funds Table + +```sql +CREATE TABLE funds ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + investor_id INTEGER NOT NULL, + fund_name VARCHAR, + fund_size INTEGER, + fund_size_source_url VARCHAR, + check_size_lower INTEGER, + check_size_upper INTEGER, + source_url VARCHAR, + source_provider VARCHAR, + geographic_focus VARCHAR, -- Changed from JSON to VARCHAR + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME, + FOREIGN KEY (investor_id) REFERENCES investors (id) +); +``` + +## Code Changes + +### 1. Models (Both app/db/models.py and preprocessor/models.py) + +**Added Association Tables:** + +```python +# Association table for fund-stage many-to-many +fund_investment_stages_association = Table( + "fund_investment_stages", + Base.metadata, + Column("fund_id", Integer, ForeignKey("funds.id")), + Column("stage_id", Integer, ForeignKey("investment_stages.id")), +) + +# Association table for fund-sector many-to-many +fund_sectors_association = Table( + "fund_sectors", + Base.metadata, + Column("fund_id", Integer, ForeignKey("funds.id")), + Column("sector_id", Integer, ForeignKey("sectors.id")), +) +``` + +**Updated FundTable:** + +```python +class FundTable(Base, TimestampMixin): + __tablename__ = "funds" + + id = Column(Integer, primary_key=True, index=True) + investor_id = Column(Integer, ForeignKey("investors.id"), nullable=False) + + # Fund details + fund_name = Column(String, nullable=True) + fund_size = Column(Integer, nullable=True) + fund_size_source_url = Column(String, nullable=True) + check_size_lower = Column(Integer, nullable=True) + check_size_upper = Column(Integer, nullable=True) + source_url = Column(String, nullable=True) + source_provider = Column(String, nullable=True) + + # Geographic focus as simple string + geographic_focus = Column(String, nullable=True) + + # Relationships + investor = relationship("InvestorTable", back_populates="funds") + investment_stages = relationship( + "InvestmentStageTable", + secondary=fund_investment_stages_association, + back_populates="funds", + ) + sectors = relationship( + "SectorTable", + secondary=fund_sectors_association, + back_populates="funds", + ) +``` + +**New InvestmentStageTable:** + +```python +class InvestmentStageTable(Base, TimestampMixin): + __tablename__ = "investment_stages" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String, nullable=False, unique=True) + + # Relationships + funds = relationship( + "FundTable", + secondary=fund_investment_stages_association, + back_populates="investment_stages", + ) +``` + +**Updated SectorTable:** + +```python +class SectorTable(Base, TimestampMixin): + __tablename__ = "sectors" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String, nullable=False) + + # Relationships + investors = relationship(...) + companies = relationship(...) + projects = relationship(...) + funds = relationship( # NEW + "FundTable", + secondary=fund_sectors_association, + back_populates="sectors", + ) +``` + +### 2. Router Schemas (app/schemas/router_schemas.py) + +**New InvestmentStageSchema:** + +```python +class InvestmentStageSchema(BaseModel): + id: int + name: str + + class Config: + from_attributes = True +``` + +**Updated FundSchema:** + +```python +class FundSchema(BaseModel): + id: int + fund_name: str | None + fund_size: int | None + fund_size_source_url: str | None + check_size_lower: int | None + check_size_upper: int | None + source_url: str | None + source_provider: str | None + geographic_focus: str | None # Changed from List[str] + investment_stages: List[InvestmentStageSchema] | None # Changed from List[str] + sectors: List[SectorSchema] | None # Changed from List[str] + created_at: Optional[datetime] = None + updated_at: Optional[datetime] = None + + class Config: + from_attributes = True +``` + +**Updated InvestorFundData:** + +```python +class InvestorFundData(BaseModel): + # ... investor fields ... + + # Fund fields + fund_id: int | None + fund_name: str | None + fund_size: int | None + fund_size_source_url: str | None + check_size_lower: int | None + check_size_upper: int | None + geographic_focus: str | None # Changed from List[str] + fund_investment_stages: List[InvestmentStageSchema] | None # NEW name + fund_sectors: List[SectorSchema] | None # NEW name + + # ... related data ... +``` + +### 3. LLM Parser (app/services/llm_parser.py) + +**Updated Fund Processing:** + +```python +# Process funds +funds = profile.get("funds", []) +for fund in funds: + if isinstance(fund, dict): + fund_data = { + "fund_name": fund.get("fundName"), + "fund_size": None, + "fund_size_source_url": fund.get("fundSizeSourceUrl"), + "check_size_lower": None, + "check_size_upper": None, + "source_url": fund.get("sourceUrl"), + "source_provider": fund.get("sourceProvider"), + "geographic_focus": None, # Will be converted to string + "investment_stage_names": fund.get("investmentStageFocus", []), + "sector_names": fund.get("sectorFocus", []), + } + + # Convert geographic focus from array to comma-separated string + geo_focus = fund.get("geographicFocus", []) + if geo_focus and isinstance(geo_focus, list): + fund_data["geographic_focus"] = ", ".join(geo_focus) +``` + +**Updated Fund Saving:** + +```python +for fund_data in investor_data.get("funds", []): + fund = FundTable( + investor_id=investor.id, + fund_name=fund_data.get("fund_name"), + fund_size=fund_data.get("fund_size"), + fund_size_source_url=fund_data.get("fund_size_source_url"), + check_size_lower=fund_data.get("check_size_lower"), + check_size_upper=fund_data.get("check_size_upper"), + source_url=fund_data.get("source_url"), + source_provider=fund_data.get("source_provider"), + geographic_focus=fund_data.get("geographic_focus"), # String + ) + db.add(fund) + db.flush() # Get the fund ID + + # Add investment stages (many-to-many) + for stage_name in fund_data.get("investment_stage_names", []): + stage = self._get_or_create_investment_stage(db, stage_name) + fund.investment_stages.append(stage) + + # Add sectors (many-to-many) + for sector_name in fund_data.get("sector_names", []): + sector = self._get_or_create_sector(db, sector_name) + fund.sectors.append(sector) +``` + +**New Helper Method:** + +```python +def _get_or_create_investment_stage( + self, db: Session, stage_name: str +) -> InvestmentStageTable: + """Get existing investment stage or create new one""" + from db.models import InvestmentStageTable + + stage = ( + db.query(InvestmentStageTable) + .filter(InvestmentStageTable.name == stage_name) + .first() + ) + if not stage: + stage = InvestmentStageTable(name=stage_name) + db.add(stage) + db.flush() + return stage +``` + +### 4. Router (app/routers/investors.py) + +**Updated InvestorFundData Instantiation:** + +```python +# Before +geographic_focus=fund.geographic_focus, # Was List[str] +investment_stage_focus=fund.investment_stage_focus, # Was List[str] +sector_focus=fund.sector_focus, # Was List[str] + +# After +geographic_focus=fund.geographic_focus, # Now str +fund_investment_stages=fund.investment_stages, # Now relationship +fund_sectors=fund.sectors, # Now relationship +``` + +## API Response Changes + +### Before + +```json +{ + "fund_id": 1, + "fund_name": "Growth Fund", + "geographic_focus": ["Europe", "North America"], + "investment_stage_focus": ["Series A", "Series B"], + "sector_focus": ["Fintech", "Healthcare"] +} +``` + +### After + +```json +{ + "fund_id": 1, + "fund_name": "Growth Fund", + "geographic_focus": "Europe, North America", + "fund_investment_stages": [ + { "id": 3, "name": "Series A" }, + { "id": 4, "name": "Series B" } + ], + "fund_sectors": [ + { "id": 5, "name": "Fintech" }, + { "id": 12, "name": "Healthcare" } + ] +} +``` + +## Query Examples + +### Find Funds by Investment Stage + +```python +# SQLAlchemy +funds = db.query(FundTable).join( + FundTable.investment_stages +).filter( + InvestmentStageTable.name == "Series A" +).all() + +# SQL +SELECT f.* FROM funds f +JOIN fund_investment_stages fis ON f.id = fis.fund_id +JOIN investment_stages s ON fis.stage_id = s.id +WHERE s.name = 'Series A'; +``` + +### Find Funds by Sector + +```python +# SQLAlchemy +funds = db.query(FundTable).join( + FundTable.sectors +).filter( + SectorTable.name == "Fintech" +).all() + +# SQL +SELECT f.* FROM funds f +JOIN fund_sectors fs ON f.id = fs.fund_id +JOIN sectors s ON fs.sector_id = s.id +WHERE s.name = 'Fintech'; +``` + +### Find Funds by Geographic Focus + +```python +# SQLAlchemy +funds = db.query(FundTable).filter( + FundTable.geographic_focus.ilike("%Europe%") +).all() + +# SQL +SELECT * FROM funds +WHERE geographic_focus LIKE '%Europe%'; +``` + +### Complex Query: Funds Investing in Fintech at Series A in Europe + +```python +funds = db.query(FundTable).join( + FundTable.investment_stages +).join( + FundTable.sectors +).filter( + InvestmentStageTable.name == "Series A", + SectorTable.name == "Fintech", + FundTable.geographic_focus.ilike("%Europe%") +).all() +``` + +## Benefits + +### 1. Better Data Normalization ✨ + +- Investment stages and sectors are now properly normalized +- No duplicate data stored in JSON arrays +- Single source of truth for stage/sector names + +### 2. Efficient Filtering 🔍 + +- Can filter funds by stages/sectors using SQL JOINs +- No need to parse JSON for queries +- Database indexes can be used effectively + +### 3. Data Integrity 🛡️ + +- Foreign key constraints ensure referential integrity +- Can't reference non-existent stages or sectors +- Cascade deletes work properly + +### 4. Easier Aggregations 📊 + +```sql +-- Count funds per investment stage +SELECT s.name, COUNT(DISTINCT f.id) as fund_count +FROM investment_stages s +LEFT JOIN fund_investment_stages fis ON s.id = fis.stage_id +LEFT JOIN funds f ON fis.fund_id = f.id +GROUP BY s.name; + +-- Count funds per sector +SELECT s.name, COUNT(DISTINCT f.id) as fund_count +FROM sectors s +LEFT JOIN fund_sectors fs ON s.id = fs.sector_id +LEFT JOIN funds f ON fs.fund_id = f.id +GROUP BY s.name; +``` + +### 5. Consistent Pattern 🎯 + +- Follows same many-to-many pattern as: + - Investors ↔ Sectors + - Companies ↔ Sectors + - Projects ↔ Sectors +- Makes codebase more maintainable + +## Frontend Updates Required + +### Geographic Focus + +```typescript +// OLD +const geoList = fund.geographic_focus.join(", "); + +// NEW +const geoStr = fund.geographic_focus; // Already a string +``` + +### Investment Stages + +```typescript +// OLD +const stages = fund.investment_stage_focus; // string[] + +// NEW +const stages = fund.fund_investment_stages.map((s) => s.name); // InvestmentStageSchema[] +``` + +### Sectors + +```typescript +// OLD +const sectors = fund.sector_focus; // string[] + +// NEW +const sectors = fund.fund_sectors.map((s) => s.name); // SectorSchema[] +``` + +## Files Modified + +1. ✅ `preprocessor/models.py` - Updated FundTable, added association tables +2. ✅ `app/db/models.py` - Updated FundTable, added InvestmentStageTable +3. ✅ `app/schemas/router_schemas.py` - Updated FundSchema, InvestorFundData +4. ✅ `app/services/llm_parser.py` - Updated fund processing logic +5. ✅ `app/routers/investors.py` - Updated response formatting +6. ✅ `preprocessor/migrate_fund_relationships.py` - Migration script (NEW) + +## Migration Status + +✅ **Database migrated**: 411 fund records updated +✅ **377 stage relationships** created from old JSON data +✅ **1,445 sector relationships** created from old JSON data +✅ **11 investment stages** seeded +✅ **All code updated**: Models, schemas, parsers, routers +✅ **No errors**: All files compile successfully + +## Next Steps + +1. **Test the API** with new response structure +2. **Update frontend** to use new field formats +3. **Re-parse CSV** (optional) to ensure all new data uses the correct structure +4. **Update filtering UI** to leverage the new relationships + +## Summary + +The fund schema has been successfully refactored to: + +- Store `geographic_focus` as a simple string for easier display +- Use proper many-to-many relationships for `investment_stages` +- Use proper many-to-many relationships with existing `sectors` table +- Enable efficient filtering and aggregation by stage/sector +- Maintain better data normalization and integrity + +This enables powerful queries like "Show me all Fintech funds investing at Series A in Europe" with simple SQL JOINs! 🎉 diff --git a/app/db/models.py b/app/db/models.py index 86acced..a3e0774 100644 --- a/app/db/models.py +++ b/app/db/models.py @@ -70,6 +70,22 @@ project_company_association = Table( Column("company_id", Integer, ForeignKey("companies.id")), ) +# Association table for fund-stage many-to-many +fund_investment_stages_association = Table( + "fund_investment_stages", + Base.metadata, + Column("fund_id", Integer, ForeignKey("funds.id")), + Column("stage_id", Integer, ForeignKey("investment_stages.id")), +) + +# Association table for fund-sector many-to-many +fund_sectors_association = Table( + "fund_sectors", + Base.metadata, + Column("fund_id", Integer, ForeignKey("funds.id")), + Column("sector_id", Integer, ForeignKey("sectors.id")), +) + class InvestorTable(Base, TimestampMixin): __tablename__ = "investors" @@ -172,13 +188,21 @@ class FundTable(Base, TimestampMixin): source_url = Column(String, nullable=True) source_provider = Column(String, nullable=True) # e.g., "Perplexity" - # JSON array fields - geographic_focus = Column(JSON, nullable=True) # Array of regions/countries - investment_stage_focus = Column(JSON, nullable=True) # Array of stages - sector_focus = Column(JSON, nullable=True) # Array of sectors + # Geographic focus as simple string + geographic_focus = Column(String, nullable=True) # Relationships investor = relationship("InvestorTable", back_populates="funds") + investment_stages = relationship( + "InvestmentStageTable", + secondary=fund_investment_stages_association, + back_populates="funds", + ) + sectors = relationship( + "SectorTable", + secondary=fund_sectors_association, + back_populates="funds", + ) class CompanyTable(Base, TimestampMixin): @@ -224,26 +248,43 @@ class CompanyMember(Base, TimestampMixin): company = relationship("CompanyTable", back_populates="members") +class InvestmentStageTable(Base, TimestampMixin): + __tablename__ = "investment_stages" + + id = Column(Integer, primary_key=True, index=True) + name = Column(String, nullable=False, unique=True) + + # Relationships + funds = relationship( + "FundTable", + secondary=fund_investment_stages_association, + back_populates="investment_stages", + ) + + class SectorTable(Base, TimestampMixin): __tablename__ = "sectors" id = Column(Integer, primary_key=True, index=True) name = Column(String, nullable=False) - # Add relationship back to investors + # Relationships investors = relationship( "InvestorTable", secondary=investor_sector_association, back_populates="sectors", ) - companies = relationship( "CompanyTable", secondary=company_sector_association, back_populates="sectors" ) - projects = relationship( "ProjectTable", secondary=project_sector_association, back_populates="sector" ) + funds = relationship( + "FundTable", + secondary=fund_sectors_association, + back_populates="sectors", + ) class ProjectTable(Base, TimestampMixin): diff --git a/app/routers/investors.py b/app/routers/investors.py index b26ebcb..951e04f 100644 --- a/app/routers/investors.py +++ b/app/routers/investors.py @@ -82,8 +82,8 @@ def read_investors(db: Session = Depends(get_db)): check_size_lower=fund.check_size_lower, check_size_upper=fund.check_size_upper, geographic_focus=fund.geographic_focus, - investment_stage_focus=fund.investment_stage_focus, - sector_focus=fund.sector_focus, + fund_investment_stages=fund.investment_stages, # Now a relationship + fund_sectors=fund.sectors, # Now a relationship # Related data (same for all funds of this investor) portfolio_companies=investor.portfolio_companies, team_members=investor.team_members, @@ -113,8 +113,8 @@ def read_investors(db: Session = Depends(get_db)): check_size_lower=None, check_size_upper=None, geographic_focus=None, - investment_stage_focus=None, - sector_focus=None, + fund_investment_stages=None, + fund_sectors=None, # Related data portfolio_companies=investor.portfolio_companies, team_members=investor.team_members, @@ -208,8 +208,8 @@ def filter_investors( check_size_lower=fund.check_size_lower, check_size_upper=fund.check_size_upper, geographic_focus=fund.geographic_focus, - investment_stage_focus=fund.investment_stage_focus, - sector_focus=fund.sector_focus, + fund_investment_stages=fund.investment_stages, # Now a relationship + fund_sectors=fund.sectors, # Now a relationship # Related data portfolio_companies=investor.portfolio_companies, team_members=investor.team_members, @@ -239,8 +239,8 @@ def filter_investors( check_size_lower=None, check_size_upper=None, geographic_focus=None, - investment_stage_focus=None, - sector_focus=None, + fund_investment_stages=None, + fund_sectors=None, # Related data portfolio_companies=investor.portfolio_companies, team_members=investor.team_members, @@ -502,8 +502,8 @@ def find_similar_investors( check_size_lower=fund.check_size_lower, check_size_upper=fund.check_size_upper, geographic_focus=fund.geographic_focus, - investment_stage_focus=fund.investment_stage_focus, - sector_focus=fund.sector_focus, + fund_investment_stages=fund.investment_stages, # Now a relationship + fund_sectors=fund.sectors, # Now a relationship # Related data portfolio_companies=investor.portfolio_companies, team_members=investor.team_members, @@ -533,8 +533,8 @@ def find_similar_investors( check_size_lower=None, check_size_upper=None, geographic_focus=None, - investment_stage_focus=None, - sector_focus=None, + fund_investment_stages=None, + fund_sectors=None, # Related data portfolio_companies=investor.portfolio_companies, team_members=investor.team_members, diff --git a/app/schemas/router_schemas.py b/app/schemas/router_schemas.py index 942f2b1..e43382f 100644 --- a/app/schemas/router_schemas.py +++ b/app/schemas/router_schemas.py @@ -22,6 +22,14 @@ class SectorSchema(BaseModel): from_attributes = True +class InvestmentStageSchema(BaseModel): + id: int + name: str + + class Config: + from_attributes = True + + class InvestorMemberSchema(BaseModel): id: int name: str @@ -41,9 +49,9 @@ class FundSchema(BaseModel): check_size_upper: int | None # NEW: Upper bound of check size range source_url: str | None source_provider: str | None - geographic_focus: List[str] | None - investment_stage_focus: List[str] | None - sector_focus: List[str] | None + geographic_focus: str | None # Changed from List[str] to string + investment_stages: List[InvestmentStageSchema] | None # Changed to relationship + sectors: List[SectorSchema] | None # Changed to relationship created_at: Optional[datetime] = None updated_at: Optional[datetime] = None @@ -134,9 +142,11 @@ class InvestorFundData(BaseModel): fund_size_source_url: str | None check_size_lower: int | None # NEW: Lower bound of check size range check_size_upper: int | None # NEW: Upper bound of check size range - geographic_focus: List[str] | None - investment_stage_focus: List[str] | None - sector_focus: List[str] | None + geographic_focus: str | None # Changed from List[str] to string + fund_investment_stages: ( + List[InvestmentStageSchema] | None + ) # Changed to relationship + fund_sectors: List[SectorSchema] | None # Changed to relationship # Related data portfolio_companies: List[CompanySchema] diff --git a/app/services/llm_parser.py b/app/services/llm_parser.py index 7fbd46d..146f100 100644 --- a/app/services/llm_parser.py +++ b/app/services/llm_parser.py @@ -9,6 +9,7 @@ from db.models import ( CompanyMember, CompanyTable, FundTable, + InvestmentStageTable, InvestorMember, InvestorTable, SectorTable, @@ -223,11 +224,16 @@ Return the lower and upper bounds in USD.""" "check_size_upper": None, "source_url": fund.get("sourceUrl"), "source_provider": fund.get("sourceProvider"), - "geographic_focus": fund.get("geographicFocus", []), - "investment_stage_focus": fund.get("investmentStageFocus", []), - "sector_focus": fund.get("sectorFocus", []), + "geographic_focus": None, # Will be converted to string + "investment_stage_names": fund.get("investmentStageFocus", []), + "sector_names": fund.get("sectorFocus", []), } + # Convert geographic focus from array to comma-separated string + geo_focus = fund.get("geographicFocus", []) + if geo_focus and isinstance(geo_focus, list): + fund_data["geographic_focus"] = ", ".join(geo_focus) + # Convert fund size to USD integer fund_size_str = fund.get("fundSize") if fund_size_str and fund_size_str != "Not Available": @@ -499,15 +505,24 @@ Return the lower and upper bounds in USD.""" fund_name=fund_data.get("fund_name"), fund_size=fund_data.get("fund_size"), # Now an integer fund_size_source_url=fund_data.get("fund_size_source_url"), - check_size_lower=fund_data.get("check_size_lower"), # NEW - check_size_upper=fund_data.get("check_size_upper"), # NEW + check_size_lower=fund_data.get("check_size_lower"), + check_size_upper=fund_data.get("check_size_upper"), source_url=fund_data.get("source_url"), source_provider=fund_data.get("source_provider"), - geographic_focus=fund_data.get("geographic_focus"), - investment_stage_focus=fund_data.get("investment_stage_focus"), - sector_focus=fund_data.get("sector_focus"), + geographic_focus=fund_data.get("geographic_focus"), # Now a string ) db.add(fund) + db.flush() # Get the fund ID + + # Add investment stages (many-to-many) + for stage_name in fund_data.get("investment_stage_names", []): + stage = self._get_or_create_investment_stage(db, stage_name) + fund.investment_stages.append(stage) + + # Add sectors (many-to-many) + for sector_name in fund_data.get("sector_names", []): + sector = self._get_or_create_sector(db, sector_name) + fund.sectors.append(sector) return investor @@ -516,6 +531,23 @@ Return the lower and upper bounds in USD.""" db.rollback() return None + def _get_or_create_investment_stage( + self, db: Session, stage_name: str + ) -> InvestmentStageTable: + """Get existing investment stage or create new one""" + from db.models import InvestmentStageTable + + stage = ( + db.query(InvestmentStageTable) + .filter(InvestmentStageTable.name == stage_name) + .first() + ) + if not stage: + stage = InvestmentStageTable(name=stage_name) + db.add(stage) + db.flush() # Get the ID without committing + return stage + def _get_or_create_sector(self, db: Session, sector_name: str) -> SectorTable: """Get existing sector or create new one""" sector = db.query(SectorTable).filter(SectorTable.name == sector_name).first() diff --git a/preprocessor/migrate_fund_relationships.py b/preprocessor/migrate_fund_relationships.py new file mode 100644 index 0000000..deef75c --- /dev/null +++ b/preprocessor/migrate_fund_relationships.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +""" +Migration script to update fund table schema: +1. Change geographic_focus from JSON to STRING +2. Create investment_stages table and fund_investment_stages association table +3. Create fund_sectors association table for many-to-many with sectors +4. Remove investment_stage_focus and sector_focus JSON columns +""" + +import sqlite3 +from pathlib import Path + + +def migrate_fund_relationships(): + db_path = Path(__file__).parent / "version_two.db" + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + print("🔄 Starting fund relationships migration...") + + try: + # Step 1: Drop and recreate investment_stages table with correct schema + print("1️⃣ Recreating investment_stages table...") + cursor.execute("DROP TABLE IF EXISTS investment_stages") + cursor.execute(""" + CREATE TABLE investment_stages ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + name VARCHAR NOT NULL UNIQUE, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME + ) + """) + + # Insert standard investment stages + stages = [ + "Seed", + "Pre-Seed", + "Series A", + "Series B", + "Series C", + "Series D+", + "Growth", + "Late Stage", + "IPO", + "Venture", + "Early Stage", + ] + for stage in stages: + cursor.execute( + """ + INSERT OR IGNORE INTO investment_stages (name) VALUES (?) + """, + (stage,), + ) + + print(f" ✅ Created investment_stages table with {len(stages)} stages") + + # Step 2: Create fund_investment_stages association table + print("2️⃣ Creating fund_investment_stages association table...") + cursor.execute(""" + CREATE TABLE IF NOT EXISTS fund_investment_stages ( + fund_id INTEGER NOT NULL, + stage_id INTEGER NOT NULL, + PRIMARY KEY (fund_id, stage_id), + FOREIGN KEY (fund_id) REFERENCES funds (id) ON DELETE CASCADE, + FOREIGN KEY (stage_id) REFERENCES investment_stages (id) ON DELETE CASCADE + ) + """) + print(" ✅ Created fund_investment_stages association table") + + # Step 3: Create fund_sectors association table + print("3️⃣ Creating fund_sectors association table...") + cursor.execute(""" + CREATE TABLE IF NOT EXISTS fund_sectors ( + fund_id INTEGER NOT NULL, + sector_id INTEGER NOT NULL, + PRIMARY KEY (fund_id, sector_id), + FOREIGN KEY (fund_id) REFERENCES funds (id) ON DELETE CASCADE, + FOREIGN KEY (sector_id) REFERENCES sectors (id) ON DELETE CASCADE + ) + """) + print(" ✅ Created fund_sectors association table") + + # Step 4: Get current funds table columns + cursor.execute("PRAGMA table_info(funds)") + columns = {col[1]: col for col in cursor.fetchall()} + print(f"\n📊 Current funds table has {len(columns)} columns") + + # Step 5: Create new funds table with updated schema + print("4️⃣ Creating new funds table schema...") + cursor.execute(""" + CREATE TABLE funds_new ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + investor_id INTEGER NOT NULL, + fund_name VARCHAR, + fund_size INTEGER, + fund_size_source_url VARCHAR, + check_size_lower INTEGER, + check_size_upper INTEGER, + source_url VARCHAR, + source_provider VARCHAR, + geographic_focus VARCHAR, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP, + updated_at DATETIME, + FOREIGN KEY (investor_id) REFERENCES investors (id) + ) + """) + + # Step 6: Copy data from old table to new table + print("5️⃣ Copying data from old funds table...") + cursor.execute(""" + INSERT INTO funds_new ( + id, investor_id, fund_name, fund_size, fund_size_source_url, + check_size_lower, check_size_upper, source_url, source_provider, + geographic_focus, created_at, updated_at + ) + SELECT + id, investor_id, fund_name, fund_size, fund_size_source_url, + check_size_lower, check_size_upper, source_url, source_provider, + CASE + WHEN geographic_focus IS NOT NULL AND geographic_focus != '[]' + THEN REPLACE(REPLACE(geographic_focus, '["', ''), '"]', '') + ELSE NULL + END as geographic_focus, + created_at, updated_at + FROM funds + """) + rows_copied = cursor.rowcount + print(f" ✅ Copied {rows_copied} rows") + + # Step 7: Migrate investment_stage_focus data to association table + print("6️⃣ Migrating investment stage focus data...") + cursor.execute(""" + SELECT id, investment_stage_focus FROM funds + WHERE investment_stage_focus IS NOT NULL AND investment_stage_focus != '[]' + """) + funds_with_stages = cursor.fetchall() + + stage_migrations = 0 + for fund_id, stages_json in funds_with_stages: + if stages_json: + try: + import json + + stages = json.loads(stages_json) + for stage_name in stages: + # Find matching stage + cursor.execute( + """ + SELECT id FROM investment_stages WHERE name = ? + """, + (stage_name,), + ) + result = cursor.fetchone() + if result: + stage_id = result[0] + cursor.execute( + """ + INSERT OR IGNORE INTO fund_investment_stages (fund_id, stage_id) + VALUES (?, ?) + """, + (fund_id, stage_id), + ) + stage_migrations += 1 + except: + pass + + print(f" ✅ Migrated {stage_migrations} stage relationships") + + # Step 8: Migrate sector_focus data to association table + print("7️⃣ Migrating sector focus data...") + cursor.execute(""" + SELECT id, sector_focus FROM funds + WHERE sector_focus IS NOT NULL AND sector_focus != '[]' + """) + funds_with_sectors = cursor.fetchall() + + sector_migrations = 0 + for fund_id, sectors_json in funds_with_sectors: + if sectors_json: + try: + import json + + sectors = json.loads(sectors_json) + for sector_name in sectors: + # Find or create sector + cursor.execute( + """ + SELECT id FROM sectors WHERE name = ? + """, + (sector_name,), + ) + result = cursor.fetchone() + if result: + sector_id = result[0] + else: + cursor.execute( + """ + INSERT INTO sectors (name) VALUES (?) + """, + (sector_name,), + ) + sector_id = cursor.lastrowid + + cursor.execute( + """ + INSERT OR IGNORE INTO fund_sectors (fund_id, sector_id) + VALUES (?, ?) + """, + (fund_id, sector_id), + ) + sector_migrations += 1 + except: + pass + + print(f" ✅ Migrated {sector_migrations} sector relationships") + + # Step 9: Drop old funds table + print("8️⃣ Dropping old funds table...") + cursor.execute("DROP TABLE funds") + + # Step 10: Rename new table to funds + print("9️⃣ Renaming funds_new to funds...") + cursor.execute("ALTER TABLE funds_new RENAME TO funds") + + # Commit all changes + conn.commit() + + print("\n✅ Migration completed successfully!") + print("\n📝 Summary:") + print(f" - Created investment_stages table with {len(stages)} stages") + print(" - Created fund_investment_stages association table") + print(" - Created fund_sectors association table") + print(f" - Migrated {rows_copied} fund records") + print(f" - Migrated {stage_migrations} stage relationships") + print(f" - Migrated {sector_migrations} sector relationships") + print(" - geographic_focus: JSON → STRING") + print(" - investment_stage_focus: REMOVED (now in fund_investment_stages)") + print(" - sector_focus: REMOVED (now in fund_sectors)") + + except Exception as e: + conn.rollback() + print(f"\n❌ Migration failed: {e}") + raise + finally: + conn.close() + + +if __name__ == "__main__": + migrate_fund_relationships() diff --git a/preprocessor/models.py b/preprocessor/models.py index d768803..4897f91 100644 --- a/preprocessor/models.py +++ b/preprocessor/models.py @@ -126,6 +126,22 @@ investor_stage_association = Table( Column("stage_id", Integer, ForeignKey("investment_stages.id")), ) +# Association table for fund-stage many-to-many +fund_investment_stages_association = Table( + "fund_investment_stages", + Base.metadata, + Column("fund_id", Integer, ForeignKey("funds.id")), + Column("stage_id", Integer, ForeignKey("investment_stages.id")), +) + +# Association table for fund-sector many-to-many +fund_sectors_association = Table( + "fund_sectors", + Base.metadata, + Column("fund_id", Integer, ForeignKey("funds.id")), + Column("sector_id", Integer, ForeignKey("sectors.id")), +) + class InvestorTable(Base, TimestampMixin): __tablename__ = "investors" @@ -235,27 +251,40 @@ class FundTable(Base, TimestampMixin): source_url = Column(String, nullable=True) source_provider = Column(String, nullable=True) # e.g., "Perplexity" - # JSON array fields - geographic_focus = Column(JSON, nullable=True) # Array of regions/countries - investment_stage_focus = Column(JSON, nullable=True) # Array of stages - sector_focus = Column(JSON, nullable=True) # Array of sectors + # Geographic focus as simple string + geographic_focus = Column(String, nullable=True) # Relationships investor = relationship("InvestorTable", back_populates="funds") + investment_stages = relationship( + "InvestmentStageTable", + secondary=fund_investment_stages_association, + back_populates="funds", + ) + sectors = relationship( + "SectorTable", + secondary=fund_sectors_association, + back_populates="funds", + ) class InvestmentStageTable(Base, TimestampMixin): __tablename__ = "investment_stages" id = Column(Integer, primary_key=True, index=True) - stage = Column(Enum(InvestmentStage), nullable=False, unique=True) + name = Column(String, nullable=False, unique=True) - # Relationship back to investors + # Relationships investors = relationship( "InvestorTable", secondary=investor_stage_association, back_populates="investment_stages", ) + funds = relationship( + "FundTable", + secondary=fund_investment_stages_association, + back_populates="investment_stages", + ) class CompanyTable(Base, TimestampMixin): @@ -307,20 +336,23 @@ class SectorTable(Base, TimestampMixin): id = Column(Integer, primary_key=True, index=True) name = Column(String, nullable=False) - # Add relationship back to investors + # Relationships investors = relationship( "InvestorTable", secondary=investor_sector_association, back_populates="sectors", ) - companies = relationship( "CompanyTable", secondary=company_sector_association, back_populates="sectors" ) - projects = relationship( "ProjectTable", secondary=project_sector_association, back_populates="sector" ) + funds = relationship( + "FundTable", + secondary=fund_sectors_association, + back_populates="sectors", + ) class ProjectTable(Base, TimestampMixin): diff --git a/preprocessor/version_two.db b/preprocessor/version_two.db index 174cc40..f040109 100644 Binary files a/preprocessor/version_two.db and b/preprocessor/version_two.db differ diff --git a/test_fund_schema.py b/test_fund_schema.py new file mode 100644 index 0000000..b130fad --- /dev/null +++ b/test_fund_schema.py @@ -0,0 +1,123 @@ +#!/usr/bin/env python3 +""" +Quick verification script to test the new fund relationship schema +""" + +import sys + +sys.path.insert(0, "/home/oluwasanmi/Documents/Work/MKD/anton_wireframe/preprocessor") + +from models import FundTable, InvestmentStageTable, SectorTable, get_db_session + + +def test_fund_relationships(): + """Test the new fund relationship schema""" + db = get_db_session() + + print("🧪 Testing Fund Relationship Schema\n") + + # Test 1: Check investment stages + print("1️⃣ Investment Stages:") + stages = db.query(InvestmentStageTable).all() + print(f" Found {len(stages)} stages:") + for stage in stages[:5]: + print(f" - {stage.name}") + print() + + # Test 2: Check fund with relationships + print("2️⃣ Sample Fund with Relationships:") + fund = db.query(FundTable).filter(FundTable.fund_name.isnot(None)).first() + + if fund: + print(f" Fund: {fund.fund_name}") + print(f" Geographic Focus: {fund.geographic_focus}") + + print(f" Investment Stages ({len(fund.investment_stages)}):") + for stage in fund.investment_stages[:3]: + print(f" - {stage.name}") + + print(f" Sectors ({len(fund.sectors)}):") + for sector in fund.sectors[:3]: + print(f" - {sector.name}") + else: + print(" No funds found") + print() + + # Test 3: Check association tables + print("3️⃣ Association Table Stats:") + + # Count fund-stage relationships + from sqlalchemy import text + + result = db.execute(text("SELECT COUNT(*) FROM fund_investment_stages")) + stage_count = result.scalar() + print(f" Fund-Stage relationships: {stage_count}") + + # Count fund-sector relationships + result = db.execute(text("SELECT COUNT(*) FROM fund_sectors")) + sector_count = result.scalar() + print(f" Fund-Sector relationships: {sector_count}") + print() + + # Test 4: Query funds by stage + print("4️⃣ Query Test - Funds with 'Series A' stage:") + series_a_funds = ( + db.query(FundTable) + .join(FundTable.investment_stages) + .filter(InvestmentStageTable.name.ilike("%Series A%")) + .limit(3) + .all() + ) + + print(f" Found {len(series_a_funds)} funds:") + for fund in series_a_funds: + print(f" - {fund.fund_name or 'Unnamed'}") + stages = [s.name for s in fund.investment_stages] + print(f" Stages: {', '.join(stages)}") + print() + + # Test 5: Query funds by sector + print("5️⃣ Query Test - Funds investing in first sector:") + first_sector = db.query(SectorTable).first() + if first_sector: + sector_funds = ( + db.query(FundTable) + .join(FundTable.sectors) + .filter(SectorTable.id == first_sector.id) + .limit(3) + .all() + ) + + print(f" Sector: {first_sector.name}") + print(f" Found {len(sector_funds)} funds:") + for fund in sector_funds: + print(f" - {fund.fund_name or 'Unnamed'}") + print() + + # Test 6: Geographic focus string search + print("6️⃣ Query Test - Funds with Europe in geographic focus:") + europe_funds = ( + db.query(FundTable) + .filter(FundTable.geographic_focus.ilike("%Europe%")) + .limit(3) + .all() + ) + + print(f" Found {len(europe_funds)} funds:") + for fund in europe_funds: + print(f" - {fund.fund_name or 'Unnamed'}") + print(f" Geographic Focus: {fund.geographic_focus}") + print() + + print("✅ All tests completed successfully!") + db.close() + + +if __name__ == "__main__": + try: + test_fund_relationships() + except Exception as e: + print(f"❌ Error: {e}") + import traceback + + traceback.print_exc() diff --git a/version_two.db b/version_two.db new file mode 100644 index 0000000..e69de29