1 Commits

44 changed files with 356 additions and 31150 deletions
+2 -5
View File
@@ -10,11 +10,8 @@
*__pycache__
/*.db
*.cypython
nohup.out
server.log
server.pid
/preprocessor
Binary file not shown.
Binary file not shown.
Binary file not shown.
-5
View File
@@ -9,10 +9,6 @@ from sqlalchemy.orm import Session, sessionmaker
Base = declarative_base()
# Database configuration
# Use the preprocessor's database for consistency
# Get absolute path to the preprocessor database
# APP_DIR = Path(__file__).parent.parent
# PREPROCESSOR_DB = APP_DIR.parent / "preprocessor" / "version_two.db"
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./investors.db")
# Create engine
@@ -42,7 +38,6 @@ def get_session_sync() -> Session:
"""Get a database session for synchronous operations"""
return SessionLocal()
def get_db_session():
"""Get a database session for direct use."""
return SessionLocal()
+10 -143
View File
@@ -2,7 +2,7 @@ import enum
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table, Text, func
from sqlalchemy.orm import declarative_mixin, relationship
from sqlalchemy.types import JSON, Enum
from sqlalchemy.types import Enum
from db.db import Base
@@ -70,22 +70,6 @@ project_company_association = Table(
Column("company_id", Integer, ForeignKey("companies.id")),
)
# Association table for fund-stage many-to-many
fund_investment_stages_association = Table(
"fund_investment_stages",
Base.metadata,
Column("fund_id", Integer, ForeignKey("funds.id")),
Column("stage_id", Integer, ForeignKey("investment_stages.id")),
)
# Association table for fund-sector many-to-many
fund_sectors_association = Table(
"fund_sectors",
Base.metadata,
Column("fund_id", Integer, ForeignKey("funds.id")),
Column("sector_id", Integer, ForeignKey("sectors.id")),
)
class InvestorTable(Base, TimestampMixin):
__tablename__ = "investors"
@@ -93,47 +77,14 @@ class InvestorTable(Base, TimestampMixin):
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
description = Column(Text, nullable=True)
# Basic investor info
website = Column(String, nullable=True)
headquarters = Column(String, nullable=True)
# AUM fields
aum = Column(Integer, nullable=True) # Store as integer for numerical filtering
aum_as_of_date = Column(String, nullable=True)
aum_source_url = Column(String, nullable=True)
# Check size (deprecated in favor of fund-level data, but keeping for backward compatibility)
check_size_lower = Column(Integer, nullable=True)
check_size_upper = Column(Integer, nullable=True)
# Geographic focus (deprecated in favor of fund-level, but keeping for backward compatibility)
aum = Column(Integer, nullable=True) # Assets Under Management
check_size_lower = Column(Integer, nullable=True) # Lower bound
check_size_upper = Column(Integer, nullable=True) # Upper bound
geographic_focus = Column(String, nullable=True)
# Investment thesis and portfolio
investment_thesis = Column(JSON, nullable=True) # Array of thesis statements
portfolio_highlights = Column(
JSON, nullable=True
) # Array of portfolio company names
linked_documents = Column(JSON, nullable=True) # Array of document URLs
# Research metadata
researcher_notes = Column(Text, nullable=True)
missing_important_fields = Column(
JSON, nullable=True
) # Array of missing field names
sources = Column(JSON, nullable=True) # JSON object with source URLs
# Portfolio info
stage_focus = Column(Enum(InvestmentStage), nullable=True)
number_of_investments = Column(Integer, default=0, nullable=True)
# Relationships
team_members = relationship(
"InvestorMember", back_populates="investor", cascade="all, delete-orphan"
)
funds = relationship(
"FundTable", back_populates="investor", cascade="all, delete-orphan"
)
team_members = relationship("InvestorMember", back_populates="investor")
# Relationship to portfolio companies
portfolio_companies = relationship(
@@ -160,52 +111,12 @@ class InvestorMember(Base, TimestampMixin):
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
role = Column(String, nullable=True)
title = Column(String, nullable=True) # Alternative to role
email = Column(String, nullable=True)
linkedin = Column(String, nullable=True) # LinkedIn profile URL
source_url = Column(String, nullable=True) # URL where member info was found
investor_id = Column(Integer, ForeignKey("investors.id"))
investor = relationship("InvestorTable", back_populates="team_members")
class FundTable(Base, TimestampMixin):
__tablename__ = "funds"
id = Column(Integer, primary_key=True, index=True)
investor_id = Column(Integer, ForeignKey("investors.id"), nullable=False)
# Fund details
fund_name = Column(String, nullable=True)
fund_size = Column(
Integer, nullable=True
) # Store as integer for numerical filtering
fund_size_source_url = Column(String, nullable=True)
# Check size range (parsed from estimated_investment_size by LLM)
check_size_lower = Column(Integer, nullable=True)
check_size_upper = Column(Integer, nullable=True)
source_url = Column(String, nullable=True)
source_provider = Column(String, nullable=True) # e.g., "Perplexity"
# Geographic focus as simple string
geographic_focus = Column(String, nullable=True)
# Relationships
investor = relationship("InvestorTable", back_populates="funds")
investment_stages = relationship(
"InvestmentStageTable",
secondary=fund_investment_stages_association,
back_populates="funds",
)
sectors = relationship(
"SectorTable",
secondary=fund_sectors_association,
back_populates="funds",
)
class CompanyTable(Base, TimestampMixin):
__tablename__ = "companies"
@@ -216,12 +127,8 @@ class CompanyTable(Base, TimestampMixin):
description = Column(String, nullable=True)
founded_year = Column(Integer, nullable=True)
website = Column(String, nullable=True)
product_service = Column(Text, nullable=True) # Product/service description
clients = Column(JSON, nullable=True) # List of client names or client information
members = relationship(
"CompanyMember", back_populates="company", cascade="all, delete-orphan"
)
members = relationship("CompanyMember", back_populates="company")
# Relationship back to investors
investors = relationship(
"InvestorTable",
@@ -251,43 +158,26 @@ class CompanyMember(Base, TimestampMixin):
company = relationship("CompanyTable", back_populates="members")
class InvestmentStageTable(Base, TimestampMixin):
__tablename__ = "investment_stages"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False, unique=True)
# Relationships
funds = relationship(
"FundTable",
secondary=fund_investment_stages_association,
back_populates="investment_stages",
)
class SectorTable(Base, TimestampMixin):
__tablename__ = "sectors"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
# Relationships
# Add relationship back to investors
investors = relationship(
"InvestorTable",
secondary=investor_sector_association,
back_populates="sectors",
)
companies = relationship(
"CompanyTable", secondary=company_sector_association, back_populates="sectors"
)
projects = relationship(
"ProjectTable", secondary=project_sector_association, back_populates="sector"
)
funds = relationship(
"FundTable",
secondary=fund_sectors_association,
back_populates="sectors",
)
class ProjectTable(Base, TimestampMixin):
@@ -299,11 +189,9 @@ class ProjectTable(Base, TimestampMixin):
stage = Column(Enum(InvestmentStage), nullable=True)
location = Column(String, nullable=True)
industry = Column(String, nullable=True)
description = Column(Text, nullable=True)
start_date = Column(DateTime, nullable=True)
end_date = Column(DateTime, nullable=True)
is_archived = Column(Integer, default=0, nullable=False) # 0 = active, 1 = archived
sector = relationship(
"SectorTable", secondary=project_sector_association, back_populates="projects"
@@ -316,24 +204,3 @@ class ProjectTable(Base, TimestampMixin):
companies = relationship(
"CompanyTable", secondary=project_company_association, back_populates="projects"
)
class InvestorInsightCache(Base, TimestampMixin):
__tablename__ = "investor_insight_cache"
id = Column(Integer, primary_key=True, index=True)
investor_id = Column(
Integer, ForeignKey("investors.id"), nullable=False, unique=True
)
# Cached insights
investment_pattern_analysis = Column(Text, nullable=False)
market_position = Column(Text, nullable=False)
# Cache management
last_refreshed = Column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
# Relationship to investor
investor = relationship("InvestorTable")
-730
View File
@@ -1,730 +0,0 @@
"""
LinkedIn Profile Scraper for Investor Members
This module uses crawl4ai to scrape team pages and find LinkedIn profiles.
Strategies:
1. Crawl the source_url (team pages) to extract LinkedIn profile links
2. Use LLM-powered web search to find LinkedIn profiles by name
Key advantages of crawl4ai:
- Handles JavaScript-rendered pages
- Better at extracting content from modern websites
- More reliable than simple requests
"""
import asyncio
import logging
import os
import re
from typing import Dict, List, Optional
from crawl4ai import AsyncWebCrawler
from ddgs import DDGS
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
# Setup logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
logger = logging.getLogger("linkedin_scraper")
load_dotenv()
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
class LinkedInProfileScraper:
"""
LinkedIn profile finder using crawl4ai and LLM-powered web search.
Strategies:
1. Crawl source URLs (team pages) to extract LinkedIn links
2. Use LLM-powered web search to find profiles by name
"""
def __init__(
self,
rate_limit_delay: float = 0.5,
use_cache: bool = True,
use_llm_search: bool = True,
):
"""
Initialize the scraper
Args:
rate_limit_delay: Delay between requests in seconds
use_cache: Whether to cache crawled pages
use_llm_search: Whether to use LLM-powered web search as fallback
"""
self.rate_limit_delay = rate_limit_delay
self.use_cache = use_cache
self.use_llm_search = use_llm_search and OPENROUTER_API_KEY
self.page_cache: Dict[str, str] = {} # Cache crawled pages by URL
self.html_cache: Dict[str, str] = {} # Cache HTML separately
self.profile_cache: Dict[str, Dict] = {} # Cache results by member
# Initialize LLM agent if API key available
if self.use_llm_search:
self._init_llm_agent()
else:
self.llm = None
self.agent = None
self.ddg_search = None
logger.info("LLM search disabled (no OPENROUTER_API_KEY)")
def _init_llm_agent(self):
"""Initialize LLM agent for web search"""
try:
self.llm = ChatOpenAI(
api_key=OPENROUTER_API_KEY,
base_url="https://openrouter.ai/api/v1",
model="x-ai/grok-4.1-fast:free",
temperature=0,
)
self.ddg_search = DDGS()
logger.info("LLM search agent initialized")
except Exception as e:
logger.error(f"Failed to initialize LLM agent: {e}")
self.llm = None
self.ddg_search = None
def web_search(self, query: str) -> List[Dict]:
"""Tool to search the web using DuckDuckGo"""
if not self.ddg_search:
return []
try:
results = list(self.ddg_search.text(query, max_results=10))
return results
except Exception as e:
logger.error(f"Web search error: {e}")
return []
async def crawl_page(self, url: str) -> Optional[str]:
"""
Crawl a webpage and return its content.
Args:
url: URL to crawl
Returns:
Page content as markdown/text, or None if failed
"""
if not url:
return None
# Check cache first
if self.use_cache and url in self.page_cache:
logger.debug(f"Using cached page for {url}")
return self.page_cache[url]
try:
logger.info(f"Crawling: {url}")
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(url)
if result and result.markdown:
content = result.markdown
# Also get HTML for better link extraction
html_content = result.html if hasattr(result, "html") else ""
# Cache the results
if self.use_cache:
self.page_cache[url] = content
self.html_cache[url] = html_content
return content
except Exception as e:
logger.error(f"Error crawling {url}: {e}")
return None
def extract_linkedin_urls_from_content(self, content: str) -> List[Dict[str, str]]:
"""
Extract all LinkedIn profile URLs from content (HTML or markdown).
Returns:
List of dicts with 'url', 'context', and 'username'
"""
linkedin_links = []
# Pattern for LinkedIn profile URLs (handles country-specific domains)
linkedin_pattern = (
r"https?://(?:www\.)?(?:[a-z]{2,3}\.)?linkedin\.com/in/([a-zA-Z0-9_-]+)/?"
)
# Find all LinkedIn URLs
matches = list(re.finditer(linkedin_pattern, content, re.IGNORECASE))
for match in matches:
url = match.group(0).rstrip("/")
# Normalize URL
url = self._normalize_linkedin_url(url)
# Get surrounding context (200 chars before and after)
start = max(0, match.start() - 200)
end = min(len(content), match.end() + 200)
context = content[start:end]
# Clean up context (remove HTML tags for readability)
context = re.sub(r"<[^>]+>", " ", context)
context = " ".join(context.split()) # Normalize whitespace
linkedin_links.append(
{"url": url, "context": context, "username": match.group(1)}
)
# Remove duplicates while preserving order
seen_urls = set()
unique_links = []
for link in linkedin_links:
if link["url"] not in seen_urls:
seen_urls.add(link["url"])
unique_links.append(link)
return unique_links
def _normalize_linkedin_url(self, url: str) -> str:
"""Normalize LinkedIn URL to standard format"""
# Remove trailing slashes
url = url.rstrip("/")
# Convert country-specific to www
url = re.sub(
r"https?://[a-z]{2,3}\.linkedin\.com", "https://www.linkedin.com", url
)
# Ensure https
if url.startswith("http://"):
url = url.replace("http://", "https://")
return url
def _name_matches_context(self, name: str, context: str) -> float:
"""
Check if a person's name appears in the context around a LinkedIn URL.
Returns:
Confidence score 0-100
"""
if not name or not context:
return 0
context_lower = context.lower()
name_lower = name.lower()
# Split name into parts (handle multiple spaces, titles like "Dr.", etc.)
name_parts = [p for p in name_lower.replace(".", " ").split() if len(p) > 1]
# Check for full name match
if name_lower in context_lower:
return 95
# Check for name parts in context
matches = sum(
1 for part in name_parts if part in context_lower and len(part) > 2
)
if len(name_parts) > 0:
if matches == len(name_parts):
return 90 # All name parts found
elif matches >= 2:
return 75 # At least 2 parts found (first + last typically)
elif matches == 1 and len(name_parts) <= 2:
return 50 # Only one part found but name is short
elif matches == 1:
return 35 # Only one part found
return 0
def _name_matches_username(self, name: str, username: str) -> float:
"""
Check if LinkedIn username contains parts of the name.
Returns:
Confidence score 0-100
"""
if not name or not username:
return 0
name_lower = name.lower()
username_lower = username.lower().replace("-", " ").replace("_", " ")
name_parts = [p for p in name_lower.replace(".", " ").split() if len(p) > 2]
matches = sum(1 for part in name_parts if part in username_lower)
if len(name_parts) > 0:
if matches == len(name_parts) and len(name_parts) >= 2:
return 85 # Full name in username
elif matches >= 2:
return 70 # Multiple parts match
elif matches == 1:
return 35 # Only one part matches
return 0
async def find_linkedin_from_source(
self, name: str, source_url: str, role: Optional[str] = None
) -> Dict:
"""
Find LinkedIn profile by crawling the source URL (team page).
Args:
name: Person's name
source_url: URL of the team/about page
role: Person's role (for additional context matching)
Returns:
Dict with linkedin_url, confidence, method, notes
"""
if not source_url:
return {
"linkedin_url": None,
"confidence": 0,
"method": "source_crawl",
"notes": "No source URL provided",
}
# Crawl the page
content = await self.crawl_page(source_url)
if not content:
return {
"linkedin_url": None,
"confidence": 0,
"method": "source_crawl",
"notes": f"Failed to crawl {source_url}",
}
# Get HTML for better link extraction
html = self.html_cache.get(source_url, content)
# Extract all LinkedIn URLs from both HTML and markdown
linkedin_links = self.extract_linkedin_urls_from_content(html)
if not linkedin_links:
linkedin_links = self.extract_linkedin_urls_from_content(content)
if not linkedin_links:
return {
"linkedin_url": None,
"confidence": 0,
"method": "source_crawl",
"notes": f"No LinkedIn URLs found on {source_url}",
}
# Score each LinkedIn URL based on name matching
best_match = None
best_score = 0
for link in linkedin_links:
# Score based on context matching
context_score = self._name_matches_context(name, link["context"])
# Score based on username matching
username_score = self._name_matches_username(name, link["username"])
# Also check if role appears in context
role_bonus = 0
if role and role.lower() in link["context"].lower():
role_bonus = 10
# Combined score (take best of context or username, plus role bonus)
total_score = max(context_score, username_score) + role_bonus
logger.debug(
f" {name} -> {link['url']}: context={context_score}, username={username_score}, role={role_bonus}, total={total_score}"
)
if total_score > best_score:
best_score = total_score
best_match = link
if best_match and best_score >= 30: # Minimum threshold
return {
"linkedin_url": best_match["url"],
"confidence": min(best_score, 100),
"method": "source_crawl",
"notes": f"Found on {source_url}",
}
return {
"linkedin_url": None,
"confidence": 0,
"method": "source_crawl",
"notes": f'No matching LinkedIn profile found for "{name}" on {source_url}',
}
async def find_linkedin_via_search(
self, name: str, company: str, role: Optional[str] = None
) -> Dict:
"""
Find LinkedIn profile using web search.
Args:
name: Person's name
company: Company/investor name
role: Person's role (optional)
Returns:
Dict with linkedin_url, confidence, method, notes
"""
if not self.ddg_search:
return {
"linkedin_url": None,
"confidence": 0,
"method": "web_search",
"notes": "Web search not available",
}
try:
# Build search query - search for LinkedIn profile
query = f"{name} {company} site:linkedin.com/in"
if role:
query = f"{name} {role} {company} site:linkedin.com/in"
logger.debug(f"Searching: {query}")
results = self.web_search(query)
if results:
# Look for LinkedIn profile URLs in results
linkedin_pattern = r"https?://(?:www\.)?(?:[a-z]{2,3}\.)?linkedin\.com/in/([a-zA-Z0-9_-]+)"
for result in results:
url = result.get("href") or result.get("link") or ""
title = result.get("title", "").lower()
body = result.get("body", "").lower()
match = re.search(linkedin_pattern, url, re.IGNORECASE)
if match:
linkedin_url = self._normalize_linkedin_url(match.group(0))
username = match.group(1)
# Score based on name matching in title/body and username
context = f"{title} {body}"
context_score = self._name_matches_context(name, context)
username_score = self._name_matches_username(name, username)
total_score = max(context_score, username_score)
if total_score >= 30:
return {
"linkedin_url": linkedin_url,
"confidence": min(
total_score, 90
), # Cap at 90 for search results
"method": "web_search",
"notes": "Found via web search",
}
return {
"linkedin_url": None,
"confidence": 0,
"method": "web_search",
"notes": "No matching profile found in search results",
}
except Exception as e:
logger.error(f"Web search error for {name}: {e}")
return {
"linkedin_url": None,
"confidence": 0,
"method": "web_search",
"notes": f"Search error: {str(e)}",
}
async def find_linkedin_profile(
self,
name: str,
company: str,
role: Optional[str] = None,
source_url: Optional[str] = None,
) -> Dict:
"""
Find LinkedIn profile for a person.
Primary strategy: Crawl source URL to find LinkedIn links.
Args:
name: Person's name
company: Company/investor name
role: Person's role/title (optional)
source_url: URL where person info was found (optional)
Returns:
Dict with:
- linkedin_url: Found LinkedIn URL or None
- confidence: Confidence score (0-100)
- method: Method used to find the profile
- notes: Additional information
"""
cache_key = f"{name}|{company}"
# Check cache
if self.use_cache and cache_key in self.profile_cache:
logger.debug(f"Using cached result for {name}")
return self.profile_cache[cache_key]
result = {"linkedin_url": None, "confidence": 0, "method": "none", "notes": ""}
# Primary strategy: Crawl source URL
if source_url:
result = await self.find_linkedin_from_source(name, source_url, role)
if result["linkedin_url"]:
if self.use_cache:
self.profile_cache[cache_key] = result
return result
# Fallback strategy: Web search (if enabled and no result from source crawl)
if self.use_llm_search and not result.get("linkedin_url"):
search_result = await self.find_linkedin_via_search(name, company, role)
if search_result["linkedin_url"]:
if self.use_cache:
self.profile_cache[cache_key] = search_result
return search_result
# If no source URL or no match found
if not result["linkedin_url"]:
result = {
"linkedin_url": None,
"confidence": 0,
"method": "none",
"notes": "No source URL available"
if not source_url
else result.get("notes", "Not found"),
}
if self.use_cache:
self.profile_cache[cache_key] = result
return result
async def batch_find_profiles(
self, members: List[Dict], progress_callback=None, db_callback=None
) -> List[Dict]:
"""
Find LinkedIn profiles for multiple members efficiently.
Groups members by source_url to minimize crawling the same page multiple times.
Args:
members: List of dicts with 'name', 'company', 'role', 'source_url', 'id'
progress_callback: Optional callback function(current, total, result)
db_callback: Optional callback to save to database immediately when profile found
Signature: db_callback(member_id, linkedin_url) -> bool
Returns:
List of results for each member
"""
results = []
total = len(members)
# Group members by source_url for efficient crawling
url_groups: Dict[str, List[Dict]] = {}
no_url_members = []
for member in members:
url = member.get("source_url")
if url:
if url not in url_groups:
url_groups[url] = []
url_groups[url].append(member)
else:
no_url_members.append(member)
logger.info(
f"Processing {len(url_groups)} unique source URLs for {total} members"
)
logger.info(f"Members with source URLs: {total - len(no_url_members)}")
logger.info(f"Members without source URLs: {len(no_url_members)}")
if self.use_llm_search:
logger.info("Web search fallback: ENABLED")
else:
logger.info("Web search fallback: DISABLED")
processed = 0
# Process members grouped by URL (efficient - one crawl per page)
for url, group_members in url_groups.items():
# Crawl the page once
content = await self.crawl_page(url)
html = self.html_cache.get(url, content or "")
# Extract all LinkedIn URLs from this page
linkedin_links = []
if content:
linkedin_links = self.extract_linkedin_urls_from_content(html)
if not linkedin_links:
linkedin_links = self.extract_linkedin_urls_from_content(content)
# Match each member in this group
for member in group_members:
processed += 1
result = None
found_linkedin = False
if linkedin_links:
# Find best matching LinkedIn for this member
best_match = None
best_score = 0
for link in linkedin_links:
context_score = self._name_matches_context(
member["name"], link["context"]
)
username_score = self._name_matches_username(
member["name"], link["username"]
)
role_bonus = (
10
if member.get("role")
and member["role"].lower() in link["context"].lower()
else 0
)
total_score = max(context_score, username_score) + role_bonus
if total_score > best_score:
best_score = total_score
best_match = link
if best_match and best_score >= 30:
result = {
"linkedin_url": best_match["url"],
"confidence": min(best_score, 100),
"method": "source_crawl",
"notes": f"Found on {url}",
"member_id": member.get("id"),
"member_name": member["name"],
}
found_linkedin = True
# Save to database immediately if callback provided
if db_callback and member.get("id"):
db_callback(member["id"], best_match["url"])
# If no result from source crawl, try web search IMMEDIATELY
if not found_linkedin and self.use_llm_search:
search_result = await self.find_linkedin_via_search(
member["name"], member["company"], member.get("role")
)
if search_result["linkedin_url"]:
result = {
"linkedin_url": search_result["linkedin_url"],
"confidence": search_result["confidence"],
"method": "web_search",
"notes": search_result.get("notes", "Found via web search"),
"member_id": member.get("id"),
"member_name": member["name"],
}
found_linkedin = True
# Save to database immediately
if db_callback and member.get("id"):
db_callback(member["id"], search_result["linkedin_url"])
# If still no result, record as not found
if not found_linkedin:
result = {
"linkedin_url": None,
"confidence": 0,
"method": "source_crawl" if content else "none",
"notes": f"No match on {url}"
if linkedin_links
else (
f"No LinkedIn URLs on {url}"
if content
else f"Failed to crawl {url}"
),
"member_id": member.get("id"),
"member_name": member["name"],
}
results.append(result)
if progress_callback:
progress_callback(processed, total, result)
# Small delay between different URLs
await asyncio.sleep(self.rate_limit_delay)
# Process members without source URLs - do web search immediately for each
for member in no_url_members:
processed += 1
result = None
# Try web search immediately
if self.use_llm_search:
search_result = await self.find_linkedin_via_search(
member["name"], member["company"], member.get("role")
)
if search_result["linkedin_url"]:
result = {
"linkedin_url": search_result["linkedin_url"],
"confidence": search_result["confidence"],
"method": "web_search",
"notes": search_result.get("notes", "Found via web search"),
"member_id": member.get("id"),
"member_name": member["name"],
}
# Save to database immediately
if db_callback and member.get("id"):
db_callback(member["id"], search_result["linkedin_url"])
# If no result from search
if not result:
result = {
"linkedin_url": None,
"confidence": 0,
"method": "web_search" if self.use_llm_search else "none",
"notes": "No LinkedIn profile found"
if self.use_llm_search
else "No source URL available",
"member_id": member.get("id"),
"member_name": member["name"],
}
results.append(result)
if progress_callback:
progress_callback(processed, total, result)
# Rate limit between searches
await asyncio.sleep(self.rate_limit_delay)
return results
def format_linkedin_url(url: str) -> str:
"""Normalize LinkedIn URL format"""
if not url:
return url
# Remove trailing slashes
url = url.rstrip("/")
# Ensure https and normalize to www
url = re.sub(r"https?://[a-z]{2,3}\.linkedin\.com", "https://www.linkedin.com", url)
if url.startswith("http://"):
url = url.replace("http://", "https://")
return url
# Async wrapper for sync contexts
def run_batch_scraper(
members: List[Dict], rate_limit: float = 0.5, progress_callback=None
) -> List[Dict]:
"""
Synchronous wrapper for batch_find_profiles.
Args:
members: List of member dicts
rate_limit: Delay between URL crawls
progress_callback: Optional progress callback
Returns:
List of results
"""
scraper = LinkedInProfileScraper(rate_limit_delay=rate_limit)
return asyncio.run(scraper.batch_find_profiles(members, progress_callback))
+21 -98
View File
@@ -1,23 +1,12 @@
import io
import logging
import pandas as pd
from db.db import Base, db_dependency, engine
from dotenv import load_dotenv
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI, File, Form, UploadFile
from pydantic import BaseModel
from routers import (
addition,
companies,
folk_crm,
insight_route,
investors,
projects,
report_route,
)
from schemas.router_schemas import CompanyData, InvestmentResponse, PaginatedResponse
from services.company_querying import CompanyQueryProcessor
from routers import companies, investors, projects
from schemas.router_schemas import InvestorList
from services.llm_parser import InvestorProcessor
from services.querying import QueryProcessor
@@ -29,21 +18,10 @@ def init_database():
Base.metadata.create_all(bind=engine)
logger = logging.getLogger(__name__)
init_database()
app = FastAPI()
# Add CORS middleware to allow frontend requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with specific origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request models
class QueryRequest(BaseModel):
@@ -57,17 +35,6 @@ class QueryRequest(BaseModel):
}
class CompanyQueryRequest(BaseModel):
question: str
class Config:
json_schema_extra = {
"example": {
"question": "Find me companies in the fintech sector located in San Francisco."
}
}
@app.get("/")
def health():
return {"Hello": "World"}
@@ -77,29 +44,6 @@ def health():
async def parse_csv(
db: db_dependency, file: UploadFile = File(...), is_investor: int = Form(...)
):
"""
Parse and import CSV data into the database.
**For investors:**
- Expected columns: Name, Website, Final Investor Profile, Final Profile sourcing
- Manually parses JSON profiles for efficiency
- Uses LLM only for currency conversion to USD
- Handles AUM, fund sizes, and check sizes as integers
**For companies:**
- Expected columns: Name, Website, Perplexity Gap Output (or Final Investor Profile)
- 100% manual JSON parsing - no LLM needed
- **Only extracts:** founded_year and key_executives
- **Only updates companies already in the database** (syncs with existing records)
- Skips companies not found in the database
**Benefits:**
- Fast processing (5-10s per record)
- Low cost (minimal or no LLM usage)
- Accurate data extraction
- Automatic database persistence
- Safe: won't create duplicate companies
"""
# Read uploaded CSV with pandas
content = await file.read()
df = pd.read_csv(io.StringIO(content.decode("utf-8")))
@@ -108,56 +52,35 @@ async def parse_csv(
processor = InvestorProcessor()
if is_investor == 1:
# Manual parser with LLM currency conversion
results = await processor.parse_investors(df, save_to_db=True)
# Results are already dicts from the new parser
return results
results = await processor.parse_investors(df)
else:
# Manual parser for companies (no LLM needed)
results = await processor.parse_companies(df, save_to_db=True)
# Results are already dicts from the new parser
return results
results = await processor.parse_companies(df)
# Convert Pydantic objects to dictionaries
return [r.model_dump() for r in results]
@app.post(
"/query", response_model=PaginatedResponse[InvestmentResponse], tags=["Querying"]
)
@app.post("/query", response_model=InvestorList, tags=["Querying"])
async def query_investors(request: QueryRequest):
"""Query investors/funds using natural language"""
try:
processor = QueryProcessor()
result = await processor.process_query(request.question)
logger.info(f"Query completed successfully with {result.total} results")
return result
except Exception as e:
logger.error(f"Error in query_investors: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
"""
Query investors using natural language.
@app.post(
"/query-companies", response_model=PaginatedResponse[CompanyData], tags=["Querying"]
)
async def query_companies(request: CompanyQueryRequest):
"""Query companies using natural language"""
try:
processor = CompanyQueryProcessor()
result = await processor.process_query(request.question)
logger.info(f"Company query completed successfully with {result.total} results")
return result
except Exception as e:
logger.error(f"Error in query_companies: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
Supports queries like:
- "Show me seed stage investors"
- "Find fintech investors in Silicon Valley"
- "Growth stage investors with $5M+ check sizes"
- "Healthcare investors in Europe"
"""
processor = QueryProcessor()
results = processor.process_query(request.question)
return results
app.include_router(investors.router)
app.include_router(companies.router)
app.include_router(projects.router)
app.include_router(folk_crm.router)
app.include_router(insight_route.router)
app.include_router(report_route.router)
app.include_router(addition.router)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app="main:app", host="0.0.0.0", port=8585)
uvicorn.run(app="main:app", host="0.0.0.0", port=8585, reload=True)
Binary file not shown.
Binary file not shown.
-370
View File
@@ -1,370 +0,0 @@
from typing import Optional
from db.db import get_db
from db.models import FundTable, InvestorTable, SectorTable
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from sqlalchemy.orm import Session
router = APIRouter(tags=["Additional Routes"])
# Response schemas
class SectorsResponse(BaseModel):
sectors: list[str]
total: int
class CountryInfo(BaseModel):
name: str
class ContinentInfo(BaseModel):
name: str
countries: list[str]
class GeographyResponse(BaseModel):
continents: list[ContinentInfo]
total_continents: int
total_countries: int
# Mapping of countries to continents
COUNTRY_TO_CONTINENT = {
# Africa
"Algeria": "Africa",
"Angola": "Africa",
"Benin": "Africa",
"Botswana": "Africa",
"Burkina Faso": "Africa",
"Burundi": "Africa",
"Cameroon": "Africa",
"Cape Verde": "Africa",
"Central African Republic": "Africa",
"Chad": "Africa",
"Comoros": "Africa",
"Congo": "Africa",
"Democratic Republic of the Congo": "Africa",
"Djibouti": "Africa",
"Egypt": "Africa",
"Equatorial Guinea": "Africa",
"Eritrea": "Africa",
"Eswatini": "Africa",
"Ethiopia": "Africa",
"Gabon": "Africa",
"Gambia": "Africa",
"Ghana": "Africa",
"Guinea": "Africa",
"Guinea-Bissau": "Africa",
"Ivory Coast": "Africa",
"Kenya": "Africa",
"Lesotho": "Africa",
"Liberia": "Africa",
"Libya": "Africa",
"Madagascar": "Africa",
"Malawi": "Africa",
"Mali": "Africa",
"Mauritania": "Africa",
"Mauritius": "Africa",
"Morocco": "Africa",
"Mozambique": "Africa",
"Namibia": "Africa",
"Niger": "Africa",
"Nigeria": "Africa",
"Rwanda": "Africa",
"Sao Tome and Principe": "Africa",
"Senegal": "Africa",
"Seychelles": "Africa",
"Sierra Leone": "Africa",
"Somalia": "Africa",
"South Africa": "Africa",
"South Sudan": "Africa",
"Sudan": "Africa",
"Tanzania": "Africa",
"Togo": "Africa",
"Tunisia": "Africa",
"Uganda": "Africa",
"Zambia": "Africa",
"Zimbabwe": "Africa",
# Asia
"Afghanistan": "Asia",
"Armenia": "Asia",
"Azerbaijan": "Asia",
"Bahrain": "Asia",
"Bangladesh": "Asia",
"Bhutan": "Asia",
"Brunei": "Asia",
"Cambodia": "Asia",
"China": "Asia",
"Cyprus": "Asia",
"Georgia": "Asia",
"Hong Kong": "Asia",
"India": "Asia",
"Indonesia": "Asia",
"Iran": "Asia",
"Iraq": "Asia",
"Israel": "Asia",
"Japan": "Asia",
"Jordan": "Asia",
"Kazakhstan": "Asia",
"Kuwait": "Asia",
"Kyrgyzstan": "Asia",
"Laos": "Asia",
"Lebanon": "Asia",
"Malaysia": "Asia",
"Maldives": "Asia",
"Mongolia": "Asia",
"Myanmar": "Asia",
"Nepal": "Asia",
"North Korea": "Asia",
"Oman": "Asia",
"Pakistan": "Asia",
"Palestine": "Asia",
"Philippines": "Asia",
"Qatar": "Asia",
"Saudi Arabia": "Asia",
"Singapore": "Asia",
"South Korea": "Asia",
"Sri Lanka": "Asia",
"Syria": "Asia",
"Taiwan": "Asia",
"Tajikistan": "Asia",
"Thailand": "Asia",
"Timor-Leste": "Asia",
"Turkey": "Asia",
"Turkmenistan": "Asia",
"United Arab Emirates": "Asia",
"UAE": "Asia",
"Uzbekistan": "Asia",
"Vietnam": "Asia",
"Yemen": "Asia",
# Europe
"Albania": "Europe",
"Andorra": "Europe",
"Austria": "Europe",
"Belarus": "Europe",
"Belgium": "Europe",
"Bosnia and Herzegovina": "Europe",
"Bulgaria": "Europe",
"Croatia": "Europe",
"Czech Republic": "Europe",
"Czechia": "Europe",
"Denmark": "Europe",
"Estonia": "Europe",
"Finland": "Europe",
"France": "Europe",
"Germany": "Europe",
"Greece": "Europe",
"Hungary": "Europe",
"Iceland": "Europe",
"Ireland": "Europe",
"Italy": "Europe",
"Kosovo": "Europe",
"Latvia": "Europe",
"Liechtenstein": "Europe",
"Lithuania": "Europe",
"Luxembourg": "Europe",
"Malta": "Europe",
"Moldova": "Europe",
"Monaco": "Europe",
"Montenegro": "Europe",
"Netherlands": "Europe",
"North Macedonia": "Europe",
"Norway": "Europe",
"Poland": "Europe",
"Portugal": "Europe",
"Romania": "Europe",
"Russia": "Europe",
"San Marino": "Europe",
"Serbia": "Europe",
"Slovakia": "Europe",
"Slovenia": "Europe",
"Spain": "Europe",
"Sweden": "Europe",
"Switzerland": "Europe",
"Ukraine": "Europe",
"United Kingdom": "Europe",
"UK": "Europe",
"Vatican City": "Europe",
# North America
"Antigua and Barbuda": "North America",
"Bahamas": "North America",
"Barbados": "North America",
"Belize": "North America",
"Canada": "North America",
"Costa Rica": "North America",
"Cuba": "North America",
"Dominica": "North America",
"Dominican Republic": "North America",
"El Salvador": "North America",
"Grenada": "North America",
"Guatemala": "North America",
"Haiti": "North America",
"Honduras": "North America",
"Jamaica": "North America",
"Mexico": "North America",
"Nicaragua": "North America",
"Panama": "North America",
"Saint Kitts and Nevis": "North America",
"Saint Lucia": "North America",
"Saint Vincent and the Grenadines": "North America",
"Trinidad and Tobago": "North America",
"United States": "North America",
"USA": "North America",
"US": "North America",
# South America
"Argentina": "South America",
"Bolivia": "South America",
"Brazil": "South America",
"Chile": "South America",
"Colombia": "South America",
"Ecuador": "South America",
"Guyana": "South America",
"Paraguay": "South America",
"Peru": "South America",
"Suriname": "South America",
"Uruguay": "South America",
"Venezuela": "South America",
# Oceania
"Australia": "Oceania",
"Fiji": "Oceania",
"Kiribati": "Oceania",
"Marshall Islands": "Oceania",
"Micronesia": "Oceania",
"Nauru": "Oceania",
"New Zealand": "Oceania",
"Palau": "Oceania",
"Papua New Guinea": "Oceania",
"Samoa": "Oceania",
"Solomon Islands": "Oceania",
"Tonga": "Oceania",
"Tuvalu": "Oceania",
"Vanuatu": "Oceania",
}
# Valid continent names for direct matching
VALID_CONTINENTS = {
"Africa",
"Asia",
"Europe",
"North America",
"South America",
"Oceania",
"Antarctica",
}
def extract_countries_from_geographic_focus(geographic_focus: str) -> set[str]:
"""
Extract country names from a geographic_focus string.
Handles comma-separated values, slashes, and various formats.
"""
if not geographic_focus:
return set()
countries = set()
# Split by common delimiters
parts = geographic_focus.replace("/", ",").replace(";", ",").split(",")
for part in parts:
cleaned = part.strip()
if cleaned:
# Check if it's a known country
if cleaned in COUNTRY_TO_CONTINENT:
countries.add(cleaned)
# Check for partial matches (e.g., "United States of America" -> "United States")
else:
for country in COUNTRY_TO_CONTINENT.keys():
if country.lower() in cleaned.lower() or cleaned.lower() in country.lower():
countries.add(country)
break
return countries
def organize_geography(geographic_data: list[str]) -> dict[str, set[str]]:
"""
Organize geographic data into continents and their countries.
Returns a dict with continent names as keys and sets of countries as values.
"""
continent_countries: dict[str, set[str]] = {}
for geo_focus in geographic_data:
if not geo_focus:
continue
# Extract countries from the geographic focus string
countries = extract_countries_from_geographic_focus(geo_focus)
for country in countries:
continent = COUNTRY_TO_CONTINENT.get(country)
if continent:
if continent not in continent_countries:
continent_countries[continent] = set()
continent_countries[continent].add(country)
# Also check if the geographic focus itself is a continent
cleaned_geo = geo_focus.strip()
if cleaned_geo in VALID_CONTINENTS:
if cleaned_geo not in continent_countries:
continent_countries[cleaned_geo] = set()
return continent_countries
@router.get("/sectors", response_model=SectorsResponse)
def get_unique_sectors(db: Session = Depends(get_db)):
"""
Get all unique sectors from the database.
Returns a list of sector names sorted alphabetically.
"""
sectors = db.query(SectorTable.name).distinct().order_by(SectorTable.name).all()
sector_names = [s[0] for s in sectors if s[0]]
return SectorsResponse(sectors=sector_names, total=len(sector_names))
@router.get("/geography", response_model=GeographyResponse)
def get_arranged_geography(db: Session = Depends(get_db)):
"""
Get all unique geographic locations arranged by continent and countries.
Extracts geography from both investors and funds tables.
Returns continents with their associated countries.
"""
# Collect all geographic focus data from investors
investor_geo = (
db.query(InvestorTable.geographic_focus)
.filter(InvestorTable.geographic_focus.isnot(None))
.distinct()
.all()
)
# Collect all geographic focus data from funds
fund_geo = (
db.query(FundTable.geographic_focus)
.filter(FundTable.geographic_focus.isnot(None))
.distinct()
.all()
)
# Combine all geographic data
all_geo_data = [g[0] for g in investor_geo] + [g[0] for g in fund_geo]
# Organize into continents and countries
continent_countries = organize_geography(all_geo_data)
# Build response
continents = []
total_countries = 0
for continent_name in sorted(continent_countries.keys()):
countries = sorted(continent_countries[continent_name])
total_countries += len(countries)
continents.append(ContinentInfo(name=continent_name, countries=countries))
return GeographyResponse(
continents=continents,
total_continents=len(continents),
total_countries=total_countries,
)
+18 -67
View File
@@ -1,10 +1,10 @@
from typing import Optional
from typing import List, Optional
from db.db import get_db
from db.models import CompanyTable, InvestorTable
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from schemas.router_schemas import CompanyData, PaginatedResponse
from schemas.router_schemas import CompanyData
from sqlalchemy.orm import Session, selectinload
router = APIRouter(tags=["Company Routes"])
@@ -29,63 +29,38 @@ class CompanyUpdate(BaseModel):
website: Optional[str] = None
@router.get("/companies", response_model=PaginatedResponse[CompanyData])
def read_companies(
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
db: Session = Depends(get_db),
):
"""Get all companies with their investor relationships (paginated)"""
# Calculate offset
offset = (page - 1) * page_size
# Get total count
total_count = (
db.query(CompanyTable)
.filter(CompanyTable.name.isnot(None), CompanyTable.description.isnot(None))
.count()
)
# Get paginated results
@router.get("/companies", response_model=List[CompanyData])
def read_companies(db: Session = Depends(get_db)):
"""Get all companies with their investor relationships"""
companies = (
db.query(CompanyTable)
.filter(CompanyTable.name.isnot(None), CompanyTable.description.isnot(None))
.filter(
CompanyTable.name.isnot(None),
CompanyTable.description.isnot(None)
)
.options(
selectinload(CompanyTable.investors),
selectinload(CompanyTable.members),
selectinload(CompanyTable.sectors),
)
.offset(offset)
.limit(page_size)
.all()
)
# Transform CompanyTable objects to CompanyData format
company_data_list = []
for company in companies:
# Sort sectors alphabetically
sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else []
company_data = CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=sorted_sectors,
sectors=company.sectors,
)
company_data_list.append(company_data)
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=company_data_list,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
return company_data_list
@router.get("/companies/filter", response_model=PaginatedResponse[CompanyData])
@router.get("/companies/filter", response_model=List[CompanyData])
def filter_companies(
industry: Optional[str] = Query(
None, description="Filter by industry (partial match)"
@@ -101,11 +76,9 @@ def filter_companies(
investor_name: Optional[str] = Query(
None, description="Filter by investor name (partial match)"
),
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
db: Session = Depends(get_db),
):
"""Filter companies based on various criteria (paginated)"""
"""Filter companies based on various criteria"""
# Start with base query
query = db.query(CompanyTable).options(
@@ -139,36 +112,20 @@ def filter_companies(
InvestorTable.name.ilike(f"%{investor_name}%")
)
# Get total count before pagination
total_count = query.count()
# Calculate offset and apply pagination
offset = (page - 1) * page_size
companies = query.offset(offset).limit(page_size).all()
companies = query.all()
# Transform to CompanyData format
company_data_list = []
for company in companies:
# Sort sectors alphabetically
sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else []
company_data = CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=sorted_sectors,
sectors=company.sectors,
)
company_data_list.append(company_data)
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=company_data_list,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
return company_data_list
@router.get("/companies/{company_id}", response_model=CompanyData)
@@ -188,15 +145,12 @@ def read_company(company_id: int, db: Session = Depends(get_db)):
if not company:
raise HTTPException(status_code=404, detail="Company not found")
# Sort sectors alphabetically
sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else []
# Transform to CompanyData format
return CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=sorted_sectors,
sectors=company.sectors,
)
@@ -257,15 +211,12 @@ def update_company(
.first()
)
# Sort sectors alphabetically
sorted_sectors = sorted(company_with_relations.sectors, key=lambda s: s.name) if company_with_relations.sectors else []
# Transform to CompanyData format
return CompanyData(
company=company_with_relations,
investors=company_with_relations.investors,
members=company_with_relations.members,
sectors=sorted_sectors,
sectors=company_with_relations.sectors,
)
-204
View File
@@ -1,204 +0,0 @@
import os
from typing import List
from db.db import get_db
from db.models import InvestorTable
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from services.crm import FolkAPI
from sqlalchemy.orm import Session, selectinload
router = APIRouter(prefix="/folk", tags=["Folk CRM"])
def get_folk_client():
"""Get Folk API client with loaded environment variables"""
return FolkAPI(api_key=os.environ.get("FOLK_API_KEY", ""))
class GroupResponse(BaseModel):
id: str
name: str
class SyncInvestorsRequest(BaseModel):
investor_ids: List[int]
group_id: str
class SyncResult(BaseModel):
investor_id: int
investor_name: str
company_id: str
company_name: str
team_members_synced: int
person_ids: List[str]
class SyncInvestorsResponse(BaseModel):
success: bool
synced_count: int
results: List[SyncResult]
errors: List[dict]
@router.get("/groups", response_model=List[GroupResponse])
def get_folk_groups():
"""Get all groups from Folk CRM.
Returns a list of groups with their id and name that can be used
to sync investors to Folk.
"""
try:
folk = get_folk_client()
groups_data = folk.get_groups()
items = groups_data.get("data", {}).get("items", [])
return [GroupResponse(id=item["id"], name=item["name"]) for item in items]
except Exception as e:
raise HTTPException(
status_code=500, detail=f"Failed to fetch groups from Folk: {str(e)}"
)
@router.post("/sync-investors", response_model=SyncInvestorsResponse)
def sync_investors_to_folk(
request: SyncInvestorsRequest, db: Session = Depends(get_db)
):
"""Sync investors to Folk CRM as companies with their team members as people.
Takes a list of investor IDs and a Folk group ID, then:
1. Creates each investor as a company in the specified Folk group
2. Creates each team member as a person linked to that company
Args:
investor_ids: List of investor IDs from the database
group_id: Folk group ID where investors should be added
Returns:
Summary of sync operation including successes and errors
"""
folk = get_folk_client()
# Fetch investors with their team members
investors = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
)
.filter(InvestorTable.id.in_(request.investor_ids))
.all()
)
if not investors:
raise HTTPException(
status_code=404, detail="No investors found with the provided IDs"
)
results = []
errors = []
for investor in investors:
try:
# Create company in Folk
company_data = folk.create_company(
name=investor.name,
group_id=request.group_id,
website=investor.website,
description=investor.description,
addresses=[investor.headquarters] if investor.headquarters else None,
)
company_id = company_data.get("data", {}).get("id")
if not company_id:
errors.append(
{
"investor_id": investor.id,
"investor_name": investor.name,
"error": "No company ID returned from Folk API",
}
)
continue
# Create team members as people
person_ids = []
team_members_synced = 0
for member in investor.team_members:
try:
# Extract first name and last name from full name
name_parts = member.name.split(maxsplit=1)
first_name = name_parts[0] if name_parts else member.name
last_name = name_parts[1] if len(name_parts) > 1 else ""
# Build URLs list from source_url if available
urls_list = None
if hasattr(member, "source_url") and member.source_url:
urls_list = [member.source_url]
# Get LinkedIn URL if available
linkedin_url = None
if hasattr(member, "linkedin") and member.linkedin:
linkedin_url = member.linkedin
# Build job title from title or role
job_title = None
if hasattr(member, "title") and member.title:
job_title = member.title
elif hasattr(member, "role") and member.role:
job_title = member.role
person_data = folk.create_person(
first_name=first_name,
last_name=last_name,
email=member.email,
company_id=company_id,
group_id=request.group_id,
linkedin_url=linkedin_url,
urls=urls_list,
jobTitle=job_title,
)
person_id = person_data.get("data", {}).get("id")
if person_id:
person_ids.append(person_id)
team_members_synced += 1
except Exception as person_error:
# Log person creation error but continue with other members
errors.append(
{
"investor_id": investor.id,
"investor_name": investor.name,
"team_member_name": member.name,
"error": f"Failed to create person: {str(person_error)}",
}
)
results.append(
SyncResult(
investor_id=investor.id,
investor_name=investor.name,
company_id=company_id,
company_name=company_data.get("data", {}).get(
"name", investor.name
),
team_members_synced=team_members_synced,
person_ids=person_ids,
)
)
except Exception as e:
errors.append(
{
"investor_id": investor.id,
"investor_name": investor.name,
"error": str(e),
}
)
return SyncInvestorsResponse(
success=len(results) > 0,
synced_count=len(results),
results=results,
errors=errors,
)
-122
View File
@@ -1,122 +0,0 @@
from datetime import datetime, timedelta
from typing import Optional
from db.db import get_db
from db.models import InvestorInsightCache, InvestorTable, ProjectTable
from fastapi import APIRouter, Depends, HTTPException
from schemas.insight_schema import InsightResponse
from services.compatibility_score import (
calculate_project_investor_compatibility,
generate_compatibility_explanation,
)
from services.insight import QueryProcessor
from sqlalchemy.orm import Session
router = APIRouter()
@router.get(
"/insights/{investor_id}", response_model=InsightResponse, tags=["Insights"]
)
async def get_insights(
investor_id: int, project_id: Optional[int] = None, db: Session = Depends(get_db)
):
"""
Get investor insights including investment pattern analysis, market position,
and optionally compatibility score with a project.
Args:
investor_id: The ID of the investor to analyze
project_id: Optional project ID to calculate compatibility score
Returns:
InsightResponse with investment_pattern_analysis, market_position,
and compatibility_score (if project_id provided)
"""
# Get investor from database
investor = db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
if not investor:
raise HTTPException(
status_code=404, detail=f"Investor with id {investor_id} not found"
)
# Check if we have cached insights
cached_insights = (
db.query(InvestorInsightCache)
.filter(InvestorInsightCache.investor_id == investor_id)
.first()
)
# Determine if cache needs refresh (older than 1 month)
needs_refresh = True
if cached_insights:
# Calculate if cache is older than 1 month
cache_age = (
datetime.now(cached_insights.last_refreshed.tzinfo)
- cached_insights.last_refreshed
)
needs_refresh = cache_age > timedelta(days=30)
# Fetch new insights if needed
if needs_refresh:
# Initialize the query processor for insights
query_processor = QueryProcessor()
# Get investment pattern analysis and market position using web search
insights = await query_processor.get_investor_insights(
investor_name=investor.name,
investor_website=investor.website,
investor_description=investor.description,
investor_headquarters=investor.headquarters,
investment_thesis=investor.investment_thesis,
portfolio_highlights=investor.portfolio_highlights,
)
# Update or create cache entry
if cached_insights:
# Update existing cache
cached_insights.investment_pattern_analysis = insights[
"investment_pattern_analysis"
]
cached_insights.market_position = insights["market_position"]
cached_insights.last_refreshed = datetime.now(
cached_insights.last_refreshed.tzinfo
)
else:
# Create new cache entry
cached_insights = InvestorInsightCache(
investor_id=investor_id,
investment_pattern_analysis=insights["investment_pattern_analysis"],
market_position=insights["market_position"],
)
db.add(cached_insights)
db.commit()
db.refresh(cached_insights)
# Calculate compatibility score if project_id is provided
compatibility_score = None
if project_id:
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not project:
raise HTTPException(
status_code=404, detail=f"Project with id {project_id} not found"
)
# Calculate the compatibility score
score = calculate_project_investor_compatibility(
project, investor, use_funds=True
)
# Generate detailed explanation
compatibility_score = generate_compatibility_explanation(
project, investor, score, use_funds=True
)
else:
compatibility_score = "Select a project to see compatibility analysis"
return InsightResponse(
investment_pattern_analysis=cached_insights.investment_pattern_analysis,
market_position=cached_insights.market_position,
compatibility_score=compatibility_score,
)
+114 -494
View File
@@ -1,21 +1,11 @@
from typing import Optional
from typing import List, Optional
from db.db import get_db
from db.models import FundTable, InvestorTable, ProjectTable, SectorTable
from db.models import InvestorTable, SectorTable
from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel
from schemas.router_schemas import (
CompanyMinimal,
InvestmentResponse,
InvestmentStage,
InvestorData,
PaginatedResponse,
SectorMinimal,
)
from services.compatibility_score import (
_calculate_project_fund_compatibility,
_calculate_project_investor_direct_compatibility,
)
from schemas.router_schemas import InvestmentStage, InvestorData
from services.querying import QueryProcessor
from sqlalchemy.orm import Session, selectinload
router = APIRouter(tags=["Investor Routes"])
@@ -25,189 +15,53 @@ router = APIRouter(tags=["Investor Routes"])
class InvestorCreate(BaseModel):
name: str
description: Optional[str] = None
website: Optional[str] = None
headquarters: Optional[str] = None
aum: int
check_size_lower: int
check_size_upper: int
geographic_focus: str
stage_focus: InvestmentStage
number_of_investments: int = 0
class InvestorUpdate(BaseModel):
name: Optional[str] = None
description: Optional[str] = None
website: Optional[str] = None
headquarters: Optional[str] = None
aum: Optional[int] = None
check_size_lower: Optional[int] = None
check_size_upper: Optional[int] = None
geographic_focus: Optional[str] = None
stage_focus: Optional[InvestmentStage] = None
number_of_investments: Optional[int] = None
@router.get("/investors", response_model=PaginatedResponse[InvestmentResponse])
def read_investors(
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
project_id: Optional[int] = Query(
None, description="Optional project ID for compatibility scoring"
),
db: Session = Depends(get_db),
):
"""Get all investors with their funds as separate entries (paginated)
Each investor-fund combination is returned as a separate row.
An investor with 3 funds will appear as 3 entries.
If project_id is provided, calculates compatibility scores for each investor.
"""
# Calculate offset
offset = (page - 1) * page_size
# Get total count
total_count = db.query(InvestorTable).count()
# Load project if project_id provided
project = None
if project_id is not None:
project = (
db.query(ProjectTable)
.options(selectinload(ProjectTable.sector))
.filter(ProjectTable.id == project_id)
.first()
@router.get("/investors", response_model=List[InvestorData])
def read_investors(db: Session = Depends(get_db)):
"""Get all investors with their related data"""
investors = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# When project_id is provided, we need to get all investors first to sort by compatibility score
# Otherwise, we can paginate at the database level
if project is not None:
# Get all investors (we'll sort by compatibility score, then paginate)
all_investors = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds).selectinload(
FundTable.investment_stages
),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
)
.all()
)
# We'll paginate after sorting by compatibility score
investors = all_investors
else:
# Get paginated results (no sorting needed)
investors = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds).selectinload(
FundTable.investment_stages
),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
)
.offset(offset)
.limit(page_size)
.all()
)
# Transform to InvestmentResponse format (one row per investor-fund combination)
investment_responses = []
for investor in investors:
# Get top 3 portfolio companies (id and name only)
portfolio_companies = [
CompanyMinimal(id=company.id, name=company.name)
for company in investor.portfolio_companies[:3]
]
# If investor has funds, create one entry per fund
if investor.funds:
for fund in investor.funds:
# Calculate compatibility score for this specific fund
compatibility_score = 1.0
if project is not None:
compatibility_score = _calculate_project_fund_compatibility(
project=project, fund=fund
)
# Get stage focus as comma-separated string
stage_focus = (
", ".join([stage.name for stage in fund.investment_stages])
if fund.investment_stages
else None
)
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in sorted(
fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name
)
]
investment_response = InvestmentResponse(
id=investor.id,
name=f"{investor.name} - {fund.fund_name}"
if fund.fund_name
else investor.name,
aum=investor.aum,
check_size_lower=fund.check_size_lower,
check_size_upper=fund.check_size_upper,
geographic_focus=fund.geographic_focus,
stage_focus=stage_focus,
portfolio_companies=portfolio_companies,
sectors=fund_sectors,
compatibility_score=compatibility_score,
)
investment_responses.append(investment_response)
else:
# If no funds, create one entry with null fund fields
# Calculate compatibility using investor-level data
compatibility_score = 1.0
if project is not None:
compatibility_score = _calculate_project_investor_direct_compatibility(
project=project, investor=investor
)
investment_response = InvestmentResponse(
id=investor.id,
name=investor.name,
aum=investor.aum,
check_size_lower=None,
check_size_upper=None,
geographic_focus=None,
stage_focus=None,
portfolio_companies=portfolio_companies,
sectors=[],
compatibility_score=compatibility_score,
)
investment_responses.append(investment_response)
# Sort by compatibility score (descending) when project_id is provided
if project is not None:
investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True)
# Apply pagination after sorting
investment_responses = investment_responses[offset : offset + page_size]
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=investment_responses,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
.all()
)
# Transform InvestorTable objects to InvestorData format
investor_data_list = []
for investor in investors:
investor_data = InvestorData(
investor=investor, # This maps to InvestorSchema
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members,
sectors=investor.sectors,
)
investor_data_list.append(investor_data)
@router.get("/investors/filter", response_model=PaginatedResponse[InvestmentResponse])
return investor_data_list
@router.get("/investors/filter", response_model=List[InvestorData])
def filter_investors(
stage: Optional[InvestmentStage] = Query(
None, description="Filter by investment stage"
@@ -220,161 +74,67 @@ def filter_investors(
sector: Optional[str] = Query(None, description="Sector name (partial match)"),
min_aum: Optional[int] = Query(None, description="Minimum AUM"),
max_aum: Optional[int] = Query(None, description="Maximum AUM"),
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
project_id: Optional[int] = Query(
None, description="Optional project ID for compatibility scoring"
),
db: Session = Depends(get_db),
):
"""Filter investors based on various criteria (paginated)
"""Filter investors based on various criteria"""
Returns investor-fund combinations as separate rows.
Queries the funds table to find matching funds.
If project_id is provided, calculates compatibility scores for each investor.
"""
# Load project if project_id provided
project = None
if project_id is not None:
project = (
db.query(ProjectTable)
.options(selectinload(ProjectTable.sector))
.filter(ProjectTable.id == project_id)
.first()
)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Start with base query on funds table
query = db.query(FundTable).options(
selectinload(FundTable.investor).selectinload(
InvestorTable.portfolio_companies
),
selectinload(FundTable.investor).selectinload(InvestorTable.team_members),
selectinload(FundTable.investor).selectinload(InvestorTable.sectors),
selectinload(FundTable.investment_stages),
selectinload(FundTable.sectors),
# Start with base query
query = db.query(InvestorTable).options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
)
# Apply filters at fund level
# Apply filters
if stage:
query = query.filter(InvestorTable.stage_focus == stage)
if min_check_size is not None:
query = query.filter(FundTable.check_size_lower >= min_check_size)
query = query.filter(InvestorTable.check_size_lower >= min_check_size)
if max_check_size is not None:
query = query.filter(FundTable.check_size_upper <= max_check_size)
query = query.filter(InvestorTable.check_size_upper <= max_check_size)
if geography:
query = query.filter(FundTable.geographic_focus.ilike(f"%{geography}%"))
query = query.filter(InvestorTable.geographic_focus.ilike(f"%{geography}%"))
# Apply filters at investor level (through relationship)
if min_aum is not None:
query = query.join(FundTable.investor).filter(InvestorTable.aum >= min_aum)
query = query.filter(InvestorTable.aum >= min_aum)
if max_aum is not None:
if min_aum is None: # Only join if not already joined
query = query.join(FundTable.investor)
query = query.filter(InvestorTable.aum <= max_aum)
# Filter by sector if provided (at fund level)
# Filter by sector if provided
if sector:
query = query.join(FundTable.sectors).filter(
query = query.join(InvestorTable.sectors).filter(
SectorTable.name.ilike(f"%{sector}%")
)
# Get total count before pagination
total_count = query.count()
investors = query.all()
# When project_id is provided, we need to get all funds first to sort by compatibility score
# Otherwise, we can paginate at the database level
if project is not None:
# Get all funds (we'll sort by compatibility score, then paginate)
all_funds = query.all()
funds = all_funds
else:
# Calculate offset and apply pagination (no sorting needed)
offset = (page - 1) * page_size
funds = query.offset(offset).limit(page_size).all()
# Transform to InvestmentResponse format (one row per fund)
investment_responses = []
for fund in funds:
investor = fund.investor
# Calculate compatibility score for this specific fund
compatibility_score = 1.0
if project is not None:
compatibility_score = _calculate_project_fund_compatibility(
project=project, fund=fund
)
# Get top 3 portfolio companies (id and name only)
portfolio_companies = [
CompanyMinimal(id=company.id, name=company.name)
for company in investor.portfolio_companies[:3]
]
# Get stage focus as comma-separated string
stage_focus = (
", ".join([stage.name for stage in fund.investment_stages])
if fund.investment_stages
else None
# Transform to InvestorData format
investor_data_list = []
for investor in investors:
investor_data = InvestorData(
investor=investor,
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members,
sectors=investor.sectors,
)
investor_data_list.append(investor_data)
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in sorted(
fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name
)
]
investment_response = InvestmentResponse(
id=investor.id,
name=f"{investor.name} - {fund.fund_name}"
if fund.fund_name
else investor.name,
aum=investor.aum,
check_size_lower=fund.check_size_lower,
check_size_upper=fund.check_size_upper,
geographic_focus=fund.geographic_focus,
stage_focus=stage_focus,
portfolio_companies=portfolio_companies,
sectors=fund_sectors,
compatibility_score=compatibility_score,
)
investment_responses.append(investment_response)
# Sort by compatibility score (descending) when project_id is provided
if project is not None:
investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True)
# Apply pagination after sorting
offset = (page - 1) * page_size
investment_responses = investment_responses[offset : offset + page_size]
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=investment_responses,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
return investor_data_list
@router.get("/investors/{investor_id}", response_model=InvestorData)
def read_investor(investor_id: int, db: Session = Depends(get_db)):
"""Get a specific investor by ID with all their funds"""
"""Get a specific investor by ID"""
investor = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds),
)
.filter(InvestorTable.id == investor_id)
.first()
@@ -383,13 +143,12 @@ def read_investor(investor_id: int, db: Session = Depends(get_db)):
if not investor:
raise HTTPException(status_code=404, detail="Investor not found")
# Transform to InvestorData format (includes funds array)
# Transform to InvestorData format
return InvestorData(
investor=investor,
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members,
sectors=investor.sectors,
funds=investor.funds,
)
@@ -408,7 +167,6 @@ def create_investor(investor: InvestorCreate, db: Session = Depends(get_db)):
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds),
)
.filter(InvestorTable.id == db_investor.id)
.first()
@@ -420,91 +178,24 @@ def create_investor(investor: InvestorCreate, db: Session = Depends(get_db)):
portfolio_companies=investor_with_relations.portfolio_companies,
team_members=investor_with_relations.team_members,
sectors=investor_with_relations.sectors,
funds=investor_with_relations.funds,
)
@router.put("/investors/{investor_id}", response_model=InvestorData)
def update_investor(
investor_id: int, investor: InvestorUpdate, db: Session = Depends(get_db)
):
"""Update an existing investor"""
db_investor = (
db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
)
if not db_investor:
raise HTTPException(status_code=404, detail="Investor not found")
update_data = investor.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(db_investor, field, value)
db.commit()
db.refresh(db_investor)
# Reload with relationships
investor_with_relations = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds),
)
.filter(InvestorTable.id == investor_id)
.first()
)
# Transform to InvestorData format
return InvestorData(
investor=investor_with_relations,
portfolio_companies=investor_with_relations.portfolio_companies,
team_members=investor_with_relations.team_members,
sectors=investor_with_relations.sectors,
funds=investor_with_relations.funds,
)
@router.delete("/investors/{investor_id}")
def delete_investor(investor_id: int, db: Session = Depends(get_db)):
"""Delete an investor"""
db_investor = (
db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
)
if not db_investor:
raise HTTPException(status_code=404, detail="Investor not found")
db.delete(db_investor)
db.commit()
return {"message": "Investor deleted successfully"}
@router.get(
"/investors/{investor_id}/similar",
response_model=PaginatedResponse[InvestmentResponse],
)
@router.get("/investors/{investor_id}/similar", response_model=List[InvestorData])
def find_similar_investors(
investor_id: int,
limit: int = Query(10, description="Maximum number of similar investors to return"),
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
db: Session = Depends(get_db),
db: Session = Depends(get_db)
):
"""Find investors similar to a given investor based on characteristics (paginated)
"""Find investors similar to a given investor based on characteristics"""
Returns investor-fund combinations as separate rows.
Queries the funds table to find matching funds.
"""
# Get the target investor to get their funds for comparison
# Get the target investor
target_investor = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
)
.filter(InvestorTable.id == investor_id)
.first()
@@ -513,149 +204,78 @@ def find_similar_investors(
if not target_investor:
raise HTTPException(status_code=404, detail="Investor not found")
# Get target investor's sector IDs for comparison (from their funds)
target_sector_ids = set()
target_stage_ids = set()
target_check_ranges = []
target_geographies = []
# Get target investor's sector IDs for comparison
target_sector_ids = {sector.id for sector in target_investor.sectors}
for fund in target_investor.funds:
if fund.sectors:
target_sector_ids.update({sector.id for sector in fund.sectors})
if fund.investment_stages:
target_stage_ids.update({stage.id for stage in fund.investment_stages})
if fund.check_size_lower and fund.check_size_upper:
target_check_ranges.append((fund.check_size_lower, fund.check_size_upper))
if fund.geographic_focus:
target_geographies.append(fund.geographic_focus.lower())
# Query all funds from other investors
candidate_funds = (
db.query(FundTable)
# Query all other investors with their relationships
candidates = (
db.query(InvestorTable)
.options(
selectinload(FundTable.investor).selectinload(
InvestorTable.portfolio_companies
),
selectinload(FundTable.investor).selectinload(InvestorTable.team_members),
selectinload(FundTable.investor).selectinload(InvestorTable.sectors),
selectinload(FundTable.investment_stages),
selectinload(FundTable.sectors),
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
)
.join(FundTable.investor)
.filter(InvestorTable.id != investor_id)
.all()
)
# Calculate similarity scores for each fund
scored_funds = []
for fund in candidate_funds:
# Calculate similarity scores
scored_investors = []
for candidate in candidates:
score = 0
# Stage focus match (30 points)
if candidate.stage_focus == target_investor.stage_focus:
score += 30
# Geographic focus match (20 points for exact, 10 for partial)
if fund.geographic_focus and target_geographies:
fund_geo_lower = fund.geographic_focus.lower()
for target_geo in target_geographies:
if fund_geo_lower == target_geo:
score += 20
break
elif fund_geo_lower in target_geo or target_geo in fund_geo_lower:
score += 10
break
if candidate.geographic_focus and target_investor.geographic_focus:
if candidate.geographic_focus.lower() == target_investor.geographic_focus.lower():
score += 20
elif (candidate.geographic_focus.lower() in target_investor.geographic_focus.lower() or
target_investor.geographic_focus.lower() in candidate.geographic_focus.lower()):
score += 10
# Check size overlap (20 points max)
if fund.check_size_lower and fund.check_size_upper and target_check_ranges:
max_overlap_score = 0
for target_lower, target_upper in target_check_ranges:
overlap_start = max(fund.check_size_lower, target_lower)
overlap_end = min(fund.check_size_upper, target_upper)
if overlap_end > overlap_start:
overlap = overlap_end - overlap_start
target_range = target_upper - target_lower
overlap_ratio = overlap / target_range if target_range > 0 else 0
max_overlap_score = max(max_overlap_score, int(20 * overlap_ratio))
score += max_overlap_score
if (candidate.check_size_lower and candidate.check_size_upper and
target_investor.check_size_lower and target_investor.check_size_upper):
# Calculate overlap percentage
overlap_start = max(candidate.check_size_lower, target_investor.check_size_lower)
overlap_end = min(candidate.check_size_upper, target_investor.check_size_upper)
if overlap_end > overlap_start:
overlap = overlap_end - overlap_start
target_range = target_investor.check_size_upper - target_investor.check_size_lower
overlap_ratio = overlap / target_range if target_range > 0 else 0
score += int(20 * overlap_ratio)
# AUM similarity (15 points max)
if fund.investor.aum and target_investor.aum:
aum_diff = abs(fund.investor.aum - target_investor.aum)
max_aum = max(fund.investor.aum, target_investor.aum)
if candidate.aum and target_investor.aum:
aum_diff = abs(candidate.aum - target_investor.aum)
max_aum = max(candidate.aum, target_investor.aum)
similarity_ratio = 1 - (aum_diff / max_aum) if max_aum > 0 else 0
score += int(15 * similarity_ratio)
# Sector overlap (30 points max)
if fund.sectors and target_sector_ids:
fund_sector_ids = {sector.id for sector in fund.sectors}
common_sectors = target_sector_ids.intersection(fund_sector_ids)
candidate_sector_ids = {sector.id for sector in candidate.sectors}
if target_sector_ids and candidate_sector_ids:
common_sectors = target_sector_ids.intersection(candidate_sector_ids)
overlap_ratio = len(common_sectors) / len(target_sector_ids)
score += int(30 * overlap_ratio)
# Investment stage match (15 points max)
if fund.investment_stages and target_stage_ids:
fund_stage_ids = {stage.id for stage in fund.investment_stages}
common_stages = target_stage_ids.intersection(fund_stage_ids)
overlap_ratio = len(common_stages) / len(target_stage_ids)
score += int(15 * overlap_ratio)
if score > 0: # Only include investors with some similarity
scored_investors.append((score, candidate))
if score > 0: # Only include funds with some similarity
scored_funds.append((score, fund))
# Sort by score (descending) and take top N
scored_investors.sort(key=lambda x: x[0], reverse=True)
similar_investors = [inv for score, inv in scored_investors[:limit]]
# Sort by score (descending) and take top N based on limit
scored_funds.sort(key=lambda x: x[0], reverse=True)
top_similar = scored_funds[:limit]
# Apply pagination to the top similar funds
total_count = len(top_similar)
offset = (page - 1) * page_size
paginated_similar = top_similar[offset : offset + page_size]
similar_funds = [fund for score, fund in paginated_similar]
# Transform to InvestmentResponse format (one row per fund)
investment_responses = []
for fund in similar_funds:
investor = fund.investor
# Get top 3 portfolio companies (id and name only)
portfolio_companies = [
CompanyMinimal(id=company.id, name=company.name)
for company in investor.portfolio_companies[:3]
]
# Get stage focus as comma-separated string
stage_focus = (
", ".join([stage.name for stage in fund.investment_stages])
if fund.investment_stages
else None
# Transform to InvestorData format
return [
InvestorData(
investor=inv,
portfolio_companies=inv.portfolio_companies,
team_members=inv.team_members,
sectors=inv.sectors,
)
# Get top 3 sectors from fund (id and name only)
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in (fund.sectors[:3] if fund.sectors else [])
]
investment_response = InvestmentResponse(
id=investor.id,
name=f"{investor.name} - {fund.fund_name}"
if fund.fund_name
else investor.name,
aum=investor.aum,
check_size_lower=fund.check_size_lower,
check_size_upper=fund.check_size_upper,
geographic_focus=fund.geographic_focus,
stage_focus=stage_focus,
portfolio_companies=portfolio_companies,
sectors=fund_sectors,
compatibility_score=1.0,
)
investment_responses.append(investment_response)
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=investment_responses,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
for inv in similar_investors
]
+11 -145
View File
@@ -14,45 +14,21 @@ from schemas.project_schemas import (
ProjectData,
ProjectUpdate,
)
from schemas.router_schemas import PaginatedResponse
from sqlalchemy.orm import Session, selectinload
router = APIRouter(tags=["Project Routes"])
@router.get("/projects", response_model=PaginatedResponse[ProjectData])
def read_projects(
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
include_archived: bool = Query(False, description="Include archived projects"),
db: Session = Depends(get_db),
):
"""Get all projects with their related data (paginated)
By default, archived projects are excluded. Set include_archived=True to include them.
"""
# Calculate offset
offset = (page - 1) * page_size
# Start with base query
query = db.query(ProjectTable)
# Filter out archived projects by default
if not include_archived:
query = query.filter(ProjectTable.is_archived == 0)
# Get total count
total_count = query.count()
# Get paginated results
@router.get("/projects", response_model=List[ProjectData])
def read_projects(db: Session = Depends(get_db)):
"""Get all projects with their related data"""
projects = (
query.options(
db.query(ProjectTable)
.options(
selectinload(ProjectTable.sector),
selectinload(ProjectTable.investors),
selectinload(ProjectTable.companies),
)
.offset(offset)
.limit(page_size)
.all()
)
@@ -67,16 +43,7 @@ def read_projects(
)
project_data_list.append(project_data)
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=project_data_list,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
return project_data_list
@router.get("/projects/{project_id}", response_model=ProjectData)
@@ -172,7 +139,7 @@ def update_project(
@router.delete("/projects/{project_id}")
def delete_project(project_id: int, db: Session = Depends(get_db)):
"""Delete a project permanently"""
"""Delete a project"""
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not db_project:
@@ -184,88 +151,7 @@ def delete_project(project_id: int, db: Session = Depends(get_db)):
return {"message": "Project deleted successfully"}
@router.post("/projects/{project_id}/archive")
def archive_project(project_id: int, db: Session = Depends(get_db)):
"""Archive a project (soft delete)"""
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not db_project:
raise HTTPException(status_code=404, detail="Project not found")
db_project.is_archived = 1
db.commit()
db.refresh(db_project)
return {"message": "Project archived successfully", "project_id": project_id}
@router.post("/projects/{project_id}/unarchive")
def unarchive_project(project_id: int, db: Session = Depends(get_db)):
"""Unarchive a project (restore from archive)"""
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not db_project:
raise HTTPException(status_code=404, detail="Project not found")
db_project.is_archived = 0
db.commit()
db.refresh(db_project)
return {"message": "Project unarchived successfully", "project_id": project_id}
@router.get("/projects/archived", response_model=PaginatedResponse[ProjectData])
def read_archived_projects(
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
db: Session = Depends(get_db),
):
"""Get all archived projects (paginated)"""
# Calculate offset
offset = (page - 1) * page_size
# Query only archived projects
query = db.query(ProjectTable).filter(ProjectTable.is_archived == 1)
# Get total count
total_count = query.count()
# Get paginated results
projects = (
query.options(
selectinload(ProjectTable.sector),
selectinload(ProjectTable.investors),
selectinload(ProjectTable.companies),
)
.offset(offset)
.limit(page_size)
.all()
)
# Transform ProjectTable objects to ProjectData format
project_data_list = []
for project in projects:
project_data = ProjectData(
project=project,
sector=project.sector,
investors=project.investors,
companies=project.companies,
)
project_data_list.append(project_data)
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=project_data_list,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
@router.get("/projects/filter", response_model=PaginatedResponse[ProjectData])
@router.get("/projects/filter", response_model=List[ProjectData])
def filter_projects(
stage: Optional[InvestmentStage] = Query(
None, description="Filter by project stage"
@@ -273,7 +159,6 @@ def filter_projects(
min_valuation: Optional[int] = Query(None, description="Minimum valuation"),
max_valuation: Optional[int] = Query(None, description="Maximum valuation"),
location: Optional[str] = Query(None, description="Location (partial match)"),
industry: Optional[str] = Query(None, description="Industry (partial match)"),
sector: Optional[str] = Query(None, description="Sector name (partial match)"),
investor_name: Optional[str] = Query(
None, description="Investor name (partial match)"
@@ -281,11 +166,9 @@ def filter_projects(
company_name: Optional[str] = Query(
None, description="Company name (partial match)"
),
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
db: Session = Depends(get_db),
):
"""Filter projects based on various criteria (paginated)"""
"""Filter projects based on various criteria"""
# Start with base query
query = db.query(ProjectTable).options(
@@ -307,9 +190,6 @@ def filter_projects(
if location:
query = query.filter(ProjectTable.location.ilike(f"%{location}%"))
if industry:
query = query.filter(ProjectTable.industry.ilike(f"%{industry}%"))
if sector:
query = query.join(ProjectTable.sector).filter(
SectorTable.name.ilike(f"%{sector}%")
@@ -325,12 +205,7 @@ def filter_projects(
CompanyTable.name.ilike(f"%{company_name}%")
)
# Get total count before pagination
total_count = query.count()
# Calculate offset and apply pagination
offset = (page - 1) * page_size
projects = query.offset(offset).limit(page_size).all()
projects = query.all()
# Transform to ProjectData format
project_data_list = []
@@ -343,16 +218,7 @@ def filter_projects(
)
project_data_list.append(project_data)
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=project_data_list,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
return project_data_list
# Association management routes
-118
View File
@@ -1,118 +0,0 @@
from typing import Optional
from db.db import get_db
from db.models import FundTable, InvestorTable, ProjectTable
from fastapi import APIRouter, Depends, HTTPException, Query
from fastapi.responses import Response
from services.report_gen import ReportGenerator
from sqlalchemy.orm import Session, selectinload
router = APIRouter(tags=["Report Generation"])
@router.get("/report/investor/{investor_id}")
async def generate_investor_report(
investor_id: int,
project_id: Optional[int] = Query(
None, description="Optional project ID for compatibility analysis"
),
db: Session = Depends(get_db),
):
"""
Generate a PDF report for an investor profile.
Args:
investor_id: The ID of the investor to generate a report for
project_id: Optional project ID to include mandate match analysis
Returns:
PDF file as a downloadable response
"""
# Fetch investor data with all relationships
investor = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
)
.filter(InvestorTable.id == investor_id)
.first()
)
if not investor:
raise HTTPException(status_code=404, detail="Investor not found")
# Prepare investor data dictionary
investor_data = {
"name": investor.name,
"description": investor.description,
"website": investor.website,
"headquarters": investor.headquarters,
"aum": investor.aum,
"portfolio_highlights": investor.portfolio_highlights or [],
"investment_thesis": investor.investment_thesis or [],
"sectors": [sector.name for sector in investor.sectors],
"team_members": [
{
"name": member.name,
"role": member.role,
"title": member.title,
"email": member.email,
}
for member in investor.team_members
],
"funds": [],
}
# Get all funds with their data
if investor.funds:
for fund in investor.funds:
fund_data = {
"fund_name": fund.fund_name,
"fund_size": fund.fund_size,
"check_size_lower": fund.check_size_lower,
"check_size_upper": fund.check_size_upper,
"geographic_focus": fund.geographic_focus,
"investment_stages": [stage.name for stage in fund.investment_stages],
"sectors": [sector.name for sector in fund.sectors],
}
investor_data["funds"].append(fund_data)
# Fetch project data if project_id is provided
project_data = None
if project_id:
project = (
db.query(ProjectTable)
.options(selectinload(ProjectTable.sector))
.filter(ProjectTable.id == project_id)
.first()
)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
project_data = {
"name": project.name,
"description": project.description,
"location": project.location,
"valuation": project.valuation,
"stage": project.stage.name if project.stage else None,
"sectors": [sector.name for sector in project.sector],
}
# Generate PDF report
report_generator = ReportGenerator()
pdf_bytes = await report_generator.generate_investor_report(
investor_data, project_data, investor_model=investor, project_model=project
)
# Return PDF as downloadable file
filename = f"{investor.name.replace(' ', '_')}_Report.pdf"
return Response(
content=pdf_bytes,
media_type="application/pdf",
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
)
Binary file not shown.
-18
View File
@@ -1,18 +0,0 @@
from typing import Optional
from pydantic import BaseModel
class InsightResponse(BaseModel):
investment_pattern_analysis: str
market_position: str
compatibility_score: Optional[str] = None
class Config:
json_schema_extra = {
"example": {
"investment_pattern_analysis": "Sequoia has been increasingly active in AI/ML startups (43% increase in last 18 months). Their average investment size has grown 23% year-over-year, indicating confidence in larger rounds. Peak activity in Q2-Q3, suggesting seasonal investment patterns.",
"market_position": "Top 3 most active VC in enterprise software deals. Strong presence in unicorn companies (47 portfolio unicorns). Consistently leads or co-leads rounds, indicating decision-making influence.",
"compatibility_score": "0.85",
}
}
+1 -4
View File
@@ -30,7 +30,7 @@ class InvestorSchema(BaseModel):
check_size_lower: int | None
check_size_upper: int | None
geographic_focus: str | None
stage_focus: Optional[InvestmentStage] = None
stage_focus: InvestmentStage
number_of_investments: int | None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
@@ -60,7 +60,6 @@ class ProjectSchema(BaseModel):
valuation: int | None
stage: InvestmentStage | None
location: str | None
industry: str | None
description: Optional[str]
start_date: Optional[datetime]
end_date: Optional[datetime]
@@ -76,7 +75,6 @@ class ProjectCreate(BaseModel):
valuation: Optional[int] = None
stage: Optional[InvestmentStage] = None
location: Optional[str] = None
industry: Optional[str] = None
description: Optional[str] = None
start_date: Optional[datetime] = None
end_date: Optional[datetime] = None
@@ -87,7 +85,6 @@ class ProjectUpdate(BaseModel):
valuation: Optional[int] = None
stage: Optional[InvestmentStage] = None
location: Optional[str] = None
industry: Optional[str] = None
description: Optional[str] = None
start_date: Optional[datetime] = None
end_date: Optional[datetime] = None
+4
View File
@@ -258,6 +258,10 @@ class InvestorSchema(BaseModel):
default=None,
description="Geographic investment focus. Do not return any special characters, Just locations separated by commas. Leave empty if not clearly identifiable.",
)
stage_focus: InvestmentStage = Field(
default=InvestmentStage.SEED,
description="Investment stage focus. Use SEED as default if uncertain.",
)
number_of_investments: Optional[int] = Field(
default=None,
ge=0,
+7 -174
View File
@@ -1,12 +1,9 @@
from datetime import datetime
from enum import Enum
from typing import Any, Generic, List, Optional, TypeVar
from typing import List, Optional
from pydantic import BaseModel
# Generic type for pagination
T = TypeVar("T")
class InvestmentStage(str, Enum):
SEED = "SEED"
@@ -25,39 +22,11 @@ class SectorSchema(BaseModel):
from_attributes = True
class InvestmentStageSchema(BaseModel):
id: int
name: str
class Config:
from_attributes = True
class InvestorMemberSchema(BaseModel):
id: int
name: str
role: str | None
email: str | None
linkedin: str | None
class Config:
from_attributes = True
class FundSchema(BaseModel):
id: int
fund_name: str | None
fund_size: int | None # Changed to int for numerical filtering
fund_size_source_url: str | None
check_size_lower: int | None # NEW: Lower bound of check size range
check_size_upper: int | None # NEW: Upper bound of check size range
source_url: str | None
source_provider: str | None
geographic_focus: str | None # Changed from List[str] to string
investment_stages: List[InvestmentStageSchema] | None # Changed to relationship
sectors: List[SectorSchema] | None # Changed to relationship
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
@@ -93,20 +62,11 @@ class InvestorSchema(BaseModel):
id: int
name: str
description: Optional[str]
website: Optional[str] = None
headquarters: Optional[str] = None
aum: int | None
aum_as_of_date: str | None = None
aum_source_url: str | None = None
check_size_lower: int | None
check_size_upper: int | None
geographic_focus: str | None
investment_thesis: Any = (
None # Flexible JSON field - can be list, dict, or list of dicts
)
portfolio_highlights: Any = (
None # Flexible JSON field - can be list, dict, or list of dicts
)
stage_focus: InvestmentStage
number_of_investments: int | None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
@@ -116,87 +76,22 @@ class InvestorSchema(BaseModel):
class InvestorData(BaseModel):
"""Comprehensive investor data schema - used for individual investor requests"""
"""Comprehensive investor data schema for LLM processing"""
investor: InvestorSchema
portfolio_companies: List[CompanySchema]
team_members: List[InvestorMemberSchema]
sectors: List[SectorSchema]
funds: List[FundSchema]
class Config:
from_attributes = True
class InvestorFundData(BaseModel):
"""Investor-Fund combined data - used for list/filter requests
Each row represents one investor-fund combination.
An investor with 3 funds will appear as 3 separate entries.
"""
# Investor fields
investor_id: int
investor_name: str
investor_description: Optional[str]
investor_website: Optional[str]
investor_headquarters: Optional[str]
aum: int | None
aum_as_of_date: str | None
aum_source_url: str | None
investment_thesis: Any = None # Flexible JSON field
portfolio_highlights: Any = None # Flexible JSON field
number_of_investments: int | None
# Fund fields
fund_id: int | None
fund_name: str | None
fund_size: int | None # Changed to int for numerical filtering
fund_size_source_url: str | None
check_size_lower: int | None # NEW: Lower bound of check size range
check_size_upper: int | None # NEW: Upper bound of check size range
geographic_focus: str | None # Changed from List[str] to string
fund_investment_stages: (
List[InvestmentStageSchema] | None
) # Changed to relationship
fund_sectors: List[SectorSchema] | None # Changed to relationship
# Related data
portfolio_companies: List[CompanySchema]
team_members: List[InvestorMemberSchema]
sectors: List[SectorSchema]
class Config:
from_attributes = True
class InvestorMinimal(BaseModel):
"""Minimal investor info with just id and name"""
id: int
name: str
class Config:
from_attributes = True
class CompanySchemaMinimal(BaseModel):
id: int
name: str
industry: str | None
location: str | None
founded_year: Optional[int]
website: Optional[str]
class Config:
from_attributes = True
class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
company: CompanySchemaMinimal
investors: List[InvestorMinimal]
members: List[CompanyMemberSchema] = []
sectors: List[SectorSchema] = []
company: CompanySchema
sectors: List[SectorSchema]
members: List[CompanyMemberSchema]
investors: List[InvestorSchema]
class Config:
from_attributes = True
@@ -204,65 +99,3 @@ class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
class InvestorList(BaseModel):
investors: List[InvestorData]
class InvestorFundList(BaseModel):
"""List of investor-fund combinations"""
investor_funds: List[InvestorFundData]
class CompanyMinimal(BaseModel):
"""Minimal company info with just id and name"""
id: int
name: str
class Config:
from_attributes = True
class SectorMinimal(BaseModel):
"""Minimal sector info with just id and name"""
id: int
name: str
class Config:
from_attributes = True
class InvestmentResponse(BaseModel):
"""Simplified investment response schema
One row per investor-fund combination with streamlined data
"""
id: int # Investor ID
name: (
str # Combination of investor name and fund name (e.g., "Investor A - Fund A")
)
aum: int | None # From investor
check_size_lower: int | None # From fund
check_size_upper: int | None # From fund
geographic_focus: str | None # From fund
stage_focus: str | None # Comma-separated stages from fund
portfolio_companies: List[CompanyMinimal] # Top 3 companies from investor
sectors: List[SectorMinimal] # Top 3 sectors from fund
compatibility_score: float # 0 to 1 (default 1 for now)
class Config:
from_attributes = True
class PaginatedResponse(BaseModel, Generic[T]):
"""Generic paginated response schema"""
items: List[T]
total: int
page: int
page_size: int
total_pages: int
class Config:
from_attributes = True
Binary file not shown.
Binary file not shown.
-228
View File
@@ -1,228 +0,0 @@
import asyncio
import hashlib
import logging
import os
from typing import List
from db.db import get_db
from db.models import CompanyTable
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from schemas.router_schemas import CompanyData, PaginatedResponse
from sqlalchemy import text
from sqlalchemy.orm import selectinload
logger = logging.getLogger(__name__)
class CompanyQueryProcessor:
def __init__(self):
self.llm = ChatOpenAI(
api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1",
model="openai/gpt-4o-mini",
temperature=0,
)
# Query cache for performance
self.query_cache = {}
# SQL generation prompt
self.sql_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a SQL expert. Generate a SQLite query to find company IDs based on user requirements.
Database Schema:
- companies: id, name, industry, location, description, founded_year, website
- company_sector: company_id, sector_id
- sectors: id, name
- investor_companies: investor_id, company_id
- investors: id, name, aum
- team_members: id, company_id, name, title
IMPORTANT RULES:
1. ALWAYS return ONLY company IDs (companies.id) - use SELECT DISTINCT c.id
2. For industry: Check BOTH industry field AND sectors table with synonyms
- Use LEFT JOIN for sectors so companies without sector tags still match
- Include related terms: 'Fintech' → c.industry LIKE '%Fintech%' OR c.industry LIKE '%Finance%' OR sec.name LIKE '%Fintech%' OR sec.name LIKE '%Financial%'
- 'AI' → c.industry LIKE '%AI%' OR c.industry LIKE '%Artificial Intelligence%' OR c.industry LIKE '%Machine Learning%' OR sec.name LIKE '%AI%' OR sec.name LIKE '%ML%'
3. For location: Be FLEXIBLE with variations and abbreviations
- 'San Francisco' → c.location LIKE '%San Francisco%' OR c.location LIKE '%SF%' OR c.location LIKE '%Bay Area%'
- 'New York' → c.location LIKE '%New York%' OR c.location LIKE '%NYC%' OR c.location LIKE '%NY%'
- 'Europe' → c.location LIKE '%Europe%' OR c.location LIKE '%UK%' OR c.location LIKE '%London%' OR c.location LIKE '%Berlin%' OR c.location LIKE '%Paris%'
4. For sectors: Use LEFT JOIN and include multiple synonyms
- 'Healthcare' → sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR c.industry LIKE '%Health%'
5. For founding year filters (include NULL to be inclusive):
- "founded after 2020" → WHERE (founded_year >= 2020 OR founded_year IS NULL)
- "founded before 2018" → WHERE (founded_year <= 2018 OR founded_year IS NULL)
- "founded in 2020" → WHERE founded_year = 2020
6. For investor-related queries: Use JOIN investor_companies
7. Use LEFT JOIN for sectors so companies without tags still match
8. Use DISTINCT to avoid duplicates from joins
9. Be INCLUSIVE - use OR conditions with synonyms and variations
10. Return a single, complete SELECT query
Example Queries:
Q: "Fintech companies founded in 2020"
A: SELECT DISTINCT c.id FROM companies c
LEFT JOIN company_sector cs ON c.id = cs.company_id
LEFT JOIN sectors sec ON cs.sector_id = sec.id
WHERE (c.industry LIKE '%Fintech%' OR c.industry LIKE '%Finance%' OR c.industry LIKE '%Financial%' OR sec.name LIKE '%Fintech%' OR sec.name LIKE '%Financial Services%')
AND c.founded_year = 2020
Q: "AI companies in San Francisco"
A: SELECT DISTINCT c.id FROM companies c
LEFT JOIN company_sector cs ON c.id = cs.company_id
LEFT JOIN sectors sec ON cs.sector_id = sec.id
WHERE (c.industry LIKE '%AI%' OR c.industry LIKE '%Artificial Intelligence%' OR c.industry LIKE '%Machine Learning%' OR sec.name LIKE '%AI%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%')
AND (c.location LIKE '%San Francisco%' OR c.location LIKE '%SF%' OR c.location LIKE '%Bay Area%')
Q: "Healthcare companies"
A: SELECT DISTINCT c.id FROM companies c
LEFT JOIN company_sector cs ON c.id = cs.company_id
LEFT JOIN sectors sec ON cs.sector_id = sec.id
WHERE c.industry LIKE '%Healthcare%' OR c.industry LIKE '%Health%' OR c.industry LIKE '%Medical%' OR sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR sec.name LIKE '%Pharma%'
Q: "Companies funded by Sequoia"
A: SELECT DISTINCT c.id FROM companies c
JOIN investor_companies ic ON c.id = ic.company_id
JOIN investors i ON ic.investor_id = i.id
WHERE i.name LIKE '%Sequoia%'
Q: "European startups founded after 2019"
A: SELECT DISTINCT c.id FROM companies c
WHERE (c.location LIKE '%Europe%' OR c.location LIKE '%UK%' OR c.location LIKE '%London%' OR c.location LIKE '%Germany%' OR c.location LIKE '%Berlin%' OR c.location LIKE '%France%' OR c.location LIKE '%Paris%')
AND (c.founded_year > 2019 OR c.founded_year IS NULL)
Q: "SaaS companies"
A: SELECT DISTINCT c.id FROM companies c
LEFT JOIN company_sector cs ON c.id = cs.company_id
LEFT JOIN sectors sec ON cs.sector_id = sec.id
WHERE c.industry LIKE '%SaaS%' OR c.industry LIKE '%Software%' OR c.industry LIKE '%Cloud%' OR sec.name LIKE '%SaaS%' OR sec.name LIKE '%Software%'
IMPORTANT:
- Use LEFT JOIN so companies without sector tags still match via industry field
- Use OR conditions with related keywords/synonyms to cast a wider net
- Include NULL checks for optional filters to avoid excluding companies with missing data
Return ONLY the SQL query, no explanations or markdown.""",
),
("user", "{question}"),
]
)
def _get_cache_key(self, question: str) -> str:
"""Generate cache key from normalized question."""
return hashlib.md5(question.lower().strip().encode()).hexdigest()
# synchronous helper is provided below as `_process_query_sync` and an
# async wrapper `process_query` runs it in a thread. This keeps the
# FastAPI event loop non-blocking while reusing the existing sync code.
async def process_query(self, question: str) -> PaginatedResponse[CompanyData]:
"""Async wrapper for process_query. Runs blocking work in a thread to avoid
blocking the event loop.
"""
return await asyncio.to_thread(self._process_query_sync, question)
def _process_query_sync(self, question: str) -> PaginatedResponse[CompanyData]:
"""Synchronous implementation of process_query. This is run in a thread by
the async wrapper above.
"""
cache_key = self._get_cache_key(question)
# Check cache first
if cache_key in self.query_cache:
sql_query = self.query_cache[cache_key]
logger.info(f"Using cached SQL: {sql_query}")
else:
# Generate SQL query
messages = self.sql_prompt.format_messages(question=question)
response = self.llm.invoke(messages)
sql_query = response.content.strip()
# Clean up SQL (remove markdown code blocks if present)
sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
# Cache the query
self.query_cache[cache_key] = sql_query
logger.info(f"Generated SQL: {sql_query}")
# Execute query to get company IDs
db_session = next(get_db())
try:
result = db_session.execute(text(sql_query))
company_ids = [row[0] for row in result.fetchall()]
logger.info(
f"Found {len(company_ids)} company IDs: {company_ids[:10]}{'...' if len(company_ids) > 10 else ''}"
)
return self._fetch_companies_by_ids(company_ids)
except Exception as e:
logger.error(f"SQL execution error: {e}")
logger.error(f"Failed SQL: {sql_query}")
# Return empty result
return PaginatedResponse(
items=[], total=0, page=1, page_size=10, total_pages=0
)
finally:
db_session.close()
def _fetch_companies_by_ids(
self, company_ids: List[int]
) -> PaginatedResponse[CompanyData]:
"""Fetch companies with all their relationships from the database using company IDs.
Args:
company_ids: List of company IDs to fetch
"""
if not company_ids:
return PaginatedResponse(
items=[],
total=0,
page=1,
page_size=10,
total_pages=0,
)
# Get database session
db_session = next(get_db())
try:
# Query companies with all necessary relationships loaded
companies = (
db_session.query(CompanyTable)
.options(
selectinload(CompanyTable.investors),
selectinload(CompanyTable.members),
selectinload(CompanyTable.sectors),
)
.filter(CompanyTable.id.in_(company_ids))
.all()
)
# Transform to CompanyData format
company_data_list = []
for company in companies:
company_data = CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=company.sectors,
)
company_data_list.append(company_data)
total_count = len(company_data_list)
total_pages = 1 if total_count > 0 else 0
return PaginatedResponse(
items=company_data_list,
total=total_count,
page=1,
page_size=total_count,
total_pages=total_pages,
)
finally:
db_session.close()
-785
View File
@@ -1,785 +0,0 @@
"""
Compatibility Score Service
This module calculates compatibility scores between projects and investors.
The scoring system evaluates multiple dimensions to determine how well a project
matches with an investor's investment criteria.
"""
from difflib import SequenceMatcher
from typing import List, Optional, Tuple
from db.models import FundTable, InvestorTable, ProjectTable
def calculate_project_investor_compatibility(
project: ProjectTable, investor: InvestorTable, use_funds: bool = True
) -> float:
"""
Calculate compatibility score between a project and an investor.
Args:
project: The project to evaluate
investor: The investor to compare against
use_funds: If True, evaluates against investor's funds. If False, uses investor-level data.
Returns:
A score between 0 and 1, where 1 is perfect match
Scoring breakdown (out of 100 points):
- Investment Stage Match: 30 points
- Sector Overlap: 30 points
- Geographic Match: 20 points
- Valuation/Check Size Fit: 20 points
"""
if use_funds and investor.funds:
# Calculate score for each fund and return the highest
max_score = 0.0
for fund in investor.funds:
fund_score = _calculate_project_fund_compatibility(project, fund)
max_score = max(max_score, fund_score)
return max_score
else:
# Use investor-level data (fallback)
return _calculate_project_investor_direct_compatibility(project, investor)
def calculate_project_investors_compatibility(
project: ProjectTable, investors: List[InvestorTable], use_funds: bool = True
) -> List[Tuple[InvestorTable, float]]:
"""
Calculate compatibility scores between a project and multiple investors.
Args:
project: The project to evaluate
investors: List of investors to compare against
use_funds: If True, evaluates against investors' funds. If False, uses investor-level data.
Returns:
List of tuples (investor, score) sorted by score descending
"""
scored_investors = []
for investor in investors:
score = calculate_project_investor_compatibility(project, investor, use_funds)
scored_investors.append((investor, score))
# Sort by score descending
scored_investors.sort(key=lambda x: x[1], reverse=True)
return scored_investors
def _calculate_project_fund_compatibility(
project: ProjectTable, fund: FundTable
) -> float:
"""
Calculate compatibility score between a project and a specific fund.
Scoring breakdown:
- Investment Stage Match: 30 points (all or nothing if stage exists)
- Sector Overlap: 30 points (proportional to overlap)
- Geographic Match: 20 points (exact=20, partial=10, none=0)
- Valuation/Check Size Fit: 20 points (proportional to fit)
Returns:
A score between 0 and 1
"""
total_score = 0
max_score = 100
# 1. Investment Stage Match (30 points)
stage_score = 0
if project.stage and fund.investment_stages:
# Check if project stage matches any of the fund's investment stages
fund_stage_names = {stage.name for stage in fund.investment_stages}
# Convert project.stage enum to string for comparison
project_stage_name = (
project.stage.value
if hasattr(project.stage, "value")
else str(project.stage)
)
# Normalize both for case-insensitive comparison
project_stage_normalized = project_stage_name.upper().strip()
fund_stages_normalized = {name.upper().strip() for name in fund_stage_names}
if project_stage_normalized in fund_stages_normalized:
stage_score = 30
else:
# Partial credit for adjacent stages
stage_score = _calculate_stage_proximity(
project_stage_normalized, fund_stages_normalized
)
total_score += stage_score
# 2. Sector Overlap (30 points)
sector_score = 0
if project.sector and fund.sectors:
project_sectors = [s for s in project.sector if hasattr(s, "name")]
fund_sectors = [s for s in fund.sectors if hasattr(s, "name")]
if project_sectors and fund_sectors:
# Use fuzzy matching to account for similar but not identical sector names
match_count = 0
total_matches = 0
for proj_sector in project_sectors:
best_match_score = 0
proj_name = proj_sector.name.lower().strip()
for fund_sector in fund_sectors:
fund_name = fund_sector.name.lower().strip()
# Exact match
if proj_name == fund_name:
best_match_score = 1.0
break
# Fuzzy match using sequence matcher
similarity = SequenceMatcher(None, proj_name, fund_name).ratio()
# Also check if one contains the other (substring match)
if proj_name in fund_name or fund_name in proj_name:
similarity = max(similarity, 0.8)
best_match_score = max(best_match_score, similarity)
# Count matches with threshold
# Perfect match (1.0), strong match (>0.75), partial match (>0.6)
if best_match_score >= 0.6:
total_matches += best_match_score
match_count += 1
if match_count > 0:
# Calculate overlap ratio based on fuzzy matches
overlap_ratio = total_matches / len(project_sectors)
sector_score = int(30 * overlap_ratio)
total_score += sector_score
# 3. Geographic Match (20 points)
geo_score = 0
if project.location and fund.geographic_focus:
project_location_lower = project.location.lower().strip()
fund_geo_lower = (fund.geographic_focus or "").lower().strip()
# Exact match
if project_location_lower == fund_geo_lower:
geo_score = 20
# Partial match (one contains the other)
elif (
project_location_lower in fund_geo_lower
or fund_geo_lower in project_location_lower
):
geo_score = 15
# Check for common geographic terms or regional overlap (continent/country matching)
elif _check_geographic_overlap(project_location_lower, fund_geo_lower):
# Give higher score for continent/country matches (e.g., Germany -> Europe)
geo_score = 18
total_score += geo_score
# 4. Valuation/Check Size Fit (20 points)
valuation_score = 0
if project.valuation and fund.check_size_lower and fund.check_size_upper:
# Check if project valuation falls within or near the check size range
# Typically, check size is a fraction of valuation (e.g., 10-20%)
# We'll assume check size represents potential investment amount
if fund.check_size_lower <= project.valuation <= fund.check_size_upper:
# Valuation is within the check size range (might be too small)
valuation_score = 10
else:
# Check if the check size is reasonable for this valuation
# Typical investment is 10-30% of valuation
reasonable_valuation_min = fund.check_size_lower * 3 # Investing ~33%
reasonable_valuation_max = fund.check_size_upper * 10 # Investing ~10%
if (
reasonable_valuation_min
<= project.valuation
<= reasonable_valuation_max
):
# Perfect fit
valuation_score = 20
elif project.valuation < reasonable_valuation_min:
# Project might be too small
ratio = (
project.valuation / reasonable_valuation_min
if reasonable_valuation_min > 0
else 0
)
valuation_score = int(10 * ratio)
else:
# Project might be too large
ratio = (
reasonable_valuation_max / project.valuation
if project.valuation > 0
else 0
)
valuation_score = int(10 * ratio)
total_score += valuation_score
# Convert to 0-1 scale
return total_score / max_score
def _calculate_project_investor_direct_compatibility(
project: ProjectTable, investor: InvestorTable
) -> float:
"""
Calculate compatibility using investor-level data (fallback when no funds available).
Uses the same scoring system but with investor-level attributes.
"""
total_score = 0
max_score = 100
# 1. Investment Stage - Skip this since investors don't have a direct stage field
# We could add 30 points to other categories, but for consistency, we'll leave it as 0
stage_score = 0
total_score += stage_score
# 2. Sector Overlap (30 points)
sector_score = 0
if project.sector and investor.sectors:
project_sectors = [s for s in project.sector if hasattr(s, "name")]
investor_sectors = [s for s in investor.sectors if hasattr(s, "name")]
if project_sectors and investor_sectors:
# Use fuzzy matching to account for similar but not identical sector names
match_count = 0
total_matches = 0
for proj_sector in project_sectors:
best_match_score = 0
proj_name = proj_sector.name.lower().strip()
for inv_sector in investor_sectors:
inv_name = inv_sector.name.lower().strip()
# Exact match
if proj_name == inv_name:
best_match_score = 1.0
break
# Fuzzy match using sequence matcher
similarity = SequenceMatcher(None, proj_name, inv_name).ratio()
# Also check if one contains the other (substring match)
if proj_name in inv_name or inv_name in proj_name:
similarity = max(similarity, 0.8)
best_match_score = max(best_match_score, similarity)
# Count matches with threshold
if best_match_score >= 0.6:
total_matches += best_match_score
match_count += 1
if match_count > 0:
# Calculate overlap ratio based on fuzzy matches
overlap_ratio = total_matches / len(project_sectors)
sector_score = int(30 * overlap_ratio)
total_score += sector_score
# 3. Geographic Match (20 points)
geo_score = 0
if project.location and investor.geographic_focus:
project_location_lower = project.location.lower()
investor_geo_lower = (investor.geographic_focus or "").lower()
if project_location_lower == investor_geo_lower:
geo_score = 20
elif (
project_location_lower in investor_geo_lower
or investor_geo_lower in project_location_lower
):
geo_score = 15
elif _check_geographic_overlap(project_location_lower, investor_geo_lower):
# Give higher score for continent/country matches (e.g., Germany -> Europe)
geo_score = 18
total_score += geo_score
# 4. Valuation/Check Size Fit (20 points)
valuation_score = 0
if project.valuation and investor.check_size_lower and investor.check_size_upper:
reasonable_valuation_min = investor.check_size_lower * 3
reasonable_valuation_max = investor.check_size_upper * 10
if reasonable_valuation_min <= project.valuation <= reasonable_valuation_max:
valuation_score = 20
elif project.valuation < reasonable_valuation_min:
ratio = (
project.valuation / reasonable_valuation_min
if reasonable_valuation_min > 0
else 0
)
valuation_score = int(10 * ratio)
else:
ratio = (
reasonable_valuation_max / project.valuation
if project.valuation > 0
else 0
)
valuation_score = int(10 * ratio)
total_score += valuation_score
# Convert to 0-1 scale
return total_score / max_score
def _calculate_stage_proximity(project_stage: str, fund_stages: set) -> int:
"""
Calculate proximity score between project stage and fund stages.
Awards partial credit for adjacent investment stages.
Stage progression: SEED -> SERIES_A -> SERIES_B -> SERIES_C -> GROWTH -> LATE_STAGE
Returns:
Score from 0-15 (half credit for adjacent stages)
"""
stage_order = ["SEED", "SERIES_A", "SERIES_B", "SERIES_C", "GROWTH", "LATE_STAGE"]
# Normalize project stage for comparison
project_stage_normalized = project_stage.upper().strip()
try:
project_idx = stage_order.index(project_stage_normalized)
except ValueError:
return 0
# Check for adjacent stages
adjacent_stages = []
if project_idx > 0:
adjacent_stages.append(stage_order[project_idx - 1])
if project_idx < len(stage_order) - 1:
adjacent_stages.append(stage_order[project_idx + 1])
# Normalize fund stages and check for matches
for stage in fund_stages:
stage_normalized = stage.upper().strip()
if stage_normalized in adjacent_stages:
return 15 # Half credit for adjacent stage
return 0
def _check_geographic_overlap(location1: str, location2: str) -> bool:
"""
Check for common geographic terms between two locations.
Examples:
- "San Francisco, CA" and "California" -> True
- "New York" and "USA" -> True (if both contain USA/US)
- "London, UK" and "United Kingdom" -> True
- "Germany" and "Europe" -> True
"""
# Normalize inputs
loc1 = location1.lower().strip()
loc2 = location2.lower().strip()
# Common geographic groupings with broader regional mappings
geo_groups = [
# North America
["usa", "us", "united states", "america", "u.s.", "u.s.a"],
["canada", "canadian"],
["mexico", "mexican"],
# Europe and countries
[
"europe",
"european",
"eu",
"germany",
"france",
"uk",
"united kingdom",
"britain",
"spain",
"italy",
"netherlands",
"belgium",
"sweden",
"denmark",
"norway",
"finland",
"poland",
"portugal",
"austria",
"switzerland",
"ireland",
"greece",
"czech",
"romania",
],
# UK specific
["uk", "united kingdom", "britain", "england", "scotland", "wales", "london"],
# US states
["california", "ca", "san francisco", "los angeles", "silicon valley"],
["new york", "ny", "nyc"],
["texas", "tx"],
["massachusetts", "ma", "boston"],
["washington", "seattle"],
# Asia
[
"asia",
"asian",
"china",
"japan",
"korea",
"singapore",
"hong kong",
"india",
"indonesia",
"thailand",
"vietnam",
"malaysia",
"philippines",
],
# Middle East
["middle east", "israel", "uae", "dubai", "saudi arabia"],
# Latin America
["latin america", "brazil", "argentina", "chile", "colombia", "mexico"],
# Africa
["africa", "african", "south africa", "nigeria", "kenya", "egypt"],
# Oceania
["australia", "australian", "new zealand"],
]
# Check if both locations match any group
for group in geo_groups:
found_in_1 = any(term in loc1 for term in group)
found_in_2 = any(term in loc2 for term in group)
if found_in_1 and found_in_2:
return True
# Check for direct substring match (one contains the other)
if loc1 in loc2 or loc2 in loc1:
return True
return False
def get_top_compatible_investors(
project: ProjectTable,
investors: List[InvestorTable],
limit: int = 10,
min_score: float = 0.0,
use_funds: bool = True,
) -> List[Tuple[InvestorTable, float]]:
"""
Get the top N most compatible investors for a project.
Args:
project: The project to find investors for
investors: List of all available investors
limit: Maximum number of investors to return
min_score: Minimum compatibility score threshold (0-1)
use_funds: If True, evaluates against investors' funds
Returns:
List of tuples (investor, score) sorted by score descending,
limited to 'limit' items and filtered by min_score
"""
scored_investors = calculate_project_investors_compatibility(
project, investors, use_funds
)
# Filter by minimum score
filtered_investors = [
(investor, score) for investor, score in scored_investors if score >= min_score
]
# Return top N
return filtered_investors[:limit]
def get_compatibility_score_breakdown(
project: ProjectTable, investor: InvestorTable, fund: Optional[FundTable] = None
) -> dict:
"""
Get a detailed breakdown of the compatibility score components.
Useful for debugging or showing users why a particular score was calculated.
Returns:
Dictionary with score components and explanations
"""
if fund:
total_score = 0
# Stage score
stage_score = 0
stage_match = False
if project.stage and fund.investment_stages:
fund_stage_names = {stage.name for stage in fund.investment_stages}
project_stage_name = (
project.stage.value
if hasattr(project.stage, "value")
else str(project.stage)
)
if project_stage_name in fund_stage_names:
stage_score = 30
stage_match = True
else:
stage_score = _calculate_stage_proximity(
project_stage_name, fund_stage_names
)
# Sector score
sector_score = 0
matching_sectors = []
if project.sector and fund.sectors:
project_sector_ids = {sector.id for sector in project.sector}
fund_sector_ids = {sector.id for sector in fund.sectors}
if project_sector_ids and fund_sector_ids:
common_sectors = project_sector_ids.intersection(fund_sector_ids)
matching_sectors = [
s.name for s in fund.sectors if s.id in common_sectors
]
overlap_ratio = len(common_sectors) / len(project_sector_ids)
sector_score = int(30 * overlap_ratio)
# Geographic score
geo_score = 0
geo_match_type = "none"
if project.location and fund.geographic_focus:
project_location_lower = project.location.lower()
fund_geo_lower = fund.geographic_focus.lower()
if project_location_lower == fund_geo_lower:
geo_score = 20
geo_match_type = "exact"
elif (
project_location_lower in fund_geo_lower
or fund_geo_lower in project_location_lower
):
geo_score = 10
geo_match_type = "partial"
elif _check_geographic_overlap(project_location_lower, fund_geo_lower):
geo_score = 5
geo_match_type = "regional"
# Valuation score
valuation_score = 0
valuation_fit = "unknown"
if project.valuation and fund.check_size_lower and fund.check_size_upper:
reasonable_valuation_min = fund.check_size_lower * 3
reasonable_valuation_max = fund.check_size_upper * 10
if (
reasonable_valuation_min
<= project.valuation
<= reasonable_valuation_max
):
valuation_score = 20
valuation_fit = "perfect"
elif project.valuation < reasonable_valuation_min:
ratio = (
project.valuation / reasonable_valuation_min
if reasonable_valuation_min > 0
else 0
)
valuation_score = int(10 * ratio)
valuation_fit = "too_small"
else:
ratio = (
reasonable_valuation_max / project.valuation
if project.valuation > 0
else 0
)
valuation_score = int(10 * ratio)
valuation_fit = "too_large"
total_score = stage_score + sector_score + geo_score + valuation_score
return {
"total_score": total_score / 100,
"breakdown": {
"stage": {
"score": stage_score,
"max_score": 30,
"match": stage_match,
"project_stage": project.stage.value if project.stage else None,
"fund_stages": [s.name for s in fund.investment_stages]
if fund.investment_stages
else [],
},
"sector": {
"score": sector_score,
"max_score": 30,
"matching_sectors": matching_sectors,
"project_sectors": [s.name for s in project.sector]
if project.sector
else [],
"fund_sectors": [s.name for s in fund.sectors]
if fund.sectors
else [],
},
"geography": {
"score": geo_score,
"max_score": 20,
"match_type": geo_match_type,
"project_location": project.location,
"fund_geography": fund.geographic_focus,
},
"valuation": {
"score": valuation_score,
"max_score": 20,
"fit": valuation_fit,
"project_valuation": project.valuation,
"fund_check_size_range": f"{fund.check_size_lower}-{fund.check_size_upper}"
if fund.check_size_lower
else None,
},
},
}
else:
# Investor-level breakdown (simplified)
return {
"total_score": _calculate_project_investor_direct_compatibility(
project, investor
),
"note": "Using investor-level data (no specific fund selected)",
}
def generate_compatibility_explanation(
project: ProjectTable, investor: InvestorTable, score: float, use_funds: bool = True
) -> str:
"""
Generate a detailed, natural language explanation of the compatibility score.
Args:
project: The project being evaluated
investor: The investor being compared against
score: The calculated compatibility score (0-1)
use_funds: Whether fund-level data was used
Returns:
A formatted string with the compatibility score and detailed explanation
"""
score_percentage = int(score * 100)
# Determine match quality
if score_percentage >= 80:
match_level = "Excellent match"
elif score_percentage >= 65:
match_level = "Strong match"
elif score_percentage >= 50:
match_level = "Good match"
elif score_percentage >= 35:
match_level = "Moderate match"
else:
match_level = "Limited match"
# Collect alignment factors
alignment_factors = []
recommendations = []
# Get the best matching fund if using funds
best_fund = None
if use_funds and investor.funds:
best_score = 0
for fund in investor.funds:
fund_score = _calculate_project_fund_compatibility(project, fund)
if fund_score > best_score:
best_score = fund_score
best_fund = fund
# Analyze sector alignment
if project.sector:
project_sectors = [s.name for s in project.sector if hasattr(s, "name")]
if best_fund and best_fund.sectors:
fund_sectors = {s.name for s in best_fund.sectors if hasattr(s, "name")}
common_sectors = set(project_sectors) & fund_sectors
if common_sectors:
sectors_str = ", ".join(list(common_sectors)[:2])
alignment_factors.append(f"{sectors_str} sector focus")
elif project_sectors:
recommendations.append(
f"Consider emphasizing any {project_sectors[0]} industry connections"
)
elif investor.sectors:
investor_sectors = {s.name for s in investor.sectors if hasattr(s, "name")}
common_sectors = set(project_sectors) & investor_sectors
if common_sectors:
sectors_str = ", ".join(list(common_sectors)[:2])
alignment_factors.append(f"{sectors_str} sector focus")
# Analyze stage alignment
if project.stage:
stage_name = (
project.stage.value
if hasattr(project.stage, "value")
else str(project.stage)
)
stage_display = stage_name.replace("_", " ").title()
if best_fund and best_fund.investment_stages:
fund_stage_names = {
s.name for s in best_fund.investment_stages if hasattr(s, "name")
}
if stage_name in fund_stage_names:
alignment_factors.append(f"{stage_display} stage")
else:
recommendations.append(
"Investor typically focuses on different stages; highlight your traction and growth metrics"
)
if not best_fund:
alignment_factors.append(f"{stage_display} stage")
# Analyze geographic alignment
if project.location:
if best_fund and best_fund.geographic_focus:
if (
project.location.lower() in best_fund.geographic_focus.lower()
or best_fund.geographic_focus.lower() in project.location.lower()
):
alignment_factors.append(f"{project.location} presence")
elif investor.headquarters:
if (
project.location.lower() in investor.headquarters.lower()
or investor.headquarters.lower() in project.location.lower()
):
alignment_factors.append(f"{project.location} market presence")
# Analyze valuation/check size fit
if project.valuation:
if best_fund and best_fund.check_size_lower and best_fund.check_size_upper:
reasonable_min = best_fund.check_size_lower * 3
reasonable_max = best_fund.check_size_upper * 10
if reasonable_min <= project.valuation <= reasonable_max:
alignment_factors.append("appropriate funding stage")
elif project.valuation < reasonable_min:
recommendations.append(
"You may be early for this investor; consider approaching at a later stage"
)
else:
recommendations.append(
"Consider highlighting your growth trajectory and market opportunity"
)
# Build the explanation
explanation_parts = [f"Based on your startup profile: {score_percentage}% match"]
if alignment_factors:
alignment_text = ", ".join(alignment_factors)
explanation_parts.append(f"{match_level}: {alignment_text}.")
else:
explanation_parts.append(f"{match_level}.")
if recommendations:
rec_text = recommendations[0] # Show the most important recommendation
explanation_parts.append(rec_text + ".")
return " ".join(explanation_parts)
-205
View File
@@ -1,205 +0,0 @@
import logging
import os
import sys
import requests
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
class FolkAPI:
BASE_URL = "https://api.folk.app/v1"
def __init__(self, api_key: str):
api_key = os.environ.get("FOLK_API_KEY", api_key)
self.headers = {"Authorization": f"Bearer {api_key}"}
logger.info(f"FolkAPI initialized with API key: {api_key[:4]}***")
def get_groups(self):
"""Fetch all groups from Folk."""
url = f"{self.BASE_URL}/groups"
response = requests.get(url, headers=self.headers)
response.raise_for_status()
return response.json()
def create_company(
self,
name: str,
group_id: str = None,
website: str = None,
linkedin_url: str = None,
description: str = None,
emails=None,
phones=None,
addresses=None,
urls=None,
custom_field_values=None,
groups=None,
**kwargs,
):
"""Create a company (investor) in a specific group.
This method builds a payload matching Folk's Create Company API:
https://developer.folk.app/api-reference/companies/create-a-company
It keeps backward compatibility with the previous `group_id`,
`website` and `linkedin_url` arguments.
"""
url = f"{self.BASE_URL}/companies"
# Build the top-level payload expected by Folk
data = {"name": name}
if description:
data["description"] = description
# Groups: prefer explicit `groups`, else fall back to `group_id`
if groups:
# Accept either list of ids or list of dicts
formatted = []
for g in groups:
if isinstance(g, dict) and g.get("id"):
formatted.append({"id": g["id"]})
else:
formatted.append({"id": str(g)})
data["groups"] = formatted
elif group_id:
data["groups"] = [{"id": group_id}]
# Helper to normalize single or multiple inputs into lists
def _to_list(val):
if val is None:
return None
if isinstance(val, (list, tuple)):
return [v for v in val if v is not None]
return [val]
# URLs: include website and linkedin_url if provided and merge with urls
urls_list = _to_list(urls) or []
if website:
urls_list.append(website)
if linkedin_url:
urls_list.append(linkedin_url)
if urls_list:
data["urls"] = urls_list
# Emails/phones/addresses
emails_list = _to_list(emails)
if emails_list:
data["emails"] = emails_list
phones_list = _to_list(phones)
if phones_list:
data["phones"] = phones_list
addresses_list = _to_list(addresses)
if addresses_list:
data["addresses"] = addresses_list
# Custom field values follow the API's structure
if custom_field_values:
data["customFieldValues"] = custom_field_values
# Allow passing any additional top-level fields via kwargs (careful)
for k, v in kwargs.items():
# don't overwrite keys we explicitly set
if k not in data:
data[k] = v
response = requests.post(url, headers=self.headers, json=data)
response.raise_for_status()
return response.json()
def create_person(
self,
first_name: str,
last_name: str,
email: str = None,
company_id: str = None,
group_id: str = None,
linkedin_url: str = None,
companies=None,
emails=None,
phones=None,
addresses=None,
urls=None,
custom_field_values=None,
groups=None,
**kwargs,
):
"""Create a person in the workspace.
Builds payload matching Folk's Create Person API: use camelCase
keys (firstName, lastName, groups, companies, emails, etc.).
Keeps backward compatibility with `company_id` and `group_id`.
"""
url = f"{self.BASE_URL}/people"
data = {"firstName": first_name, "lastName": last_name}
# Groups: explicit `groups` preferred, else fallback to `group_id`
if groups:
formatted = []
for g in groups:
if isinstance(g, dict) and g.get("id"):
formatted.append({"id": g["id"]})
else:
formatted.append({"id": str(g)})
data["groups"] = formatted
elif group_id:
data["groups"] = [{"id": group_id}]
# Companies: keep backward compatibility with company_id
if companies:
formatted = []
for c in companies:
if isinstance(c, dict):
formatted.append(c)
elif isinstance(c, str):
# treat as id
formatted.append({"id": c})
if formatted:
data["companies"] = formatted
elif company_id:
data["companies"] = [{"id": company_id}]
# Helper to normalize into lists
def _to_list(val):
if val is None:
return None
if isinstance(val, (list, tuple)):
return [v for v in val if v is not None]
return [val]
emails_list = _to_list(emails) or []
if email:
emails_list.insert(0, email)
if emails_list:
data["emails"] = emails_list
phones_list = _to_list(phones)
if phones_list:
data["phones"] = phones_list
addresses_list = _to_list(addresses)
if addresses_list:
data["addresses"] = addresses_list
urls_list = _to_list(urls) or []
if linkedin_url:
urls_list.append(linkedin_url)
if urls_list:
data["urls"] = urls_list
if custom_field_values:
data["customFieldValues"] = custom_field_values
# Allow passthrough of other top-level fields in kwargs
for k, v in kwargs.items():
if k not in data:
data[k] = v
response = requests.post(url, headers=self.headers, json=data)
response.raise_for_status()
return response.json()
-181
View File
@@ -1,181 +0,0 @@
import asyncio
import logging
import os
from crawl4ai import AsyncWebCrawler
from ddgs import DDGS
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
from schemas.insight_schema import InsightResponse
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
logger = logging.getLogger("web_search_agent")
load_dotenv()
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
if not OPENROUTER_API_KEY:
logger.warning("OPENROUTER_API_KEY not set. LLM calls will fail if invoked.")
class QueryProcessor:
def __init__(self):
self.llm = ChatOpenAI(
api_key=OPENROUTER_API_KEY,
base_url="https://openrouter.ai/api/v1",
model="openai/gpt-4o-mini",
temperature=0,
)
self.agent = create_react_agent(
model=self.llm,
tools=[self.web_search],
response_format=InsightResponse,
)
self.ddg_search = DDGS()
async def crawl(self, url: str):
"""Tool to search the web using a web crawler. given the url"""
logger.info(f"\nCrawl tool called with url: {url}")
async with AsyncWebCrawler() as crawler:
results = await crawler.arun(url)
return results.markdown
def web_search(self, query: str):
"""Tool to search the web using google, provide the relevant query to get the information"""
logger.info(f"\nWeb Search Tool Called with query: {query}")
if query:
result = self.ddg_search.text(query, max_results=10)
return result
return "No query provided."
async def get_investor_insights(
self,
investor_name: str,
investor_website: str = None,
investor_description: str = None,
investor_headquarters: str = None,
investment_thesis: list = None,
portfolio_highlights: list = None,
) -> dict:
"""
Get investment pattern analysis and market position for an investor.
Args:
investor_name: Name of the investor/VC firm
investor_website: Website URL of the investor
investor_description: Description of the investor
investor_headquarters: Headquarters location
investment_thesis: List of investment thesis statements
portfolio_highlights: List of notable portfolio companies
Returns:
Dictionary with investment_pattern_analysis and market_position
"""
logger.info(f"Getting insights for investor: {investor_name}")
# Build context information
context_parts = [f'Investment Firm: "{investor_name}"']
if investor_website:
context_parts.append(f"Website: {investor_website}")
if investor_headquarters:
context_parts.append(f"Location: {investor_headquarters}")
if investor_description:
context_parts.append(f"Description: {investor_description}")
if investment_thesis and isinstance(investment_thesis, list):
thesis_str = ", ".join(
str(item) for item in investment_thesis[:3]
) # Limit to first 3
context_parts.append(f"Investment Focus: {thesis_str}")
if portfolio_highlights and isinstance(portfolio_highlights, list):
portfolio_str = ", ".join(
str(item) for item in portfolio_highlights[:5]
) # Limit to first 5
context_parts.append(f"Notable Portfolio Companies: {portfolio_str}")
context = "\n".join(context_parts)
prompt = f"""
Research and analyze the following investment firm:
{context}
CRITICAL INSTRUCTIONS:
- You MUST provide concrete, data-driven insights with specific numbers and percentages
- Use the web_search tool to find recent news, press releases, and investment databases (Crunchbase, PitchBook, etc.)
- If you cannot find sufficient data after searching, make reasonable inferences based on available information
- DO NOT state that data is unavailable or ambiguous - provide the best analysis possible with what you find
- Focus on ACTIONABLE insights, not disclaimers
- Only call the tool twice at most, be strategic in your searches
- Summarize your findings concisely and clearly
Provide insights in the InsightResponse schema format:
1. investment_pattern_analysis (MAX 3 SENTENCES):
- Recent investment activity and trends in the last 12-18 months
- Investment size ranges, deal frequency, and sector preferences
- Notable patterns (e.g., "increased AI investments by 40%", "average check size $5-10M")
- If specific numbers aren't available, provide reasonable estimates based on portfolio and market position
2. market_position (MAX 3 SENTENCES):
- Standing in the venture capital market
- Activity level in specific sectors and notable unicorn investments
- Deal leadership roles (lead vs co-lead) and market influence
- Regional or global market presence and competitive positioning
Use the web_search tool strategically. Search for:
- "{investor_name}" recent investments 2024 2025
- "{investor_name}" portfolio Crunchbase
- "{investor_name}" funding rounds news
- Specific portfolio companies if mentioned above
"""
try:
result = await self.agent.ainvoke({"messages": [("user", prompt)]})
# The agent with response_format=InsightResponse returns structured output
logger.info(f"Raw agent result keys: {result.keys()}")
# Check if structured_response exists and is an InsightResponse object
if "structured_response" in result:
structured = result["structured_response"]
logger.info(f"Structured response type: {type(structured)}")
# If it's already an InsightResponse object, convert to dict
if isinstance(structured, InsightResponse):
return structured.model_dump()
# If it's already a dict, return it
elif isinstance(structured, dict):
return structured
# Fallback: shouldn't reach here, but handle it gracefully
logger.warning("No structured_response found in result, using fallback")
return {
"investment_pattern_analysis": "Unable to retrieve investment pattern analysis at this time.",
"market_position": "Unable to retrieve market position at this time.",
}
except Exception as e:
logger.error(f"Error getting insights for {investor_name}: {e}")
logger.exception("Full exception details:")
return {
"investment_pattern_analysis": "Unable to retrieve investment pattern analysis at this time.",
"market_position": "Unable to retrieve market position at this time.",
}
async def main():
qp = QueryProcessor()
result = await qp.agent.ainvoke(
{"messages": [("user", "Can you tell me about 3T Finance investment company")]}
)
final_message = result["messages"][-1].content
print(final_message)
if __name__ == "__main__":
asyncio.run(main())
+83 -751
View File
@@ -1,7 +1,5 @@
import asyncio
import json
import os
import re
from typing import Optional
import pandas as pd
@@ -9,35 +7,15 @@ from db.db import get_db_session
from db.models import (
CompanyMember,
CompanyTable,
FundTable,
InvestmentStageTable,
InvestorMember,
InvestorTable,
SectorTable,
)
from langchain_openai import ChatOpenAI
from pydantic import BaseModel
from schemas.py_schemas import CompanyData, InvestorData
from sqlalchemy.orm import Session
class CurrencyConversion(BaseModel):
"""Schema for LLM currency conversion responses"""
amount_usd: int = 0
confidence: str = "high" # high, medium, low
notes: str = ""
class CheckSizeRange(BaseModel):
"""Schema for LLM check size range parsing from estimated investment size"""
lower_bound_usd: int = 0
upper_bound_usd: int = 0
confidence: str = "high" # high, medium, low
notes: str = ""
class InvestorProcessor:
def __init__(self):
self.llm = ChatOpenAI(
@@ -47,534 +25,9 @@ class InvestorProcessor:
temperature=0,
)
# Structured LLMs for specific parsing tasks
self.currency_converter_llm = self.llm.with_structured_output(
CurrencyConversion
)
self.check_size_parser_llm = self.llm.with_structured_output(CheckSizeRange)
# Keep legacy structured LLMs for backward compatibility
self.investor_structured_llm = self.llm.with_structured_output(InvestorData)
self.company_structured_llm = self.llm.with_structured_output(CompanyData)
async def convert_to_usd(self, amount_str: str) -> Optional[int]:
"""
Use LLM to convert currency amounts to USD integers.
Handles formats like:
- "EUR 850,000,000"
- "$5M"
- "GBP 10-20 million"
- "Approximately EUR 100 million"
"""
if not amount_str or amount_str == "Not Available" or amount_str == "0":
return None
try:
prompt = f"""Convert this amount to USD as an integer (whole number, no decimals).
If it's a range, use the midpoint. If already in USD, just extract the number.
Remove all commas and convert millions/billions to actual numbers.
Amount: {amount_str}
Examples:
- "EUR 850,000,000" -> 935000000 (assuming EUR to USD rate ~1.10)
- "$5M" -> 5000000
- "GBP 10-20 million" -> 18000000 (midpoint 15M * 1.20 rate)
- "Approximately EUR 100 million" -> 110000000
Return only the USD integer amount with current exchange rates."""
result = await self.currency_converter_llm.ainvoke(prompt)
return result.amount_usd if result.amount_usd > 0 else None
except Exception as e:
print(f"Error converting currency '{amount_str}': {e}")
return None
async def parse_check_size_range(
self, estimated_investment_str: str
) -> tuple[Optional[int], Optional[int]]:
"""
Use LLM to parse check size range from estimated investment size string.
Returns tuple of (lower_bound_usd, upper_bound_usd).
Handles formats like:
- "EUR 1,000 to 2,000"
- "$100K-$500K"
- "Between $1M and $5M"
- "Up to EUR 10 million"
- "$2M typical"
"""
if (
not estimated_investment_str
or estimated_investment_str == "Not Available"
or estimated_investment_str == "0"
):
return None, None
try:
prompt = f"""Parse this check size/investment range into lower and upper bounds in USD as integers.
Input: {estimated_investment_str}
Instructions:
- If it's a range (e.g., "EUR 1M to 5M"), extract both bounds
- If it's a single amount (e.g., "$2M typical"), use it as both lower and upper
- If it says "up to X", use 0 as lower and X as upper
- Convert all currencies to USD using current exchange rates
- Return integers (whole numbers, no decimals)
Examples:
- "EUR 1,000 to 2,000" -> lower: 1100, upper: 2200
- "$100K-$500K" -> lower: 100000, upper: 500000
- "Between $1M and $5M" -> lower: 1000000, upper: 5000000
- "Up to EUR 10 million" -> lower: 0, upper: 11000000
- "$2M typical" -> lower: 2000000, upper: 2000000
- "GBP 500K-2M" -> lower: 600000, upper: 2400000
Return the lower and upper bounds in USD."""
result = await self.check_size_parser_llm.ainvoke(prompt)
lower = result.lower_bound_usd if result.lower_bound_usd > 0 else None
upper = result.upper_bound_usd if result.upper_bound_usd > 0 else None
return lower, upper
except Exception as e:
print(f"Error parsing check size range '{estimated_investment_str}': {e}")
return None, None
def parse_json_profile(self, json_str: str) -> Optional[dict]:
"""
Manually parse the JSON profile from the CSV.
Returns a cleaned dictionary with the investor profile data.
Handles JSON wrapped in markdown code blocks (```json ... ```).
Handles trailing quotes and extra data after JSON.
"""
if not json_str or pd.isna(json_str):
return None
try:
# Clean the JSON string
cleaned_json = json_str.strip()
# Check if it's plain text (no JSON structure)
if not cleaned_json.startswith(("{", "```", "'")):
print(" ⚠️ No JSON structure found - skipping")
return None
# Remove markdown code block markers if present
if cleaned_json.startswith("```"):
# Remove opening marker (```json or ```Json or ```)
lines = cleaned_json.split("\n")
if lines[0].startswith("```"):
lines = lines[1:] # Remove first line
# Remove closing marker (``` or ```')
if lines and lines[-1].strip() in ("```", "```'", '```"'):
lines = lines[:-1] # Remove last line
cleaned_json = "\n".join(lines).strip()
# Remove trailing quotes that might be left over
if cleaned_json.endswith(("'", '"')):
cleaned_json = cleaned_json[:-1].strip()
# Try to find JSON boundaries if there's extra data
# Look for the first { and the last }
start_idx = cleaned_json.find("{")
if start_idx == -1:
print(" ⚠️ No opening brace found - not valid JSON")
return None
# Find the matching closing brace
# We need to count braces to find the actual end
brace_count = 0
end_idx = -1
for i in range(start_idx, len(cleaned_json)):
if cleaned_json[i] == "{":
brace_count += 1
elif cleaned_json[i] == "}":
brace_count -= 1
if brace_count == 0:
end_idx = i + 1
break
if end_idx == -1:
print(" ⚠️ No matching closing brace found")
return None
# Extract just the JSON part
cleaned_json = cleaned_json[start_idx:end_idx]
# Parse JSON string
profile = json.loads(cleaned_json)
return profile
except json.JSONDecodeError as e:
print(f" ❌ JSON parsing error: {e}")
# Print first 200 chars for debugging
preview = json_str[:200] if len(json_str) > 200 else json_str
print(f" Preview: {preview}...")
return None
except Exception as e:
print(f" ❌ Unexpected error: {e}")
return None
async def process_investor_profile(
self, name: str, website: str, profile_json: str
) -> Optional[dict]:
"""
Process investor profile from CSV data.
Manually extracts fields and uses LLM only for currency conversion.
"""
profile = self.parse_json_profile(profile_json)
if not profile:
return None
try:
# Extract basic info
investor_data = {
"name": name.strip() if name else None,
"website": website.strip() if website else None,
"headquarters": profile.get("headquarters"),
"description": profile.get("investorDescription"),
"aum": None,
"aum_as_of_date": None,
"aum_source_url": None,
"investment_thesis": profile.get("investmentThesisFocus", []),
"portfolio_highlights": profile.get("portfolioHighlights", []),
"linked_documents": profile.get("linkedDocuments", []),
"researcher_notes": profile.get("researcherNotes"),
"missing_important_fields": profile.get("missingImportantFields", []),
"sources": profile.get("sources", {}),
"team_members": [],
"funds": [],
}
# Process AUM
aum_data = profile.get("overallAssetsUnderManagement", {})
if aum_data and isinstance(aum_data, dict):
aum_amount = aum_data.get("aumAmount")
if aum_amount and aum_amount != "Not Available":
# Convert AUM to USD integer
aum_usd = await self.convert_to_usd(aum_amount)
investor_data["aum"] = aum_usd
investor_data["aum_as_of_date"] = aum_data.get("asOfDate")
investor_data["aum_source_url"] = aum_data.get("sourceUrl")
# Process senior leadership
senior_leadership = profile.get("seniorLeadership", [])
for member in senior_leadership:
if isinstance(member, dict) and member.get("name"):
investor_data["team_members"].append(
{
"name": member.get("name"),
"title": member.get("title"),
"role": member.get("title"), # Use title as role
"email": None,
"source_url": member.get("sourceUrl"),
}
)
# Process funds
funds = profile.get("funds", [])
for fund in funds:
if isinstance(fund, dict):
fund_data = {
"fund_name": fund.get("fundName"),
"fund_size": None,
"fund_size_source_url": fund.get("fundSizeSourceUrl"),
"check_size_lower": None,
"check_size_upper": None,
"source_url": fund.get("sourceUrl"),
"source_provider": fund.get("sourceProvider"),
"geographic_focus": None, # Will be converted to string
"investment_stage_names": fund.get("investmentStageFocus", []),
"sector_names": fund.get("sectorFocus", []),
}
# Convert geographic focus from array to comma-separated string
geo_focus = fund.get("geographicFocus", [])
if geo_focus and isinstance(geo_focus, list):
fund_data["geographic_focus"] = ", ".join(geo_focus)
# Convert fund size to USD integer
fund_size_str = fund.get("fundSize")
if fund_size_str and fund_size_str != "Not Available":
fund_size_usd = await self.convert_to_usd(fund_size_str)
if fund_size_usd:
fund_data["fund_size"] = fund_size_usd # Store as integer
# Parse check size range from estimated investment size
est_size_str = fund.get("estimatedInvestmentSize")
if est_size_str and est_size_str != "Not Available":
check_lower, check_upper = await self.parse_check_size_range(
est_size_str
)
if check_lower is not None:
fund_data["check_size_lower"] = check_lower
if check_upper is not None:
fund_data["check_size_upper"] = check_upper
investor_data["funds"].append(fund_data)
return investor_data
except Exception as e:
print(f"Error processing investor profile for {name}: {e}")
return None
async def process_company_profile(
self, name: str, website: str, profile_json: str, investor_names: str = None
) -> Optional[dict]:
"""
Process company profile from CSV data.
Only extracts founded_year and key_executives - rest is in base database.
"""
profile = self.parse_json_profile(profile_json)
if not profile:
return None
try:
# Only extract founded_year and key_executives
company_data = {
"name": name.strip() if name else None,
"founded_year": None,
"key_executives": [],
}
# Process key executives/leadership
key_executives = profile.get("keyExecutives", [])
if not key_executives:
# Try alternative field names
key_executives = profile.get("seniorLeadership", [])
for exec_member in key_executives:
if isinstance(exec_member, dict) and exec_member.get("name"):
company_data["key_executives"].append(
{
"name": exec_member.get("name"),
"title": exec_member.get("title"),
"source_url": exec_member.get("sourceUrl"),
}
)
# Try to extract founding year from description
description = profile.get("companyDescription", "")
if description:
# Look for patterns like "founded in 2020", "Gegründet 2020", "founded 2020"
year_patterns = [
r"founded in (\d{4})",
r"founded (\d{4})",
r"Gegründet (\d{4})",
r"established in (\d{4})",
r"since (\d{4})",
r"\((\d{4})\)", # Year in parentheses
]
for pattern in year_patterns:
match = re.search(pattern, description, re.IGNORECASE)
if match:
try:
year = int(match.group(1))
if 1900 <= year <= 2025: # Sanity check
company_data["founded_year"] = year
break
except Exception:
continue
return company_data
except Exception as e:
print(f"Error processing company profile for {name}: {e}")
return None
def _save_parsed_company_to_db(
self, db: Session, company_data: dict
) -> Optional[CompanyTable]:
"""Save manually parsed company data to database - only updates founded_year and key_executives"""
try:
# Check if company already exists (should exist in base database)
existing_company = (
db.query(CompanyTable).filter_by(name=company_data["name"]).first()
)
if existing_company:
# Update only founded_year on existing company
company = existing_company
updated_fields = []
if company_data.get("founded_year"):
company.founded_year = company_data["founded_year"]
updated_fields.append(
f"founded_year: {company_data['founded_year']}"
)
# Add/update company members (key executives)
# First, remove existing members if updating
db.query(CompanyMember).filter_by(company_id=company.id).delete()
exec_count = 0
for exec_data in company_data.get("key_executives", []):
member = CompanyMember(
name=exec_data.get("name"),
role=exec_data.get("title"),
linkedin=exec_data.get(
"source_url"
), # Store source URL in linkedin field
company_id=company.id,
)
db.add(member)
exec_count += 1
if exec_count > 0:
updated_fields.append(f"{exec_count} executives")
if updated_fields:
print(f" 📝 Updated: {', '.join(updated_fields)}")
return company
else:
# Company not found in base database, skip
print(" ⚠️ Not in database - skipping")
return None
except Exception as e:
print(f" ❌ Error saving: {e}")
db.rollback()
return None
def _save_parsed_investor_to_db(
self, db: Session, investor_data: dict
) -> Optional[InvestorTable]:
"""Save manually parsed investor data to database"""
try:
# Check if investor already exists
existing_investor = (
db.query(InvestorTable).filter_by(name=investor_data["name"]).first()
)
if existing_investor:
# Update existing investor
investor = existing_investor
investor.website = investor_data.get("website") or investor.website
investor.headquarters = (
investor_data.get("headquarters") or investor.headquarters
)
investor.description = (
investor_data.get("description") or investor.description
)
investor.aum = investor_data.get("aum") or investor.aum
investor.aum_as_of_date = (
investor_data.get("aum_as_of_date") or investor.aum_as_of_date
)
investor.aum_source_url = (
investor_data.get("aum_source_url") or investor.aum_source_url
)
investor.investment_thesis = (
investor_data.get("investment_thesis") or investor.investment_thesis
)
investor.portfolio_highlights = (
investor_data.get("portfolio_highlights")
or investor.portfolio_highlights
)
investor.linked_documents = (
investor_data.get("linked_documents") or investor.linked_documents
)
investor.researcher_notes = (
investor_data.get("researcher_notes") or investor.researcher_notes
)
investor.missing_important_fields = (
investor_data.get("missing_important_fields")
or investor.missing_important_fields
)
investor.sources = investor_data.get("sources") or investor.sources
else:
# Create new investor
investor = InvestorTable(
name=investor_data["name"],
website=investor_data.get("website"),
headquarters=investor_data.get("headquarters"),
description=investor_data.get("description"),
aum=investor_data.get("aum"),
aum_as_of_date=investor_data.get("aum_as_of_date"),
aum_source_url=investor_data.get("aum_source_url"),
investment_thesis=investor_data.get("investment_thesis"),
portfolio_highlights=investor_data.get("portfolio_highlights"),
linked_documents=investor_data.get("linked_documents"),
researcher_notes=investor_data.get("researcher_notes"),
missing_important_fields=investor_data.get(
"missing_important_fields"
),
sources=investor_data.get("sources"),
)
db.add(investor)
db.flush()
# Add/update team members
# First, remove existing team members if updating
if existing_investor:
db.query(InvestorMember).filter_by(investor_id=investor.id).delete()
for member_data in investor_data.get("team_members", []):
member = InvestorMember(
name=member_data.get("name"),
role=member_data.get("role"),
title=member_data.get("title"),
email=member_data.get("email"),
source_url=member_data.get("source_url"),
investor_id=investor.id,
)
db.add(member)
# Add/update funds
# First, remove existing funds if updating
if existing_investor:
db.query(FundTable).filter_by(investor_id=investor.id).delete()
for fund_data in investor_data.get("funds", []):
fund = FundTable(
investor_id=investor.id,
fund_name=fund_data.get("fund_name"),
fund_size=fund_data.get("fund_size"), # Now an integer
fund_size_source_url=fund_data.get("fund_size_source_url"),
check_size_lower=fund_data.get("check_size_lower"),
check_size_upper=fund_data.get("check_size_upper"),
source_url=fund_data.get("source_url"),
source_provider=fund_data.get("source_provider"),
geographic_focus=fund_data.get("geographic_focus"), # Now a string
)
db.add(fund)
db.flush() # Get the fund ID
# Add investment stages (many-to-many)
for stage_name in fund_data.get("investment_stage_names", []):
stage = self._get_or_create_investment_stage(db, stage_name)
fund.investment_stages.append(stage)
# Add sectors (many-to-many)
for sector_name in fund_data.get("sector_names", []):
sector = self._get_or_create_sector(db, sector_name)
fund.sectors.append(sector)
return investor
except Exception as e:
print(f"Error saving investor to database: {e}")
db.rollback()
return None
def _get_or_create_investment_stage(
self, db: Session, stage_name: str
) -> InvestmentStageTable:
"""Get existing investment stage or create new one"""
from db.models import InvestmentStageTable
stage = (
db.query(InvestmentStageTable)
.filter(InvestmentStageTable.name == stage_name)
.first()
)
if not stage:
stage = InvestmentStageTable(name=stage_name)
db.add(stage)
db.flush() # Get the ID without committing
return stage
def _get_or_create_sector(self, db: Session, sector_name: str) -> SectorTable:
"""Get existing sector or create new one"""
sector = db.query(SectorTable).filter(SectorTable.name == sector_name).first()
@@ -596,6 +49,7 @@ Return the lower and upper bounds in USD."""
check_size_lower=investor_data.investor.check_size_lower,
check_size_upper=investor_data.investor.check_size_upper,
geographic_focus=investor_data.investor.geographic_focus,
stage_focus=investor_data.investor.stage_focus,
number_of_investments=investor_data.investor.number_of_investments,
)
db.add(investor)
@@ -719,263 +173,141 @@ Return the lower and upper bounds in USD."""
print(f"Error processing row {row_idx + 1}: {e}")
return None
async def _process_single_investor(
self, idx: int, row: pd.Series, total_rows: int
) -> Optional[dict]:
"""Process a single investor row"""
try:
name = row.get("Name", "").strip() if pd.notna(row.get("Name")) else None
website = (
row.get("Website", "").strip() if pd.notna(row.get("Website")) else None
)
profile_json = (
row.get("Final Investor Profile", "")
if pd.notna(row.get("Final Investor Profile"))
else None
)
if not name or not profile_json:
print(f"⚠️ Row {idx + 1}: Skipping - missing name or profile")
return None
print(f"📊 Processing {idx + 1}/{total_rows}: {name}")
# Process the investor profile
investor_data = await self.process_investor_profile(
name, website, profile_json
)
if investor_data:
print(f"{name} parsed successfully")
return investor_data
else:
print(f" ⚠️ {name} failed to process")
return None
except Exception as e:
print(f"❌ Error processing row {idx + 1}: {e}")
return None
async def parse_investors(
self, df: pd.DataFrame, save_to_db: bool = True, batch_size: int = 10
):
"""
Parse investors from DataFrame using manual JSON parsing and LLM for currency conversion.
Processes multiple investors concurrently for better performance.
Expected CSV columns: Name, Website, Final Investor Profile, Final Profile sourcing
Args:
df: DataFrame with investor data
save_to_db: Whether to save to database
batch_size: Number of investors to process concurrently (default: 10)
"""
results = []
async def parse_investors(self, df, save_to_db: bool = True):
"""Parse investors from DataFrame and optionally save to database"""
investors = []
df = df[20:]
db = None
if save_to_db:
db = get_db_session()
try:
total_rows = len(df)
print(
f"\n🚀 Starting to process {total_rows} investors with batch size {batch_size}..."
)
# Process rows in batches asynchronously
batch_size = 20 # Adjust batch size as needed
rows = [(idx, row) for idx, row in df.iterrows()]
# Process in batches
for batch_start in range(0, total_rows, batch_size):
batch_end = min(batch_start + batch_size, total_rows)
print(
f"\n🔄 Processing batch {batch_start + 1}-{batch_end} of {total_rows}..."
)
for i in range(0, len(rows), batch_size):
batch = rows[i : i + batch_size]
# Create tasks for concurrent processing
tasks = []
for idx in range(batch_start, batch_end):
row = df.iloc[idx]
task = self._process_single_investor(idx, row, total_rows)
tasks.append(task)
# Process batch asynchronously
tasks = [
self._process_row(row, idx, is_investor=True) for idx, row in batch
]
# Process batch concurrently
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
# Filter out None results and exceptions, then save to database
for investor_data in batch_results:
if investor_data and not isinstance(investor_data, Exception):
results.append(investor_data)
# Handle results from batch
for (idx, row), result in zip(batch, batch_results):
if isinstance(result, Exception):
print(f"Error processing row {idx}: {result}")
if db:
db.rollback()
continue
# Save to database
if result:
# Convert dict to InvestorData if needed
if isinstance(result, dict):
investor_data = InvestorData(**result)
else:
investor_data = result
investors.append(investor_data)
# Save to database if requested
if save_to_db and db:
try:
saved_investor = self._save_parsed_investor_to_db(
saved_investor = self._save_investor_to_db(
db, investor_data
)
if saved_investor:
print(
f" ✅ Saved {investor_data['name']} to database (ID: {saved_investor.id})"
)
else:
print(
f" ❌ Failed to save {investor_data['name']} to database"
)
db.commit()
print(
f"✅ Saved investor '{saved_investor.name}' to database"
)
except Exception as e:
db.rollback()
print(
f" ❌ Database error for {investor_data['name']}: {e}"
)
elif isinstance(investor_data, Exception):
print(f" ❌ Exception occurred: {investor_data}")
print(f"❌ Failed to save investor to database: {e}")
# Commit batch to database
if save_to_db and db:
try:
db.commit()
print(f"💾 Committed batch {batch_start + 1}-{batch_end}")
except Exception as e:
db.rollback()
print(f"❌ Failed to commit batch: {e}")
print(
f"Completed batch {i // batch_size + 1} of {(len(rows) + batch_size - 1) // batch_size}"
)
except Exception as e:
print(f"❌ Fatal error in parse_investors: {e}")
print(f"Error in batch processing: {e}")
if db:
db.rollback()
finally:
if db:
db.close()
print(f"\n🎉 Completed! Processed {len(results)}/{total_rows} investors")
return results
return investors
async def _process_single_company(
self, idx: int, row: pd.Series, total_rows: int
) -> Optional[dict]:
"""Process a single company row"""
try:
name = row.get("Name", "").strip() if pd.notna(row.get("Name")) else None
website = (
row.get("Website", "").strip() if pd.notna(row.get("Website")) else None
)
investor_names = (
row.get("Investor", "").strip()
if pd.notna(row.get("Investor"))
else None
)
# Try both column names for flexibility
profile_json = (
row.get("Perplexity Gap Output", "")
if pd.notna(row.get("Perplexity Gap Output"))
else row.get("Final Investor Profile", "")
if pd.notna(row.get("Final Investor Profile"))
else None
)
if not name or not profile_json:
print(f"⚠️ Row {idx + 1}: Skipping - missing name or profile")
return None
print(f"📊 Processing {idx + 1}/{total_rows}: {name}")
# Process the company profile
company_data = await self.process_company_profile(
name, website, profile_json, investor_names
)
if company_data:
print(f"{name} parsed successfully")
return company_data
else:
print(f" ⚠️ {name} failed to process")
return None
except Exception as e:
print(f"❌ Error processing row {idx + 1}: {e}")
return None
async def parse_companies(
self, df: pd.DataFrame, save_to_db: bool = True, batch_size: int = 10
):
"""
Parse companies from DataFrame using manual JSON parsing.
Processes multiple companies concurrently for better performance.
Expected CSV columns: Name, Website, Investor, Final Investor Profile (actually company profile)
Args:
df: DataFrame with company data
save_to_db: Whether to save to database
batch_size: Number of companies to process concurrently (default: 10)
"""
results = []
async def parse_companies(self, df, save_to_db: bool = True):
"""Parse companies from DataFrame and optionally save to database"""
companies = []
df = df[20:]
db = None
if save_to_db:
db = get_db_session()
try:
total_rows = len(df)
print(
f"\n🚀 Starting to process {total_rows} companies with batch size {batch_size}..."
)
# Process rows in batches asynchronously
batch_size = 20 # Adjust batch size as needed
rows = [(idx, row) for idx, row in df.iterrows()]
# Process in batches
for batch_start in range(0, total_rows, batch_size):
batch_end = min(batch_start + batch_size, total_rows)
print(
f"\n🔄 Processing batch {batch_start + 1}-{batch_end} of {total_rows}..."
)
for i in range(0, len(rows), batch_size):
batch = rows[i : i + batch_size]
# Create tasks for concurrent processing
tasks = []
for idx in range(batch_start, batch_end):
row = df.iloc[idx]
task = self._process_single_company(idx, row, total_rows)
tasks.append(task)
# Process batch asynchronously
tasks = [
self._process_row(row, idx, is_investor=False) for idx, row in batch
]
# Process batch concurrently
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
# Filter out None results and exceptions, then save to database
for company_data in batch_results:
if company_data and not isinstance(company_data, Exception):
results.append(company_data)
# Handle results from batch
for (idx, row), result in zip(batch, batch_results):
if isinstance(result, Exception):
print(f"Error processing row {idx}: {result}")
if db:
db.rollback()
continue
# Save to database
if result:
# Convert dict to CompanyData if needed
if isinstance(result, dict):
company_data = CompanyData(**result)
else:
company_data = result
companies.append(company_data)
# Save to database if requested
if save_to_db and db:
try:
saved_company = self._save_parsed_company_to_db(
saved_company = self._save_company_to_db(
db, company_data
)
if saved_company:
print(
f" ✅ Saved {company_data['name']} to database (ID: {saved_company.id})"
)
else:
print(
f" ❌ Failed to save {company_data['name']} to database"
)
db.commit()
print(
f"✅ Saved company '{saved_company.name}' to database"
)
except Exception as e:
db.rollback()
print(
f" ❌ Database error for {company_data['name']}: {e}"
)
elif isinstance(company_data, Exception):
print(f" ❌ Exception occurred: {company_data}")
print(f"❌ Failed to save company to database: {e}")
# Commit batch to database
if save_to_db and db:
try:
db.commit()
print(f"💾 Committed batch {batch_start + 1}-{batch_end}")
except Exception as e:
db.rollback()
print(f"❌ Failed to commit batch: {e}")
print(
f"Completed batch {i // batch_size + 1} of {(len(rows) + batch_size - 1) // batch_size}"
)
except Exception as e:
print(f"❌ Fatal error in parse_companies: {e}")
print(f"Error processing row {idx}: {e}")
if db:
db.rollback()
finally:
if db:
db.close()
print(f"\n🎉 Completed! Processed {len(results)}/{total_rows} companies")
return results
return companies
# async def main():
+76 -253
View File
@@ -1,25 +1,19 @@
import asyncio
import hashlib
import logging
import os
from typing import List, Optional
from typing import List
from db.db import get_db
from db.models import FundTable, InvestorTable, ProjectTable
from langchain_core.prompts import ChatPromptTemplate
from db.db import DATABASE_URL, get_db
from db.models import InvestorTable
from langchain import hub
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from schemas.router_schemas import (
CompanyMinimal,
InvestmentResponse,
PaginatedResponse,
SectorMinimal,
)
from sqlalchemy import text
from langgraph.prebuilt import create_react_agent
from schemas.py_schemas import InvestorData, InvestorList
from sqlalchemy.orm import selectinload
from services.compatibility_score import calculate_project_investor_compatibility
logger = logging.getLogger(__name__)
# Connect to SQLite
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
db = SQLDatabase.from_uri(DATABASE_URL)
class QueryProcessor:
@@ -30,266 +24,95 @@ class QueryProcessor:
model="openai/gpt-4o-mini",
temperature=0,
)
# Query cache for performance
self.query_cache = {}
# SQL generation prompt
self.sql_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a SQL expert. Generate a SQLite query to find fund IDs based on user requirements.
Database Schema:
- funds: id, fund_name, investor_id, check_size_lower, check_size_upper, geographic_focus
- fund_sectors: fund_id, sector_id
- fund_investment_stages: fund_id, stage_id
- sectors: id, name
- investment_stages: id, name
- investors: id, name, aum
IMPORTANT RULES:
1. ALWAYS return ONLY fund IDs (funds.id) - use SELECT DISTINCT f.id
2. For geography: Be FLEXIBLE - use OR with variations and partial matches
- 'Europe' → WHERE geographic_focus LIKE '%Europe%' OR geographic_focus LIKE '%European%'
- 'America' → WHERE geographic_focus LIKE '%America%' OR geographic_focus LIKE '%US%' OR geographic_focus LIKE '%United States%'
- 'Asia' → WHERE geographic_focus LIKE '%Asia%' OR geographic_focus LIKE '%Asian%'
- If no geography specified, DON'T filter by geography
3. For stages: Use LEFT JOIN and LIKE for flexible matching with synonyms
- 'Seed' → s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%'
- 'Series A' → s.name LIKE '%Series A%' OR s.name LIKE '%A%'
- 'Growth' → s.name LIKE '%Growth%' OR s.name LIKE '%Late%' OR s.name LIKE '%Expansion%'
- If stage not specified, include ALL funds
4. For sectors: Use LEFT JOIN and include related terms with OR
- 'Fintech' → sec.name LIKE '%Fintech%' OR sec.name LIKE '%Finance%' OR sec.name LIKE '%Financial%'
- 'AI' → sec.name LIKE '%AI%' OR sec.name LIKE '%Artificial Intelligence%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%'
- 'Healthcare' → sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%'
5. For check size filters (be flexible with ranges):
- "under X" → WHERE (check_size_upper <= X OR check_size_upper IS NULL)
- "over X" → WHERE (check_size_lower >= X OR check_size_lower IS NULL)
- "between X and Y" → WHERE check_size_lower >= X AND check_size_upper <= Y
6. Use LEFT JOIN for stages and sectors so funds without tags still match
7. Use DISTINCT to avoid duplicates from joins
8. Be INCLUSIVE - use OR conditions to cast a wider net
9. If query is very simple (e.g., just "seed stage"), don't add unnecessary filters
10. Return a single, complete SELECT query
Example Queries:
Q: "Seed stage investors in Europe"
A: SELECT DISTINCT f.id FROM funds f
LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id
LEFT JOIN investment_stages s ON fis.stage_id = s.id
WHERE (s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%' OR s.id IS NULL)
AND (f.geographic_focus LIKE '%Europe%' OR f.geographic_focus LIKE '%European%')
Q: "Fintech investors with check size under 5 million"
A: SELECT DISTINCT f.id FROM funds f
LEFT JOIN fund_sectors fs ON f.id = fs.fund_id
LEFT JOIN sectors sec ON fs.sector_id = sec.id
WHERE (sec.name LIKE '%Fintech%' OR sec.name LIKE '%Finance%' OR sec.name LIKE '%Financial%' OR sec.id IS NULL)
AND (f.check_size_upper <= 5000000 OR f.check_size_upper IS NULL)
Q: "Seed stage investors"
A: SELECT DISTINCT f.id FROM funds f
LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id
LEFT JOIN investment_stages s ON fis.stage_id = s.id
WHERE s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%'
Q: "Growth stage investors"
A: SELECT DISTINCT f.id FROM funds f
LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id
LEFT JOIN investment_stages s ON fis.stage_id = s.id
WHERE s.name LIKE '%Growth%' OR s.name LIKE '%Late%' OR s.name LIKE '%Expansion%' OR s.name LIKE '%Series C%' OR s.name LIKE '%Series D%'
Q: "AI investors in America"
A: SELECT DISTINCT f.id FROM funds f
LEFT JOIN fund_sectors fs ON f.id = fs.fund_id
LEFT JOIN sectors sec ON fs.sector_id = sec.id
WHERE (sec.name LIKE '%AI%' OR sec.name LIKE '%Artificial Intelligence%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%')
AND (f.geographic_focus LIKE '%America%' OR f.geographic_focus LIKE '%US%' OR f.geographic_focus LIKE '%United States%' OR f.geographic_focus LIKE '%USA%')
Q: "Healthcare investors"
A: SELECT DISTINCT f.id FROM funds f
LEFT JOIN fund_sectors fs ON f.id = fs.fund_id
LEFT JOIN sectors sec ON fs.sector_id = sec.id
WHERE sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR sec.name LIKE '%Pharma%'
IMPORTANT: Use LEFT JOIN so funds without sector/stage tags can still match. Include synonym terms with OR for better recall.
Return ONLY the SQL query, no explanations or markdown.""",
),
("user", "{question}"),
]
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
# Update system message to specifically request only investor IDs
system_message_updated = (
prompt_template.format(dialect="SQLite", top_k=5)
+ "\n\nIMPORTANT: You must ONLY return the investor IDs (id field) that match the user's criteria. "
+ "Do NOT return any other information, explanations, or data. "
+ "Your response should be ONLY a comma-separated list of numbers representing the investor IDs. "
+ "Example format: 1, 5, 12, 23"
)
self.agent = create_react_agent(
model=self.llm,
tools=self.toolkit.get_tools(),
prompt=system_message_updated,
)
def _get_cache_key(self, question: str) -> str:
"""Generate cache key from normalized question."""
return hashlib.md5(question.lower().strip().encode()).hexdigest()
def process_query(self, question: str) -> InvestorList:
"""Process a query using the LLM and return investor data."""
# Let the LLM handle all database interactions and filtering to get IDs
response = self.agent.invoke(
{"messages": [("user", question)]},
)
async def process_query(
self, question: str, project_id: Optional[int] = None
) -> PaginatedResponse[InvestmentResponse]:
"""Async wrapper for process_query. Runs blocking work in a thread to avoid
blocking the event loop.
"""
return await asyncio.to_thread(self._process_query_sync, question, project_id)
# Extract the actual message content
ai_response = (
response["messages"][-1].content if response.get("messages") else ""
)
def _process_query_sync(
self, question: str, project_id: Optional[int] = None
) -> PaginatedResponse[InvestmentResponse]:
"""Synchronous implementation of process_query. This is run in a thread by
the async wrapper above.
"""
cache_key = self._get_cache_key(question)
# Extract investor IDs from the AI response
investor_ids = self._extract_investor_ids_from_response(ai_response)
# Check cache first
if cache_key in self.query_cache:
sql_query = self.query_cache[cache_key]
logger.info(f"Using cached SQL: {sql_query}")
else:
# Generate SQL query
messages = self.sql_prompt.format_messages(question=question)
response = self.llm.invoke(messages)
sql_query = response.content.strip()
# Fetch full investor data using the IDs
return self._fetch_investors_by_ids(investor_ids)
# Clean up SQL (remove markdown code blocks if present)
sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
def _extract_investor_ids_from_response(self, ai_response: str) -> List[int]:
"""Extract investor IDs from AI response."""
import re
# Cache the query
self.query_cache[cache_key] = sql_query
logger.info(f"Generated SQL: {sql_query}")
# Execute query to get fund IDs
db_session = next(get_db())
investor_ids = []
try:
result = db_session.execute(text(sql_query))
fund_ids = [row[0] for row in result.fetchall()]
logger.info(
f"Found {len(fund_ids)} fund IDs: {fund_ids[:10]}{'...' if len(fund_ids) > 10 else ''}"
)
# Try multiple patterns to extract IDs from the response
# Pattern 1: Simple numbers (assuming they are IDs)
numbers = re.findall(r"\b\d+\b", ai_response)
investor_ids = [int(num) for num in numbers]
# Pattern 2: If response contains explicit ID references
id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower())
if id_matches:
investor_ids = [int(id_str) for id_str in id_matches]
return self._fetch_funds_by_ids(fund_ids, project_id)
except Exception as e:
logger.error(f"SQL execution error: {e}")
logger.error(f"Failed SQL: {sql_query}")
# Return empty result
return PaginatedResponse(
items=[], total=0, page=1, page_size=10, total_pages=0
)
finally:
db_session.close()
print(f"Error extracting IDs from response: {e}")
return []
def _fetch_funds_by_ids(
self, fund_ids: List[int], project_id: Optional[int] = None
) -> PaginatedResponse[InvestmentResponse]:
"""Fetch funds with all their relationships from the database using fund IDs.
Constructs response similar to read_investors but starting from funds.
return investor_ids
Args:
fund_ids: List of fund IDs to fetch
project_id: Optional project ID for compatibility scoring
"""
if not fund_ids:
return PaginatedResponse(
items=[],
total=0,
page=1,
page_size=len(fund_ids) if fund_ids else 10,
total_pages=0,
)
def _fetch_investors_by_ids(self, investor_ids: List[int]) -> InvestorList:
"""Fetch investors with all their relationships from the database using IDs."""
if not investor_ids:
return InvestorList(investors=[])
# Get database session
db_session = next(get_db())
try:
# Load project if project_id provided
project = None
if project_id is not None:
project = (
db_session.query(ProjectTable)
.options(selectinload(ProjectTable.sector))
.filter(ProjectTable.id == project_id)
.first()
)
# Query funds with all necessary relationships loaded
funds = (
db_session.query(FundTable)
# Build query with all relationships loaded
query = (
db_session.query(InvestorTable)
.options(
selectinload(FundTable.investor).selectinload(
InvestorTable.portfolio_companies
),
selectinload(FundTable.investor).selectinload(
InvestorTable.team_members
),
selectinload(FundTable.investor).selectinload(
InvestorTable.sectors
),
selectinload(FundTable.investment_stages),
selectinload(FundTable.sectors),
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
)
.filter(FundTable.id.in_(fund_ids))
.all()
.filter(InvestorTable.id.in_(investor_ids))
)
# Transform to InvestmentResponse format (one row per fund)
investment_responses = []
for fund in funds:
investor = fund.investor
investors = query.all()
# Calculate compatibility score if project provided
compatibility_score = 1.0
if project is not None:
compatibility_score = calculate_project_investor_compatibility(
project=project, investor=investor, use_funds=True
)
# Get top 3 portfolio companies (id and name only)
portfolio_companies = [
CompanyMinimal(id=company.id, name=company.name)
for company in investor.portfolio_companies[:3]
]
# Get stage focus as comma-separated string
stage_focus = (
", ".join([stage.name for stage in fund.investment_stages])
if fund.investment_stages
else None
# Transform to InvestorData format
investor_data_list = []
for investor in investors:
investor_data = InvestorData(
investor=investor,
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members,
sectors=investor.sectors,
)
investor_data_list.append(investor_data)
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name)
]
investment_response = InvestmentResponse(
id=investor.id,
name=f"{investor.name} - {fund.fund_name}"
if fund.fund_name
else investor.name,
aum=investor.aum,
check_size_lower=fund.check_size_lower,
check_size_upper=fund.check_size_upper,
geographic_focus=fund.geographic_focus,
stage_focus=stage_focus,
portfolio_companies=portfolio_companies,
sectors=fund_sectors,
compatibility_score=compatibility_score,
)
investment_responses.append(investment_response)
total_count = len(investment_responses)
total_pages = 1 if total_count > 0 else 0
return PaginatedResponse(
items=investment_responses,
total=total_count,
page=1,
page_size=total_count,
total_pages=total_pages,
)
return InvestorList(investors=investor_data_list)
finally:
db_session.close()
-340
View File
@@ -1,340 +0,0 @@
from pathlib import Path
from typing import Any, Dict, List, Optional
# Import database models and compatibility score service
from db.models import InvestorTable, ProjectTable
from jinja2 import Environment, FileSystemLoader
from playwright.async_api import async_playwright
from services.compatibility_score import calculate_project_investor_compatibility
class ReportGenerator:
"""Service for generating PDF reports from HTML templates"""
def __init__(self):
# Set up Jinja2 environment
template_dir = Path(__file__).parent.parent / "templates"
self.env = Environment(loader=FileSystemLoader(str(template_dir)))
async def generate_investor_report(
self,
investor_data: Dict[str, Any],
project_data: Optional[Dict[str, Any]] = None,
investor_model: Optional[InvestorTable] = None,
project_model: Optional[ProjectTable] = None,
) -> bytes:
"""
Generate a PDF report for an investor profile.
Args:
investor_data: Dictionary containing investor information
project_data: Optional dictionary containing project information for compatibility analysis
investor_model: Optional database model for investor (used for compatibility scoring)
project_model: Optional database model for project (used for compatibility scoring)
Returns:
bytes: PDF file content
"""
# Prepare template context
context = self._prepare_context(
investor_data, project_data, investor_model, project_model
)
# Render HTML from template
template = self.env.get_template("report.html")
html_content = template.render(**context)
# Convert HTML to PDF using Playwright
pdf_bytes = await self._html_to_pdf(html_content)
return pdf_bytes
def _prepare_context(
self,
investor_data: Dict[str, Any],
project_data: Optional[Dict[str, Any]] = None,
investor_model: Optional[InvestorTable] = None,
project_model: Optional[ProjectTable] = None,
) -> Dict[str, Any]:
"""Prepare the context dictionary for template rendering"""
context = {
"investor": investor_data,
"project": project_data,
"compatibility_score": 0,
"match_criteria": [],
"recommendation": None,
}
# If project data is provided, calculate compatibility
if project_data:
# Use the compatibility_score service if models are provided
if investor_model and project_model:
# Calculate using the standardized compatibility score service
# Returns score between 0 and 1, convert to percentage (0-100)
score_decimal = calculate_project_investor_compatibility(
project=project_model, investor=investor_model, use_funds=True
)
context["compatibility_score"] = int(score_decimal * 100)
else:
# Fallback to old calculation method if models not provided
context["compatibility_score"] = self._calculate_compatibility_score(
investor_data, project_data
)
context["match_criteria"] = self._generate_match_criteria(
investor_data, project_data
)
context["recommendation"] = self._generate_recommendation(
context["compatibility_score"], context["match_criteria"]
)
return context
def _calculate_compatibility_score(
self, investor_data: Dict[str, Any], project_data: Dict[str, Any]
) -> int:
"""Calculate overall compatibility score between investor and project"""
score = 0
weights = {
"sector": 30,
"stage": 30,
"geography": 20,
"check_size": 20,
}
# Aggregate data from all funds
all_sectors = set(investor_data.get("sectors", []))
all_stages = set()
all_geographies = []
check_ranges = []
for fund in investor_data.get("funds", []):
all_sectors.update(fund.get("sectors", []))
all_stages.update(fund.get("investment_stages", []))
if fund.get("geographic_focus"):
all_geographies.append(fund["geographic_focus"])
if fund.get("check_size_lower") and fund.get("check_size_upper"):
check_ranges.append(
{
"lower": fund["check_size_lower"],
"upper": fund["check_size_upper"],
}
)
# Sector match
project_sectors = set(project_data.get("sectors", []))
if all_sectors and project_sectors:
if all_sectors & project_sectors:
score += weights["sector"]
# Stage match - case insensitive comparison
project_stage = project_data.get("stage")
if project_stage and all_stages:
# Normalize stage names for comparison (case-insensitive)
normalized_stages = {
stage.lower().replace("_", " ") for stage in all_stages
}
project_stage_normalized = project_stage.lower().replace("_", " ")
if project_stage_normalized in normalized_stages:
score += weights["stage"]
# Geography match - check if any fund matches
project_geo = (project_data.get("location") or "").lower()
geo_match = False
if all_geographies:
for geo in all_geographies:
if geo:
geo_lower = geo.lower()
# Match if investor geography is "global" or if there's a location overlap
if "global" in geo_lower or "worldwide" in geo_lower:
geo_match = True
break
if project_geo and (
geo_lower in project_geo or project_geo in geo_lower
):
geo_match = True
break
if geo_match:
score += weights["geography"]
# Check size match - check if any fund's range matches
project_valuation = project_data.get("valuation", 0)
check_match = False
if project_valuation and check_ranges:
for check_range in check_ranges:
if check_range["lower"] <= project_valuation <= check_range["upper"]:
check_match = True
break
if check_match:
score += weights["check_size"]
return min(score, 100)
def _generate_match_criteria(
self, investor_data: Dict[str, Any], project_data: Dict[str, Any]
) -> List[Dict[str, str]]:
"""Generate detailed match criteria table"""
criteria = []
# Aggregate data from all funds
all_sectors = set(investor_data.get("sectors", []))
all_stages = set()
all_geographies = []
check_ranges = []
for fund in investor_data.get("funds", []):
all_sectors.update(fund.get("sectors", []))
all_stages.update(fund.get("investment_stages", []))
if fund.get("geographic_focus"):
all_geographies.append(fund["geographic_focus"])
if fund.get("check_size_lower") and fund.get("check_size_upper"):
check_ranges.append(
{
"lower": fund["check_size_lower"],
"upper": fund["check_size_upper"],
"fund_name": fund.get("fund_name", "Unnamed Fund"),
}
)
# Sector criterion
project_sectors = project_data.get("sectors", [])
sector_match = "Perfect" if all_sectors & set(project_sectors) else "Mismatch"
criteria.append(
{
"name": "Sector",
"requirement": ", ".join(project_sectors) if project_sectors else "N/A",
"evidence": ", ".join(list(all_sectors)[:3]) if all_sectors else "N/A",
"match": sector_match,
"weight": "30%",
}
)
# Stage criterion - case insensitive comparison
project_stage = project_data.get("stage", "N/A")
stage_match = "Mismatch"
if project_stage != "N/A" and all_stages:
# Normalize stage names for comparison
normalized_stages = {
stage.lower().replace("_", " ") for stage in all_stages
}
project_stage_normalized = project_stage.lower().replace("_", " ")
stage_match = (
"Perfect"
if project_stage_normalized in normalized_stages
else "Mismatch"
)
elif project_stage == "N/A":
stage_match = "N/A"
criteria.append(
{
"name": "Stage",
"requirement": str(project_stage),
"evidence": ", ".join(all_stages) if all_stages else "N/A",
"match": stage_match,
"weight": "30%",
}
)
# Geography criterion
project_geo = project_data.get("location") or "N/A"
investor_geo_display = ", ".join(all_geographies) if all_geographies else "N/A"
# Safe comparison handling None values and "Global" matches
geo_match = "Mismatch"
if project_geo != "N/A" and all_geographies:
for geo in all_geographies:
if geo:
geo_lower = geo.lower()
# Match if investor geography is "global" or if there's a location overlap
if "global" in geo_lower or "worldwide" in geo_lower:
geo_match = "Perfect"
break
if (
geo_lower in project_geo.lower()
or project_geo.lower() in geo_lower
):
geo_match = "Strong"
break
elif not all_geographies and project_geo == "N/A":
geo_match = "N/A"
criteria.append(
{
"name": "Geography",
"requirement": project_geo,
"evidence": investor_geo_display,
"match": geo_match,
"weight": "20%",
}
)
# Check Size criterion
project_val = project_data.get("valuation", 0)
# Build evidence string from all fund ranges
check_evidence = "N/A"
if check_ranges:
evidence_parts = []
for cr in check_ranges[:3]: # Show up to 3 funds
range_str = (
f"{cr['lower'] / 1000000:.0f}M - €{cr['upper'] / 1000000:.0f}M"
)
if cr["fund_name"]:
evidence_parts.append(f"{cr['fund_name']}: {range_str}")
else:
evidence_parts.append(range_str)
check_evidence = "; ".join(evidence_parts)
# Check if project valuation matches any fund
check_match = "N/A"
if project_val > 0 and check_ranges:
match_found = any(
cr["lower"] <= project_val <= cr["upper"] for cr in check_ranges
)
check_match = "Perfect" if match_found else "Mismatch"
criteria.append(
{
"name": "Check Size",
"requirement": f"{project_val / 1000000:.0f}M"
if project_val
else "N/A",
"evidence": check_evidence,
"match": check_match,
"weight": "20%",
}
)
return criteria
def _generate_recommendation(
self, score: int, criteria: List[Dict[str, str]]
) -> str:
"""Generate recommendation text based on score and criteria"""
if score >= 85:
return "High Priority. A strong target due to exceptional alignment on the most heavily-weighted criteria: Sector and Stage. The strong geographic fit further solidifies this recommendation."
elif score >= 70:
return "Medium Priority. Good alignment on key criteria with some areas of strong fit. The geographic fit in the target region supports this recommendation."
else:
return "Low Priority. Limited alignment on key investment criteria. Consider for future evaluation if circumstances change."
async def _html_to_pdf(self, html_content: str) -> bytes:
"""Convert HTML content to PDF using Playwright"""
async with async_playwright() as p:
browser = await p.chromium.launch()
page = await browser.new_page()
# Set content and wait for any dynamic content to load
await page.set_content(html_content, wait_until="networkidle")
# Generate PDF with proper settings
pdf_bytes = await page.pdf(
format="A4",
print_background=True,
margin={"top": "0", "right": "0", "bottom": "0", "left": "0"},
)
await browser.close()
return pdf_bytes
-329
View File
@@ -1,329 +0,0 @@
<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<title>Investor Profile Report</title>
<script src="https://cdn.tailwindcss.com"></script>
<style>
@page {
size: A4;
margin: 0;
}
html,
body {
margin: 0;
padding: 0;
height: 100%;
background: white;
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI",
Roboto, sans-serif;
}
/* Each page is exactly one A4 sheet */
.page {
width: 210mm;
height: 297mm;
position: relative;
background: white;
overflow: hidden;
}
/* Adds a break between pages (for print/PDF) */
.page-with-break {
page-break-after: always;
}
/* Inner content wrapper for consistent padding */
.page-content {
box-sizing: border-box;
padding: 48px; /* equivalent to Tailwind p-12 */
height: 100%;
display: flex;
flex-direction: column;
}
.tag {
display: inline-block;
padding: 4px 12px;
background: #f3f4f6;
border-radius: 4px;
font-size: 12px;
margin: 4px;
}
/* Ensure the footer text stays inside page bounds */
.page-footer {
position: absolute;
bottom: 48px;
right: 48px;
font-size: 10px;
color: #9ca3af; /* Tailwind gray-400 */
}
</style>
</head>
<body>
<!-- Page 1 -->
<div class="page page-with-break">
<div class="page-content">
<div class="flex justify-between items-start mb-8">
<div>
<p class="text-sm text-gray-600 mb-2">Investor Profile</p>
<h1 class="text-4xl font-bold text-gray-900">
{{ investor.name }}
</h1>
</div>
<a
href="{{ investor.website }}"
target="_blank"
class="bg-gray-200 text-gray-700 px-4 py-2 rounded text-sm no-underline"
>Visit Website →</a
>
</div>
<div class="grid grid-cols-2 gap-8 flex-grow">
<!-- Left Column -->
<div>
<div class="mb-4">
<h2 class="text-sm font-bold text-gray-900 uppercase mb-4">
Investor Description
</h2>
<p class="text-sm text-gray-700 leading-relaxed">
{{ investor.description or 'No description available.' }}
</p>
</div>
<div class="mb-4">
<h2 class="text-sm font-bold text-gray-900 uppercase mb-4">
Portfolio Highlights
</h2>
<div class="flex flex-wrap gap-2">
{% if investor.portfolio_highlights %}
{% for company in investor.portfolio_highlights[:5] %}
<span class="tag">{{ company }}</span>
{% endfor %}
{% else %}
<p class="text-sm text-gray-500">
No portfolio highlights available
</p>
{% endif %}
</div>
</div>
<div class="mb-4">
<h2 class="text-sm font-bold text-gray-900 uppercase mb-4">
Senior Leadership
</h2>
{% if investor.team_members %}
{% for member in investor.team_members[:2] %}
<div class="mb-3">
<p class="text-sm font-semibold text-gray-900">
{{ member.name }}
</p>
<p class="text-sm text-gray-600">
{{ member.role or member.title or 'Team Member' }}
</p>
{% if member.email %}
<p class="text-xs text-blue-600">
{{ member.email }}
</p>
{% endif %}
</div>
{% endfor %}
{% else %}
<p class="text-sm text-gray-500">No team information available</p>
{% endif %}
</div>
</div>
<!-- Right Column -->
<div class="bg-gray-50 p-6 rounded-lg">
<h2 class="text-sm font-bold text-gray-900 uppercase mb-4">
Key Data
</h2>
<div class="space-y-3 text-sm">
<div>
<p class="text-xs text-gray-600">Headquarters:</p>
<p class="font-semibold text-gray-900">
{{ investor.headquarters or 'N/A' }}
</p>
</div>
<div>
<p class="text-xs text-gray-600">Sectors:</p>
<p class="font-semibold text-gray-900">
{% if investor.sectors %}
{{ investor.sectors | join(', ') }}
{% else %}
N/A
{% endif %}
</p>
</div>
<div>
<p class="text-xs text-gray-600">AUM (EUR million):</p>
<p class="font-semibold text-gray-900">
{% if investor.aum %}
€{{ '{:,.0f}'.format(investor.aum / 1000000) }}M
{% else %}
N/A
{% endif %}
</p>
</div>
<div>
<p class="text-xs text-gray-600 mb-1">Number of Funds:</p>
<p class="font-semibold text-gray-900">
{{ investor.funds | length if investor.funds else 'N/A' }}
</p>
</div>
</div>
<div class="mt-4">
<h3 class="text-xs font-bold text-gray-900 uppercase mb-2">
Fund Details
</h3>
{% if investor.funds %}
{% for fund in investor.funds %}
<div class="mb-3 pb-3 border-b border-gray-200">
<p class="text-sm font-semibold text-gray-900 mb-1">
{{ fund.fund_name or 'Fund ' + loop.index|string }}
</p>
<div class="text-xs text-gray-700 space-y-1">
{% if fund.fund_size %}
<p>Fund Size: €{{ '{:,.0f}'.format(fund.fund_size / 1000000) }}M</p>
{% endif %}
{% if fund.check_size_lower and fund.check_size_upper %}
<p>Check Size: €{{ '{:,.0f}'.format(fund.check_size_lower / 1000000) }}M - €{{ '{:,.0f}'.format(fund.check_size_upper / 1000000) }}M</p>
{% endif %}
{% if fund.geographic_focus %}
<p>Geography: {{ fund.geographic_focus }}</p>
{% endif %}
{% if fund.investment_stages %}
<p>Stages: {{ fund.investment_stages | join(', ') }}</p>
{% endif %}
{% if fund.sectors %}
<p>Sectors: {{ fund.sectors[:3] | join(', ') }}</p>
{% endif %}
</div>
</div>
{% endfor %}
{% else %}
<p class="text-xs text-gray-500">No fund information available</p>
{% endif %}
</div>
</div>
</div>
<div class="page-footer">Page 1</div>
</div>
</div>
<!-- Page 2 -->
{% if project %}
<div class="page">
<div class="page-content">
<h1 class="text-3xl font-bold text-gray-900 mb-8">
{{ investor.name }}: Mandate Match Analysis
</h1>
<!-- Overall Match Circle -->
<div class="flex justify-center mb-12">
<div class="text-center">
<p class="text-sm font-bold text-gray-700 uppercase mb-4">
Overall Mandate Match
</p>
<div
class="w-48 h-48 rounded-full border-8 border-green-400 flex items-center justify-center bg-green-50 mx-auto"
>
<span class="text-5xl font-bold text-green-600"
>{{ compatibility_score }}%</span
>
</div>
</div>
</div>
<!-- Mandate Alignment Analysis Table -->
<div class="mb-12">
<h2 class="text-xl font-bold text-gray-900 mb-6">
Mandate Alignment Analysis
</h2>
<table class="w-full border-collapse">
<thead>
<tr class="border-b-2 border-gray-300">
<th
class="text-left py-3 px-4 text-sm font-bold text-gray-700"
>
Criterion
</th>
<th
class="text-left py-3 px-4 text-sm font-bold text-gray-700"
>
Mandate Requirement
</th>
<th
class="text-left py-3 px-4 text-sm font-bold text-gray-700"
>
Investor Evidence (from Database)
</th>
<th
class="text-left py-3 px-4 text-sm font-bold text-gray-700"
>
Match Score
</th>
<th
class="text-left py-3 px-4 text-sm font-bold text-gray-700"
>
Weight
</th>
</tr>
</thead>
<tbody>
{% for criterion in match_criteria %}
<tr class="border-b border-gray-200">
<td class="py-4 px-4 text-sm text-gray-900">
{{ criterion.name }}
</td>
<td class="py-4 px-4 text-sm text-gray-700">
{{ criterion.requirement }}
</td>
<td class="py-4 px-4 text-sm text-gray-700">
{{ criterion.evidence }}
</td>
<td class="py-4 px-4 text-sm">
<span
class="{% if criterion.match == 'Perfect' %}text-green-600{% elif criterion.match == 'Strong' %}text-blue-600{% else %}text-yellow-600{% endif %} font-semibold"
>
{{ criterion.match }}
</span>
</td>
<td class="py-4 px-4 text-sm text-gray-700">
{{ criterion.weight }}
</td>
</tr>
{% endfor %}
</tbody>
</table>
</div>
<!-- Final Recommendation -->
<div class="bg-blue-50 border-l-4 border-blue-500 p-6 rounded">
<h3 class="text-lg font-bold text-gray-900 mb-3">
Final Recommendation & Rationale
</h3>
<p class="text-sm text-gray-700 leading-relaxed">
{{ recommendation or "High Priority. A strong target due to
exceptional alignment on the most heavily-weighted criteria:
Sector and Stage. The strong geographic fit in the DACH
region further solidifies this recommendation." }}
</p>
</div>
<div class="absolute bottom-12 right-12 text-xs text-gray-400">
Page 2
</div>
</div>
{% endif %}
</body>
</html>
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
Binary file not shown.
-315
View File
@@ -1,315 +0,0 @@
import logging
import re
import unicodedata
import pandas as pd
from models import CompanyTable, InvestorTable, SectorTable, engine, init_database
from sqlalchemy.orm import sessionmaker
# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Import the schema
init_database()
# ===================== Ingesting Original Data =====================#
def parse_investor_names(investor_names_str):
"""Parse comma-separated investor names and return a list"""
if pd.isna(investor_names_str) or investor_names_str == "":
return []
# Split by comma and clean whitespace
# investors = [name.strip() for name in str(investor_names_str).split(",")]
investors = [
clean_name(name.strip()) for name in str(investor_names_str).split(",")
]
return [investor for investor in investors if investor]
def parse_industries(industries_str):
"""Parse comma-separated industries and return a list"""
if pd.isna(industries_str) or industries_str == "":
return []
# Split by comma and clean whitespace
industries = [industry.strip() for industry in str(industries_str).split(",")]
return [industry for industry in industries if industry]
def clean_special_characters(text):
"""Clean special characters from text, converting to ASCII equivalents"""
if not text:
return text
# First remove ellipses and other problematic patterns
text = str(text).replace("...", "").replace("..", "")
# Normalize unicode characters to their closest ASCII equivalents
normalized = unicodedata.normalize("NFKD", text)
# Remove accents and convert to ASCII
ascii_text = normalized.encode("ascii", "ignore").decode("ascii")
# Remove any remaining non-alphanumeric characters except spaces, hyphens, and periods
cleaned = re.sub(r"[^a-zA-Z0-9\s\-\.]", "", ascii_text)
# Clean up multiple spaces
cleaned = re.sub(r"\s+", " ", cleaned).strip()
return cleaned
def clean_string(value):
"""Clean string values, converting empty/null/nan/0 to None and removing special characters"""
if (
pd.isna(value)
or value == ""
or str(value).lower() in ["nan", "null", "none", "0", "0.0"]
):
return None
# First clean special characters
cleaned = clean_special_characters(str(value).strip())
# Check if result is just "0" after cleaning
if cleaned in ["0", "0.0", "null", "nan", "none"]:
return None
return cleaned if cleaned else None
def clean_name(value):
"""Clean names (companies, investors) with special character handling"""
if (
pd.isna(value)
or value == ""
or str(value).lower() in ["nan", "null", "none", "0", "0.0"]
):
return None
# Clean special characters but be more permissive for names
text = str(value).strip()
# First remove ellipses and other problematic patterns
# text = text.replace("...", "").replace("..", "")
# Normalize unicode characters
normalized = unicodedata.normalize("NFKD", text)
# Convert to ASCII but keep more characters for business names
ascii_text = normalized.encode("ascii", "ignore").decode("ascii")
# Allow alphanumeric, spaces, hyphens, periods, parentheses, and ampersands
cleaned = re.sub(r"[^a-zA-Z0-9\s\-\.\(\)&]", "", ascii_text)
# Clean up multiple spaces
cleaned = re.sub(r"\s+", " ", cleaned).strip()
# Remove any trailing or leading periods
cleaned = cleaned.strip(".")
cleaned = cleaned.replace("..", "").replace("...", "")
# Check if result is just "0" after cleaning
if cleaned in ["0", "0.0", "null", "nan", "none"]:
return None
return cleaned if cleaned else None
def clean_integer(value):
"""Clean integer values, converting empty/null/nan/0 to None"""
if pd.isna(value) or str(value).lower() in ["nan", "null", "none", "", "0", "0.0"]:
return None
try:
cleaned_val = int(float(value))
return cleaned_val if cleaned_val > 0 else None
except (ValueError, TypeError):
return None
def parse_website(website_str: str):
try:
_, end = website_str.split(":")
if end == "0":
return None
return "https:" + end
except Exception:
return None
def ingest_data():
# Create database engine and session
Session = sessionmaker(bind=engine)
session = Session()
# Load CSV files
print("Loading CSV files...")
companies_df = pd.read_csv("companies.csv")
investors_df = pd.read_csv("investors.csv")
print(f"📊 Companies CSV: {len(companies_df)} rows")
print(f"📊 Investors CSV: {len(investors_df)} rows")
# Step 1: Ingest Investors
print("\n🔄 Step 1: Ingesting Investors...")
investors_processed = 0
for index, row in investors_df.iterrows():
try:
investor_name = clean_name(row.get("Filtered investor names", ""))
if investor_name:
# Check if investor already exists
existing_investor = (
session.query(InvestorTable).filter_by(name=investor_name).first()
)
if not existing_investor:
investor = InvestorTable(
name=investor_name,
description=clean_string(row.get("Business model", "")),
headquarters=clean_string(row.get("HQ", "")),
website=parse_website(str(row.get("Website", "")).strip()),
number_of_investments=clean_integer(
row.get("Number of investments")
),
)
session.add(investor)
investors_processed += 1
if investors_processed % 1000 == 0:
session.commit()
print(f" Committed {investors_processed} investors")
except Exception as e:
logger.error(f"Error processing investor {index}: {e}")
continue
session.commit()
print(f"✅ Investors completed: {investors_processed} processed")
# Step 2: Ingest Companies and Rounds
print("\n🔄 Step 2: Ingesting Companies and Sectors...")
companies_processed = 0
sectors_created = set()
for index, row in companies_df.iterrows():
try:
# Process company
company_name = clean_name(row.get("Organization Name", ""))
if not company_name:
continue
# Check if company already exists
existing_company = (
session.query(CompanyTable).filter_by(name=company_name).first()
)
if existing_company:
company = existing_company
else:
# Create company
company = CompanyTable(
name=company_name,
description=clean_string(row.get("Organization Description", "")),
location=clean_string(row.get("Organization Location", "")),
industry=clean_string(row.get("Organization Industries", "")),
website=clean_string(row.get("Organization Website", "")),
)
session.add(company)
session.flush() # Get the company ID
companies_processed += 1
# Process investor relationships
investor_names_str = row.get("Investor Names", "")
if pd.notna(investor_names_str) and investor_names_str:
investor_names = parse_investor_names(investor_names_str)
for investor_name in investor_names:
# Find investor in database
investor = (
session.query(InvestorTable)
.filter_by(name=investor_name.strip())
.first()
)
if investor:
# Add investor-company relationship
if company not in investor.portfolio_companies:
investor.portfolio_companies.append(company)
else:
print("This company has an investor not in DB:", investor_name)
# Process sectors/industries
industries_str = row.get("Organization Industries", "")
if pd.notna(industries_str) and industries_str:
industries = parse_industries(industries_str)
for industry_name in industries:
industry_name = industry_name.strip()
if industry_name:
# Check if sector exists
sector = (
session.query(SectorTable)
.filter_by(name=industry_name)
.first()
)
if not sector:
sector = SectorTable(name=industry_name)
session.add(sector)
session.flush()
sectors_created.add(industry_name)
# Add company-sector relationship
if sector not in company.sectors:
company.sectors.append(sector)
# Commit every 100 companies
if companies_processed % 100 == 0 and companies_processed > 0:
session.commit()
print(f" Processed {companies_processed} companies...")
except Exception as e:
logger.error(f"Error processing company {index}: {e}")
session.rollback()
continue
# Step 3: Link investors to sectors based on portfolio companies
print("\n🔄 Step 3: Linking Investors to Sectors...")
investors_linked_to_sectors = 0
all_investors = session.query(InvestorTable).all()
for investor in all_investors:
sectors = set()
for company in investor.portfolio_companies:
for sector in company.sectors:
sectors.add(sector)
# Add sectors to investor if not already present
for sector in sectors:
if sector not in investor.sectors:
investor.sectors.append(sector)
if sectors:
investors_linked_to_sectors += 1
session.commit()
print(f"✅ Linked {investors_linked_to_sectors} investors to sectors")
# Final commit
session.commit()
# Final counts
final_investors = session.query(InvestorTable).count()
final_companies = session.query(CompanyTable).count()
final_sectors = session.query(SectorTable).count()
print("\n🎉 Ingestion Complete!")
print(f" Investors: {final_investors}")
print(f" Companies: {final_companies}")
print(f" Sectors: {final_sectors}")
session.close()
if __name__ == "__main__":
ingest_data()
# print(clean_name("A... Energi"))
# print(clean_name("B.. Tech"))
# print(clean_name("A... Energi"))
-381
View File
@@ -1,381 +0,0 @@
import enum
from typing import Annotated
from fastapi import Depends
from sqlalchemy import (
Column,
DateTime,
ForeignKey,
Integer,
String,
Table,
Text,
create_engine,
func,
)
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import Session, declarative_mixin, relationship, sessionmaker
from sqlalchemy.types import JSON, Enum
Base = declarative_base()
# Database configuration
# DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./investors.db")
# Create engine
engine = create_engine("sqlite:///./investors.db", echo=False)
# Create session factory
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
def get_db():
db = SessionLocal()
try:
yield db
finally:
db.close()
db_dependency = Annotated[Session, Depends(get_db)]
def init_database():
"""Initialize the database by creating all tables"""
Base.metadata.create_all(bind=engine)
def get_session_sync() -> Session:
"""Get a database session for synchronous operations"""
return SessionLocal()
def get_db_session():
"""Get a database session for direct use."""
return SessionLocal()
@declarative_mixin
class TimestampMixin:
created_at = Column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
class InvestmentStage(enum.Enum):
SEED = "SEED"
SERIES_A = "SERIES_A"
SERIES_B = "SERIES_B"
SERIES_C = "SERIES_C"
GROWTH = "GROWTH"
LATE_STAGE = "LATE_STAGE"
# Association table for many-to-many relationship between investors and companies
investor_company_association = Table(
"investor_companies",
Base.metadata,
Column("investor_id", Integer, ForeignKey("investors.id")),
Column("company_id", Integer, ForeignKey("companies.id")),
)
# Association table for investor-sector many-to-many
investor_sector_association = Table(
"investor_sectors",
Base.metadata,
Column("investor_id", Integer, ForeignKey("investors.id")),
Column("sector_id", Integer, ForeignKey("sectors.id")),
)
company_sector_association = Table(
"company_sector",
Base.metadata,
Column("company_id", Integer, ForeignKey("companies.id")),
Column("sector_id", Integer, ForeignKey("sectors.id")),
)
project_sector_association = Table(
"project_sector",
Base.metadata,
Column("project_id", Integer, ForeignKey("projects.id")),
Column("sector_id", Integer, ForeignKey("sectors.id")),
)
project_investor_association = Table(
"project_investors",
Base.metadata,
Column("project_id", Integer, ForeignKey("projects.id")),
Column("investor_id", Integer, ForeignKey("investors.id")),
)
project_company_association = Table(
"project_companies",
Base.metadata,
Column("project_id", Integer, ForeignKey("projects.id")),
Column("company_id", Integer, ForeignKey("companies.id")),
)
# Association table for investor-stage many-to-many
investor_stage_association = Table(
"investor_stages",
Base.metadata,
Column("investor_id", Integer, ForeignKey("investors.id")),
Column("stage_id", Integer, ForeignKey("investment_stages.id")),
)
# Association table for fund-stage many-to-many
fund_investment_stages_association = Table(
"fund_investment_stages",
Base.metadata,
Column("fund_id", Integer, ForeignKey("funds.id")),
Column("stage_id", Integer, ForeignKey("investment_stages.id")),
)
# Association table for fund-sector many-to-many
fund_sectors_association = Table(
"fund_sectors",
Base.metadata,
Column("fund_id", Integer, ForeignKey("funds.id")),
Column("sector_id", Integer, ForeignKey("sectors.id")),
)
class InvestorTable(Base, TimestampMixin):
__tablename__ = "investors"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
description = Column(Text, nullable=True)
# Basic investor info
website = Column(String, nullable=True)
headquarters = Column(String, nullable=True)
# AUM fields
aum = Column(Integer, nullable=True) # Store as integer for numerical filtering
aum_as_of_date = Column(String, nullable=True)
aum_source_url = Column(String, nullable=True)
# Check size (deprecated in favor of fund-level data, but keeping for backward compatibility)
check_size_lower = Column(Integer, nullable=True)
check_size_upper = Column(Integer, nullable=True)
# Geographic focus (deprecated in favor of fund-level, but keeping for backward compatibility)
geographic_focus = Column(String, nullable=True)
# Investment thesis and portfolio
investment_thesis = Column(JSON, nullable=True) # Array of thesis statements
portfolio_highlights = Column(
JSON, nullable=True
) # Array of portfolio company names
linked_documents = Column(JSON, nullable=True) # Array of document URLs
# Research metadata
researcher_notes = Column(Text, nullable=True)
missing_important_fields = Column(
JSON, nullable=True
) # Array of missing field names
sources = Column(JSON, nullable=True) # JSON object with source URLs
# Portfolio info
number_of_investments = Column(Integer, nullable=True)
# Relationships
team_members = relationship(
"InvestorMember", back_populates="investor", cascade="all, delete-orphan"
)
funds = relationship(
"FundTable", back_populates="investor", cascade="all, delete-orphan"
)
# Many-to-many relationship with investment stages
investment_stages = relationship(
"InvestmentStageTable",
secondary=investor_stage_association,
back_populates="investors",
)
# Relationship to portfolio companies
portfolio_companies = relationship(
"CompanyTable",
secondary=investor_company_association,
back_populates="investors",
)
sectors = relationship(
"SectorTable",
secondary=investor_sector_association,
back_populates="investors",
)
projects = relationship(
"ProjectTable",
secondary=project_investor_association,
back_populates="investors",
)
class InvestorMember(Base, TimestampMixin):
__tablename__ = "investor_members"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
role = Column(String, nullable=True)
title = Column(String, nullable=True) # Alternative to role
email = Column(String, nullable=True)
source_url = Column(String, nullable=True) # URL where member info was found
investor_id = Column(Integer, ForeignKey("investors.id"))
investor = relationship("InvestorTable", back_populates="team_members")
class FundTable(Base, TimestampMixin):
__tablename__ = "funds"
id = Column(Integer, primary_key=True, index=True)
investor_id = Column(Integer, ForeignKey("investors.id"), nullable=False)
# Fund details
fund_name = Column(String, nullable=True)
fund_size = Column(
Integer, nullable=True
) # Store as integer for numerical filtering
fund_size_source_url = Column(String, nullable=True)
# Check size range (parsed from estimated_investment_size by LLM)
check_size_lower = Column(Integer, nullable=True)
check_size_upper = Column(Integer, nullable=True)
source_url = Column(String, nullable=True)
source_provider = Column(String, nullable=True) # e.g., "Perplexity"
# Geographic focus as simple string
geographic_focus = Column(String, nullable=True)
# Relationships
investor = relationship("InvestorTable", back_populates="funds")
investment_stages = relationship(
"InvestmentStageTable",
secondary=fund_investment_stages_association,
back_populates="funds",
)
sectors = relationship(
"SectorTable",
secondary=fund_sectors_association,
back_populates="funds",
)
class InvestmentStageTable(Base, TimestampMixin):
__tablename__ = "investment_stages"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False, unique=True)
# Relationships
investors = relationship(
"InvestorTable",
secondary=investor_stage_association,
back_populates="investment_stages",
)
funds = relationship(
"FundTable",
secondary=fund_investment_stages_association,
back_populates="investment_stages",
)
class CompanyTable(Base, TimestampMixin):
__tablename__ = "companies"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
industry = Column(String, nullable=True)
location = Column(String, nullable=True)
description = Column(String, nullable=True)
founded_year = Column(Integer, nullable=True)
website = Column(String, nullable=True)
members = relationship(
"CompanyMember", back_populates="company", cascade="all, delete-orphan"
)
# Relationship back to investors
investors = relationship(
"InvestorTable",
secondary=investor_company_association,
back_populates="portfolio_companies",
)
sectors = relationship(
"SectorTable", secondary=company_sector_association, back_populates="companies"
)
projects = relationship(
"ProjectTable",
secondary=project_company_association,
back_populates="companies",
)
class CompanyMember(Base, TimestampMixin):
__tablename__ = "company_members"
id = Column(Integer, primary_key=True)
name = Column(String)
linkedin = Column(String, nullable=True)
role = Column(String, nullable=True)
company_id = Column(Integer, ForeignKey("companies.id"), nullable=False)
company = relationship("CompanyTable", back_populates="members")
class SectorTable(Base, TimestampMixin):
__tablename__ = "sectors"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
# Relationships
investors = relationship(
"InvestorTable",
secondary=investor_sector_association,
back_populates="sectors",
)
companies = relationship(
"CompanyTable", secondary=company_sector_association, back_populates="sectors"
)
projects = relationship(
"ProjectTable", secondary=project_sector_association, back_populates="sector"
)
funds = relationship(
"FundTable",
secondary=fund_sectors_association,
back_populates="sectors",
)
class ProjectTable(Base, TimestampMixin):
__tablename__ = "projects"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
valuation = Column(Integer, nullable=True)
stage = Column(Enum(InvestmentStage), nullable=True)
location = Column(String, nullable=True)
description = Column(Text, nullable=True)
start_date = Column(DateTime, nullable=True)
end_date = Column(DateTime, nullable=True)
sector = relationship(
"SectorTable", secondary=project_sector_association, back_populates="projects"
)
investors = relationship(
"InvestorTable",
secondary=project_investor_association,
back_populates="projects",
)
companies = relationship(
"CompanyTable", secondary=project_company_association, back_populates="projects"
)
BIN
View File
Binary file not shown.
-117
View File
@@ -1,117 +0,0 @@
"""
Migration: Add fields from feedback fixes
Date: 2025-01-07
Adds the following fields:
- projects.is_archived (INTEGER, default 0)
- companies.product_service (TEXT, nullable)
- companies.clients (TEXT, nullable - stored as JSON string)
- investor_members.linkedin (VARCHAR, nullable)
"""
import sys
from pathlib import Path
# Add parent directory to path to import app modules
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import text
from app.db.db import engine
def check_column_exists(conn, table_name, column_name):
"""Check if a column exists in a table"""
result = conn.execute(text(f"PRAGMA table_info({table_name})"))
columns = [row[1] for row in result]
return column_name in columns
def upgrade():
"""Add new columns to tables"""
print("Running migration: Add feedback fixes fields")
print("=" * 60)
with engine.begin() as conn: # Use begin() for transaction management
# 1. Add is_archived to projects table
print("\n1. Adding 'is_archived' column to projects table...")
if check_column_exists(conn, "projects", "is_archived"):
print(" ✓ Column 'is_archived' already exists. Skipping.")
else:
conn.execute(
text(
"ALTER TABLE projects ADD COLUMN is_archived INTEGER DEFAULT 0 NOT NULL"
)
)
# Set default value for existing rows
conn.execute(
text("UPDATE projects SET is_archived = 0 WHERE is_archived IS NULL")
)
print(" ✓ Successfully added 'is_archived' column to projects table")
# 2. Add product_service to companies table
print("\n2. Adding 'product_service' column to companies table...")
if check_column_exists(conn, "companies", "product_service"):
print(" ✓ Column 'product_service' already exists. Skipping.")
else:
conn.execute(text("ALTER TABLE companies ADD COLUMN product_service TEXT"))
print(" ✓ Successfully added 'product_service' column to companies table")
# 3. Add clients to companies table
print("\n3. Adding 'clients' column to companies table...")
if check_column_exists(conn, "companies", "clients"):
print(" ✓ Column 'clients' already exists. Skipping.")
else:
conn.execute(text("ALTER TABLE companies ADD COLUMN clients TEXT"))
print(" ✓ Successfully added 'clients' column to companies table")
# 4. Add linkedin to investor_members table
print("\n4. Adding 'linkedin' column to investor_members table...")
if check_column_exists(conn, "investor_members", "linkedin"):
print(" ✓ Column 'linkedin' already exists. Skipping.")
else:
conn.execute(
text("ALTER TABLE investor_members ADD COLUMN linkedin VARCHAR")
)
print(" ✓ Successfully added 'linkedin' column to investor_members table")
print("\n" + "=" * 60)
print("Migration completed successfully!")
def downgrade():
"""Remove added columns from tables"""
print("Running downgrade: Remove feedback fixes fields")
print("=" * 60)
# Note: SQLite doesn't support DROP COLUMN directly
print("\nWarning: SQLite doesn't support DROP COLUMN directly.")
print("To remove these columns, you would need to:")
print("1. Create new tables without the columns")
print("2. Copy data from old tables to new tables")
print("3. Drop old tables and rename new tables")
print("\nColumns to remove:")
print(" - projects.is_archived")
print(" - companies.product_service")
print(" - companies.clients")
print(" - investor_members.linkedin")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run database migration")
parser.add_argument(
"direction",
choices=["upgrade", "downgrade"],
default="upgrade",
nargs="?",
help="Migration direction (default: upgrade)",
)
args = parser.parse_args()
if args.direction == "upgrade":
upgrade()
else:
downgrade()
-67
View File
@@ -1,67 +0,0 @@
"""
Migration: Add industry column to projects table
Date: 2025-10-23
"""
import os
import sys
from pathlib import Path
# Add parent directory to path to import app modules
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import create_engine, text
from app.db.db import DATABASE_URL, engine
def upgrade():
"""Add industry column to projects table"""
print("Running migration: Add industry column to projects table")
with engine.connect() as conn:
# Check if column already exists
result = conn.execute(text("PRAGMA table_info(projects)"))
columns = [row[1] for row in result]
if 'industry' in columns:
print("Column 'industry' already exists in projects table. Skipping migration.")
return
# Add the industry column
conn.execute(text("ALTER TABLE projects ADD COLUMN industry VARCHAR"))
conn.commit()
print("Successfully added 'industry' column to projects table")
def downgrade():
"""Remove industry column from projects table"""
print("Running downgrade: Remove industry column from projects table")
# Note: SQLite doesn't support DROP COLUMN directly
# This is a simplified version - in production you'd need to recreate the table
print("Warning: SQLite doesn't support DROP COLUMN.")
print("To remove the column, you would need to:")
print("1. Create a new table without the industry column")
print("2. Copy data from old table to new table")
print("3. Drop old table and rename new table")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run database migration")
parser.add_argument(
"direction",
choices=["upgrade", "downgrade"],
default="upgrade",
nargs="?",
help="Migration direction (default: upgrade)"
)
args = parser.parse_args()
if args.direction == "upgrade":
upgrade()
else:
downgrade()
-40
View File
@@ -1,40 +1,26 @@
aiofiles==24.1.0
aiohappyeyeballs==2.6.1
aiohttp==3.12.15
aiosignal==1.4.0
aiosqlite==0.21.0
alphashape==1.3.1
annotated-types==0.7.0
anyio==4.10.0
attrs==25.3.0
backoff==2.2.1
bcrypt==4.3.0
beautifulsoup4==4.14.2
brotli==1.1.0
build==1.3.0
cachetools==5.5.2
certifi==2025.8.3
cffi==2.0.0
chardet==5.2.0
charset-normalizer==3.4.3
chromadb==1.0.20
click==8.2.1
click-log==0.4.0
coloredlogs==15.0.1
crawl4ai==0.7.4
cryptography==46.0.1
dataclasses-json==0.6.7
ddgs==9.5.2
distro==1.9.0
dnspython==2.7.0
durationpy==0.10
email-validator==2.3.0
fake-http-header==0.3.5
fake-useragent==2.2.0
fastapi==0.116.1
fastapi-cli==0.0.8
fastapi-cloud-cli==0.1.5
fastuuid==0.13.5
filelock==3.19.1
flatbuffers==25.2.10
frozenlist==1.7.0
@@ -44,24 +30,19 @@ googleapis-common-protos==1.70.0
greenlet==3.2.4
grpcio==1.74.0
h11==0.16.0
h2==4.3.0
hf-xet==1.1.8
hpack==4.1.0
httpcore==1.0.9
httptools==0.6.4
httpx==0.28.1
httpx-sse==0.4.1
huggingface-hub==0.34.4
humanfriendly==10.0
humanize==4.13.0
hyperframe==6.1.0
idna==3.10
importlib-metadata==8.7.0
importlib-resources==6.5.2
itsdangerous==2.2.0
jinja2==3.1.6
jiter==0.10.0
joblib==1.5.2
jsonpatch==1.33
jsonpointer==3.0.0
jsonschema==4.25.1
@@ -77,9 +58,6 @@ langgraph-checkpoint==2.1.1
langgraph-prebuilt==0.6.4
langgraph-sdk==0.2.4
langsmith==0.4.20
lark==1.3.0
litellm==1.77.5
lxml==5.4.0
markdown-it-py==4.0.0
markupsafe==3.0.2
marshmallow==3.26.1
@@ -88,8 +66,6 @@ mmh3==5.2.0
mpmath==1.3.0
multidict==6.6.4
mypy-extensions==1.1.0
networkx==3.5
nltk==3.9.1
numpy==2.3.2
oauthlib==3.3.1
onnxruntime==1.22.1
@@ -105,26 +81,18 @@ ormsgpack==1.10.0
overrides==7.7.0
packaging==25.0
pandas==2.3.2
patchright==1.55.2
pillow==11.3.0
pip==25.2
playwright==1.55.0
posthog==5.4.0
primp==0.15.0
propcache==0.3.2
protobuf==6.32.0
psutil==7.1.0
pyasn1==0.6.1
pyasn1-modules==0.4.2
pybase64==1.4.2
pycparser==2.23
pydantic==2.11.7
pydantic-core==2.33.2
pydantic-extra-types==2.10.5
pydantic-settings==2.10.1
pyee==13.0.0
pygments==2.19.2
pyopenssl==25.3.0
pypika==0.48.9
pyproject-hooks==1.2.0
python-dateutil==2.9.0.post0
@@ -132,7 +100,6 @@ python-dotenv==1.1.1
python-multipart==0.0.20
pytz==2025.2
pyyaml==6.0.2
rank-bm25==0.2.2
referencing==0.36.2
regex==2025.7.34
requests==2.32.5
@@ -143,24 +110,17 @@ rich-toolkit==0.15.0
rignore==0.6.4
rpds-py==0.27.1
rsa==4.9.1
rtree==1.4.1
scipy==1.16.2
sentry-sdk==2.35.1
shapely==2.1.2
shellingham==1.5.4
six==1.17.0
sniffio==1.3.1
snowballstemmer==2.2.0
soupsieve==2.8
sqlalchemy==2.0.43
starlette==0.47.3
sympy==1.14.0
tenacity==9.1.2
tf-playwright-stealth==1.2.0
tiktoken==0.11.0
tokenizers==0.21.4
tqdm==4.67.1
trimesh==4.8.3
typer==0.16.1
typing-extensions==4.15.0
typing-inspect==0.9.0
-68
View File
@@ -1,68 +0,0 @@
#!/bin/bash
# Server management script for app/main.py
# Usage: ./server_manager.sh start|stop|restart
PID_FILE="server.pid"
LOG_FILE="server.log"
start() {
if [ -f "$PID_FILE" ] && kill -0 $(cat "$PID_FILE") 2>/dev/null; then
echo "Server is already running (PID: $(cat "$PID_FILE"))"
return 1
fi
echo "Starting server..."
nohup uv run app/main.py > "$LOG_FILE" 2>&1 &
echo $! > "$PID_FILE"
echo "Server started (PID: $(cat "$PID_FILE"))"
}
stop() {
if [ ! -f "$PID_FILE" ]; then
echo "Server is not running (no PID file found)"
return 1
fi
PID=$(cat "$PID_FILE")
if ! kill -0 "$PID" 2>/dev/null; then
echo "Server is not running (PID $PID not found)"
rm -f "$PID_FILE"
return 1
fi
echo "Stopping server (PID: $PID)..."
kill "$PID"
# Wait for process to stop
for i in {1..10}; do
if ! kill -0 "$PID" 2>/dev/null; then
break
fi
sleep 1
done
if kill -0 "$PID" 2>/dev/null; then
echo "Force killing server..."
kill -9 "$PID"
fi
rm -f "$PID_FILE"
echo "Server stopped"
}
restart() {
stop
sleep 2
start
}
case "$1" in
start)
start
;;
stop)
stop
;;
restart)
restart
;;
*)
echo "Usage: $0 {start|stop|restart}"
exit 1
;;
esac
-310
View File
@@ -1,310 +0,0 @@
#!/usr/bin/env python3
"""
Update Investor Members LinkedIn Profiles Script
This script finds and updates LinkedIn profile URLs for investor members in the database.
Uses crawl4ai to efficiently scrape team pages and extract LinkedIn URLs.
Usage:
python update_linkedin_profiles.py [--test] [--limit N] [--skip-existing]
Options:
--test Test mode: process only 10 records and don't update database
--limit N Process only N records (default: all)
--skip-existing Skip members that already have LinkedIn URLs
--start-from N Start from record N (for resuming)
"""
import argparse
import asyncio
import json
import os
import sys
from datetime import datetime
# Add app to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "app"))
from db.db import get_db_session
from db.models import InvestorMember, InvestorTable
from linkedin_scraper import LinkedInProfileScraper, format_linkedin_url
def progress_callback(current, total, result):
"""Print progress updates"""
percent = (current / total) * 100
status = "" if result["linkedin_url"] else ""
print(f"[{current}/{total} - {percent:.1f}%] {status} {result['member_name']}")
if result["linkedin_url"]:
print(
f"{result['linkedin_url']} (confidence: {result['confidence']}%, method: {result['method']})"
)
def create_db_callback(test_mode=False):
"""
Create a callback function that saves LinkedIn profiles to the database immediately.
This allows stopping and resuming without losing progress.
"""
saved_count = {"count": 0} # Use dict to allow modification in closure
def db_callback(member_id: int, linkedin_url: str) -> bool:
"""Save LinkedIn URL to database immediately"""
if test_mode:
print(f" [TEST] Would save to DB: member {member_id}")
saved_count["count"] += 1
return True
try:
db = get_db_session()
member = db.query(InvestorMember).filter_by(id=member_id).first()
if member:
member.linkedin = format_linkedin_url(linkedin_url)
db.commit()
saved_count["count"] += 1
return True
except Exception as e:
print(f" ⚠️ DB Error for member {member_id}: {e}")
try:
db.rollback()
except Exception:
pass
return False
finally:
try:
db.close()
except Exception:
pass
return False
return db_callback, saved_count
def update_database(members_data, test_mode=False):
"""Update database with found LinkedIn profiles"""
db = get_db_session()
try:
updated_count = 0
for data in members_data:
if data["linkedin_url"] and data["member_id"]:
if not test_mode:
member = (
db.query(InvestorMember).filter_by(id=data["member_id"]).first()
)
if member:
member.linkedin = format_linkedin_url(data["linkedin_url"])
updated_count += 1
else:
print(
f" [TEST MODE] Would update member {data['member_id']}: {data['linkedin_url']}"
)
updated_count += 1
if not test_mode:
db.commit()
print(f"\n✓ Successfully updated {updated_count} records in database")
else:
print(f"\n[TEST MODE] Would have updated {updated_count} records")
return updated_count
except Exception as e:
db.rollback()
print(f"\n✗ Error updating database: {e}")
raise
finally:
db.close()
def save_results(results, filename="linkedin_scraping_results.json"):
"""Save results to JSON file for backup/analysis"""
output = {
"timestamp": datetime.now().isoformat(),
"total_processed": len(results),
"found_count": sum(1 for r in results if r["linkedin_url"]),
"results": results,
}
with open(filename, "w") as f:
json.dump(output, f, indent=2)
print(f"\n✓ Results saved to {filename}")
def print_summary(results):
"""Print summary statistics"""
total = len(results)
found = sum(1 for r in results if r["linkedin_url"])
not_found = total - found
# Count by method
methods = {}
for r in results:
if r["linkedin_url"]:
method = r["method"]
methods[method] = methods.get(method, 0) + 1
# Average confidence for found profiles
avg_confidence = (
sum(r["confidence"] for r in results if r["linkedin_url"]) / found
if found > 0
else 0
)
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"Total processed: {total}")
print(f"LinkedIn found: {found} ({found / total * 100:.1f}%)")
print(f"Not found: {not_found} ({not_found / total * 100:.1f}%)")
print(f"\nAverage confidence: {avg_confidence:.1f}%")
print("\nMethods used:")
for method, count in sorted(methods.items(), key=lambda x: x[1], reverse=True):
print(f" {method:20s} {count:5d} ({count / found * 100:.1f}%)")
print("=" * 60)
def main():
parser = argparse.ArgumentParser(
description="Update LinkedIn profiles for investor members"
)
parser.add_argument(
"--test",
action="store_true",
help="Test mode: process only 10 records without updating database",
)
parser.add_argument("--limit", type=int, help="Limit number of records to process")
parser.add_argument(
"--skip-existing",
action="store_true",
help="Skip members that already have LinkedIn URLs",
)
parser.add_argument(
"--start-from",
type=int,
default=0,
help="Start from record N (for resuming interrupted runs)",
)
parser.add_argument(
"--rate-limit",
type=float,
default=0.5,
help="Delay between URL crawls in seconds (default: 0.5)",
)
args = parser.parse_args()
# Test mode overrides limit
if args.test and not args.limit:
args.limit = 10
print("=" * 60)
print("LinkedIn Profile Scraper for Investor Members (crawl4ai)")
print("=" * 60)
if args.test:
print("\n⚠️ TEST MODE - No database changes will be made")
# Initialize database and scraper
db = get_db_session()
try:
# Build query
query = db.query(InvestorMember, InvestorTable).join(
InvestorTable, InvestorMember.investor_id == InvestorTable.id
)
# Filter existing if requested
if args.skip_existing:
query = query.filter(
(InvestorMember.linkedin.is_(None)) | (InvestorMember.linkedin == "")
)
print("\n✓ Filtering to members without LinkedIn profiles")
# Get total count
total_available = query.count()
print(f"\n✓ Found {total_available} members to process")
# Apply offset and limit
if args.start_from > 0:
query = query.offset(args.start_from)
print(f"✓ Starting from record {args.start_from}")
if args.limit:
query = query.limit(args.limit)
print(f"✓ Processing {args.limit} records")
# Fetch members
members_data = []
for member, investor in query.all():
members_data.append(
{
"id": member.id,
"name": member.name,
"company": investor.name,
"role": member.role,
"source_url": member.source_url,
}
)
if not members_data:
print("\n⚠️ No members to process")
return
# Count unique source URLs
unique_urls = len(set(m["source_url"] for m in members_data if m["source_url"]))
with_urls = sum(1 for m in members_data if m["source_url"])
print(f"\n✓ Loaded {len(members_data)} members")
print(
f"{with_urls} members have source URLs ({unique_urls} unique pages to crawl)"
)
print(f"{len(members_data) - with_urls} members without source URLs")
print(f"✓ Rate limit: {args.rate_limit}s between page crawls")
print("\nStarting LinkedIn profile search using crawl4ai...\n")
finally:
db.close()
# Initialize scraper
scraper = LinkedInProfileScraper(rate_limit_delay=args.rate_limit, use_cache=True)
print("️ Using crawl4ai to scrape team pages and extract LinkedIn URLs")
print(
"️ Profiles are saved to database IMMEDIATELY when found - safe to stop anytime!\n"
)
# Create database callback for real-time saving
db_callback, saved_count = create_db_callback(test_mode=args.test)
# Process members asynchronously with real-time DB saving
results = asyncio.run(
scraper.batch_find_profiles(
members_data, progress_callback=progress_callback, db_callback=db_callback
)
)
# Print summary
print_summary(results)
# Save results
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_file = f"linkedin_results_{timestamp}.json"
save_results(results, results_file)
# Show database update summary
if not args.test:
print(
f"\n✓ Database updated in real-time: {saved_count['count']} profiles saved"
)
else:
print(
f"\n[TEST MODE] Would have saved {saved_count['count']} profiles to database"
)
print("\n✓ Done! You can resume anytime with --skip-existing")
if __name__ == "__main__":
main()