11 Commits

17 changed files with 1795 additions and 77 deletions
Binary file not shown.
+4
View File
@@ -162,6 +162,7 @@ class InvestorMember(Base, TimestampMixin):
role = Column(String, nullable=True)
title = Column(String, nullable=True) # Alternative to role
email = Column(String, nullable=True)
linkedin = Column(String, nullable=True) # LinkedIn profile URL
source_url = Column(String, nullable=True) # URL where member info was found
investor_id = Column(Integer, ForeignKey("investors.id"))
@@ -215,6 +216,8 @@ class CompanyTable(Base, TimestampMixin):
description = Column(String, nullable=True)
founded_year = Column(Integer, nullable=True)
website = Column(String, nullable=True)
product_service = Column(Text, nullable=True) # Product/service description
clients = Column(JSON, nullable=True) # List of client names or client information
members = relationship(
"CompanyMember", back_populates="company", cascade="all, delete-orphan"
@@ -300,6 +303,7 @@ class ProjectTable(Base, TimestampMixin):
description = Column(Text, nullable=True)
start_date = Column(DateTime, nullable=True)
end_date = Column(DateTime, nullable=True)
is_archived = Column(Integer, default=0, nullable=False) # 0 = active, 1 = archived
sector = relationship(
"SectorTable", secondary=project_sector_association, back_populates="projects"
+730
View File
@@ -0,0 +1,730 @@
"""
LinkedIn Profile Scraper for Investor Members
This module uses crawl4ai to scrape team pages and find LinkedIn profiles.
Strategies:
1. Crawl the source_url (team pages) to extract LinkedIn profile links
2. Use LLM-powered web search to find LinkedIn profiles by name
Key advantages of crawl4ai:
- Handles JavaScript-rendered pages
- Better at extracting content from modern websites
- More reliable than simple requests
"""
import asyncio
import logging
import os
import re
from typing import Dict, List, Optional
from crawl4ai import AsyncWebCrawler
from ddgs import DDGS
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
# Setup logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
logger = logging.getLogger("linkedin_scraper")
load_dotenv()
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
class LinkedInProfileScraper:
"""
LinkedIn profile finder using crawl4ai and LLM-powered web search.
Strategies:
1. Crawl source URLs (team pages) to extract LinkedIn links
2. Use LLM-powered web search to find profiles by name
"""
def __init__(
self,
rate_limit_delay: float = 0.5,
use_cache: bool = True,
use_llm_search: bool = True,
):
"""
Initialize the scraper
Args:
rate_limit_delay: Delay between requests in seconds
use_cache: Whether to cache crawled pages
use_llm_search: Whether to use LLM-powered web search as fallback
"""
self.rate_limit_delay = rate_limit_delay
self.use_cache = use_cache
self.use_llm_search = use_llm_search and OPENROUTER_API_KEY
self.page_cache: Dict[str, str] = {} # Cache crawled pages by URL
self.html_cache: Dict[str, str] = {} # Cache HTML separately
self.profile_cache: Dict[str, Dict] = {} # Cache results by member
# Initialize LLM agent if API key available
if self.use_llm_search:
self._init_llm_agent()
else:
self.llm = None
self.agent = None
self.ddg_search = None
logger.info("LLM search disabled (no OPENROUTER_API_KEY)")
def _init_llm_agent(self):
"""Initialize LLM agent for web search"""
try:
self.llm = ChatOpenAI(
api_key=OPENROUTER_API_KEY,
base_url="https://openrouter.ai/api/v1",
model="x-ai/grok-4.1-fast:free",
temperature=0,
)
self.ddg_search = DDGS()
logger.info("LLM search agent initialized")
except Exception as e:
logger.error(f"Failed to initialize LLM agent: {e}")
self.llm = None
self.ddg_search = None
def web_search(self, query: str) -> List[Dict]:
"""Tool to search the web using DuckDuckGo"""
if not self.ddg_search:
return []
try:
results = list(self.ddg_search.text(query, max_results=10))
return results
except Exception as e:
logger.error(f"Web search error: {e}")
return []
async def crawl_page(self, url: str) -> Optional[str]:
"""
Crawl a webpage and return its content.
Args:
url: URL to crawl
Returns:
Page content as markdown/text, or None if failed
"""
if not url:
return None
# Check cache first
if self.use_cache and url in self.page_cache:
logger.debug(f"Using cached page for {url}")
return self.page_cache[url]
try:
logger.info(f"Crawling: {url}")
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(url)
if result and result.markdown:
content = result.markdown
# Also get HTML for better link extraction
html_content = result.html if hasattr(result, "html") else ""
# Cache the results
if self.use_cache:
self.page_cache[url] = content
self.html_cache[url] = html_content
return content
except Exception as e:
logger.error(f"Error crawling {url}: {e}")
return None
def extract_linkedin_urls_from_content(self, content: str) -> List[Dict[str, str]]:
"""
Extract all LinkedIn profile URLs from content (HTML or markdown).
Returns:
List of dicts with 'url', 'context', and 'username'
"""
linkedin_links = []
# Pattern for LinkedIn profile URLs (handles country-specific domains)
linkedin_pattern = (
r"https?://(?:www\.)?(?:[a-z]{2,3}\.)?linkedin\.com/in/([a-zA-Z0-9_-]+)/?"
)
# Find all LinkedIn URLs
matches = list(re.finditer(linkedin_pattern, content, re.IGNORECASE))
for match in matches:
url = match.group(0).rstrip("/")
# Normalize URL
url = self._normalize_linkedin_url(url)
# Get surrounding context (200 chars before and after)
start = max(0, match.start() - 200)
end = min(len(content), match.end() + 200)
context = content[start:end]
# Clean up context (remove HTML tags for readability)
context = re.sub(r"<[^>]+>", " ", context)
context = " ".join(context.split()) # Normalize whitespace
linkedin_links.append(
{"url": url, "context": context, "username": match.group(1)}
)
# Remove duplicates while preserving order
seen_urls = set()
unique_links = []
for link in linkedin_links:
if link["url"] not in seen_urls:
seen_urls.add(link["url"])
unique_links.append(link)
return unique_links
def _normalize_linkedin_url(self, url: str) -> str:
"""Normalize LinkedIn URL to standard format"""
# Remove trailing slashes
url = url.rstrip("/")
# Convert country-specific to www
url = re.sub(
r"https?://[a-z]{2,3}\.linkedin\.com", "https://www.linkedin.com", url
)
# Ensure https
if url.startswith("http://"):
url = url.replace("http://", "https://")
return url
def _name_matches_context(self, name: str, context: str) -> float:
"""
Check if a person's name appears in the context around a LinkedIn URL.
Returns:
Confidence score 0-100
"""
if not name or not context:
return 0
context_lower = context.lower()
name_lower = name.lower()
# Split name into parts (handle multiple spaces, titles like "Dr.", etc.)
name_parts = [p for p in name_lower.replace(".", " ").split() if len(p) > 1]
# Check for full name match
if name_lower in context_lower:
return 95
# Check for name parts in context
matches = sum(
1 for part in name_parts if part in context_lower and len(part) > 2
)
if len(name_parts) > 0:
if matches == len(name_parts):
return 90 # All name parts found
elif matches >= 2:
return 75 # At least 2 parts found (first + last typically)
elif matches == 1 and len(name_parts) <= 2:
return 50 # Only one part found but name is short
elif matches == 1:
return 35 # Only one part found
return 0
def _name_matches_username(self, name: str, username: str) -> float:
"""
Check if LinkedIn username contains parts of the name.
Returns:
Confidence score 0-100
"""
if not name or not username:
return 0
name_lower = name.lower()
username_lower = username.lower().replace("-", " ").replace("_", " ")
name_parts = [p for p in name_lower.replace(".", " ").split() if len(p) > 2]
matches = sum(1 for part in name_parts if part in username_lower)
if len(name_parts) > 0:
if matches == len(name_parts) and len(name_parts) >= 2:
return 85 # Full name in username
elif matches >= 2:
return 70 # Multiple parts match
elif matches == 1:
return 35 # Only one part matches
return 0
async def find_linkedin_from_source(
self, name: str, source_url: str, role: Optional[str] = None
) -> Dict:
"""
Find LinkedIn profile by crawling the source URL (team page).
Args:
name: Person's name
source_url: URL of the team/about page
role: Person's role (for additional context matching)
Returns:
Dict with linkedin_url, confidence, method, notes
"""
if not source_url:
return {
"linkedin_url": None,
"confidence": 0,
"method": "source_crawl",
"notes": "No source URL provided",
}
# Crawl the page
content = await self.crawl_page(source_url)
if not content:
return {
"linkedin_url": None,
"confidence": 0,
"method": "source_crawl",
"notes": f"Failed to crawl {source_url}",
}
# Get HTML for better link extraction
html = self.html_cache.get(source_url, content)
# Extract all LinkedIn URLs from both HTML and markdown
linkedin_links = self.extract_linkedin_urls_from_content(html)
if not linkedin_links:
linkedin_links = self.extract_linkedin_urls_from_content(content)
if not linkedin_links:
return {
"linkedin_url": None,
"confidence": 0,
"method": "source_crawl",
"notes": f"No LinkedIn URLs found on {source_url}",
}
# Score each LinkedIn URL based on name matching
best_match = None
best_score = 0
for link in linkedin_links:
# Score based on context matching
context_score = self._name_matches_context(name, link["context"])
# Score based on username matching
username_score = self._name_matches_username(name, link["username"])
# Also check if role appears in context
role_bonus = 0
if role and role.lower() in link["context"].lower():
role_bonus = 10
# Combined score (take best of context or username, plus role bonus)
total_score = max(context_score, username_score) + role_bonus
logger.debug(
f" {name} -> {link['url']}: context={context_score}, username={username_score}, role={role_bonus}, total={total_score}"
)
if total_score > best_score:
best_score = total_score
best_match = link
if best_match and best_score >= 30: # Minimum threshold
return {
"linkedin_url": best_match["url"],
"confidence": min(best_score, 100),
"method": "source_crawl",
"notes": f"Found on {source_url}",
}
return {
"linkedin_url": None,
"confidence": 0,
"method": "source_crawl",
"notes": f'No matching LinkedIn profile found for "{name}" on {source_url}',
}
async def find_linkedin_via_search(
self, name: str, company: str, role: Optional[str] = None
) -> Dict:
"""
Find LinkedIn profile using web search.
Args:
name: Person's name
company: Company/investor name
role: Person's role (optional)
Returns:
Dict with linkedin_url, confidence, method, notes
"""
if not self.ddg_search:
return {
"linkedin_url": None,
"confidence": 0,
"method": "web_search",
"notes": "Web search not available",
}
try:
# Build search query - search for LinkedIn profile
query = f"{name} {company} site:linkedin.com/in"
if role:
query = f"{name} {role} {company} site:linkedin.com/in"
logger.debug(f"Searching: {query}")
results = self.web_search(query)
if results:
# Look for LinkedIn profile URLs in results
linkedin_pattern = r"https?://(?:www\.)?(?:[a-z]{2,3}\.)?linkedin\.com/in/([a-zA-Z0-9_-]+)"
for result in results:
url = result.get("href") or result.get("link") or ""
title = result.get("title", "").lower()
body = result.get("body", "").lower()
match = re.search(linkedin_pattern, url, re.IGNORECASE)
if match:
linkedin_url = self._normalize_linkedin_url(match.group(0))
username = match.group(1)
# Score based on name matching in title/body and username
context = f"{title} {body}"
context_score = self._name_matches_context(name, context)
username_score = self._name_matches_username(name, username)
total_score = max(context_score, username_score)
if total_score >= 30:
return {
"linkedin_url": linkedin_url,
"confidence": min(
total_score, 90
), # Cap at 90 for search results
"method": "web_search",
"notes": "Found via web search",
}
return {
"linkedin_url": None,
"confidence": 0,
"method": "web_search",
"notes": "No matching profile found in search results",
}
except Exception as e:
logger.error(f"Web search error for {name}: {e}")
return {
"linkedin_url": None,
"confidence": 0,
"method": "web_search",
"notes": f"Search error: {str(e)}",
}
async def find_linkedin_profile(
self,
name: str,
company: str,
role: Optional[str] = None,
source_url: Optional[str] = None,
) -> Dict:
"""
Find LinkedIn profile for a person.
Primary strategy: Crawl source URL to find LinkedIn links.
Args:
name: Person's name
company: Company/investor name
role: Person's role/title (optional)
source_url: URL where person info was found (optional)
Returns:
Dict with:
- linkedin_url: Found LinkedIn URL or None
- confidence: Confidence score (0-100)
- method: Method used to find the profile
- notes: Additional information
"""
cache_key = f"{name}|{company}"
# Check cache
if self.use_cache and cache_key in self.profile_cache:
logger.debug(f"Using cached result for {name}")
return self.profile_cache[cache_key]
result = {"linkedin_url": None, "confidence": 0, "method": "none", "notes": ""}
# Primary strategy: Crawl source URL
if source_url:
result = await self.find_linkedin_from_source(name, source_url, role)
if result["linkedin_url"]:
if self.use_cache:
self.profile_cache[cache_key] = result
return result
# Fallback strategy: Web search (if enabled and no result from source crawl)
if self.use_llm_search and not result.get("linkedin_url"):
search_result = await self.find_linkedin_via_search(name, company, role)
if search_result["linkedin_url"]:
if self.use_cache:
self.profile_cache[cache_key] = search_result
return search_result
# If no source URL or no match found
if not result["linkedin_url"]:
result = {
"linkedin_url": None,
"confidence": 0,
"method": "none",
"notes": "No source URL available"
if not source_url
else result.get("notes", "Not found"),
}
if self.use_cache:
self.profile_cache[cache_key] = result
return result
async def batch_find_profiles(
self, members: List[Dict], progress_callback=None, db_callback=None
) -> List[Dict]:
"""
Find LinkedIn profiles for multiple members efficiently.
Groups members by source_url to minimize crawling the same page multiple times.
Args:
members: List of dicts with 'name', 'company', 'role', 'source_url', 'id'
progress_callback: Optional callback function(current, total, result)
db_callback: Optional callback to save to database immediately when profile found
Signature: db_callback(member_id, linkedin_url) -> bool
Returns:
List of results for each member
"""
results = []
total = len(members)
# Group members by source_url for efficient crawling
url_groups: Dict[str, List[Dict]] = {}
no_url_members = []
for member in members:
url = member.get("source_url")
if url:
if url not in url_groups:
url_groups[url] = []
url_groups[url].append(member)
else:
no_url_members.append(member)
logger.info(
f"Processing {len(url_groups)} unique source URLs for {total} members"
)
logger.info(f"Members with source URLs: {total - len(no_url_members)}")
logger.info(f"Members without source URLs: {len(no_url_members)}")
if self.use_llm_search:
logger.info("Web search fallback: ENABLED")
else:
logger.info("Web search fallback: DISABLED")
processed = 0
# Process members grouped by URL (efficient - one crawl per page)
for url, group_members in url_groups.items():
# Crawl the page once
content = await self.crawl_page(url)
html = self.html_cache.get(url, content or "")
# Extract all LinkedIn URLs from this page
linkedin_links = []
if content:
linkedin_links = self.extract_linkedin_urls_from_content(html)
if not linkedin_links:
linkedin_links = self.extract_linkedin_urls_from_content(content)
# Match each member in this group
for member in group_members:
processed += 1
result = None
found_linkedin = False
if linkedin_links:
# Find best matching LinkedIn for this member
best_match = None
best_score = 0
for link in linkedin_links:
context_score = self._name_matches_context(
member["name"], link["context"]
)
username_score = self._name_matches_username(
member["name"], link["username"]
)
role_bonus = (
10
if member.get("role")
and member["role"].lower() in link["context"].lower()
else 0
)
total_score = max(context_score, username_score) + role_bonus
if total_score > best_score:
best_score = total_score
best_match = link
if best_match and best_score >= 30:
result = {
"linkedin_url": best_match["url"],
"confidence": min(best_score, 100),
"method": "source_crawl",
"notes": f"Found on {url}",
"member_id": member.get("id"),
"member_name": member["name"],
}
found_linkedin = True
# Save to database immediately if callback provided
if db_callback and member.get("id"):
db_callback(member["id"], best_match["url"])
# If no result from source crawl, try web search IMMEDIATELY
if not found_linkedin and self.use_llm_search:
search_result = await self.find_linkedin_via_search(
member["name"], member["company"], member.get("role")
)
if search_result["linkedin_url"]:
result = {
"linkedin_url": search_result["linkedin_url"],
"confidence": search_result["confidence"],
"method": "web_search",
"notes": search_result.get("notes", "Found via web search"),
"member_id": member.get("id"),
"member_name": member["name"],
}
found_linkedin = True
# Save to database immediately
if db_callback and member.get("id"):
db_callback(member["id"], search_result["linkedin_url"])
# If still no result, record as not found
if not found_linkedin:
result = {
"linkedin_url": None,
"confidence": 0,
"method": "source_crawl" if content else "none",
"notes": f"No match on {url}"
if linkedin_links
else (
f"No LinkedIn URLs on {url}"
if content
else f"Failed to crawl {url}"
),
"member_id": member.get("id"),
"member_name": member["name"],
}
results.append(result)
if progress_callback:
progress_callback(processed, total, result)
# Small delay between different URLs
await asyncio.sleep(self.rate_limit_delay)
# Process members without source URLs - do web search immediately for each
for member in no_url_members:
processed += 1
result = None
# Try web search immediately
if self.use_llm_search:
search_result = await self.find_linkedin_via_search(
member["name"], member["company"], member.get("role")
)
if search_result["linkedin_url"]:
result = {
"linkedin_url": search_result["linkedin_url"],
"confidence": search_result["confidence"],
"method": "web_search",
"notes": search_result.get("notes", "Found via web search"),
"member_id": member.get("id"),
"member_name": member["name"],
}
# Save to database immediately
if db_callback and member.get("id"):
db_callback(member["id"], search_result["linkedin_url"])
# If no result from search
if not result:
result = {
"linkedin_url": None,
"confidence": 0,
"method": "web_search" if self.use_llm_search else "none",
"notes": "No LinkedIn profile found"
if self.use_llm_search
else "No source URL available",
"member_id": member.get("id"),
"member_name": member["name"],
}
results.append(result)
if progress_callback:
progress_callback(processed, total, result)
# Rate limit between searches
await asyncio.sleep(self.rate_limit_delay)
return results
def format_linkedin_url(url: str) -> str:
"""Normalize LinkedIn URL format"""
if not url:
return url
# Remove trailing slashes
url = url.rstrip("/")
# Ensure https and normalize to www
url = re.sub(r"https?://[a-z]{2,3}\.linkedin\.com", "https://www.linkedin.com", url)
if url.startswith("http://"):
url = url.replace("http://", "https://")
return url
# Async wrapper for sync contexts
def run_batch_scraper(
members: List[Dict], rate_limit: float = 0.5, progress_callback=None
) -> List[Dict]:
"""
Synchronous wrapper for batch_find_profiles.
Args:
members: List of member dicts
rate_limit: Delay between URL crawls
progress_callback: Optional progress callback
Returns:
List of results
"""
scraper = LinkedInProfileScraper(rate_limit_delay=rate_limit)
return asyncio.run(scraper.batch_find_profiles(members, progress_callback))
+2
View File
@@ -8,6 +8,7 @@ from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from routers import (
addition,
companies,
folk_crm,
insight_route,
@@ -154,6 +155,7 @@ app.include_router(projects.router)
app.include_router(folk_crm.router)
app.include_router(insight_route.router)
app.include_router(report_route.router)
app.include_router(addition.router)
if __name__ == "__main__":
import uvicorn
+370
View File
@@ -0,0 +1,370 @@
from typing import Optional
from db.db import get_db
from db.models import FundTable, InvestorTable, SectorTable
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from sqlalchemy.orm import Session
router = APIRouter(tags=["Additional Routes"])
# Response schemas
class SectorsResponse(BaseModel):
sectors: list[str]
total: int
class CountryInfo(BaseModel):
name: str
class ContinentInfo(BaseModel):
name: str
countries: list[str]
class GeographyResponse(BaseModel):
continents: list[ContinentInfo]
total_continents: int
total_countries: int
# Mapping of countries to continents
COUNTRY_TO_CONTINENT = {
# Africa
"Algeria": "Africa",
"Angola": "Africa",
"Benin": "Africa",
"Botswana": "Africa",
"Burkina Faso": "Africa",
"Burundi": "Africa",
"Cameroon": "Africa",
"Cape Verde": "Africa",
"Central African Republic": "Africa",
"Chad": "Africa",
"Comoros": "Africa",
"Congo": "Africa",
"Democratic Republic of the Congo": "Africa",
"Djibouti": "Africa",
"Egypt": "Africa",
"Equatorial Guinea": "Africa",
"Eritrea": "Africa",
"Eswatini": "Africa",
"Ethiopia": "Africa",
"Gabon": "Africa",
"Gambia": "Africa",
"Ghana": "Africa",
"Guinea": "Africa",
"Guinea-Bissau": "Africa",
"Ivory Coast": "Africa",
"Kenya": "Africa",
"Lesotho": "Africa",
"Liberia": "Africa",
"Libya": "Africa",
"Madagascar": "Africa",
"Malawi": "Africa",
"Mali": "Africa",
"Mauritania": "Africa",
"Mauritius": "Africa",
"Morocco": "Africa",
"Mozambique": "Africa",
"Namibia": "Africa",
"Niger": "Africa",
"Nigeria": "Africa",
"Rwanda": "Africa",
"Sao Tome and Principe": "Africa",
"Senegal": "Africa",
"Seychelles": "Africa",
"Sierra Leone": "Africa",
"Somalia": "Africa",
"South Africa": "Africa",
"South Sudan": "Africa",
"Sudan": "Africa",
"Tanzania": "Africa",
"Togo": "Africa",
"Tunisia": "Africa",
"Uganda": "Africa",
"Zambia": "Africa",
"Zimbabwe": "Africa",
# Asia
"Afghanistan": "Asia",
"Armenia": "Asia",
"Azerbaijan": "Asia",
"Bahrain": "Asia",
"Bangladesh": "Asia",
"Bhutan": "Asia",
"Brunei": "Asia",
"Cambodia": "Asia",
"China": "Asia",
"Cyprus": "Asia",
"Georgia": "Asia",
"Hong Kong": "Asia",
"India": "Asia",
"Indonesia": "Asia",
"Iran": "Asia",
"Iraq": "Asia",
"Israel": "Asia",
"Japan": "Asia",
"Jordan": "Asia",
"Kazakhstan": "Asia",
"Kuwait": "Asia",
"Kyrgyzstan": "Asia",
"Laos": "Asia",
"Lebanon": "Asia",
"Malaysia": "Asia",
"Maldives": "Asia",
"Mongolia": "Asia",
"Myanmar": "Asia",
"Nepal": "Asia",
"North Korea": "Asia",
"Oman": "Asia",
"Pakistan": "Asia",
"Palestine": "Asia",
"Philippines": "Asia",
"Qatar": "Asia",
"Saudi Arabia": "Asia",
"Singapore": "Asia",
"South Korea": "Asia",
"Sri Lanka": "Asia",
"Syria": "Asia",
"Taiwan": "Asia",
"Tajikistan": "Asia",
"Thailand": "Asia",
"Timor-Leste": "Asia",
"Turkey": "Asia",
"Turkmenistan": "Asia",
"United Arab Emirates": "Asia",
"UAE": "Asia",
"Uzbekistan": "Asia",
"Vietnam": "Asia",
"Yemen": "Asia",
# Europe
"Albania": "Europe",
"Andorra": "Europe",
"Austria": "Europe",
"Belarus": "Europe",
"Belgium": "Europe",
"Bosnia and Herzegovina": "Europe",
"Bulgaria": "Europe",
"Croatia": "Europe",
"Czech Republic": "Europe",
"Czechia": "Europe",
"Denmark": "Europe",
"Estonia": "Europe",
"Finland": "Europe",
"France": "Europe",
"Germany": "Europe",
"Greece": "Europe",
"Hungary": "Europe",
"Iceland": "Europe",
"Ireland": "Europe",
"Italy": "Europe",
"Kosovo": "Europe",
"Latvia": "Europe",
"Liechtenstein": "Europe",
"Lithuania": "Europe",
"Luxembourg": "Europe",
"Malta": "Europe",
"Moldova": "Europe",
"Monaco": "Europe",
"Montenegro": "Europe",
"Netherlands": "Europe",
"North Macedonia": "Europe",
"Norway": "Europe",
"Poland": "Europe",
"Portugal": "Europe",
"Romania": "Europe",
"Russia": "Europe",
"San Marino": "Europe",
"Serbia": "Europe",
"Slovakia": "Europe",
"Slovenia": "Europe",
"Spain": "Europe",
"Sweden": "Europe",
"Switzerland": "Europe",
"Ukraine": "Europe",
"United Kingdom": "Europe",
"UK": "Europe",
"Vatican City": "Europe",
# North America
"Antigua and Barbuda": "North America",
"Bahamas": "North America",
"Barbados": "North America",
"Belize": "North America",
"Canada": "North America",
"Costa Rica": "North America",
"Cuba": "North America",
"Dominica": "North America",
"Dominican Republic": "North America",
"El Salvador": "North America",
"Grenada": "North America",
"Guatemala": "North America",
"Haiti": "North America",
"Honduras": "North America",
"Jamaica": "North America",
"Mexico": "North America",
"Nicaragua": "North America",
"Panama": "North America",
"Saint Kitts and Nevis": "North America",
"Saint Lucia": "North America",
"Saint Vincent and the Grenadines": "North America",
"Trinidad and Tobago": "North America",
"United States": "North America",
"USA": "North America",
"US": "North America",
# South America
"Argentina": "South America",
"Bolivia": "South America",
"Brazil": "South America",
"Chile": "South America",
"Colombia": "South America",
"Ecuador": "South America",
"Guyana": "South America",
"Paraguay": "South America",
"Peru": "South America",
"Suriname": "South America",
"Uruguay": "South America",
"Venezuela": "South America",
# Oceania
"Australia": "Oceania",
"Fiji": "Oceania",
"Kiribati": "Oceania",
"Marshall Islands": "Oceania",
"Micronesia": "Oceania",
"Nauru": "Oceania",
"New Zealand": "Oceania",
"Palau": "Oceania",
"Papua New Guinea": "Oceania",
"Samoa": "Oceania",
"Solomon Islands": "Oceania",
"Tonga": "Oceania",
"Tuvalu": "Oceania",
"Vanuatu": "Oceania",
}
# Valid continent names for direct matching
VALID_CONTINENTS = {
"Africa",
"Asia",
"Europe",
"North America",
"South America",
"Oceania",
"Antarctica",
}
def extract_countries_from_geographic_focus(geographic_focus: str) -> set[str]:
"""
Extract country names from a geographic_focus string.
Handles comma-separated values, slashes, and various formats.
"""
if not geographic_focus:
return set()
countries = set()
# Split by common delimiters
parts = geographic_focus.replace("/", ",").replace(";", ",").split(",")
for part in parts:
cleaned = part.strip()
if cleaned:
# Check if it's a known country
if cleaned in COUNTRY_TO_CONTINENT:
countries.add(cleaned)
# Check for partial matches (e.g., "United States of America" -> "United States")
else:
for country in COUNTRY_TO_CONTINENT.keys():
if country.lower() in cleaned.lower() or cleaned.lower() in country.lower():
countries.add(country)
break
return countries
def organize_geography(geographic_data: list[str]) -> dict[str, set[str]]:
"""
Organize geographic data into continents and their countries.
Returns a dict with continent names as keys and sets of countries as values.
"""
continent_countries: dict[str, set[str]] = {}
for geo_focus in geographic_data:
if not geo_focus:
continue
# Extract countries from the geographic focus string
countries = extract_countries_from_geographic_focus(geo_focus)
for country in countries:
continent = COUNTRY_TO_CONTINENT.get(country)
if continent:
if continent not in continent_countries:
continent_countries[continent] = set()
continent_countries[continent].add(country)
# Also check if the geographic focus itself is a continent
cleaned_geo = geo_focus.strip()
if cleaned_geo in VALID_CONTINENTS:
if cleaned_geo not in continent_countries:
continent_countries[cleaned_geo] = set()
return continent_countries
@router.get("/sectors", response_model=SectorsResponse)
def get_unique_sectors(db: Session = Depends(get_db)):
"""
Get all unique sectors from the database.
Returns a list of sector names sorted alphabetically.
"""
sectors = db.query(SectorTable.name).distinct().order_by(SectorTable.name).all()
sector_names = [s[0] for s in sectors if s[0]]
return SectorsResponse(sectors=sector_names, total=len(sector_names))
@router.get("/geography", response_model=GeographyResponse)
def get_arranged_geography(db: Session = Depends(get_db)):
"""
Get all unique geographic locations arranged by continent and countries.
Extracts geography from both investors and funds tables.
Returns continents with their associated countries.
"""
# Collect all geographic focus data from investors
investor_geo = (
db.query(InvestorTable.geographic_focus)
.filter(InvestorTable.geographic_focus.isnot(None))
.distinct()
.all()
)
# Collect all geographic focus data from funds
fund_geo = (
db.query(FundTable.geographic_focus)
.filter(FundTable.geographic_focus.isnot(None))
.distinct()
.all()
)
# Combine all geographic data
all_geo_data = [g[0] for g in investor_geo] + [g[0] for g in fund_geo]
# Organize into continents and countries
continent_countries = organize_geography(all_geo_data)
# Build response
continents = []
total_countries = 0
for continent_name in sorted(continent_countries.keys()):
countries = sorted(continent_countries[continent_name])
total_countries += len(countries)
continents.append(ContinentInfo(name=continent_name, countries=countries))
return GeographyResponse(
continents=continents,
total_continents=len(continents),
total_countries=total_countries,
)
+14 -4
View File
@@ -63,11 +63,13 @@ def read_companies(
# Transform CompanyTable objects to CompanyData format
company_data_list = []
for company in companies:
# Sort sectors alphabetically
sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else []
company_data = CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=company.sectors,
sectors=sorted_sectors,
)
company_data_list.append(company_data)
@@ -147,11 +149,13 @@ def filter_companies(
# Transform to CompanyData format
company_data_list = []
for company in companies:
# Sort sectors alphabetically
sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else []
company_data = CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=company.sectors,
sectors=sorted_sectors,
)
company_data_list.append(company_data)
@@ -184,12 +188,15 @@ def read_company(company_id: int, db: Session = Depends(get_db)):
if not company:
raise HTTPException(status_code=404, detail="Company not found")
# Sort sectors alphabetically
sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else []
# Transform to CompanyData format
return CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=company.sectors,
sectors=sorted_sectors,
)
@@ -250,12 +257,15 @@ def update_company(
.first()
)
# Sort sectors alphabetically
sorted_sectors = sorted(company_with_relations.sectors, key=lambda s: s.name) if company_with_relations.sectors else []
# Transform to CompanyData format
return CompanyData(
company=company_with_relations,
investors=company_with_relations.investors,
members=company_with_relations.members,
sectors=company_with_relations.sectors,
sectors=sorted_sectors,
)
+6
View File
@@ -136,6 +136,11 @@ def sync_investors_to_folk(
if hasattr(member, "source_url") and member.source_url:
urls_list = [member.source_url]
# Get LinkedIn URL if available
linkedin_url = None
if hasattr(member, "linkedin") and member.linkedin:
linkedin_url = member.linkedin
# Build job title from title or role
job_title = None
if hasattr(member, "title") and member.title:
@@ -149,6 +154,7 @@ def sync_investors_to_folk(
email=member.email,
company_id=company_id,
group_id=request.group_id,
linkedin_url=linkedin_url,
urls=urls_list,
jobTitle=job_title,
)
+66 -21
View File
@@ -13,7 +13,6 @@ from schemas.router_schemas import (
SectorMinimal,
)
from services.compatibility_score import (
calculate_project_investor_compatibility,
_calculate_project_fund_compatibility,
_calculate_project_investor_direct_compatibility,
)
@@ -81,20 +80,42 @@ def read_investors(
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Get paginated results
investors = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
# When project_id is provided, we need to get all investors first to sort by compatibility score
# Otherwise, we can paginate at the database level
if project is not None:
# Get all investors (we'll sort by compatibility score, then paginate)
all_investors = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds).selectinload(
FundTable.investment_stages
),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
)
.all()
)
# We'll paginate after sorting by compatibility score
investors = all_investors
else:
# Get paginated results (no sorting needed)
investors = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds).selectinload(
FundTable.investment_stages
),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
)
.offset(offset)
.limit(page_size)
.all()
)
.offset(offset)
.limit(page_size)
.all()
)
# Transform to InvestmentResponse format (one row per investor-fund combination)
investment_responses = []
@@ -122,10 +143,12 @@ def read_investors(
else None
)
# Get top 3 sectors from fund (id and name only)
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in (fund.sectors[:3] if fund.sectors else [])
for sector in sorted(
fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name
)
]
investment_response = InvestmentResponse(
@@ -166,6 +189,12 @@ def read_investors(
)
investment_responses.append(investment_response)
# Sort by compatibility score (descending) when project_id is provided
if project is not None:
investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True)
# Apply pagination after sorting
investment_responses = investment_responses[offset : offset + page_size]
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
@@ -257,9 +286,16 @@ def filter_investors(
# Get total count before pagination
total_count = query.count()
# Calculate offset and apply pagination
offset = (page - 1) * page_size
funds = query.offset(offset).limit(page_size).all()
# When project_id is provided, we need to get all funds first to sort by compatibility score
# Otherwise, we can paginate at the database level
if project is not None:
# Get all funds (we'll sort by compatibility score, then paginate)
all_funds = query.all()
funds = all_funds
else:
# Calculate offset and apply pagination (no sorting needed)
offset = (page - 1) * page_size
funds = query.offset(offset).limit(page_size).all()
# Transform to InvestmentResponse format (one row per fund)
investment_responses = []
@@ -286,10 +322,12 @@ def filter_investors(
else None
)
# Get top 3 sectors from fund (id and name only)
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in (fund.sectors[:3] if fund.sectors else [])
for sector in sorted(
fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name
)
]
investment_response = InvestmentResponse(
@@ -308,6 +346,13 @@ def filter_investors(
)
investment_responses.append(investment_response)
# Sort by compatibility score (descending) when project_id is provided
if project is not None:
investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True)
# Apply pagination after sorting
offset = (page - 1) * page_size
investment_responses = investment_responses[offset : offset + page_size]
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
+96 -5
View File
@@ -24,19 +24,29 @@ router = APIRouter(tags=["Project Routes"])
def read_projects(
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
include_archived: bool = Query(False, description="Include archived projects"),
db: Session = Depends(get_db),
):
"""Get all projects with their related data (paginated)"""
"""Get all projects with their related data (paginated)
By default, archived projects are excluded. Set include_archived=True to include them.
"""
# Calculate offset
offset = (page - 1) * page_size
# Start with base query
query = db.query(ProjectTable)
# Filter out archived projects by default
if not include_archived:
query = query.filter(ProjectTable.is_archived == 0)
# Get total count
total_count = db.query(ProjectTable).count()
total_count = query.count()
# Get paginated results
projects = (
db.query(ProjectTable)
.options(
query.options(
selectinload(ProjectTable.sector),
selectinload(ProjectTable.investors),
selectinload(ProjectTable.companies),
@@ -162,7 +172,7 @@ def update_project(
@router.delete("/projects/{project_id}")
def delete_project(project_id: int, db: Session = Depends(get_db)):
"""Delete a project"""
"""Delete a project permanently"""
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not db_project:
@@ -174,6 +184,87 @@ def delete_project(project_id: int, db: Session = Depends(get_db)):
return {"message": "Project deleted successfully"}
@router.post("/projects/{project_id}/archive")
def archive_project(project_id: int, db: Session = Depends(get_db)):
"""Archive a project (soft delete)"""
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not db_project:
raise HTTPException(status_code=404, detail="Project not found")
db_project.is_archived = 1
db.commit()
db.refresh(db_project)
return {"message": "Project archived successfully", "project_id": project_id}
@router.post("/projects/{project_id}/unarchive")
def unarchive_project(project_id: int, db: Session = Depends(get_db)):
"""Unarchive a project (restore from archive)"""
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not db_project:
raise HTTPException(status_code=404, detail="Project not found")
db_project.is_archived = 0
db.commit()
db.refresh(db_project)
return {"message": "Project unarchived successfully", "project_id": project_id}
@router.get("/projects/archived", response_model=PaginatedResponse[ProjectData])
def read_archived_projects(
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
db: Session = Depends(get_db),
):
"""Get all archived projects (paginated)"""
# Calculate offset
offset = (page - 1) * page_size
# Query only archived projects
query = db.query(ProjectTable).filter(ProjectTable.is_archived == 1)
# Get total count
total_count = query.count()
# Get paginated results
projects = (
query.options(
selectinload(ProjectTable.sector),
selectinload(ProjectTable.investors),
selectinload(ProjectTable.companies),
)
.offset(offset)
.limit(page_size)
.all()
)
# Transform ProjectTable objects to ProjectData format
project_data_list = []
for project in projects:
project_data = ProjectData(
project=project,
sector=project.sector,
investors=project.investors,
companies=project.companies,
)
project_data_list.append(project_data)
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=project_data_list,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
@router.get("/projects/filter", response_model=PaginatedResponse[ProjectData])
def filter_projects(
stage: Optional[InvestmentStage] = Query(
+2 -1
View File
@@ -38,6 +38,7 @@ class InvestorMemberSchema(BaseModel):
name: str
role: str | None
email: str | None
linkedin: str | None
class Config:
from_attributes = True
@@ -194,7 +195,7 @@ class CompanySchemaMinimal(BaseModel):
class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
company: CompanySchemaMinimal
investors: List[InvestorMinimal]
# members: List[CompanyMemberSchema] = []
members: List[CompanyMemberSchema] = []
sectors: List[SectorSchema] = []
class Config:
+71 -42
View File
@@ -117,41 +117,41 @@ def _calculate_project_fund_compatibility(
# 2. Sector Overlap (30 points)
sector_score = 0
if project.sector and fund.sectors:
project_sectors = [s for s in project.sector if hasattr(s, 'name')]
fund_sectors = [s for s in fund.sectors if hasattr(s, 'name')]
project_sectors = [s for s in project.sector if hasattr(s, "name")]
fund_sectors = [s for s in fund.sectors if hasattr(s, "name")]
if project_sectors and fund_sectors:
# Use fuzzy matching to account for similar but not identical sector names
match_count = 0
total_matches = 0
for proj_sector in project_sectors:
best_match_score = 0
proj_name = proj_sector.name.lower().strip()
for fund_sector in fund_sectors:
fund_name = fund_sector.name.lower().strip()
# Exact match
if proj_name == fund_name:
best_match_score = 1.0
break
# Fuzzy match using sequence matcher
similarity = SequenceMatcher(None, proj_name, fund_name).ratio()
# Also check if one contains the other (substring match)
if proj_name in fund_name or fund_name in proj_name:
similarity = max(similarity, 0.8)
best_match_score = max(best_match_score, similarity)
# Count matches with threshold
# Perfect match (1.0), strong match (>0.75), partial match (>0.6)
if best_match_score >= 0.6:
total_matches += best_match_score
match_count += 1
if match_count > 0:
# Calculate overlap ratio based on fuzzy matches
overlap_ratio = total_matches / len(project_sectors)
@@ -174,9 +174,10 @@ def _calculate_project_fund_compatibility(
or fund_geo_lower in project_location_lower
):
geo_score = 15
# Check for common geographic terms or regional overlap
# Check for common geographic terms or regional overlap (continent/country matching)
elif _check_geographic_overlap(project_location_lower, fund_geo_lower):
geo_score = 12
# Give higher score for continent/country matches (e.g., Germany -> Europe)
geo_score = 18
total_score += geo_score
@@ -245,40 +246,40 @@ def _calculate_project_investor_direct_compatibility(
# 2. Sector Overlap (30 points)
sector_score = 0
if project.sector and investor.sectors:
project_sectors = [s for s in project.sector if hasattr(s, 'name')]
investor_sectors = [s for s in investor.sectors if hasattr(s, 'name')]
project_sectors = [s for s in project.sector if hasattr(s, "name")]
investor_sectors = [s for s in investor.sectors if hasattr(s, "name")]
if project_sectors and investor_sectors:
# Use fuzzy matching to account for similar but not identical sector names
match_count = 0
total_matches = 0
for proj_sector in project_sectors:
best_match_score = 0
proj_name = proj_sector.name.lower().strip()
for inv_sector in investor_sectors:
inv_name = inv_sector.name.lower().strip()
# Exact match
if proj_name == inv_name:
best_match_score = 1.0
break
# Fuzzy match using sequence matcher
similarity = SequenceMatcher(None, proj_name, inv_name).ratio()
# Also check if one contains the other (substring match)
if proj_name in inv_name or inv_name in proj_name:
similarity = max(similarity, 0.8)
best_match_score = max(best_match_score, similarity)
# Count matches with threshold
if best_match_score >= 0.6:
total_matches += best_match_score
match_count += 1
if match_count > 0:
# Calculate overlap ratio based on fuzzy matches
overlap_ratio = total_matches / len(project_sectors)
@@ -298,9 +299,10 @@ def _calculate_project_investor_direct_compatibility(
project_location_lower in investor_geo_lower
or investor_geo_lower in project_location_lower
):
geo_score = 10
geo_score = 15
elif _check_geographic_overlap(project_location_lower, investor_geo_lower):
geo_score = 5
# Give higher score for continent/country matches (e.g., Germany -> Europe)
geo_score = 18
total_score += geo_score
@@ -382,43 +384,70 @@ def _check_geographic_overlap(location1: str, location2: str) -> bool:
# Normalize inputs
loc1 = location1.lower().strip()
loc2 = location2.lower().strip()
# Common geographic groupings with broader regional mappings
geo_groups = [
# North America
["usa", "us", "united states", "america", "u.s.", "u.s.a"],
["canada", "canadian"],
["mexico", "mexican"],
# Europe and countries
["europe", "european", "eu", "germany", "france", "uk", "united kingdom",
"britain", "spain", "italy", "netherlands", "belgium", "sweden", "denmark",
"norway", "finland", "poland", "portugal", "austria", "switzerland",
"ireland", "greece", "czech", "romania"],
[
"europe",
"european",
"eu",
"germany",
"france",
"uk",
"united kingdom",
"britain",
"spain",
"italy",
"netherlands",
"belgium",
"sweden",
"denmark",
"norway",
"finland",
"poland",
"portugal",
"austria",
"switzerland",
"ireland",
"greece",
"czech",
"romania",
],
# UK specific
["uk", "united kingdom", "britain", "england", "scotland", "wales", "london"],
# US states
["california", "ca", "san francisco", "los angeles", "silicon valley"],
["new york", "ny", "nyc"],
["texas", "tx"],
["massachusetts", "ma", "boston"],
["washington", "seattle"],
# Asia
["asia", "asian", "china", "japan", "korea", "singapore", "hong kong",
"india", "indonesia", "thailand", "vietnam", "malaysia", "philippines"],
[
"asia",
"asian",
"china",
"japan",
"korea",
"singapore",
"hong kong",
"india",
"indonesia",
"thailand",
"vietnam",
"malaysia",
"philippines",
],
# Middle East
["middle east", "israel", "uae", "dubai", "saudi arabia"],
# Latin America
["latin america", "brazil", "argentina", "chile", "colombia", "mexico"],
# Africa
["africa", "african", "south africa", "nigeria", "kenya", "egypt"],
# Oceania
["australia", "australian", "new zealand"],
]
@@ -429,7 +458,7 @@ def _check_geographic_overlap(location1: str, location2: str) -> bool:
found_in_2 = any(term in loc2 for term in group)
if found_in_1 and found_in_2:
return True
# Check for direct substring match (one contains the other)
if loc1 in loc2 or loc2 in loc1:
return True
+4 -1
View File
@@ -119,6 +119,7 @@ class FolkAPI:
email: str = None,
company_id: str = None,
group_id: str = None,
linkedin_url: str = None,
companies=None,
emails=None,
phones=None,
@@ -184,7 +185,9 @@ class FolkAPI:
addresses_list = _to_list(addresses)
if addresses_list:
data["addresses"] = addresses_list
urls_list = _to_list(urls)
urls_list = _to_list(urls) or []
if linkedin_url:
urls_list.append(linkedin_url)
if urls_list:
data["urls"] = urls_list
+1 -1
View File
@@ -49,7 +49,7 @@ class QueryProcessor:
"""Tool to search the web using google, provide the relevant query to get the information"""
logger.info(f"\nWeb Search Tool Called with query: {query}")
if query:
result = self.ddg_search.text(query, max_results=10, backend="google")
result = self.ddg_search.text(query, max_results=10)
return result
return "No query provided."
+2 -2
View File
@@ -258,10 +258,10 @@ Return ONLY the SQL query, no explanations or markdown.""",
else None
)
# Get top 3 sectors from fund (id and name only)
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in (fund.sectors[:3] if fund.sectors else [])
for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name)
]
investment_response = InvestmentResponse(
BIN
View File
Binary file not shown.
+117
View File
@@ -0,0 +1,117 @@
"""
Migration: Add fields from feedback fixes
Date: 2025-01-07
Adds the following fields:
- projects.is_archived (INTEGER, default 0)
- companies.product_service (TEXT, nullable)
- companies.clients (TEXT, nullable - stored as JSON string)
- investor_members.linkedin (VARCHAR, nullable)
"""
import sys
from pathlib import Path
# Add parent directory to path to import app modules
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import text
from app.db.db import engine
def check_column_exists(conn, table_name, column_name):
"""Check if a column exists in a table"""
result = conn.execute(text(f"PRAGMA table_info({table_name})"))
columns = [row[1] for row in result]
return column_name in columns
def upgrade():
"""Add new columns to tables"""
print("Running migration: Add feedback fixes fields")
print("=" * 60)
with engine.begin() as conn: # Use begin() for transaction management
# 1. Add is_archived to projects table
print("\n1. Adding 'is_archived' column to projects table...")
if check_column_exists(conn, "projects", "is_archived"):
print(" ✓ Column 'is_archived' already exists. Skipping.")
else:
conn.execute(
text(
"ALTER TABLE projects ADD COLUMN is_archived INTEGER DEFAULT 0 NOT NULL"
)
)
# Set default value for existing rows
conn.execute(
text("UPDATE projects SET is_archived = 0 WHERE is_archived IS NULL")
)
print(" ✓ Successfully added 'is_archived' column to projects table")
# 2. Add product_service to companies table
print("\n2. Adding 'product_service' column to companies table...")
if check_column_exists(conn, "companies", "product_service"):
print(" ✓ Column 'product_service' already exists. Skipping.")
else:
conn.execute(text("ALTER TABLE companies ADD COLUMN product_service TEXT"))
print(" ✓ Successfully added 'product_service' column to companies table")
# 3. Add clients to companies table
print("\n3. Adding 'clients' column to companies table...")
if check_column_exists(conn, "companies", "clients"):
print(" ✓ Column 'clients' already exists. Skipping.")
else:
conn.execute(text("ALTER TABLE companies ADD COLUMN clients TEXT"))
print(" ✓ Successfully added 'clients' column to companies table")
# 4. Add linkedin to investor_members table
print("\n4. Adding 'linkedin' column to investor_members table...")
if check_column_exists(conn, "investor_members", "linkedin"):
print(" ✓ Column 'linkedin' already exists. Skipping.")
else:
conn.execute(
text("ALTER TABLE investor_members ADD COLUMN linkedin VARCHAR")
)
print(" ✓ Successfully added 'linkedin' column to investor_members table")
print("\n" + "=" * 60)
print("Migration completed successfully!")
def downgrade():
"""Remove added columns from tables"""
print("Running downgrade: Remove feedback fixes fields")
print("=" * 60)
# Note: SQLite doesn't support DROP COLUMN directly
print("\nWarning: SQLite doesn't support DROP COLUMN directly.")
print("To remove these columns, you would need to:")
print("1. Create new tables without the columns")
print("2. Copy data from old tables to new tables")
print("3. Drop old tables and rename new tables")
print("\nColumns to remove:")
print(" - projects.is_archived")
print(" - companies.product_service")
print(" - companies.clients")
print(" - investor_members.linkedin")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run database migration")
parser.add_argument(
"direction",
choices=["upgrade", "downgrade"],
default="upgrade",
nargs="?",
help="Migration direction (default: upgrade)",
)
args = parser.parse_args()
if args.direction == "upgrade":
upgrade()
else:
downgrade()
+310
View File
@@ -0,0 +1,310 @@
#!/usr/bin/env python3
"""
Update Investor Members LinkedIn Profiles Script
This script finds and updates LinkedIn profile URLs for investor members in the database.
Uses crawl4ai to efficiently scrape team pages and extract LinkedIn URLs.
Usage:
python update_linkedin_profiles.py [--test] [--limit N] [--skip-existing]
Options:
--test Test mode: process only 10 records and don't update database
--limit N Process only N records (default: all)
--skip-existing Skip members that already have LinkedIn URLs
--start-from N Start from record N (for resuming)
"""
import argparse
import asyncio
import json
import os
import sys
from datetime import datetime
# Add app to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "app"))
from db.db import get_db_session
from db.models import InvestorMember, InvestorTable
from linkedin_scraper import LinkedInProfileScraper, format_linkedin_url
def progress_callback(current, total, result):
"""Print progress updates"""
percent = (current / total) * 100
status = "" if result["linkedin_url"] else ""
print(f"[{current}/{total} - {percent:.1f}%] {status} {result['member_name']}")
if result["linkedin_url"]:
print(
f"{result['linkedin_url']} (confidence: {result['confidence']}%, method: {result['method']})"
)
def create_db_callback(test_mode=False):
"""
Create a callback function that saves LinkedIn profiles to the database immediately.
This allows stopping and resuming without losing progress.
"""
saved_count = {"count": 0} # Use dict to allow modification in closure
def db_callback(member_id: int, linkedin_url: str) -> bool:
"""Save LinkedIn URL to database immediately"""
if test_mode:
print(f" [TEST] Would save to DB: member {member_id}")
saved_count["count"] += 1
return True
try:
db = get_db_session()
member = db.query(InvestorMember).filter_by(id=member_id).first()
if member:
member.linkedin = format_linkedin_url(linkedin_url)
db.commit()
saved_count["count"] += 1
return True
except Exception as e:
print(f" ⚠️ DB Error for member {member_id}: {e}")
try:
db.rollback()
except Exception:
pass
return False
finally:
try:
db.close()
except Exception:
pass
return False
return db_callback, saved_count
def update_database(members_data, test_mode=False):
"""Update database with found LinkedIn profiles"""
db = get_db_session()
try:
updated_count = 0
for data in members_data:
if data["linkedin_url"] and data["member_id"]:
if not test_mode:
member = (
db.query(InvestorMember).filter_by(id=data["member_id"]).first()
)
if member:
member.linkedin = format_linkedin_url(data["linkedin_url"])
updated_count += 1
else:
print(
f" [TEST MODE] Would update member {data['member_id']}: {data['linkedin_url']}"
)
updated_count += 1
if not test_mode:
db.commit()
print(f"\n✓ Successfully updated {updated_count} records in database")
else:
print(f"\n[TEST MODE] Would have updated {updated_count} records")
return updated_count
except Exception as e:
db.rollback()
print(f"\n✗ Error updating database: {e}")
raise
finally:
db.close()
def save_results(results, filename="linkedin_scraping_results.json"):
"""Save results to JSON file for backup/analysis"""
output = {
"timestamp": datetime.now().isoformat(),
"total_processed": len(results),
"found_count": sum(1 for r in results if r["linkedin_url"]),
"results": results,
}
with open(filename, "w") as f:
json.dump(output, f, indent=2)
print(f"\n✓ Results saved to {filename}")
def print_summary(results):
"""Print summary statistics"""
total = len(results)
found = sum(1 for r in results if r["linkedin_url"])
not_found = total - found
# Count by method
methods = {}
for r in results:
if r["linkedin_url"]:
method = r["method"]
methods[method] = methods.get(method, 0) + 1
# Average confidence for found profiles
avg_confidence = (
sum(r["confidence"] for r in results if r["linkedin_url"]) / found
if found > 0
else 0
)
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"Total processed: {total}")
print(f"LinkedIn found: {found} ({found / total * 100:.1f}%)")
print(f"Not found: {not_found} ({not_found / total * 100:.1f}%)")
print(f"\nAverage confidence: {avg_confidence:.1f}%")
print("\nMethods used:")
for method, count in sorted(methods.items(), key=lambda x: x[1], reverse=True):
print(f" {method:20s} {count:5d} ({count / found * 100:.1f}%)")
print("=" * 60)
def main():
parser = argparse.ArgumentParser(
description="Update LinkedIn profiles for investor members"
)
parser.add_argument(
"--test",
action="store_true",
help="Test mode: process only 10 records without updating database",
)
parser.add_argument("--limit", type=int, help="Limit number of records to process")
parser.add_argument(
"--skip-existing",
action="store_true",
help="Skip members that already have LinkedIn URLs",
)
parser.add_argument(
"--start-from",
type=int,
default=0,
help="Start from record N (for resuming interrupted runs)",
)
parser.add_argument(
"--rate-limit",
type=float,
default=0.5,
help="Delay between URL crawls in seconds (default: 0.5)",
)
args = parser.parse_args()
# Test mode overrides limit
if args.test and not args.limit:
args.limit = 10
print("=" * 60)
print("LinkedIn Profile Scraper for Investor Members (crawl4ai)")
print("=" * 60)
if args.test:
print("\n⚠️ TEST MODE - No database changes will be made")
# Initialize database and scraper
db = get_db_session()
try:
# Build query
query = db.query(InvestorMember, InvestorTable).join(
InvestorTable, InvestorMember.investor_id == InvestorTable.id
)
# Filter existing if requested
if args.skip_existing:
query = query.filter(
(InvestorMember.linkedin.is_(None)) | (InvestorMember.linkedin == "")
)
print("\n✓ Filtering to members without LinkedIn profiles")
# Get total count
total_available = query.count()
print(f"\n✓ Found {total_available} members to process")
# Apply offset and limit
if args.start_from > 0:
query = query.offset(args.start_from)
print(f"✓ Starting from record {args.start_from}")
if args.limit:
query = query.limit(args.limit)
print(f"✓ Processing {args.limit} records")
# Fetch members
members_data = []
for member, investor in query.all():
members_data.append(
{
"id": member.id,
"name": member.name,
"company": investor.name,
"role": member.role,
"source_url": member.source_url,
}
)
if not members_data:
print("\n⚠️ No members to process")
return
# Count unique source URLs
unique_urls = len(set(m["source_url"] for m in members_data if m["source_url"]))
with_urls = sum(1 for m in members_data if m["source_url"])
print(f"\n✓ Loaded {len(members_data)} members")
print(
f"{with_urls} members have source URLs ({unique_urls} unique pages to crawl)"
)
print(f"{len(members_data) - with_urls} members without source URLs")
print(f"✓ Rate limit: {args.rate_limit}s between page crawls")
print("\nStarting LinkedIn profile search using crawl4ai...\n")
finally:
db.close()
# Initialize scraper
scraper = LinkedInProfileScraper(rate_limit_delay=args.rate_limit, use_cache=True)
print("️ Using crawl4ai to scrape team pages and extract LinkedIn URLs")
print(
"️ Profiles are saved to database IMMEDIATELY when found - safe to stop anytime!\n"
)
# Create database callback for real-time saving
db_callback, saved_count = create_db_callback(test_mode=args.test)
# Process members asynchronously with real-time DB saving
results = asyncio.run(
scraper.batch_find_profiles(
members_data, progress_callback=progress_callback, db_callback=db_callback
)
)
# Print summary
print_summary(results)
# Save results
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_file = f"linkedin_results_{timestamp}.json"
save_results(results, results_file)
# Show database update summary
if not args.test:
print(
f"\n✓ Database updated in real-time: {saved_count['count']} profiles saved"
)
else:
print(
f"\n[TEST MODE] Would have saved {saved_count['count']} profiles to database"
)
print("\n✓ Done! You can resume anytime with --skip-existing")
if __name__ == "__main__":
main()