made corrections based on feedback

This commit is contained in:
bolade
2025-11-11 20:27:55 +01:00
parent 5e83734acf
commit 215fec2895
4 changed files with 189 additions and 45 deletions
+14 -7
View File
@@ -13,7 +13,6 @@ from schemas.router_schemas import (
SectorMinimal,
)
from services.compatibility_score import (
calculate_project_investor_compatibility,
_calculate_project_fund_compatibility,
_calculate_project_investor_direct_compatibility,
)
@@ -91,7 +90,9 @@ def read_investors(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages),
selectinload(InvestorTable.funds).selectinload(
FundTable.investment_stages
),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
)
.all()
@@ -106,7 +107,9 @@ def read_investors(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages),
selectinload(InvestorTable.funds).selectinload(
FundTable.investment_stages
),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
)
.offset(offset)
@@ -143,7 +146,9 @@ def read_investors(
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name)
for sector in sorted(
fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name
)
]
investment_response = InvestmentResponse(
@@ -188,7 +193,7 @@ def read_investors(
if project is not None:
investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True)
# Apply pagination after sorting
investment_responses = investment_responses[offset:offset + page_size]
investment_responses = investment_responses[offset : offset + page_size]
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
@@ -320,7 +325,9 @@ def filter_investors(
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name)
for sector in sorted(
fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name
)
]
investment_response = InvestmentResponse(
@@ -344,7 +351,7 @@ def filter_investors(
investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True)
# Apply pagination after sorting
offset = (page - 1) * page_size
investment_responses = investment_responses[offset:offset + page_size]
investment_responses = investment_responses[offset : offset + page_size]
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
+65 -38
View File
@@ -117,41 +117,41 @@ def _calculate_project_fund_compatibility(
# 2. Sector Overlap (30 points)
sector_score = 0
if project.sector and fund.sectors:
project_sectors = [s for s in project.sector if hasattr(s, 'name')]
fund_sectors = [s for s in fund.sectors if hasattr(s, 'name')]
project_sectors = [s for s in project.sector if hasattr(s, "name")]
fund_sectors = [s for s in fund.sectors if hasattr(s, "name")]
if project_sectors and fund_sectors:
# Use fuzzy matching to account for similar but not identical sector names
match_count = 0
total_matches = 0
for proj_sector in project_sectors:
best_match_score = 0
proj_name = proj_sector.name.lower().strip()
for fund_sector in fund_sectors:
fund_name = fund_sector.name.lower().strip()
# Exact match
if proj_name == fund_name:
best_match_score = 1.0
break
# Fuzzy match using sequence matcher
similarity = SequenceMatcher(None, proj_name, fund_name).ratio()
# Also check if one contains the other (substring match)
if proj_name in fund_name or fund_name in proj_name:
similarity = max(similarity, 0.8)
best_match_score = max(best_match_score, similarity)
# Count matches with threshold
# Perfect match (1.0), strong match (>0.75), partial match (>0.6)
if best_match_score >= 0.6:
total_matches += best_match_score
match_count += 1
if match_count > 0:
# Calculate overlap ratio based on fuzzy matches
overlap_ratio = total_matches / len(project_sectors)
@@ -246,40 +246,40 @@ def _calculate_project_investor_direct_compatibility(
# 2. Sector Overlap (30 points)
sector_score = 0
if project.sector and investor.sectors:
project_sectors = [s for s in project.sector if hasattr(s, 'name')]
investor_sectors = [s for s in investor.sectors if hasattr(s, 'name')]
project_sectors = [s for s in project.sector if hasattr(s, "name")]
investor_sectors = [s for s in investor.sectors if hasattr(s, "name")]
if project_sectors and investor_sectors:
# Use fuzzy matching to account for similar but not identical sector names
match_count = 0
total_matches = 0
for proj_sector in project_sectors:
best_match_score = 0
proj_name = proj_sector.name.lower().strip()
for inv_sector in investor_sectors:
inv_name = inv_sector.name.lower().strip()
# Exact match
if proj_name == inv_name:
best_match_score = 1.0
break
# Fuzzy match using sequence matcher
similarity = SequenceMatcher(None, proj_name, inv_name).ratio()
# Also check if one contains the other (substring match)
if proj_name in inv_name or inv_name in proj_name:
similarity = max(similarity, 0.8)
best_match_score = max(best_match_score, similarity)
# Count matches with threshold
if best_match_score >= 0.6:
total_matches += best_match_score
match_count += 1
if match_count > 0:
# Calculate overlap ratio based on fuzzy matches
overlap_ratio = total_matches / len(project_sectors)
@@ -384,43 +384,70 @@ def _check_geographic_overlap(location1: str, location2: str) -> bool:
# Normalize inputs
loc1 = location1.lower().strip()
loc2 = location2.lower().strip()
# Common geographic groupings with broader regional mappings
geo_groups = [
# North America
["usa", "us", "united states", "america", "u.s.", "u.s.a"],
["canada", "canadian"],
["mexico", "mexican"],
# Europe and countries
["europe", "european", "eu", "germany", "france", "uk", "united kingdom",
"britain", "spain", "italy", "netherlands", "belgium", "sweden", "denmark",
"norway", "finland", "poland", "portugal", "austria", "switzerland",
"ireland", "greece", "czech", "romania"],
[
"europe",
"european",
"eu",
"germany",
"france",
"uk",
"united kingdom",
"britain",
"spain",
"italy",
"netherlands",
"belgium",
"sweden",
"denmark",
"norway",
"finland",
"poland",
"portugal",
"austria",
"switzerland",
"ireland",
"greece",
"czech",
"romania",
],
# UK specific
["uk", "united kingdom", "britain", "england", "scotland", "wales", "london"],
# US states
["california", "ca", "san francisco", "los angeles", "silicon valley"],
["new york", "ny", "nyc"],
["texas", "tx"],
["massachusetts", "ma", "boston"],
["washington", "seattle"],
# Asia
["asia", "asian", "china", "japan", "korea", "singapore", "hong kong",
"india", "indonesia", "thailand", "vietnam", "malaysia", "philippines"],
[
"asia",
"asian",
"china",
"japan",
"korea",
"singapore",
"hong kong",
"india",
"indonesia",
"thailand",
"vietnam",
"malaysia",
"philippines",
],
# Middle East
["middle east", "israel", "uae", "dubai", "saudi arabia"],
# Latin America
["latin america", "brazil", "argentina", "chile", "colombia", "mexico"],
# Africa
["africa", "african", "south africa", "nigeria", "kenya", "egypt"],
# Oceania
["australia", "australian", "new zealand"],
]
@@ -431,7 +458,7 @@ def _check_geographic_overlap(location1: str, location2: str) -> bool:
found_in_2 = any(term in loc2 for term in group)
if found_in_1 and found_in_2:
return True
# Check for direct substring match (one contains the other)
if loc1 in loc2 or loc2 in loc1:
return True
BIN
View File
Binary file not shown.
+110
View File
@@ -0,0 +1,110 @@
"""
Migration: Add fields from feedback fixes
Date: 2025-01-07
Adds the following fields:
- projects.is_archived (INTEGER, default 0)
- companies.product_service (TEXT, nullable)
- companies.clients (TEXT, nullable - stored as JSON string)
- investor_members.linkedin (VARCHAR, nullable)
"""
import os
import sys
from pathlib import Path
# Add parent directory to path to import app modules
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import create_engine, text
from app.db.db import DATABASE_URL, engine
def check_column_exists(conn, table_name, column_name):
"""Check if a column exists in a table"""
result = conn.execute(text(f"PRAGMA table_info({table_name})"))
columns = [row[1] for row in result]
return column_name in columns
def upgrade():
"""Add new columns to tables"""
print("Running migration: Add feedback fixes fields")
print("=" * 60)
with engine.begin() as conn: # Use begin() for transaction management
# 1. Add is_archived to projects table
print("\n1. Adding 'is_archived' column to projects table...")
if check_column_exists(conn, "projects", "is_archived"):
print(" ✓ Column 'is_archived' already exists. Skipping.")
else:
conn.execute(text("ALTER TABLE projects ADD COLUMN is_archived INTEGER DEFAULT 0 NOT NULL"))
# Set default value for existing rows
conn.execute(text("UPDATE projects SET is_archived = 0 WHERE is_archived IS NULL"))
print(" ✓ Successfully added 'is_archived' column to projects table")
# 2. Add product_service to companies table
print("\n2. Adding 'product_service' column to companies table...")
if check_column_exists(conn, "companies", "product_service"):
print(" ✓ Column 'product_service' already exists. Skipping.")
else:
conn.execute(text("ALTER TABLE companies ADD COLUMN product_service TEXT"))
print(" ✓ Successfully added 'product_service' column to companies table")
# 3. Add clients to companies table
print("\n3. Adding 'clients' column to companies table...")
if check_column_exists(conn, "companies", "clients"):
print(" ✓ Column 'clients' already exists. Skipping.")
else:
conn.execute(text("ALTER TABLE companies ADD COLUMN clients TEXT"))
print(" ✓ Successfully added 'clients' column to companies table")
# 4. Add linkedin to investor_members table
print("\n4. Adding 'linkedin' column to investor_members table...")
if check_column_exists(conn, "investor_members", "linkedin"):
print(" ✓ Column 'linkedin' already exists. Skipping.")
else:
conn.execute(text("ALTER TABLE investor_members ADD COLUMN linkedin VARCHAR"))
print(" ✓ Successfully added 'linkedin' column to investor_members table")
print("\n" + "=" * 60)
print("Migration completed successfully!")
def downgrade():
"""Remove added columns from tables"""
print("Running downgrade: Remove feedback fixes fields")
print("=" * 60)
# Note: SQLite doesn't support DROP COLUMN directly
print("\nWarning: SQLite doesn't support DROP COLUMN directly.")
print("To remove these columns, you would need to:")
print("1. Create new tables without the columns")
print("2. Copy data from old tables to new tables")
print("3. Drop old tables and rename new tables")
print("\nColumns to remove:")
print(" - projects.is_archived")
print(" - companies.product_service")
print(" - companies.clients")
print(" - investor_members.linkedin")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run database migration")
parser.add_argument(
"direction",
choices=["upgrade", "downgrade"],
default="upgrade",
nargs="?",
help="Migration direction (default: upgrade)"
)
args = parser.parse_args()
if args.direction == "upgrade":
upgrade()
else:
downgrade()