diff --git a/app/routers/investors.py b/app/routers/investors.py index 5d41df0..19d10e1 100644 --- a/app/routers/investors.py +++ b/app/routers/investors.py @@ -13,7 +13,6 @@ from schemas.router_schemas import ( SectorMinimal, ) from services.compatibility_score import ( - calculate_project_investor_compatibility, _calculate_project_fund_compatibility, _calculate_project_investor_direct_compatibility, ) @@ -91,7 +90,9 @@ def read_investors( selectinload(InvestorTable.portfolio_companies), selectinload(InvestorTable.team_members), selectinload(InvestorTable.sectors), - selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages), + selectinload(InvestorTable.funds).selectinload( + FundTable.investment_stages + ), selectinload(InvestorTable.funds).selectinload(FundTable.sectors), ) .all() @@ -106,7 +107,9 @@ def read_investors( selectinload(InvestorTable.portfolio_companies), selectinload(InvestorTable.team_members), selectinload(InvestorTable.sectors), - selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages), + selectinload(InvestorTable.funds).selectinload( + FundTable.investment_stages + ), selectinload(InvestorTable.funds).selectinload(FundTable.sectors), ) .offset(offset) @@ -143,7 +146,9 @@ def read_investors( # Get top 3 sectors from fund (id and name only) - sorted alphabetically fund_sectors = [ SectorMinimal(id=sector.id, name=sector.name) - for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name) + for sector in sorted( + fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name + ) ] investment_response = InvestmentResponse( @@ -188,7 +193,7 @@ def read_investors( if project is not None: investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True) # Apply pagination after sorting - investment_responses = investment_responses[offset:offset + page_size] + investment_responses = investment_responses[offset : offset + page_size] # Calculate total pages total_pages = (total_count + page_size - 1) // page_size @@ -320,7 +325,9 @@ def filter_investors( # Get top 3 sectors from fund (id and name only) - sorted alphabetically fund_sectors = [ SectorMinimal(id=sector.id, name=sector.name) - for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name) + for sector in sorted( + fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name + ) ] investment_response = InvestmentResponse( @@ -344,7 +351,7 @@ def filter_investors( investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True) # Apply pagination after sorting offset = (page - 1) * page_size - investment_responses = investment_responses[offset:offset + page_size] + investment_responses = investment_responses[offset : offset + page_size] # Calculate total pages total_pages = (total_count + page_size - 1) // page_size diff --git a/app/services/compatibility_score.py b/app/services/compatibility_score.py index 707c0ba..2253bf4 100644 --- a/app/services/compatibility_score.py +++ b/app/services/compatibility_score.py @@ -117,41 +117,41 @@ def _calculate_project_fund_compatibility( # 2. Sector Overlap (30 points) sector_score = 0 if project.sector and fund.sectors: - project_sectors = [s for s in project.sector if hasattr(s, 'name')] - fund_sectors = [s for s in fund.sectors if hasattr(s, 'name')] - + project_sectors = [s for s in project.sector if hasattr(s, "name")] + fund_sectors = [s for s in fund.sectors if hasattr(s, "name")] + if project_sectors and fund_sectors: # Use fuzzy matching to account for similar but not identical sector names match_count = 0 total_matches = 0 - + for proj_sector in project_sectors: best_match_score = 0 proj_name = proj_sector.name.lower().strip() - + for fund_sector in fund_sectors: fund_name = fund_sector.name.lower().strip() - + # Exact match if proj_name == fund_name: best_match_score = 1.0 break - + # Fuzzy match using sequence matcher similarity = SequenceMatcher(None, proj_name, fund_name).ratio() - + # Also check if one contains the other (substring match) if proj_name in fund_name or fund_name in proj_name: similarity = max(similarity, 0.8) - + best_match_score = max(best_match_score, similarity) - + # Count matches with threshold # Perfect match (1.0), strong match (>0.75), partial match (>0.6) if best_match_score >= 0.6: total_matches += best_match_score match_count += 1 - + if match_count > 0: # Calculate overlap ratio based on fuzzy matches overlap_ratio = total_matches / len(project_sectors) @@ -246,40 +246,40 @@ def _calculate_project_investor_direct_compatibility( # 2. Sector Overlap (30 points) sector_score = 0 if project.sector and investor.sectors: - project_sectors = [s for s in project.sector if hasattr(s, 'name')] - investor_sectors = [s for s in investor.sectors if hasattr(s, 'name')] - + project_sectors = [s for s in project.sector if hasattr(s, "name")] + investor_sectors = [s for s in investor.sectors if hasattr(s, "name")] + if project_sectors and investor_sectors: # Use fuzzy matching to account for similar but not identical sector names match_count = 0 total_matches = 0 - + for proj_sector in project_sectors: best_match_score = 0 proj_name = proj_sector.name.lower().strip() - + for inv_sector in investor_sectors: inv_name = inv_sector.name.lower().strip() - + # Exact match if proj_name == inv_name: best_match_score = 1.0 break - + # Fuzzy match using sequence matcher similarity = SequenceMatcher(None, proj_name, inv_name).ratio() - + # Also check if one contains the other (substring match) if proj_name in inv_name or inv_name in proj_name: similarity = max(similarity, 0.8) - + best_match_score = max(best_match_score, similarity) - + # Count matches with threshold if best_match_score >= 0.6: total_matches += best_match_score match_count += 1 - + if match_count > 0: # Calculate overlap ratio based on fuzzy matches overlap_ratio = total_matches / len(project_sectors) @@ -384,43 +384,70 @@ def _check_geographic_overlap(location1: str, location2: str) -> bool: # Normalize inputs loc1 = location1.lower().strip() loc2 = location2.lower().strip() - + # Common geographic groupings with broader regional mappings geo_groups = [ # North America ["usa", "us", "united states", "america", "u.s.", "u.s.a"], ["canada", "canadian"], ["mexico", "mexican"], - # Europe and countries - ["europe", "european", "eu", "germany", "france", "uk", "united kingdom", - "britain", "spain", "italy", "netherlands", "belgium", "sweden", "denmark", - "norway", "finland", "poland", "portugal", "austria", "switzerland", - "ireland", "greece", "czech", "romania"], - + [ + "europe", + "european", + "eu", + "germany", + "france", + "uk", + "united kingdom", + "britain", + "spain", + "italy", + "netherlands", + "belgium", + "sweden", + "denmark", + "norway", + "finland", + "poland", + "portugal", + "austria", + "switzerland", + "ireland", + "greece", + "czech", + "romania", + ], # UK specific ["uk", "united kingdom", "britain", "england", "scotland", "wales", "london"], - # US states ["california", "ca", "san francisco", "los angeles", "silicon valley"], ["new york", "ny", "nyc"], ["texas", "tx"], ["massachusetts", "ma", "boston"], ["washington", "seattle"], - # Asia - ["asia", "asian", "china", "japan", "korea", "singapore", "hong kong", - "india", "indonesia", "thailand", "vietnam", "malaysia", "philippines"], - + [ + "asia", + "asian", + "china", + "japan", + "korea", + "singapore", + "hong kong", + "india", + "indonesia", + "thailand", + "vietnam", + "malaysia", + "philippines", + ], # Middle East ["middle east", "israel", "uae", "dubai", "saudi arabia"], - # Latin America ["latin america", "brazil", "argentina", "chile", "colombia", "mexico"], - # Africa ["africa", "african", "south africa", "nigeria", "kenya", "egypt"], - # Oceania ["australia", "australian", "new zealand"], ] @@ -431,7 +458,7 @@ def _check_geographic_overlap(location1: str, location2: str) -> bool: found_in_2 = any(term in loc2 for term in group) if found_in_1 and found_in_2: return True - + # Check for direct substring match (one contains the other) if loc1 in loc2 or loc2 in loc1: return True diff --git a/investors.db b/investors.db index 6ec347e..dceae79 100644 Binary files a/investors.db and b/investors.db differ diff --git a/migrations/add_feedback_fixes_2025_01_07.py b/migrations/add_feedback_fixes_2025_01_07.py new file mode 100644 index 0000000..d8056f2 --- /dev/null +++ b/migrations/add_feedback_fixes_2025_01_07.py @@ -0,0 +1,110 @@ +""" +Migration: Add fields from feedback fixes +Date: 2025-01-07 + +Adds the following fields: +- projects.is_archived (INTEGER, default 0) +- companies.product_service (TEXT, nullable) +- companies.clients (TEXT, nullable - stored as JSON string) +- investor_members.linkedin (VARCHAR, nullable) +""" + +import os +import sys +from pathlib import Path + +# Add parent directory to path to import app modules +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from sqlalchemy import create_engine, text +from app.db.db import DATABASE_URL, engine + + +def check_column_exists(conn, table_name, column_name): + """Check if a column exists in a table""" + result = conn.execute(text(f"PRAGMA table_info({table_name})")) + columns = [row[1] for row in result] + return column_name in columns + + +def upgrade(): + """Add new columns to tables""" + print("Running migration: Add feedback fixes fields") + print("=" * 60) + + with engine.begin() as conn: # Use begin() for transaction management + # 1. Add is_archived to projects table + print("\n1. Adding 'is_archived' column to projects table...") + if check_column_exists(conn, "projects", "is_archived"): + print(" ✓ Column 'is_archived' already exists. Skipping.") + else: + conn.execute(text("ALTER TABLE projects ADD COLUMN is_archived INTEGER DEFAULT 0 NOT NULL")) + # Set default value for existing rows + conn.execute(text("UPDATE projects SET is_archived = 0 WHERE is_archived IS NULL")) + print(" ✓ Successfully added 'is_archived' column to projects table") + + # 2. Add product_service to companies table + print("\n2. Adding 'product_service' column to companies table...") + if check_column_exists(conn, "companies", "product_service"): + print(" ✓ Column 'product_service' already exists. Skipping.") + else: + conn.execute(text("ALTER TABLE companies ADD COLUMN product_service TEXT")) + print(" ✓ Successfully added 'product_service' column to companies table") + + # 3. Add clients to companies table + print("\n3. Adding 'clients' column to companies table...") + if check_column_exists(conn, "companies", "clients"): + print(" ✓ Column 'clients' already exists. Skipping.") + else: + conn.execute(text("ALTER TABLE companies ADD COLUMN clients TEXT")) + print(" ✓ Successfully added 'clients' column to companies table") + + # 4. Add linkedin to investor_members table + print("\n4. Adding 'linkedin' column to investor_members table...") + if check_column_exists(conn, "investor_members", "linkedin"): + print(" ✓ Column 'linkedin' already exists. Skipping.") + else: + conn.execute(text("ALTER TABLE investor_members ADD COLUMN linkedin VARCHAR")) + print(" ✓ Successfully added 'linkedin' column to investor_members table") + + print("\n" + "=" * 60) + print("Migration completed successfully!") + + +def downgrade(): + """Remove added columns from tables""" + print("Running downgrade: Remove feedback fixes fields") + print("=" * 60) + + # Note: SQLite doesn't support DROP COLUMN directly + print("\nWarning: SQLite doesn't support DROP COLUMN directly.") + print("To remove these columns, you would need to:") + print("1. Create new tables without the columns") + print("2. Copy data from old tables to new tables") + print("3. Drop old tables and rename new tables") + print("\nColumns to remove:") + print(" - projects.is_archived") + print(" - companies.product_service") + print(" - companies.clients") + print(" - investor_members.linkedin") + + +if __name__ == "__main__": + import argparse + + parser = argparse.ArgumentParser(description="Run database migration") + parser.add_argument( + "direction", + choices=["upgrade", "downgrade"], + default="upgrade", + nargs="?", + help="Migration direction (default: upgrade)" + ) + + args = parser.parse_args() + + if args.direction == "upgrade": + upgrade() + else: + downgrade() +