Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c5c94936f3 |
+2
-5
@@ -10,11 +10,8 @@
|
||||
|
||||
*__pycache__
|
||||
|
||||
/*.db
|
||||
|
||||
*.cypython
|
||||
|
||||
nohup.out
|
||||
|
||||
server.log
|
||||
|
||||
server.pid
|
||||
/preprocessor
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -9,10 +9,6 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
Base = declarative_base()
|
||||
|
||||
# Database configuration
|
||||
# Use the preprocessor's database for consistency
|
||||
# Get absolute path to the preprocessor database
|
||||
# APP_DIR = Path(__file__).parent.parent
|
||||
# PREPROCESSOR_DB = APP_DIR.parent / "preprocessor" / "version_two.db"
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./investors.db")
|
||||
|
||||
# Create engine
|
||||
@@ -42,7 +38,6 @@ def get_session_sync() -> Session:
|
||||
"""Get a database session for synchronous operations"""
|
||||
return SessionLocal()
|
||||
|
||||
|
||||
def get_db_session():
|
||||
"""Get a database session for direct use."""
|
||||
return SessionLocal()
|
||||
|
||||
+10
-143
@@ -2,7 +2,7 @@ import enum
|
||||
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table, Text, func
|
||||
from sqlalchemy.orm import declarative_mixin, relationship
|
||||
from sqlalchemy.types import JSON, Enum
|
||||
from sqlalchemy.types import Enum
|
||||
|
||||
from db.db import Base
|
||||
|
||||
@@ -70,22 +70,6 @@ project_company_association = Table(
|
||||
Column("company_id", Integer, ForeignKey("companies.id")),
|
||||
)
|
||||
|
||||
# Association table for fund-stage many-to-many
|
||||
fund_investment_stages_association = Table(
|
||||
"fund_investment_stages",
|
||||
Base.metadata,
|
||||
Column("fund_id", Integer, ForeignKey("funds.id")),
|
||||
Column("stage_id", Integer, ForeignKey("investment_stages.id")),
|
||||
)
|
||||
|
||||
# Association table for fund-sector many-to-many
|
||||
fund_sectors_association = Table(
|
||||
"fund_sectors",
|
||||
Base.metadata,
|
||||
Column("fund_id", Integer, ForeignKey("funds.id")),
|
||||
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
||||
)
|
||||
|
||||
|
||||
class InvestorTable(Base, TimestampMixin):
|
||||
__tablename__ = "investors"
|
||||
@@ -93,47 +77,14 @@ class InvestorTable(Base, TimestampMixin):
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Basic investor info
|
||||
website = Column(String, nullable=True)
|
||||
headquarters = Column(String, nullable=True)
|
||||
|
||||
# AUM fields
|
||||
aum = Column(Integer, nullable=True) # Store as integer for numerical filtering
|
||||
aum_as_of_date = Column(String, nullable=True)
|
||||
aum_source_url = Column(String, nullable=True)
|
||||
|
||||
# Check size (deprecated in favor of fund-level data, but keeping for backward compatibility)
|
||||
check_size_lower = Column(Integer, nullable=True)
|
||||
check_size_upper = Column(Integer, nullable=True)
|
||||
|
||||
# Geographic focus (deprecated in favor of fund-level, but keeping for backward compatibility)
|
||||
aum = Column(Integer, nullable=True) # Assets Under Management
|
||||
check_size_lower = Column(Integer, nullable=True) # Lower bound
|
||||
check_size_upper = Column(Integer, nullable=True) # Upper bound
|
||||
geographic_focus = Column(String, nullable=True)
|
||||
|
||||
# Investment thesis and portfolio
|
||||
investment_thesis = Column(JSON, nullable=True) # Array of thesis statements
|
||||
portfolio_highlights = Column(
|
||||
JSON, nullable=True
|
||||
) # Array of portfolio company names
|
||||
linked_documents = Column(JSON, nullable=True) # Array of document URLs
|
||||
|
||||
# Research metadata
|
||||
researcher_notes = Column(Text, nullable=True)
|
||||
missing_important_fields = Column(
|
||||
JSON, nullable=True
|
||||
) # Array of missing field names
|
||||
sources = Column(JSON, nullable=True) # JSON object with source URLs
|
||||
|
||||
# Portfolio info
|
||||
stage_focus = Column(Enum(InvestmentStage), nullable=True)
|
||||
number_of_investments = Column(Integer, default=0, nullable=True)
|
||||
|
||||
# Relationships
|
||||
team_members = relationship(
|
||||
"InvestorMember", back_populates="investor", cascade="all, delete-orphan"
|
||||
)
|
||||
funds = relationship(
|
||||
"FundTable", back_populates="investor", cascade="all, delete-orphan"
|
||||
)
|
||||
team_members = relationship("InvestorMember", back_populates="investor")
|
||||
|
||||
# Relationship to portfolio companies
|
||||
portfolio_companies = relationship(
|
||||
@@ -160,52 +111,12 @@ class InvestorMember(Base, TimestampMixin):
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
role = Column(String, nullable=True)
|
||||
title = Column(String, nullable=True) # Alternative to role
|
||||
email = Column(String, nullable=True)
|
||||
linkedin = Column(String, nullable=True) # LinkedIn profile URL
|
||||
source_url = Column(String, nullable=True) # URL where member info was found
|
||||
|
||||
investor_id = Column(Integer, ForeignKey("investors.id"))
|
||||
investor = relationship("InvestorTable", back_populates="team_members")
|
||||
|
||||
|
||||
class FundTable(Base, TimestampMixin):
|
||||
__tablename__ = "funds"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
investor_id = Column(Integer, ForeignKey("investors.id"), nullable=False)
|
||||
|
||||
# Fund details
|
||||
fund_name = Column(String, nullable=True)
|
||||
fund_size = Column(
|
||||
Integer, nullable=True
|
||||
) # Store as integer for numerical filtering
|
||||
fund_size_source_url = Column(String, nullable=True)
|
||||
|
||||
# Check size range (parsed from estimated_investment_size by LLM)
|
||||
check_size_lower = Column(Integer, nullable=True)
|
||||
check_size_upper = Column(Integer, nullable=True)
|
||||
|
||||
source_url = Column(String, nullable=True)
|
||||
source_provider = Column(String, nullable=True) # e.g., "Perplexity"
|
||||
|
||||
# Geographic focus as simple string
|
||||
geographic_focus = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
investor = relationship("InvestorTable", back_populates="funds")
|
||||
investment_stages = relationship(
|
||||
"InvestmentStageTable",
|
||||
secondary=fund_investment_stages_association,
|
||||
back_populates="funds",
|
||||
)
|
||||
sectors = relationship(
|
||||
"SectorTable",
|
||||
secondary=fund_sectors_association,
|
||||
back_populates="funds",
|
||||
)
|
||||
|
||||
|
||||
class CompanyTable(Base, TimestampMixin):
|
||||
__tablename__ = "companies"
|
||||
|
||||
@@ -216,12 +127,8 @@ class CompanyTable(Base, TimestampMixin):
|
||||
description = Column(String, nullable=True)
|
||||
founded_year = Column(Integer, nullable=True)
|
||||
website = Column(String, nullable=True)
|
||||
product_service = Column(Text, nullable=True) # Product/service description
|
||||
clients = Column(JSON, nullable=True) # List of client names or client information
|
||||
|
||||
members = relationship(
|
||||
"CompanyMember", back_populates="company", cascade="all, delete-orphan"
|
||||
)
|
||||
members = relationship("CompanyMember", back_populates="company")
|
||||
# Relationship back to investors
|
||||
investors = relationship(
|
||||
"InvestorTable",
|
||||
@@ -251,43 +158,26 @@ class CompanyMember(Base, TimestampMixin):
|
||||
company = relationship("CompanyTable", back_populates="members")
|
||||
|
||||
|
||||
class InvestmentStageTable(Base, TimestampMixin):
|
||||
__tablename__ = "investment_stages"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
|
||||
# Relationships
|
||||
funds = relationship(
|
||||
"FundTable",
|
||||
secondary=fund_investment_stages_association,
|
||||
back_populates="investment_stages",
|
||||
)
|
||||
|
||||
|
||||
class SectorTable(Base, TimestampMixin):
|
||||
__tablename__ = "sectors"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
|
||||
# Relationships
|
||||
# Add relationship back to investors
|
||||
investors = relationship(
|
||||
"InvestorTable",
|
||||
secondary=investor_sector_association,
|
||||
back_populates="sectors",
|
||||
)
|
||||
|
||||
companies = relationship(
|
||||
"CompanyTable", secondary=company_sector_association, back_populates="sectors"
|
||||
)
|
||||
|
||||
projects = relationship(
|
||||
"ProjectTable", secondary=project_sector_association, back_populates="sector"
|
||||
)
|
||||
funds = relationship(
|
||||
"FundTable",
|
||||
secondary=fund_sectors_association,
|
||||
back_populates="sectors",
|
||||
)
|
||||
|
||||
|
||||
class ProjectTable(Base, TimestampMixin):
|
||||
@@ -299,11 +189,9 @@ class ProjectTable(Base, TimestampMixin):
|
||||
|
||||
stage = Column(Enum(InvestmentStage), nullable=True)
|
||||
location = Column(String, nullable=True)
|
||||
industry = Column(String, nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
start_date = Column(DateTime, nullable=True)
|
||||
end_date = Column(DateTime, nullable=True)
|
||||
is_archived = Column(Integer, default=0, nullable=False) # 0 = active, 1 = archived
|
||||
|
||||
sector = relationship(
|
||||
"SectorTable", secondary=project_sector_association, back_populates="projects"
|
||||
@@ -316,24 +204,3 @@ class ProjectTable(Base, TimestampMixin):
|
||||
companies = relationship(
|
||||
"CompanyTable", secondary=project_company_association, back_populates="projects"
|
||||
)
|
||||
|
||||
|
||||
class InvestorInsightCache(Base, TimestampMixin):
|
||||
__tablename__ = "investor_insight_cache"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
investor_id = Column(
|
||||
Integer, ForeignKey("investors.id"), nullable=False, unique=True
|
||||
)
|
||||
|
||||
# Cached insights
|
||||
investment_pattern_analysis = Column(Text, nullable=False)
|
||||
market_position = Column(Text, nullable=False)
|
||||
|
||||
# Cache management
|
||||
last_refreshed = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
|
||||
# Relationship to investor
|
||||
investor = relationship("InvestorTable")
|
||||
|
||||
@@ -1,730 +0,0 @@
|
||||
"""
|
||||
LinkedIn Profile Scraper for Investor Members
|
||||
|
||||
This module uses crawl4ai to scrape team pages and find LinkedIn profiles.
|
||||
Strategies:
|
||||
1. Crawl the source_url (team pages) to extract LinkedIn profile links
|
||||
2. Use LLM-powered web search to find LinkedIn profiles by name
|
||||
|
||||
Key advantages of crawl4ai:
|
||||
- Handles JavaScript-rendered pages
|
||||
- Better at extracting content from modern websites
|
||||
- More reliable than simple requests
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from crawl4ai import AsyncWebCrawler
|
||||
from ddgs import DDGS
|
||||
from dotenv import load_dotenv
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
# Setup logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
||||
)
|
||||
logger = logging.getLogger("linkedin_scraper")
|
||||
|
||||
load_dotenv()
|
||||
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
|
||||
|
||||
|
||||
class LinkedInProfileScraper:
|
||||
"""
|
||||
LinkedIn profile finder using crawl4ai and LLM-powered web search.
|
||||
|
||||
Strategies:
|
||||
1. Crawl source URLs (team pages) to extract LinkedIn links
|
||||
2. Use LLM-powered web search to find profiles by name
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rate_limit_delay: float = 0.5,
|
||||
use_cache: bool = True,
|
||||
use_llm_search: bool = True,
|
||||
):
|
||||
"""
|
||||
Initialize the scraper
|
||||
|
||||
Args:
|
||||
rate_limit_delay: Delay between requests in seconds
|
||||
use_cache: Whether to cache crawled pages
|
||||
use_llm_search: Whether to use LLM-powered web search as fallback
|
||||
"""
|
||||
self.rate_limit_delay = rate_limit_delay
|
||||
self.use_cache = use_cache
|
||||
self.use_llm_search = use_llm_search and OPENROUTER_API_KEY
|
||||
self.page_cache: Dict[str, str] = {} # Cache crawled pages by URL
|
||||
self.html_cache: Dict[str, str] = {} # Cache HTML separately
|
||||
self.profile_cache: Dict[str, Dict] = {} # Cache results by member
|
||||
|
||||
# Initialize LLM agent if API key available
|
||||
if self.use_llm_search:
|
||||
self._init_llm_agent()
|
||||
else:
|
||||
self.llm = None
|
||||
self.agent = None
|
||||
self.ddg_search = None
|
||||
logger.info("LLM search disabled (no OPENROUTER_API_KEY)")
|
||||
|
||||
def _init_llm_agent(self):
|
||||
"""Initialize LLM agent for web search"""
|
||||
try:
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=OPENROUTER_API_KEY,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="x-ai/grok-4.1-fast:free",
|
||||
temperature=0,
|
||||
)
|
||||
self.ddg_search = DDGS()
|
||||
logger.info("LLM search agent initialized")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize LLM agent: {e}")
|
||||
self.llm = None
|
||||
self.ddg_search = None
|
||||
|
||||
def web_search(self, query: str) -> List[Dict]:
|
||||
"""Tool to search the web using DuckDuckGo"""
|
||||
if not self.ddg_search:
|
||||
return []
|
||||
try:
|
||||
results = list(self.ddg_search.text(query, max_results=10))
|
||||
return results
|
||||
except Exception as e:
|
||||
logger.error(f"Web search error: {e}")
|
||||
return []
|
||||
|
||||
async def crawl_page(self, url: str) -> Optional[str]:
|
||||
"""
|
||||
Crawl a webpage and return its content.
|
||||
|
||||
Args:
|
||||
url: URL to crawl
|
||||
|
||||
Returns:
|
||||
Page content as markdown/text, or None if failed
|
||||
"""
|
||||
if not url:
|
||||
return None
|
||||
|
||||
# Check cache first
|
||||
if self.use_cache and url in self.page_cache:
|
||||
logger.debug(f"Using cached page for {url}")
|
||||
return self.page_cache[url]
|
||||
|
||||
try:
|
||||
logger.info(f"Crawling: {url}")
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
result = await crawler.arun(url)
|
||||
|
||||
if result and result.markdown:
|
||||
content = result.markdown
|
||||
# Also get HTML for better link extraction
|
||||
html_content = result.html if hasattr(result, "html") else ""
|
||||
|
||||
# Cache the results
|
||||
if self.use_cache:
|
||||
self.page_cache[url] = content
|
||||
self.html_cache[url] = html_content
|
||||
|
||||
return content
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error crawling {url}: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def extract_linkedin_urls_from_content(self, content: str) -> List[Dict[str, str]]:
|
||||
"""
|
||||
Extract all LinkedIn profile URLs from content (HTML or markdown).
|
||||
|
||||
Returns:
|
||||
List of dicts with 'url', 'context', and 'username'
|
||||
"""
|
||||
linkedin_links = []
|
||||
|
||||
# Pattern for LinkedIn profile URLs (handles country-specific domains)
|
||||
linkedin_pattern = (
|
||||
r"https?://(?:www\.)?(?:[a-z]{2,3}\.)?linkedin\.com/in/([a-zA-Z0-9_-]+)/?"
|
||||
)
|
||||
|
||||
# Find all LinkedIn URLs
|
||||
matches = list(re.finditer(linkedin_pattern, content, re.IGNORECASE))
|
||||
|
||||
for match in matches:
|
||||
url = match.group(0).rstrip("/")
|
||||
# Normalize URL
|
||||
url = self._normalize_linkedin_url(url)
|
||||
|
||||
# Get surrounding context (200 chars before and after)
|
||||
start = max(0, match.start() - 200)
|
||||
end = min(len(content), match.end() + 200)
|
||||
context = content[start:end]
|
||||
|
||||
# Clean up context (remove HTML tags for readability)
|
||||
context = re.sub(r"<[^>]+>", " ", context)
|
||||
context = " ".join(context.split()) # Normalize whitespace
|
||||
|
||||
linkedin_links.append(
|
||||
{"url": url, "context": context, "username": match.group(1)}
|
||||
)
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen_urls = set()
|
||||
unique_links = []
|
||||
for link in linkedin_links:
|
||||
if link["url"] not in seen_urls:
|
||||
seen_urls.add(link["url"])
|
||||
unique_links.append(link)
|
||||
|
||||
return unique_links
|
||||
|
||||
def _normalize_linkedin_url(self, url: str) -> str:
|
||||
"""Normalize LinkedIn URL to standard format"""
|
||||
# Remove trailing slashes
|
||||
url = url.rstrip("/")
|
||||
|
||||
# Convert country-specific to www
|
||||
url = re.sub(
|
||||
r"https?://[a-z]{2,3}\.linkedin\.com", "https://www.linkedin.com", url
|
||||
)
|
||||
|
||||
# Ensure https
|
||||
if url.startswith("http://"):
|
||||
url = url.replace("http://", "https://")
|
||||
|
||||
return url
|
||||
|
||||
def _name_matches_context(self, name: str, context: str) -> float:
|
||||
"""
|
||||
Check if a person's name appears in the context around a LinkedIn URL.
|
||||
|
||||
Returns:
|
||||
Confidence score 0-100
|
||||
"""
|
||||
if not name or not context:
|
||||
return 0
|
||||
|
||||
context_lower = context.lower()
|
||||
name_lower = name.lower()
|
||||
|
||||
# Split name into parts (handle multiple spaces, titles like "Dr.", etc.)
|
||||
name_parts = [p for p in name_lower.replace(".", " ").split() if len(p) > 1]
|
||||
|
||||
# Check for full name match
|
||||
if name_lower in context_lower:
|
||||
return 95
|
||||
|
||||
# Check for name parts in context
|
||||
matches = sum(
|
||||
1 for part in name_parts if part in context_lower and len(part) > 2
|
||||
)
|
||||
|
||||
if len(name_parts) > 0:
|
||||
if matches == len(name_parts):
|
||||
return 90 # All name parts found
|
||||
elif matches >= 2:
|
||||
return 75 # At least 2 parts found (first + last typically)
|
||||
elif matches == 1 and len(name_parts) <= 2:
|
||||
return 50 # Only one part found but name is short
|
||||
elif matches == 1:
|
||||
return 35 # Only one part found
|
||||
|
||||
return 0
|
||||
|
||||
def _name_matches_username(self, name: str, username: str) -> float:
|
||||
"""
|
||||
Check if LinkedIn username contains parts of the name.
|
||||
|
||||
Returns:
|
||||
Confidence score 0-100
|
||||
"""
|
||||
if not name or not username:
|
||||
return 0
|
||||
|
||||
name_lower = name.lower()
|
||||
username_lower = username.lower().replace("-", " ").replace("_", " ")
|
||||
|
||||
name_parts = [p for p in name_lower.replace(".", " ").split() if len(p) > 2]
|
||||
|
||||
matches = sum(1 for part in name_parts if part in username_lower)
|
||||
|
||||
if len(name_parts) > 0:
|
||||
if matches == len(name_parts) and len(name_parts) >= 2:
|
||||
return 85 # Full name in username
|
||||
elif matches >= 2:
|
||||
return 70 # Multiple parts match
|
||||
elif matches == 1:
|
||||
return 35 # Only one part matches
|
||||
|
||||
return 0
|
||||
|
||||
async def find_linkedin_from_source(
|
||||
self, name: str, source_url: str, role: Optional[str] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Find LinkedIn profile by crawling the source URL (team page).
|
||||
|
||||
Args:
|
||||
name: Person's name
|
||||
source_url: URL of the team/about page
|
||||
role: Person's role (for additional context matching)
|
||||
|
||||
Returns:
|
||||
Dict with linkedin_url, confidence, method, notes
|
||||
"""
|
||||
if not source_url:
|
||||
return {
|
||||
"linkedin_url": None,
|
||||
"confidence": 0,
|
||||
"method": "source_crawl",
|
||||
"notes": "No source URL provided",
|
||||
}
|
||||
|
||||
# Crawl the page
|
||||
content = await self.crawl_page(source_url)
|
||||
|
||||
if not content:
|
||||
return {
|
||||
"linkedin_url": None,
|
||||
"confidence": 0,
|
||||
"method": "source_crawl",
|
||||
"notes": f"Failed to crawl {source_url}",
|
||||
}
|
||||
|
||||
# Get HTML for better link extraction
|
||||
html = self.html_cache.get(source_url, content)
|
||||
|
||||
# Extract all LinkedIn URLs from both HTML and markdown
|
||||
linkedin_links = self.extract_linkedin_urls_from_content(html)
|
||||
if not linkedin_links:
|
||||
linkedin_links = self.extract_linkedin_urls_from_content(content)
|
||||
|
||||
if not linkedin_links:
|
||||
return {
|
||||
"linkedin_url": None,
|
||||
"confidence": 0,
|
||||
"method": "source_crawl",
|
||||
"notes": f"No LinkedIn URLs found on {source_url}",
|
||||
}
|
||||
|
||||
# Score each LinkedIn URL based on name matching
|
||||
best_match = None
|
||||
best_score = 0
|
||||
|
||||
for link in linkedin_links:
|
||||
# Score based on context matching
|
||||
context_score = self._name_matches_context(name, link["context"])
|
||||
|
||||
# Score based on username matching
|
||||
username_score = self._name_matches_username(name, link["username"])
|
||||
|
||||
# Also check if role appears in context
|
||||
role_bonus = 0
|
||||
if role and role.lower() in link["context"].lower():
|
||||
role_bonus = 10
|
||||
|
||||
# Combined score (take best of context or username, plus role bonus)
|
||||
total_score = max(context_score, username_score) + role_bonus
|
||||
|
||||
logger.debug(
|
||||
f" {name} -> {link['url']}: context={context_score}, username={username_score}, role={role_bonus}, total={total_score}"
|
||||
)
|
||||
|
||||
if total_score > best_score:
|
||||
best_score = total_score
|
||||
best_match = link
|
||||
|
||||
if best_match and best_score >= 30: # Minimum threshold
|
||||
return {
|
||||
"linkedin_url": best_match["url"],
|
||||
"confidence": min(best_score, 100),
|
||||
"method": "source_crawl",
|
||||
"notes": f"Found on {source_url}",
|
||||
}
|
||||
|
||||
return {
|
||||
"linkedin_url": None,
|
||||
"confidence": 0,
|
||||
"method": "source_crawl",
|
||||
"notes": f'No matching LinkedIn profile found for "{name}" on {source_url}',
|
||||
}
|
||||
|
||||
async def find_linkedin_via_search(
|
||||
self, name: str, company: str, role: Optional[str] = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Find LinkedIn profile using web search.
|
||||
|
||||
Args:
|
||||
name: Person's name
|
||||
company: Company/investor name
|
||||
role: Person's role (optional)
|
||||
|
||||
Returns:
|
||||
Dict with linkedin_url, confidence, method, notes
|
||||
"""
|
||||
if not self.ddg_search:
|
||||
return {
|
||||
"linkedin_url": None,
|
||||
"confidence": 0,
|
||||
"method": "web_search",
|
||||
"notes": "Web search not available",
|
||||
}
|
||||
|
||||
try:
|
||||
# Build search query - search for LinkedIn profile
|
||||
query = f"{name} {company} site:linkedin.com/in"
|
||||
if role:
|
||||
query = f"{name} {role} {company} site:linkedin.com/in"
|
||||
|
||||
logger.debug(f"Searching: {query}")
|
||||
results = self.web_search(query)
|
||||
|
||||
if results:
|
||||
# Look for LinkedIn profile URLs in results
|
||||
linkedin_pattern = r"https?://(?:www\.)?(?:[a-z]{2,3}\.)?linkedin\.com/in/([a-zA-Z0-9_-]+)"
|
||||
|
||||
for result in results:
|
||||
url = result.get("href") or result.get("link") or ""
|
||||
title = result.get("title", "").lower()
|
||||
body = result.get("body", "").lower()
|
||||
|
||||
match = re.search(linkedin_pattern, url, re.IGNORECASE)
|
||||
if match:
|
||||
linkedin_url = self._normalize_linkedin_url(match.group(0))
|
||||
username = match.group(1)
|
||||
|
||||
# Score based on name matching in title/body and username
|
||||
context = f"{title} {body}"
|
||||
context_score = self._name_matches_context(name, context)
|
||||
username_score = self._name_matches_username(name, username)
|
||||
|
||||
total_score = max(context_score, username_score)
|
||||
|
||||
if total_score >= 30:
|
||||
return {
|
||||
"linkedin_url": linkedin_url,
|
||||
"confidence": min(
|
||||
total_score, 90
|
||||
), # Cap at 90 for search results
|
||||
"method": "web_search",
|
||||
"notes": "Found via web search",
|
||||
}
|
||||
|
||||
return {
|
||||
"linkedin_url": None,
|
||||
"confidence": 0,
|
||||
"method": "web_search",
|
||||
"notes": "No matching profile found in search results",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Web search error for {name}: {e}")
|
||||
return {
|
||||
"linkedin_url": None,
|
||||
"confidence": 0,
|
||||
"method": "web_search",
|
||||
"notes": f"Search error: {str(e)}",
|
||||
}
|
||||
|
||||
async def find_linkedin_profile(
|
||||
self,
|
||||
name: str,
|
||||
company: str,
|
||||
role: Optional[str] = None,
|
||||
source_url: Optional[str] = None,
|
||||
) -> Dict:
|
||||
"""
|
||||
Find LinkedIn profile for a person.
|
||||
|
||||
Primary strategy: Crawl source URL to find LinkedIn links.
|
||||
|
||||
Args:
|
||||
name: Person's name
|
||||
company: Company/investor name
|
||||
role: Person's role/title (optional)
|
||||
source_url: URL where person info was found (optional)
|
||||
|
||||
Returns:
|
||||
Dict with:
|
||||
- linkedin_url: Found LinkedIn URL or None
|
||||
- confidence: Confidence score (0-100)
|
||||
- method: Method used to find the profile
|
||||
- notes: Additional information
|
||||
"""
|
||||
cache_key = f"{name}|{company}"
|
||||
|
||||
# Check cache
|
||||
if self.use_cache and cache_key in self.profile_cache:
|
||||
logger.debug(f"Using cached result for {name}")
|
||||
return self.profile_cache[cache_key]
|
||||
|
||||
result = {"linkedin_url": None, "confidence": 0, "method": "none", "notes": ""}
|
||||
|
||||
# Primary strategy: Crawl source URL
|
||||
if source_url:
|
||||
result = await self.find_linkedin_from_source(name, source_url, role)
|
||||
|
||||
if result["linkedin_url"]:
|
||||
if self.use_cache:
|
||||
self.profile_cache[cache_key] = result
|
||||
return result
|
||||
|
||||
# Fallback strategy: Web search (if enabled and no result from source crawl)
|
||||
if self.use_llm_search and not result.get("linkedin_url"):
|
||||
search_result = await self.find_linkedin_via_search(name, company, role)
|
||||
if search_result["linkedin_url"]:
|
||||
if self.use_cache:
|
||||
self.profile_cache[cache_key] = search_result
|
||||
return search_result
|
||||
|
||||
# If no source URL or no match found
|
||||
if not result["linkedin_url"]:
|
||||
result = {
|
||||
"linkedin_url": None,
|
||||
"confidence": 0,
|
||||
"method": "none",
|
||||
"notes": "No source URL available"
|
||||
if not source_url
|
||||
else result.get("notes", "Not found"),
|
||||
}
|
||||
|
||||
if self.use_cache:
|
||||
self.profile_cache[cache_key] = result
|
||||
|
||||
return result
|
||||
|
||||
async def batch_find_profiles(
|
||||
self, members: List[Dict], progress_callback=None, db_callback=None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Find LinkedIn profiles for multiple members efficiently.
|
||||
|
||||
Groups members by source_url to minimize crawling the same page multiple times.
|
||||
|
||||
Args:
|
||||
members: List of dicts with 'name', 'company', 'role', 'source_url', 'id'
|
||||
progress_callback: Optional callback function(current, total, result)
|
||||
db_callback: Optional callback to save to database immediately when profile found
|
||||
Signature: db_callback(member_id, linkedin_url) -> bool
|
||||
|
||||
Returns:
|
||||
List of results for each member
|
||||
"""
|
||||
results = []
|
||||
total = len(members)
|
||||
|
||||
# Group members by source_url for efficient crawling
|
||||
url_groups: Dict[str, List[Dict]] = {}
|
||||
no_url_members = []
|
||||
|
||||
for member in members:
|
||||
url = member.get("source_url")
|
||||
if url:
|
||||
if url not in url_groups:
|
||||
url_groups[url] = []
|
||||
url_groups[url].append(member)
|
||||
else:
|
||||
no_url_members.append(member)
|
||||
|
||||
logger.info(
|
||||
f"Processing {len(url_groups)} unique source URLs for {total} members"
|
||||
)
|
||||
logger.info(f"Members with source URLs: {total - len(no_url_members)}")
|
||||
logger.info(f"Members without source URLs: {len(no_url_members)}")
|
||||
if self.use_llm_search:
|
||||
logger.info("Web search fallback: ENABLED")
|
||||
else:
|
||||
logger.info("Web search fallback: DISABLED")
|
||||
|
||||
processed = 0
|
||||
|
||||
# Process members grouped by URL (efficient - one crawl per page)
|
||||
for url, group_members in url_groups.items():
|
||||
# Crawl the page once
|
||||
content = await self.crawl_page(url)
|
||||
html = self.html_cache.get(url, content or "")
|
||||
|
||||
# Extract all LinkedIn URLs from this page
|
||||
linkedin_links = []
|
||||
if content:
|
||||
linkedin_links = self.extract_linkedin_urls_from_content(html)
|
||||
if not linkedin_links:
|
||||
linkedin_links = self.extract_linkedin_urls_from_content(content)
|
||||
|
||||
# Match each member in this group
|
||||
for member in group_members:
|
||||
processed += 1
|
||||
result = None
|
||||
found_linkedin = False
|
||||
|
||||
if linkedin_links:
|
||||
# Find best matching LinkedIn for this member
|
||||
best_match = None
|
||||
best_score = 0
|
||||
|
||||
for link in linkedin_links:
|
||||
context_score = self._name_matches_context(
|
||||
member["name"], link["context"]
|
||||
)
|
||||
username_score = self._name_matches_username(
|
||||
member["name"], link["username"]
|
||||
)
|
||||
role_bonus = (
|
||||
10
|
||||
if member.get("role")
|
||||
and member["role"].lower() in link["context"].lower()
|
||||
else 0
|
||||
)
|
||||
total_score = max(context_score, username_score) + role_bonus
|
||||
|
||||
if total_score > best_score:
|
||||
best_score = total_score
|
||||
best_match = link
|
||||
|
||||
if best_match and best_score >= 30:
|
||||
result = {
|
||||
"linkedin_url": best_match["url"],
|
||||
"confidence": min(best_score, 100),
|
||||
"method": "source_crawl",
|
||||
"notes": f"Found on {url}",
|
||||
"member_id": member.get("id"),
|
||||
"member_name": member["name"],
|
||||
}
|
||||
found_linkedin = True
|
||||
# Save to database immediately if callback provided
|
||||
if db_callback and member.get("id"):
|
||||
db_callback(member["id"], best_match["url"])
|
||||
|
||||
# If no result from source crawl, try web search IMMEDIATELY
|
||||
if not found_linkedin and self.use_llm_search:
|
||||
search_result = await self.find_linkedin_via_search(
|
||||
member["name"], member["company"], member.get("role")
|
||||
)
|
||||
|
||||
if search_result["linkedin_url"]:
|
||||
result = {
|
||||
"linkedin_url": search_result["linkedin_url"],
|
||||
"confidence": search_result["confidence"],
|
||||
"method": "web_search",
|
||||
"notes": search_result.get("notes", "Found via web search"),
|
||||
"member_id": member.get("id"),
|
||||
"member_name": member["name"],
|
||||
}
|
||||
found_linkedin = True
|
||||
# Save to database immediately
|
||||
if db_callback and member.get("id"):
|
||||
db_callback(member["id"], search_result["linkedin_url"])
|
||||
|
||||
# If still no result, record as not found
|
||||
if not found_linkedin:
|
||||
result = {
|
||||
"linkedin_url": None,
|
||||
"confidence": 0,
|
||||
"method": "source_crawl" if content else "none",
|
||||
"notes": f"No match on {url}"
|
||||
if linkedin_links
|
||||
else (
|
||||
f"No LinkedIn URLs on {url}"
|
||||
if content
|
||||
else f"Failed to crawl {url}"
|
||||
),
|
||||
"member_id": member.get("id"),
|
||||
"member_name": member["name"],
|
||||
}
|
||||
|
||||
results.append(result)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(processed, total, result)
|
||||
|
||||
# Small delay between different URLs
|
||||
await asyncio.sleep(self.rate_limit_delay)
|
||||
|
||||
# Process members without source URLs - do web search immediately for each
|
||||
for member in no_url_members:
|
||||
processed += 1
|
||||
result = None
|
||||
|
||||
# Try web search immediately
|
||||
if self.use_llm_search:
|
||||
search_result = await self.find_linkedin_via_search(
|
||||
member["name"], member["company"], member.get("role")
|
||||
)
|
||||
|
||||
if search_result["linkedin_url"]:
|
||||
result = {
|
||||
"linkedin_url": search_result["linkedin_url"],
|
||||
"confidence": search_result["confidence"],
|
||||
"method": "web_search",
|
||||
"notes": search_result.get("notes", "Found via web search"),
|
||||
"member_id": member.get("id"),
|
||||
"member_name": member["name"],
|
||||
}
|
||||
# Save to database immediately
|
||||
if db_callback and member.get("id"):
|
||||
db_callback(member["id"], search_result["linkedin_url"])
|
||||
|
||||
# If no result from search
|
||||
if not result:
|
||||
result = {
|
||||
"linkedin_url": None,
|
||||
"confidence": 0,
|
||||
"method": "web_search" if self.use_llm_search else "none",
|
||||
"notes": "No LinkedIn profile found"
|
||||
if self.use_llm_search
|
||||
else "No source URL available",
|
||||
"member_id": member.get("id"),
|
||||
"member_name": member["name"],
|
||||
}
|
||||
|
||||
results.append(result)
|
||||
|
||||
if progress_callback:
|
||||
progress_callback(processed, total, result)
|
||||
|
||||
# Rate limit between searches
|
||||
await asyncio.sleep(self.rate_limit_delay)
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def format_linkedin_url(url: str) -> str:
|
||||
"""Normalize LinkedIn URL format"""
|
||||
if not url:
|
||||
return url
|
||||
|
||||
# Remove trailing slashes
|
||||
url = url.rstrip("/")
|
||||
|
||||
# Ensure https and normalize to www
|
||||
url = re.sub(r"https?://[a-z]{2,3}\.linkedin\.com", "https://www.linkedin.com", url)
|
||||
if url.startswith("http://"):
|
||||
url = url.replace("http://", "https://")
|
||||
|
||||
return url
|
||||
|
||||
|
||||
# Async wrapper for sync contexts
|
||||
def run_batch_scraper(
|
||||
members: List[Dict], rate_limit: float = 0.5, progress_callback=None
|
||||
) -> List[Dict]:
|
||||
"""
|
||||
Synchronous wrapper for batch_find_profiles.
|
||||
|
||||
Args:
|
||||
members: List of member dicts
|
||||
rate_limit: Delay between URL crawls
|
||||
progress_callback: Optional progress callback
|
||||
|
||||
Returns:
|
||||
List of results
|
||||
"""
|
||||
scraper = LinkedInProfileScraper(rate_limit_delay=rate_limit)
|
||||
return asyncio.run(scraper.batch_find_profiles(members, progress_callback))
|
||||
+21
-98
@@ -1,23 +1,12 @@
|
||||
import io
|
||||
import logging
|
||||
|
||||
import pandas as pd
|
||||
from db.db import Base, db_dependency, engine
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi import FastAPI, File, Form, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from routers import (
|
||||
addition,
|
||||
companies,
|
||||
folk_crm,
|
||||
insight_route,
|
||||
investors,
|
||||
projects,
|
||||
report_route,
|
||||
)
|
||||
from schemas.router_schemas import CompanyData, InvestmentResponse, PaginatedResponse
|
||||
from services.company_querying import CompanyQueryProcessor
|
||||
from routers import companies, investors, projects
|
||||
from schemas.router_schemas import InvestorList
|
||||
from services.llm_parser import InvestorProcessor
|
||||
from services.querying import QueryProcessor
|
||||
|
||||
@@ -29,21 +18,10 @@ def init_database():
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
init_database()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
# Add CORS middleware to allow frontend requests
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # In production, replace with specific origins
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# Request models
|
||||
class QueryRequest(BaseModel):
|
||||
@@ -57,17 +35,6 @@ class QueryRequest(BaseModel):
|
||||
}
|
||||
|
||||
|
||||
class CompanyQueryRequest(BaseModel):
|
||||
question: str
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"question": "Find me companies in the fintech sector located in San Francisco."
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def health():
|
||||
return {"Hello": "World"}
|
||||
@@ -77,29 +44,6 @@ def health():
|
||||
async def parse_csv(
|
||||
db: db_dependency, file: UploadFile = File(...), is_investor: int = Form(...)
|
||||
):
|
||||
"""
|
||||
Parse and import CSV data into the database.
|
||||
|
||||
**For investors:**
|
||||
- Expected columns: Name, Website, Final Investor Profile, Final Profile sourcing
|
||||
- Manually parses JSON profiles for efficiency
|
||||
- Uses LLM only for currency conversion to USD
|
||||
- Handles AUM, fund sizes, and check sizes as integers
|
||||
|
||||
**For companies:**
|
||||
- Expected columns: Name, Website, Perplexity Gap Output (or Final Investor Profile)
|
||||
- 100% manual JSON parsing - no LLM needed
|
||||
- **Only extracts:** founded_year and key_executives
|
||||
- **Only updates companies already in the database** (syncs with existing records)
|
||||
- Skips companies not found in the database
|
||||
|
||||
**Benefits:**
|
||||
- Fast processing (5-10s per record)
|
||||
- Low cost (minimal or no LLM usage)
|
||||
- Accurate data extraction
|
||||
- Automatic database persistence
|
||||
- Safe: won't create duplicate companies
|
||||
"""
|
||||
# Read uploaded CSV with pandas
|
||||
content = await file.read()
|
||||
df = pd.read_csv(io.StringIO(content.decode("utf-8")))
|
||||
@@ -108,56 +52,35 @@ async def parse_csv(
|
||||
processor = InvestorProcessor()
|
||||
|
||||
if is_investor == 1:
|
||||
# Manual parser with LLM currency conversion
|
||||
results = await processor.parse_investors(df, save_to_db=True)
|
||||
# Results are already dicts from the new parser
|
||||
return results
|
||||
results = await processor.parse_investors(df)
|
||||
else:
|
||||
# Manual parser for companies (no LLM needed)
|
||||
results = await processor.parse_companies(df, save_to_db=True)
|
||||
# Results are already dicts from the new parser
|
||||
return results
|
||||
results = await processor.parse_companies(df)
|
||||
|
||||
# Convert Pydantic objects to dictionaries
|
||||
return [r.model_dump() for r in results]
|
||||
|
||||
|
||||
@app.post(
|
||||
"/query", response_model=PaginatedResponse[InvestmentResponse], tags=["Querying"]
|
||||
)
|
||||
@app.post("/query", response_model=InvestorList, tags=["Querying"])
|
||||
async def query_investors(request: QueryRequest):
|
||||
"""Query investors/funds using natural language"""
|
||||
try:
|
||||
"""
|
||||
Query investors using natural language.
|
||||
|
||||
Supports queries like:
|
||||
- "Show me seed stage investors"
|
||||
- "Find fintech investors in Silicon Valley"
|
||||
- "Growth stage investors with $5M+ check sizes"
|
||||
- "Healthcare investors in Europe"
|
||||
"""
|
||||
processor = QueryProcessor()
|
||||
result = await processor.process_query(request.question)
|
||||
logger.info(f"Query completed successfully with {result.total} results")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error in query_investors: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post(
|
||||
"/query-companies", response_model=PaginatedResponse[CompanyData], tags=["Querying"]
|
||||
)
|
||||
async def query_companies(request: CompanyQueryRequest):
|
||||
"""Query companies using natural language"""
|
||||
try:
|
||||
processor = CompanyQueryProcessor()
|
||||
result = await processor.process_query(request.question)
|
||||
logger.info(f"Company query completed successfully with {result.total} results")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Error in query_companies: {e}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
results = processor.process_query(request.question)
|
||||
return results
|
||||
|
||||
|
||||
app.include_router(investors.router)
|
||||
app.include_router(companies.router)
|
||||
app.include_router(projects.router)
|
||||
app.include_router(folk_crm.router)
|
||||
app.include_router(insight_route.router)
|
||||
app.include_router(report_route.router)
|
||||
app.include_router(addition.router)
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(app="main:app", host="0.0.0.0", port=8585)
|
||||
uvicorn.run(app="main:app", host="0.0.0.0", port=8585, reload=True)
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -1,370 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from db.db import get_db
|
||||
from db.models import FundTable, InvestorTable, SectorTable
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
router = APIRouter(tags=["Additional Routes"])
|
||||
|
||||
|
||||
# Response schemas
|
||||
class SectorsResponse(BaseModel):
|
||||
sectors: list[str]
|
||||
total: int
|
||||
|
||||
|
||||
class CountryInfo(BaseModel):
|
||||
name: str
|
||||
|
||||
|
||||
class ContinentInfo(BaseModel):
|
||||
name: str
|
||||
countries: list[str]
|
||||
|
||||
|
||||
class GeographyResponse(BaseModel):
|
||||
continents: list[ContinentInfo]
|
||||
total_continents: int
|
||||
total_countries: int
|
||||
|
||||
|
||||
# Mapping of countries to continents
|
||||
COUNTRY_TO_CONTINENT = {
|
||||
# Africa
|
||||
"Algeria": "Africa",
|
||||
"Angola": "Africa",
|
||||
"Benin": "Africa",
|
||||
"Botswana": "Africa",
|
||||
"Burkina Faso": "Africa",
|
||||
"Burundi": "Africa",
|
||||
"Cameroon": "Africa",
|
||||
"Cape Verde": "Africa",
|
||||
"Central African Republic": "Africa",
|
||||
"Chad": "Africa",
|
||||
"Comoros": "Africa",
|
||||
"Congo": "Africa",
|
||||
"Democratic Republic of the Congo": "Africa",
|
||||
"Djibouti": "Africa",
|
||||
"Egypt": "Africa",
|
||||
"Equatorial Guinea": "Africa",
|
||||
"Eritrea": "Africa",
|
||||
"Eswatini": "Africa",
|
||||
"Ethiopia": "Africa",
|
||||
"Gabon": "Africa",
|
||||
"Gambia": "Africa",
|
||||
"Ghana": "Africa",
|
||||
"Guinea": "Africa",
|
||||
"Guinea-Bissau": "Africa",
|
||||
"Ivory Coast": "Africa",
|
||||
"Kenya": "Africa",
|
||||
"Lesotho": "Africa",
|
||||
"Liberia": "Africa",
|
||||
"Libya": "Africa",
|
||||
"Madagascar": "Africa",
|
||||
"Malawi": "Africa",
|
||||
"Mali": "Africa",
|
||||
"Mauritania": "Africa",
|
||||
"Mauritius": "Africa",
|
||||
"Morocco": "Africa",
|
||||
"Mozambique": "Africa",
|
||||
"Namibia": "Africa",
|
||||
"Niger": "Africa",
|
||||
"Nigeria": "Africa",
|
||||
"Rwanda": "Africa",
|
||||
"Sao Tome and Principe": "Africa",
|
||||
"Senegal": "Africa",
|
||||
"Seychelles": "Africa",
|
||||
"Sierra Leone": "Africa",
|
||||
"Somalia": "Africa",
|
||||
"South Africa": "Africa",
|
||||
"South Sudan": "Africa",
|
||||
"Sudan": "Africa",
|
||||
"Tanzania": "Africa",
|
||||
"Togo": "Africa",
|
||||
"Tunisia": "Africa",
|
||||
"Uganda": "Africa",
|
||||
"Zambia": "Africa",
|
||||
"Zimbabwe": "Africa",
|
||||
# Asia
|
||||
"Afghanistan": "Asia",
|
||||
"Armenia": "Asia",
|
||||
"Azerbaijan": "Asia",
|
||||
"Bahrain": "Asia",
|
||||
"Bangladesh": "Asia",
|
||||
"Bhutan": "Asia",
|
||||
"Brunei": "Asia",
|
||||
"Cambodia": "Asia",
|
||||
"China": "Asia",
|
||||
"Cyprus": "Asia",
|
||||
"Georgia": "Asia",
|
||||
"Hong Kong": "Asia",
|
||||
"India": "Asia",
|
||||
"Indonesia": "Asia",
|
||||
"Iran": "Asia",
|
||||
"Iraq": "Asia",
|
||||
"Israel": "Asia",
|
||||
"Japan": "Asia",
|
||||
"Jordan": "Asia",
|
||||
"Kazakhstan": "Asia",
|
||||
"Kuwait": "Asia",
|
||||
"Kyrgyzstan": "Asia",
|
||||
"Laos": "Asia",
|
||||
"Lebanon": "Asia",
|
||||
"Malaysia": "Asia",
|
||||
"Maldives": "Asia",
|
||||
"Mongolia": "Asia",
|
||||
"Myanmar": "Asia",
|
||||
"Nepal": "Asia",
|
||||
"North Korea": "Asia",
|
||||
"Oman": "Asia",
|
||||
"Pakistan": "Asia",
|
||||
"Palestine": "Asia",
|
||||
"Philippines": "Asia",
|
||||
"Qatar": "Asia",
|
||||
"Saudi Arabia": "Asia",
|
||||
"Singapore": "Asia",
|
||||
"South Korea": "Asia",
|
||||
"Sri Lanka": "Asia",
|
||||
"Syria": "Asia",
|
||||
"Taiwan": "Asia",
|
||||
"Tajikistan": "Asia",
|
||||
"Thailand": "Asia",
|
||||
"Timor-Leste": "Asia",
|
||||
"Turkey": "Asia",
|
||||
"Turkmenistan": "Asia",
|
||||
"United Arab Emirates": "Asia",
|
||||
"UAE": "Asia",
|
||||
"Uzbekistan": "Asia",
|
||||
"Vietnam": "Asia",
|
||||
"Yemen": "Asia",
|
||||
# Europe
|
||||
"Albania": "Europe",
|
||||
"Andorra": "Europe",
|
||||
"Austria": "Europe",
|
||||
"Belarus": "Europe",
|
||||
"Belgium": "Europe",
|
||||
"Bosnia and Herzegovina": "Europe",
|
||||
"Bulgaria": "Europe",
|
||||
"Croatia": "Europe",
|
||||
"Czech Republic": "Europe",
|
||||
"Czechia": "Europe",
|
||||
"Denmark": "Europe",
|
||||
"Estonia": "Europe",
|
||||
"Finland": "Europe",
|
||||
"France": "Europe",
|
||||
"Germany": "Europe",
|
||||
"Greece": "Europe",
|
||||
"Hungary": "Europe",
|
||||
"Iceland": "Europe",
|
||||
"Ireland": "Europe",
|
||||
"Italy": "Europe",
|
||||
"Kosovo": "Europe",
|
||||
"Latvia": "Europe",
|
||||
"Liechtenstein": "Europe",
|
||||
"Lithuania": "Europe",
|
||||
"Luxembourg": "Europe",
|
||||
"Malta": "Europe",
|
||||
"Moldova": "Europe",
|
||||
"Monaco": "Europe",
|
||||
"Montenegro": "Europe",
|
||||
"Netherlands": "Europe",
|
||||
"North Macedonia": "Europe",
|
||||
"Norway": "Europe",
|
||||
"Poland": "Europe",
|
||||
"Portugal": "Europe",
|
||||
"Romania": "Europe",
|
||||
"Russia": "Europe",
|
||||
"San Marino": "Europe",
|
||||
"Serbia": "Europe",
|
||||
"Slovakia": "Europe",
|
||||
"Slovenia": "Europe",
|
||||
"Spain": "Europe",
|
||||
"Sweden": "Europe",
|
||||
"Switzerland": "Europe",
|
||||
"Ukraine": "Europe",
|
||||
"United Kingdom": "Europe",
|
||||
"UK": "Europe",
|
||||
"Vatican City": "Europe",
|
||||
# North America
|
||||
"Antigua and Barbuda": "North America",
|
||||
"Bahamas": "North America",
|
||||
"Barbados": "North America",
|
||||
"Belize": "North America",
|
||||
"Canada": "North America",
|
||||
"Costa Rica": "North America",
|
||||
"Cuba": "North America",
|
||||
"Dominica": "North America",
|
||||
"Dominican Republic": "North America",
|
||||
"El Salvador": "North America",
|
||||
"Grenada": "North America",
|
||||
"Guatemala": "North America",
|
||||
"Haiti": "North America",
|
||||
"Honduras": "North America",
|
||||
"Jamaica": "North America",
|
||||
"Mexico": "North America",
|
||||
"Nicaragua": "North America",
|
||||
"Panama": "North America",
|
||||
"Saint Kitts and Nevis": "North America",
|
||||
"Saint Lucia": "North America",
|
||||
"Saint Vincent and the Grenadines": "North America",
|
||||
"Trinidad and Tobago": "North America",
|
||||
"United States": "North America",
|
||||
"USA": "North America",
|
||||
"US": "North America",
|
||||
# South America
|
||||
"Argentina": "South America",
|
||||
"Bolivia": "South America",
|
||||
"Brazil": "South America",
|
||||
"Chile": "South America",
|
||||
"Colombia": "South America",
|
||||
"Ecuador": "South America",
|
||||
"Guyana": "South America",
|
||||
"Paraguay": "South America",
|
||||
"Peru": "South America",
|
||||
"Suriname": "South America",
|
||||
"Uruguay": "South America",
|
||||
"Venezuela": "South America",
|
||||
# Oceania
|
||||
"Australia": "Oceania",
|
||||
"Fiji": "Oceania",
|
||||
"Kiribati": "Oceania",
|
||||
"Marshall Islands": "Oceania",
|
||||
"Micronesia": "Oceania",
|
||||
"Nauru": "Oceania",
|
||||
"New Zealand": "Oceania",
|
||||
"Palau": "Oceania",
|
||||
"Papua New Guinea": "Oceania",
|
||||
"Samoa": "Oceania",
|
||||
"Solomon Islands": "Oceania",
|
||||
"Tonga": "Oceania",
|
||||
"Tuvalu": "Oceania",
|
||||
"Vanuatu": "Oceania",
|
||||
}
|
||||
|
||||
# Valid continent names for direct matching
|
||||
VALID_CONTINENTS = {
|
||||
"Africa",
|
||||
"Asia",
|
||||
"Europe",
|
||||
"North America",
|
||||
"South America",
|
||||
"Oceania",
|
||||
"Antarctica",
|
||||
}
|
||||
|
||||
|
||||
def extract_countries_from_geographic_focus(geographic_focus: str) -> set[str]:
|
||||
"""
|
||||
Extract country names from a geographic_focus string.
|
||||
Handles comma-separated values, slashes, and various formats.
|
||||
"""
|
||||
if not geographic_focus:
|
||||
return set()
|
||||
|
||||
countries = set()
|
||||
# Split by common delimiters
|
||||
parts = geographic_focus.replace("/", ",").replace(";", ",").split(",")
|
||||
|
||||
for part in parts:
|
||||
cleaned = part.strip()
|
||||
if cleaned:
|
||||
# Check if it's a known country
|
||||
if cleaned in COUNTRY_TO_CONTINENT:
|
||||
countries.add(cleaned)
|
||||
# Check for partial matches (e.g., "United States of America" -> "United States")
|
||||
else:
|
||||
for country in COUNTRY_TO_CONTINENT.keys():
|
||||
if country.lower() in cleaned.lower() or cleaned.lower() in country.lower():
|
||||
countries.add(country)
|
||||
break
|
||||
|
||||
return countries
|
||||
|
||||
|
||||
def organize_geography(geographic_data: list[str]) -> dict[str, set[str]]:
|
||||
"""
|
||||
Organize geographic data into continents and their countries.
|
||||
Returns a dict with continent names as keys and sets of countries as values.
|
||||
"""
|
||||
continent_countries: dict[str, set[str]] = {}
|
||||
|
||||
for geo_focus in geographic_data:
|
||||
if not geo_focus:
|
||||
continue
|
||||
|
||||
# Extract countries from the geographic focus string
|
||||
countries = extract_countries_from_geographic_focus(geo_focus)
|
||||
|
||||
for country in countries:
|
||||
continent = COUNTRY_TO_CONTINENT.get(country)
|
||||
if continent:
|
||||
if continent not in continent_countries:
|
||||
continent_countries[continent] = set()
|
||||
continent_countries[continent].add(country)
|
||||
|
||||
# Also check if the geographic focus itself is a continent
|
||||
cleaned_geo = geo_focus.strip()
|
||||
if cleaned_geo in VALID_CONTINENTS:
|
||||
if cleaned_geo not in continent_countries:
|
||||
continent_countries[cleaned_geo] = set()
|
||||
|
||||
return continent_countries
|
||||
|
||||
|
||||
@router.get("/sectors", response_model=SectorsResponse)
|
||||
def get_unique_sectors(db: Session = Depends(get_db)):
|
||||
"""
|
||||
Get all unique sectors from the database.
|
||||
Returns a list of sector names sorted alphabetically.
|
||||
"""
|
||||
sectors = db.query(SectorTable.name).distinct().order_by(SectorTable.name).all()
|
||||
sector_names = [s[0] for s in sectors if s[0]]
|
||||
|
||||
return SectorsResponse(sectors=sector_names, total=len(sector_names))
|
||||
|
||||
|
||||
@router.get("/geography", response_model=GeographyResponse)
|
||||
def get_arranged_geography(db: Session = Depends(get_db)):
|
||||
"""
|
||||
Get all unique geographic locations arranged by continent and countries.
|
||||
Extracts geography from both investors and funds tables.
|
||||
Returns continents with their associated countries.
|
||||
"""
|
||||
# Collect all geographic focus data from investors
|
||||
investor_geo = (
|
||||
db.query(InvestorTable.geographic_focus)
|
||||
.filter(InvestorTable.geographic_focus.isnot(None))
|
||||
.distinct()
|
||||
.all()
|
||||
)
|
||||
|
||||
# Collect all geographic focus data from funds
|
||||
fund_geo = (
|
||||
db.query(FundTable.geographic_focus)
|
||||
.filter(FundTable.geographic_focus.isnot(None))
|
||||
.distinct()
|
||||
.all()
|
||||
)
|
||||
|
||||
# Combine all geographic data
|
||||
all_geo_data = [g[0] for g in investor_geo] + [g[0] for g in fund_geo]
|
||||
|
||||
# Organize into continents and countries
|
||||
continent_countries = organize_geography(all_geo_data)
|
||||
|
||||
# Build response
|
||||
continents = []
|
||||
total_countries = 0
|
||||
|
||||
for continent_name in sorted(continent_countries.keys()):
|
||||
countries = sorted(continent_countries[continent_name])
|
||||
total_countries += len(countries)
|
||||
continents.append(ContinentInfo(name=continent_name, countries=countries))
|
||||
|
||||
return GeographyResponse(
|
||||
continents=continents,
|
||||
total_continents=len(continents),
|
||||
total_countries=total_countries,
|
||||
)
|
||||
+18
-67
@@ -1,10 +1,10 @@
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from db.db import get_db
|
||||
from db.models import CompanyTable, InvestorTable
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from schemas.router_schemas import CompanyData, PaginatedResponse
|
||||
from schemas.router_schemas import CompanyData
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
router = APIRouter(tags=["Company Routes"])
|
||||
@@ -29,63 +29,38 @@ class CompanyUpdate(BaseModel):
|
||||
website: Optional[str] = None
|
||||
|
||||
|
||||
@router.get("/companies", response_model=PaginatedResponse[CompanyData])
|
||||
def read_companies(
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all companies with their investor relationships (paginated)"""
|
||||
# Calculate offset
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Get total count
|
||||
total_count = (
|
||||
db.query(CompanyTable)
|
||||
.filter(CompanyTable.name.isnot(None), CompanyTable.description.isnot(None))
|
||||
.count()
|
||||
)
|
||||
|
||||
# Get paginated results
|
||||
@router.get("/companies", response_model=List[CompanyData])
|
||||
def read_companies(db: Session = Depends(get_db)):
|
||||
"""Get all companies with their investor relationships"""
|
||||
companies = (
|
||||
db.query(CompanyTable)
|
||||
.filter(CompanyTable.name.isnot(None), CompanyTable.description.isnot(None))
|
||||
.filter(
|
||||
CompanyTable.name.isnot(None),
|
||||
CompanyTable.description.isnot(None)
|
||||
)
|
||||
.options(
|
||||
selectinload(CompanyTable.investors),
|
||||
selectinload(CompanyTable.members),
|
||||
selectinload(CompanyTable.sectors),
|
||||
)
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Transform CompanyTable objects to CompanyData format
|
||||
company_data_list = []
|
||||
for company in companies:
|
||||
# Sort sectors alphabetically
|
||||
sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else []
|
||||
company_data = CompanyData(
|
||||
company=company,
|
||||
investors=company.investors,
|
||||
members=company.members,
|
||||
sectors=sorted_sectors,
|
||||
sectors=company.sectors,
|
||||
)
|
||||
company_data_list.append(company_data)
|
||||
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=company_data_list,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
return company_data_list
|
||||
|
||||
|
||||
@router.get("/companies/filter", response_model=PaginatedResponse[CompanyData])
|
||||
@router.get("/companies/filter", response_model=List[CompanyData])
|
||||
def filter_companies(
|
||||
industry: Optional[str] = Query(
|
||||
None, description="Filter by industry (partial match)"
|
||||
@@ -101,11 +76,9 @@ def filter_companies(
|
||||
investor_name: Optional[str] = Query(
|
||||
None, description="Filter by investor name (partial match)"
|
||||
),
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Filter companies based on various criteria (paginated)"""
|
||||
"""Filter companies based on various criteria"""
|
||||
|
||||
# Start with base query
|
||||
query = db.query(CompanyTable).options(
|
||||
@@ -139,36 +112,20 @@ def filter_companies(
|
||||
InvestorTable.name.ilike(f"%{investor_name}%")
|
||||
)
|
||||
|
||||
# Get total count before pagination
|
||||
total_count = query.count()
|
||||
|
||||
# Calculate offset and apply pagination
|
||||
offset = (page - 1) * page_size
|
||||
companies = query.offset(offset).limit(page_size).all()
|
||||
companies = query.all()
|
||||
|
||||
# Transform to CompanyData format
|
||||
company_data_list = []
|
||||
for company in companies:
|
||||
# Sort sectors alphabetically
|
||||
sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else []
|
||||
company_data = CompanyData(
|
||||
company=company,
|
||||
investors=company.investors,
|
||||
members=company.members,
|
||||
sectors=sorted_sectors,
|
||||
sectors=company.sectors,
|
||||
)
|
||||
company_data_list.append(company_data)
|
||||
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=company_data_list,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
return company_data_list
|
||||
|
||||
|
||||
@router.get("/companies/{company_id}", response_model=CompanyData)
|
||||
@@ -188,15 +145,12 @@ def read_company(company_id: int, db: Session = Depends(get_db)):
|
||||
if not company:
|
||||
raise HTTPException(status_code=404, detail="Company not found")
|
||||
|
||||
# Sort sectors alphabetically
|
||||
sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else []
|
||||
|
||||
# Transform to CompanyData format
|
||||
return CompanyData(
|
||||
company=company,
|
||||
investors=company.investors,
|
||||
members=company.members,
|
||||
sectors=sorted_sectors,
|
||||
sectors=company.sectors,
|
||||
)
|
||||
|
||||
|
||||
@@ -257,15 +211,12 @@ def update_company(
|
||||
.first()
|
||||
)
|
||||
|
||||
# Sort sectors alphabetically
|
||||
sorted_sectors = sorted(company_with_relations.sectors, key=lambda s: s.name) if company_with_relations.sectors else []
|
||||
|
||||
# Transform to CompanyData format
|
||||
return CompanyData(
|
||||
company=company_with_relations,
|
||||
investors=company_with_relations.investors,
|
||||
members=company_with_relations.members,
|
||||
sectors=sorted_sectors,
|
||||
sectors=company_with_relations.sectors,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,204 +0,0 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from db.db import get_db
|
||||
from db.models import InvestorTable
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from services.crm import FolkAPI
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
router = APIRouter(prefix="/folk", tags=["Folk CRM"])
|
||||
|
||||
|
||||
def get_folk_client():
|
||||
"""Get Folk API client with loaded environment variables"""
|
||||
return FolkAPI(api_key=os.environ.get("FOLK_API_KEY", ""))
|
||||
|
||||
|
||||
class GroupResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
|
||||
|
||||
class SyncInvestorsRequest(BaseModel):
|
||||
investor_ids: List[int]
|
||||
group_id: str
|
||||
|
||||
|
||||
class SyncResult(BaseModel):
|
||||
investor_id: int
|
||||
investor_name: str
|
||||
company_id: str
|
||||
company_name: str
|
||||
team_members_synced: int
|
||||
person_ids: List[str]
|
||||
|
||||
|
||||
class SyncInvestorsResponse(BaseModel):
|
||||
success: bool
|
||||
synced_count: int
|
||||
results: List[SyncResult]
|
||||
errors: List[dict]
|
||||
|
||||
|
||||
@router.get("/groups", response_model=List[GroupResponse])
|
||||
def get_folk_groups():
|
||||
"""Get all groups from Folk CRM.
|
||||
|
||||
Returns a list of groups with their id and name that can be used
|
||||
to sync investors to Folk.
|
||||
"""
|
||||
try:
|
||||
folk = get_folk_client()
|
||||
groups_data = folk.get_groups()
|
||||
items = groups_data.get("data", {}).get("items", [])
|
||||
|
||||
return [GroupResponse(id=item["id"], name=item["name"]) for item in items]
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to fetch groups from Folk: {str(e)}"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/sync-investors", response_model=SyncInvestorsResponse)
|
||||
def sync_investors_to_folk(
|
||||
request: SyncInvestorsRequest, db: Session = Depends(get_db)
|
||||
):
|
||||
"""Sync investors to Folk CRM as companies with their team members as people.
|
||||
|
||||
Takes a list of investor IDs and a Folk group ID, then:
|
||||
1. Creates each investor as a company in the specified Folk group
|
||||
2. Creates each team member as a person linked to that company
|
||||
|
||||
Args:
|
||||
investor_ids: List of investor IDs from the database
|
||||
group_id: Folk group ID where investors should be added
|
||||
|
||||
Returns:
|
||||
Summary of sync operation including successes and errors
|
||||
"""
|
||||
folk = get_folk_client()
|
||||
# Fetch investors with their team members
|
||||
investors = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
)
|
||||
.filter(InvestorTable.id.in_(request.investor_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
if not investors:
|
||||
raise HTTPException(
|
||||
status_code=404, detail="No investors found with the provided IDs"
|
||||
)
|
||||
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
for investor in investors:
|
||||
try:
|
||||
# Create company in Folk
|
||||
company_data = folk.create_company(
|
||||
name=investor.name,
|
||||
group_id=request.group_id,
|
||||
website=investor.website,
|
||||
description=investor.description,
|
||||
addresses=[investor.headquarters] if investor.headquarters else None,
|
||||
)
|
||||
|
||||
company_id = company_data.get("data", {}).get("id")
|
||||
if not company_id:
|
||||
errors.append(
|
||||
{
|
||||
"investor_id": investor.id,
|
||||
"investor_name": investor.name,
|
||||
"error": "No company ID returned from Folk API",
|
||||
}
|
||||
)
|
||||
continue
|
||||
|
||||
# Create team members as people
|
||||
person_ids = []
|
||||
team_members_synced = 0
|
||||
|
||||
for member in investor.team_members:
|
||||
try:
|
||||
# Extract first name and last name from full name
|
||||
name_parts = member.name.split(maxsplit=1)
|
||||
first_name = name_parts[0] if name_parts else member.name
|
||||
last_name = name_parts[1] if len(name_parts) > 1 else ""
|
||||
|
||||
# Build URLs list from source_url if available
|
||||
urls_list = None
|
||||
if hasattr(member, "source_url") and member.source_url:
|
||||
urls_list = [member.source_url]
|
||||
|
||||
# Get LinkedIn URL if available
|
||||
linkedin_url = None
|
||||
if hasattr(member, "linkedin") and member.linkedin:
|
||||
linkedin_url = member.linkedin
|
||||
|
||||
# Build job title from title or role
|
||||
job_title = None
|
||||
if hasattr(member, "title") and member.title:
|
||||
job_title = member.title
|
||||
elif hasattr(member, "role") and member.role:
|
||||
job_title = member.role
|
||||
|
||||
person_data = folk.create_person(
|
||||
first_name=first_name,
|
||||
last_name=last_name,
|
||||
email=member.email,
|
||||
company_id=company_id,
|
||||
group_id=request.group_id,
|
||||
linkedin_url=linkedin_url,
|
||||
urls=urls_list,
|
||||
jobTitle=job_title,
|
||||
)
|
||||
|
||||
person_id = person_data.get("data", {}).get("id")
|
||||
if person_id:
|
||||
person_ids.append(person_id)
|
||||
team_members_synced += 1
|
||||
except Exception as person_error:
|
||||
# Log person creation error but continue with other members
|
||||
errors.append(
|
||||
{
|
||||
"investor_id": investor.id,
|
||||
"investor_name": investor.name,
|
||||
"team_member_name": member.name,
|
||||
"error": f"Failed to create person: {str(person_error)}",
|
||||
}
|
||||
)
|
||||
|
||||
results.append(
|
||||
SyncResult(
|
||||
investor_id=investor.id,
|
||||
investor_name=investor.name,
|
||||
company_id=company_id,
|
||||
company_name=company_data.get("data", {}).get(
|
||||
"name", investor.name
|
||||
),
|
||||
team_members_synced=team_members_synced,
|
||||
person_ids=person_ids,
|
||||
)
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
errors.append(
|
||||
{
|
||||
"investor_id": investor.id,
|
||||
"investor_name": investor.name,
|
||||
"error": str(e),
|
||||
}
|
||||
)
|
||||
|
||||
return SyncInvestorsResponse(
|
||||
success=len(results) > 0,
|
||||
synced_count=len(results),
|
||||
results=results,
|
||||
errors=errors,
|
||||
)
|
||||
@@ -1,122 +0,0 @@
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from db.db import get_db
|
||||
from db.models import InvestorInsightCache, InvestorTable, ProjectTable
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from schemas.insight_schema import InsightResponse
|
||||
from services.compatibility_score import (
|
||||
calculate_project_investor_compatibility,
|
||||
generate_compatibility_explanation,
|
||||
)
|
||||
from services.insight import QueryProcessor
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/insights/{investor_id}", response_model=InsightResponse, tags=["Insights"]
|
||||
)
|
||||
async def get_insights(
|
||||
investor_id: int, project_id: Optional[int] = None, db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get investor insights including investment pattern analysis, market position,
|
||||
and optionally compatibility score with a project.
|
||||
|
||||
Args:
|
||||
investor_id: The ID of the investor to analyze
|
||||
project_id: Optional project ID to calculate compatibility score
|
||||
|
||||
Returns:
|
||||
InsightResponse with investment_pattern_analysis, market_position,
|
||||
and compatibility_score (if project_id provided)
|
||||
"""
|
||||
# Get investor from database
|
||||
investor = db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
|
||||
if not investor:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Investor with id {investor_id} not found"
|
||||
)
|
||||
|
||||
# Check if we have cached insights
|
||||
cached_insights = (
|
||||
db.query(InvestorInsightCache)
|
||||
.filter(InvestorInsightCache.investor_id == investor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Determine if cache needs refresh (older than 1 month)
|
||||
needs_refresh = True
|
||||
if cached_insights:
|
||||
# Calculate if cache is older than 1 month
|
||||
cache_age = (
|
||||
datetime.now(cached_insights.last_refreshed.tzinfo)
|
||||
- cached_insights.last_refreshed
|
||||
)
|
||||
needs_refresh = cache_age > timedelta(days=30)
|
||||
|
||||
# Fetch new insights if needed
|
||||
if needs_refresh:
|
||||
# Initialize the query processor for insights
|
||||
query_processor = QueryProcessor()
|
||||
|
||||
# Get investment pattern analysis and market position using web search
|
||||
insights = await query_processor.get_investor_insights(
|
||||
investor_name=investor.name,
|
||||
investor_website=investor.website,
|
||||
investor_description=investor.description,
|
||||
investor_headquarters=investor.headquarters,
|
||||
investment_thesis=investor.investment_thesis,
|
||||
portfolio_highlights=investor.portfolio_highlights,
|
||||
)
|
||||
|
||||
# Update or create cache entry
|
||||
if cached_insights:
|
||||
# Update existing cache
|
||||
cached_insights.investment_pattern_analysis = insights[
|
||||
"investment_pattern_analysis"
|
||||
]
|
||||
cached_insights.market_position = insights["market_position"]
|
||||
cached_insights.last_refreshed = datetime.now(
|
||||
cached_insights.last_refreshed.tzinfo
|
||||
)
|
||||
else:
|
||||
# Create new cache entry
|
||||
cached_insights = InvestorInsightCache(
|
||||
investor_id=investor_id,
|
||||
investment_pattern_analysis=insights["investment_pattern_analysis"],
|
||||
market_position=insights["market_position"],
|
||||
)
|
||||
db.add(cached_insights)
|
||||
|
||||
db.commit()
|
||||
db.refresh(cached_insights)
|
||||
|
||||
# Calculate compatibility score if project_id is provided
|
||||
compatibility_score = None
|
||||
if project_id:
|
||||
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
|
||||
if not project:
|
||||
raise HTTPException(
|
||||
status_code=404, detail=f"Project with id {project_id} not found"
|
||||
)
|
||||
|
||||
# Calculate the compatibility score
|
||||
score = calculate_project_investor_compatibility(
|
||||
project, investor, use_funds=True
|
||||
)
|
||||
|
||||
# Generate detailed explanation
|
||||
compatibility_score = generate_compatibility_explanation(
|
||||
project, investor, score, use_funds=True
|
||||
)
|
||||
else:
|
||||
compatibility_score = "Select a project to see compatibility analysis"
|
||||
|
||||
return InsightResponse(
|
||||
investment_pattern_analysis=cached_insights.investment_pattern_analysis,
|
||||
market_position=cached_insights.market_position,
|
||||
compatibility_score=compatibility_score,
|
||||
)
|
||||
+98
-478
@@ -1,21 +1,11 @@
|
||||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from db.db import get_db
|
||||
from db.models import FundTable, InvestorTable, ProjectTable, SectorTable
|
||||
from db.models import InvestorTable, SectorTable
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from schemas.router_schemas import (
|
||||
CompanyMinimal,
|
||||
InvestmentResponse,
|
||||
InvestmentStage,
|
||||
InvestorData,
|
||||
PaginatedResponse,
|
||||
SectorMinimal,
|
||||
)
|
||||
from services.compatibility_score import (
|
||||
_calculate_project_fund_compatibility,
|
||||
_calculate_project_investor_direct_compatibility,
|
||||
)
|
||||
from schemas.router_schemas import InvestmentStage, InvestorData
|
||||
from services.querying import QueryProcessor
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
router = APIRouter(tags=["Investor Routes"])
|
||||
@@ -25,189 +15,53 @@ router = APIRouter(tags=["Investor Routes"])
|
||||
class InvestorCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
headquarters: Optional[str] = None
|
||||
aum: int
|
||||
check_size_lower: int
|
||||
check_size_upper: int
|
||||
geographic_focus: str
|
||||
stage_focus: InvestmentStage
|
||||
number_of_investments: int = 0
|
||||
|
||||
|
||||
class InvestorUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
headquarters: Optional[str] = None
|
||||
aum: Optional[int] = None
|
||||
check_size_lower: Optional[int] = None
|
||||
check_size_upper: Optional[int] = None
|
||||
geographic_focus: Optional[str] = None
|
||||
stage_focus: Optional[InvestmentStage] = None
|
||||
number_of_investments: Optional[int] = None
|
||||
|
||||
|
||||
@router.get("/investors", response_model=PaginatedResponse[InvestmentResponse])
|
||||
def read_investors(
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
project_id: Optional[int] = Query(
|
||||
None, description="Optional project ID for compatibility scoring"
|
||||
),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all investors with their funds as separate entries (paginated)
|
||||
|
||||
Each investor-fund combination is returned as a separate row.
|
||||
An investor with 3 funds will appear as 3 entries.
|
||||
|
||||
If project_id is provided, calculates compatibility scores for each investor.
|
||||
"""
|
||||
# Calculate offset
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Get total count
|
||||
total_count = db.query(InvestorTable).count()
|
||||
|
||||
# Load project if project_id provided
|
||||
project = None
|
||||
if project_id is not None:
|
||||
project = (
|
||||
db.query(ProjectTable)
|
||||
.options(selectinload(ProjectTable.sector))
|
||||
.filter(ProjectTable.id == project_id)
|
||||
.first()
|
||||
)
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
# When project_id is provided, we need to get all investors first to sort by compatibility score
|
||||
# Otherwise, we can paginate at the database level
|
||||
if project is not None:
|
||||
# Get all investors (we'll sort by compatibility score, then paginate)
|
||||
all_investors = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(InvestorTable.funds).selectinload(
|
||||
FundTable.investment_stages
|
||||
),
|
||||
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
# We'll paginate after sorting by compatibility score
|
||||
investors = all_investors
|
||||
else:
|
||||
# Get paginated results (no sorting needed)
|
||||
@router.get("/investors", response_model=List[InvestorData])
|
||||
def read_investors(db: Session = Depends(get_db)):
|
||||
"""Get all investors with their related data"""
|
||||
investors = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(InvestorTable.funds).selectinload(
|
||||
FundTable.investment_stages
|
||||
),
|
||||
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
|
||||
)
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Transform to InvestmentResponse format (one row per investor-fund combination)
|
||||
investment_responses = []
|
||||
# Transform InvestorTable objects to InvestorData format
|
||||
investor_data_list = []
|
||||
for investor in investors:
|
||||
# Get top 3 portfolio companies (id and name only)
|
||||
portfolio_companies = [
|
||||
CompanyMinimal(id=company.id, name=company.name)
|
||||
for company in investor.portfolio_companies[:3]
|
||||
]
|
||||
|
||||
# If investor has funds, create one entry per fund
|
||||
if investor.funds:
|
||||
for fund in investor.funds:
|
||||
# Calculate compatibility score for this specific fund
|
||||
compatibility_score = 1.0
|
||||
if project is not None:
|
||||
compatibility_score = _calculate_project_fund_compatibility(
|
||||
project=project, fund=fund
|
||||
investor_data = InvestorData(
|
||||
investor=investor, # This maps to InvestorSchema
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
)
|
||||
investor_data_list.append(investor_data)
|
||||
|
||||
# Get stage focus as comma-separated string
|
||||
stage_focus = (
|
||||
", ".join([stage.name for stage in fund.investment_stages])
|
||||
if fund.investment_stages
|
||||
else None
|
||||
)
|
||||
|
||||
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
|
||||
fund_sectors = [
|
||||
SectorMinimal(id=sector.id, name=sector.name)
|
||||
for sector in sorted(
|
||||
fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name
|
||||
)
|
||||
]
|
||||
|
||||
investment_response = InvestmentResponse(
|
||||
id=investor.id,
|
||||
name=f"{investor.name} - {fund.fund_name}"
|
||||
if fund.fund_name
|
||||
else investor.name,
|
||||
aum=investor.aum,
|
||||
check_size_lower=fund.check_size_lower,
|
||||
check_size_upper=fund.check_size_upper,
|
||||
geographic_focus=fund.geographic_focus,
|
||||
stage_focus=stage_focus,
|
||||
portfolio_companies=portfolio_companies,
|
||||
sectors=fund_sectors,
|
||||
compatibility_score=compatibility_score,
|
||||
)
|
||||
investment_responses.append(investment_response)
|
||||
else:
|
||||
# If no funds, create one entry with null fund fields
|
||||
# Calculate compatibility using investor-level data
|
||||
compatibility_score = 1.0
|
||||
if project is not None:
|
||||
compatibility_score = _calculate_project_investor_direct_compatibility(
|
||||
project=project, investor=investor
|
||||
)
|
||||
|
||||
investment_response = InvestmentResponse(
|
||||
id=investor.id,
|
||||
name=investor.name,
|
||||
aum=investor.aum,
|
||||
check_size_lower=None,
|
||||
check_size_upper=None,
|
||||
geographic_focus=None,
|
||||
stage_focus=None,
|
||||
portfolio_companies=portfolio_companies,
|
||||
sectors=[],
|
||||
compatibility_score=compatibility_score,
|
||||
)
|
||||
investment_responses.append(investment_response)
|
||||
|
||||
# Sort by compatibility score (descending) when project_id is provided
|
||||
if project is not None:
|
||||
investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True)
|
||||
# Apply pagination after sorting
|
||||
investment_responses = investment_responses[offset : offset + page_size]
|
||||
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=investment_responses,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
return investor_data_list
|
||||
|
||||
|
||||
@router.get("/investors/filter", response_model=PaginatedResponse[InvestmentResponse])
|
||||
@router.get("/investors/filter", response_model=List[InvestorData])
|
||||
def filter_investors(
|
||||
stage: Optional[InvestmentStage] = Query(
|
||||
None, description="Filter by investment stage"
|
||||
@@ -220,161 +74,67 @@ def filter_investors(
|
||||
sector: Optional[str] = Query(None, description="Sector name (partial match)"),
|
||||
min_aum: Optional[int] = Query(None, description="Minimum AUM"),
|
||||
max_aum: Optional[int] = Query(None, description="Maximum AUM"),
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
project_id: Optional[int] = Query(
|
||||
None, description="Optional project ID for compatibility scoring"
|
||||
),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Filter investors based on various criteria (paginated)
|
||||
"""Filter investors based on various criteria"""
|
||||
|
||||
Returns investor-fund combinations as separate rows.
|
||||
Queries the funds table to find matching funds.
|
||||
|
||||
If project_id is provided, calculates compatibility scores for each investor.
|
||||
"""
|
||||
|
||||
# Load project if project_id provided
|
||||
project = None
|
||||
if project_id is not None:
|
||||
project = (
|
||||
db.query(ProjectTable)
|
||||
.options(selectinload(ProjectTable.sector))
|
||||
.filter(ProjectTable.id == project_id)
|
||||
.first()
|
||||
)
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
# Start with base query on funds table
|
||||
query = db.query(FundTable).options(
|
||||
selectinload(FundTable.investor).selectinload(
|
||||
InvestorTable.portfolio_companies
|
||||
),
|
||||
selectinload(FundTable.investor).selectinload(InvestorTable.team_members),
|
||||
selectinload(FundTable.investor).selectinload(InvestorTable.sectors),
|
||||
selectinload(FundTable.investment_stages),
|
||||
selectinload(FundTable.sectors),
|
||||
# Start with base query
|
||||
query = db.query(InvestorTable).options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
)
|
||||
|
||||
# Apply filters at fund level
|
||||
# Apply filters
|
||||
if stage:
|
||||
query = query.filter(InvestorTable.stage_focus == stage)
|
||||
|
||||
if min_check_size is not None:
|
||||
query = query.filter(FundTable.check_size_lower >= min_check_size)
|
||||
query = query.filter(InvestorTable.check_size_lower >= min_check_size)
|
||||
|
||||
if max_check_size is not None:
|
||||
query = query.filter(FundTable.check_size_upper <= max_check_size)
|
||||
query = query.filter(InvestorTable.check_size_upper <= max_check_size)
|
||||
|
||||
if geography:
|
||||
query = query.filter(FundTable.geographic_focus.ilike(f"%{geography}%"))
|
||||
query = query.filter(InvestorTable.geographic_focus.ilike(f"%{geography}%"))
|
||||
|
||||
# Apply filters at investor level (through relationship)
|
||||
if min_aum is not None:
|
||||
query = query.join(FundTable.investor).filter(InvestorTable.aum >= min_aum)
|
||||
query = query.filter(InvestorTable.aum >= min_aum)
|
||||
|
||||
if max_aum is not None:
|
||||
if min_aum is None: # Only join if not already joined
|
||||
query = query.join(FundTable.investor)
|
||||
query = query.filter(InvestorTable.aum <= max_aum)
|
||||
|
||||
# Filter by sector if provided (at fund level)
|
||||
# Filter by sector if provided
|
||||
if sector:
|
||||
query = query.join(FundTable.sectors).filter(
|
||||
query = query.join(InvestorTable.sectors).filter(
|
||||
SectorTable.name.ilike(f"%{sector}%")
|
||||
)
|
||||
|
||||
# Get total count before pagination
|
||||
total_count = query.count()
|
||||
investors = query.all()
|
||||
|
||||
# When project_id is provided, we need to get all funds first to sort by compatibility score
|
||||
# Otherwise, we can paginate at the database level
|
||||
if project is not None:
|
||||
# Get all funds (we'll sort by compatibility score, then paginate)
|
||||
all_funds = query.all()
|
||||
funds = all_funds
|
||||
else:
|
||||
# Calculate offset and apply pagination (no sorting needed)
|
||||
offset = (page - 1) * page_size
|
||||
funds = query.offset(offset).limit(page_size).all()
|
||||
|
||||
# Transform to InvestmentResponse format (one row per fund)
|
||||
investment_responses = []
|
||||
for fund in funds:
|
||||
investor = fund.investor
|
||||
|
||||
# Calculate compatibility score for this specific fund
|
||||
compatibility_score = 1.0
|
||||
if project is not None:
|
||||
compatibility_score = _calculate_project_fund_compatibility(
|
||||
project=project, fund=fund
|
||||
# Transform to InvestorData format
|
||||
investor_data_list = []
|
||||
for investor in investors:
|
||||
investor_data = InvestorData(
|
||||
investor=investor,
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
)
|
||||
investor_data_list.append(investor_data)
|
||||
|
||||
# Get top 3 portfolio companies (id and name only)
|
||||
portfolio_companies = [
|
||||
CompanyMinimal(id=company.id, name=company.name)
|
||||
for company in investor.portfolio_companies[:3]
|
||||
]
|
||||
|
||||
# Get stage focus as comma-separated string
|
||||
stage_focus = (
|
||||
", ".join([stage.name for stage in fund.investment_stages])
|
||||
if fund.investment_stages
|
||||
else None
|
||||
)
|
||||
|
||||
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
|
||||
fund_sectors = [
|
||||
SectorMinimal(id=sector.id, name=sector.name)
|
||||
for sector in sorted(
|
||||
fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name
|
||||
)
|
||||
]
|
||||
|
||||
investment_response = InvestmentResponse(
|
||||
id=investor.id,
|
||||
name=f"{investor.name} - {fund.fund_name}"
|
||||
if fund.fund_name
|
||||
else investor.name,
|
||||
aum=investor.aum,
|
||||
check_size_lower=fund.check_size_lower,
|
||||
check_size_upper=fund.check_size_upper,
|
||||
geographic_focus=fund.geographic_focus,
|
||||
stage_focus=stage_focus,
|
||||
portfolio_companies=portfolio_companies,
|
||||
sectors=fund_sectors,
|
||||
compatibility_score=compatibility_score,
|
||||
)
|
||||
investment_responses.append(investment_response)
|
||||
|
||||
# Sort by compatibility score (descending) when project_id is provided
|
||||
if project is not None:
|
||||
investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True)
|
||||
# Apply pagination after sorting
|
||||
offset = (page - 1) * page_size
|
||||
investment_responses = investment_responses[offset : offset + page_size]
|
||||
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=investment_responses,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
return investor_data_list
|
||||
|
||||
|
||||
@router.get("/investors/{investor_id}", response_model=InvestorData)
|
||||
def read_investor(investor_id: int, db: Session = Depends(get_db)):
|
||||
"""Get a specific investor by ID with all their funds"""
|
||||
"""Get a specific investor by ID"""
|
||||
investor = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(InvestorTable.funds),
|
||||
)
|
||||
.filter(InvestorTable.id == investor_id)
|
||||
.first()
|
||||
@@ -383,13 +143,12 @@ def read_investor(investor_id: int, db: Session = Depends(get_db)):
|
||||
if not investor:
|
||||
raise HTTPException(status_code=404, detail="Investor not found")
|
||||
|
||||
# Transform to InvestorData format (includes funds array)
|
||||
# Transform to InvestorData format
|
||||
return InvestorData(
|
||||
investor=investor,
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
funds=investor.funds,
|
||||
)
|
||||
|
||||
|
||||
@@ -408,7 +167,6 @@ def create_investor(investor: InvestorCreate, db: Session = Depends(get_db)):
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(InvestorTable.funds),
|
||||
)
|
||||
.filter(InvestorTable.id == db_investor.id)
|
||||
.first()
|
||||
@@ -420,91 +178,24 @@ def create_investor(investor: InvestorCreate, db: Session = Depends(get_db)):
|
||||
portfolio_companies=investor_with_relations.portfolio_companies,
|
||||
team_members=investor_with_relations.team_members,
|
||||
sectors=investor_with_relations.sectors,
|
||||
funds=investor_with_relations.funds,
|
||||
)
|
||||
|
||||
|
||||
@router.put("/investors/{investor_id}", response_model=InvestorData)
|
||||
def update_investor(
|
||||
investor_id: int, investor: InvestorUpdate, db: Session = Depends(get_db)
|
||||
):
|
||||
"""Update an existing investor"""
|
||||
db_investor = (
|
||||
db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
|
||||
)
|
||||
if not db_investor:
|
||||
raise HTTPException(status_code=404, detail="Investor not found")
|
||||
|
||||
update_data = investor.dict(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_investor, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_investor)
|
||||
|
||||
# Reload with relationships
|
||||
investor_with_relations = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(InvestorTable.funds),
|
||||
)
|
||||
.filter(InvestorTable.id == investor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Transform to InvestorData format
|
||||
return InvestorData(
|
||||
investor=investor_with_relations,
|
||||
portfolio_companies=investor_with_relations.portfolio_companies,
|
||||
team_members=investor_with_relations.team_members,
|
||||
sectors=investor_with_relations.sectors,
|
||||
funds=investor_with_relations.funds,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/investors/{investor_id}")
|
||||
def delete_investor(investor_id: int, db: Session = Depends(get_db)):
|
||||
"""Delete an investor"""
|
||||
db_investor = (
|
||||
db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
|
||||
)
|
||||
if not db_investor:
|
||||
raise HTTPException(status_code=404, detail="Investor not found")
|
||||
|
||||
db.delete(db_investor)
|
||||
db.commit()
|
||||
return {"message": "Investor deleted successfully"}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/investors/{investor_id}/similar",
|
||||
response_model=PaginatedResponse[InvestmentResponse],
|
||||
)
|
||||
@router.get("/investors/{investor_id}/similar", response_model=List[InvestorData])
|
||||
def find_similar_investors(
|
||||
investor_id: int,
|
||||
limit: int = Query(10, description="Maximum number of similar investors to return"),
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""Find investors similar to a given investor based on characteristics (paginated)
|
||||
"""Find investors similar to a given investor based on characteristics"""
|
||||
|
||||
Returns investor-fund combinations as separate rows.
|
||||
Queries the funds table to find matching funds.
|
||||
"""
|
||||
|
||||
# Get the target investor to get their funds for comparison
|
||||
# Get the target investor
|
||||
target_investor = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages),
|
||||
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
|
||||
)
|
||||
.filter(InvestorTable.id == investor_id)
|
||||
.first()
|
||||
@@ -513,149 +204,78 @@ def find_similar_investors(
|
||||
if not target_investor:
|
||||
raise HTTPException(status_code=404, detail="Investor not found")
|
||||
|
||||
# Get target investor's sector IDs for comparison (from their funds)
|
||||
target_sector_ids = set()
|
||||
target_stage_ids = set()
|
||||
target_check_ranges = []
|
||||
target_geographies = []
|
||||
# Get target investor's sector IDs for comparison
|
||||
target_sector_ids = {sector.id for sector in target_investor.sectors}
|
||||
|
||||
for fund in target_investor.funds:
|
||||
if fund.sectors:
|
||||
target_sector_ids.update({sector.id for sector in fund.sectors})
|
||||
if fund.investment_stages:
|
||||
target_stage_ids.update({stage.id for stage in fund.investment_stages})
|
||||
if fund.check_size_lower and fund.check_size_upper:
|
||||
target_check_ranges.append((fund.check_size_lower, fund.check_size_upper))
|
||||
if fund.geographic_focus:
|
||||
target_geographies.append(fund.geographic_focus.lower())
|
||||
|
||||
# Query all funds from other investors
|
||||
candidate_funds = (
|
||||
db.query(FundTable)
|
||||
# Query all other investors with their relationships
|
||||
candidates = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(FundTable.investor).selectinload(
|
||||
InvestorTable.portfolio_companies
|
||||
),
|
||||
selectinload(FundTable.investor).selectinload(InvestorTable.team_members),
|
||||
selectinload(FundTable.investor).selectinload(InvestorTable.sectors),
|
||||
selectinload(FundTable.investment_stages),
|
||||
selectinload(FundTable.sectors),
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
)
|
||||
.join(FundTable.investor)
|
||||
.filter(InvestorTable.id != investor_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Calculate similarity scores for each fund
|
||||
scored_funds = []
|
||||
for fund in candidate_funds:
|
||||
# Calculate similarity scores
|
||||
scored_investors = []
|
||||
for candidate in candidates:
|
||||
score = 0
|
||||
|
||||
# Stage focus match (30 points)
|
||||
if candidate.stage_focus == target_investor.stage_focus:
|
||||
score += 30
|
||||
|
||||
# Geographic focus match (20 points for exact, 10 for partial)
|
||||
if fund.geographic_focus and target_geographies:
|
||||
fund_geo_lower = fund.geographic_focus.lower()
|
||||
for target_geo in target_geographies:
|
||||
if fund_geo_lower == target_geo:
|
||||
if candidate.geographic_focus and target_investor.geographic_focus:
|
||||
if candidate.geographic_focus.lower() == target_investor.geographic_focus.lower():
|
||||
score += 20
|
||||
break
|
||||
elif fund_geo_lower in target_geo or target_geo in fund_geo_lower:
|
||||
elif (candidate.geographic_focus.lower() in target_investor.geographic_focus.lower() or
|
||||
target_investor.geographic_focus.lower() in candidate.geographic_focus.lower()):
|
||||
score += 10
|
||||
break
|
||||
|
||||
# Check size overlap (20 points max)
|
||||
if fund.check_size_lower and fund.check_size_upper and target_check_ranges:
|
||||
max_overlap_score = 0
|
||||
for target_lower, target_upper in target_check_ranges:
|
||||
overlap_start = max(fund.check_size_lower, target_lower)
|
||||
overlap_end = min(fund.check_size_upper, target_upper)
|
||||
if (candidate.check_size_lower and candidate.check_size_upper and
|
||||
target_investor.check_size_lower and target_investor.check_size_upper):
|
||||
# Calculate overlap percentage
|
||||
overlap_start = max(candidate.check_size_lower, target_investor.check_size_lower)
|
||||
overlap_end = min(candidate.check_size_upper, target_investor.check_size_upper)
|
||||
if overlap_end > overlap_start:
|
||||
overlap = overlap_end - overlap_start
|
||||
target_range = target_upper - target_lower
|
||||
target_range = target_investor.check_size_upper - target_investor.check_size_lower
|
||||
overlap_ratio = overlap / target_range if target_range > 0 else 0
|
||||
max_overlap_score = max(max_overlap_score, int(20 * overlap_ratio))
|
||||
score += max_overlap_score
|
||||
score += int(20 * overlap_ratio)
|
||||
|
||||
# AUM similarity (15 points max)
|
||||
if fund.investor.aum and target_investor.aum:
|
||||
aum_diff = abs(fund.investor.aum - target_investor.aum)
|
||||
max_aum = max(fund.investor.aum, target_investor.aum)
|
||||
if candidate.aum and target_investor.aum:
|
||||
aum_diff = abs(candidate.aum - target_investor.aum)
|
||||
max_aum = max(candidate.aum, target_investor.aum)
|
||||
similarity_ratio = 1 - (aum_diff / max_aum) if max_aum > 0 else 0
|
||||
score += int(15 * similarity_ratio)
|
||||
|
||||
# Sector overlap (30 points max)
|
||||
if fund.sectors and target_sector_ids:
|
||||
fund_sector_ids = {sector.id for sector in fund.sectors}
|
||||
common_sectors = target_sector_ids.intersection(fund_sector_ids)
|
||||
candidate_sector_ids = {sector.id for sector in candidate.sectors}
|
||||
if target_sector_ids and candidate_sector_ids:
|
||||
common_sectors = target_sector_ids.intersection(candidate_sector_ids)
|
||||
overlap_ratio = len(common_sectors) / len(target_sector_ids)
|
||||
score += int(30 * overlap_ratio)
|
||||
|
||||
# Investment stage match (15 points max)
|
||||
if fund.investment_stages and target_stage_ids:
|
||||
fund_stage_ids = {stage.id for stage in fund.investment_stages}
|
||||
common_stages = target_stage_ids.intersection(fund_stage_ids)
|
||||
overlap_ratio = len(common_stages) / len(target_stage_ids)
|
||||
score += int(15 * overlap_ratio)
|
||||
if score > 0: # Only include investors with some similarity
|
||||
scored_investors.append((score, candidate))
|
||||
|
||||
if score > 0: # Only include funds with some similarity
|
||||
scored_funds.append((score, fund))
|
||||
# Sort by score (descending) and take top N
|
||||
scored_investors.sort(key=lambda x: x[0], reverse=True)
|
||||
similar_investors = [inv for score, inv in scored_investors[:limit]]
|
||||
|
||||
# Sort by score (descending) and take top N based on limit
|
||||
scored_funds.sort(key=lambda x: x[0], reverse=True)
|
||||
top_similar = scored_funds[:limit]
|
||||
|
||||
# Apply pagination to the top similar funds
|
||||
total_count = len(top_similar)
|
||||
offset = (page - 1) * page_size
|
||||
paginated_similar = top_similar[offset : offset + page_size]
|
||||
similar_funds = [fund for score, fund in paginated_similar]
|
||||
|
||||
# Transform to InvestmentResponse format (one row per fund)
|
||||
investment_responses = []
|
||||
for fund in similar_funds:
|
||||
investor = fund.investor
|
||||
|
||||
# Get top 3 portfolio companies (id and name only)
|
||||
portfolio_companies = [
|
||||
CompanyMinimal(id=company.id, name=company.name)
|
||||
for company in investor.portfolio_companies[:3]
|
||||
# Transform to InvestorData format
|
||||
return [
|
||||
InvestorData(
|
||||
investor=inv,
|
||||
portfolio_companies=inv.portfolio_companies,
|
||||
team_members=inv.team_members,
|
||||
sectors=inv.sectors,
|
||||
)
|
||||
for inv in similar_investors
|
||||
]
|
||||
|
||||
# Get stage focus as comma-separated string
|
||||
stage_focus = (
|
||||
", ".join([stage.name for stage in fund.investment_stages])
|
||||
if fund.investment_stages
|
||||
else None
|
||||
)
|
||||
|
||||
# Get top 3 sectors from fund (id and name only)
|
||||
fund_sectors = [
|
||||
SectorMinimal(id=sector.id, name=sector.name)
|
||||
for sector in (fund.sectors[:3] if fund.sectors else [])
|
||||
]
|
||||
|
||||
investment_response = InvestmentResponse(
|
||||
id=investor.id,
|
||||
name=f"{investor.name} - {fund.fund_name}"
|
||||
if fund.fund_name
|
||||
else investor.name,
|
||||
aum=investor.aum,
|
||||
check_size_lower=fund.check_size_lower,
|
||||
check_size_upper=fund.check_size_upper,
|
||||
geographic_focus=fund.geographic_focus,
|
||||
stage_focus=stage_focus,
|
||||
portfolio_companies=portfolio_companies,
|
||||
sectors=fund_sectors,
|
||||
compatibility_score=1.0,
|
||||
)
|
||||
investment_responses.append(investment_response)
|
||||
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=investment_responses,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
+11
-145
@@ -14,45 +14,21 @@ from schemas.project_schemas import (
|
||||
ProjectData,
|
||||
ProjectUpdate,
|
||||
)
|
||||
from schemas.router_schemas import PaginatedResponse
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
router = APIRouter(tags=["Project Routes"])
|
||||
|
||||
|
||||
@router.get("/projects", response_model=PaginatedResponse[ProjectData])
|
||||
def read_projects(
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
include_archived: bool = Query(False, description="Include archived projects"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all projects with their related data (paginated)
|
||||
|
||||
By default, archived projects are excluded. Set include_archived=True to include them.
|
||||
"""
|
||||
# Calculate offset
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Start with base query
|
||||
query = db.query(ProjectTable)
|
||||
|
||||
# Filter out archived projects by default
|
||||
if not include_archived:
|
||||
query = query.filter(ProjectTable.is_archived == 0)
|
||||
|
||||
# Get total count
|
||||
total_count = query.count()
|
||||
|
||||
# Get paginated results
|
||||
@router.get("/projects", response_model=List[ProjectData])
|
||||
def read_projects(db: Session = Depends(get_db)):
|
||||
"""Get all projects with their related data"""
|
||||
projects = (
|
||||
query.options(
|
||||
db.query(ProjectTable)
|
||||
.options(
|
||||
selectinload(ProjectTable.sector),
|
||||
selectinload(ProjectTable.investors),
|
||||
selectinload(ProjectTable.companies),
|
||||
)
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -67,16 +43,7 @@ def read_projects(
|
||||
)
|
||||
project_data_list.append(project_data)
|
||||
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=project_data_list,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
return project_data_list
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}", response_model=ProjectData)
|
||||
@@ -172,7 +139,7 @@ def update_project(
|
||||
|
||||
@router.delete("/projects/{project_id}")
|
||||
def delete_project(project_id: int, db: Session = Depends(get_db)):
|
||||
"""Delete a project permanently"""
|
||||
"""Delete a project"""
|
||||
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
|
||||
|
||||
if not db_project:
|
||||
@@ -184,88 +151,7 @@ def delete_project(project_id: int, db: Session = Depends(get_db)):
|
||||
return {"message": "Project deleted successfully"}
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/archive")
|
||||
def archive_project(project_id: int, db: Session = Depends(get_db)):
|
||||
"""Archive a project (soft delete)"""
|
||||
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
|
||||
|
||||
if not db_project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
db_project.is_archived = 1
|
||||
db.commit()
|
||||
db.refresh(db_project)
|
||||
|
||||
return {"message": "Project archived successfully", "project_id": project_id}
|
||||
|
||||
|
||||
@router.post("/projects/{project_id}/unarchive")
|
||||
def unarchive_project(project_id: int, db: Session = Depends(get_db)):
|
||||
"""Unarchive a project (restore from archive)"""
|
||||
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
|
||||
|
||||
if not db_project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
db_project.is_archived = 0
|
||||
db.commit()
|
||||
db.refresh(db_project)
|
||||
|
||||
return {"message": "Project unarchived successfully", "project_id": project_id}
|
||||
|
||||
|
||||
@router.get("/projects/archived", response_model=PaginatedResponse[ProjectData])
|
||||
def read_archived_projects(
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all archived projects (paginated)"""
|
||||
# Calculate offset
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Query only archived projects
|
||||
query = db.query(ProjectTable).filter(ProjectTable.is_archived == 1)
|
||||
|
||||
# Get total count
|
||||
total_count = query.count()
|
||||
|
||||
# Get paginated results
|
||||
projects = (
|
||||
query.options(
|
||||
selectinload(ProjectTable.sector),
|
||||
selectinload(ProjectTable.investors),
|
||||
selectinload(ProjectTable.companies),
|
||||
)
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Transform ProjectTable objects to ProjectData format
|
||||
project_data_list = []
|
||||
for project in projects:
|
||||
project_data = ProjectData(
|
||||
project=project,
|
||||
sector=project.sector,
|
||||
investors=project.investors,
|
||||
companies=project.companies,
|
||||
)
|
||||
project_data_list.append(project_data)
|
||||
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=project_data_list,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/projects/filter", response_model=PaginatedResponse[ProjectData])
|
||||
@router.get("/projects/filter", response_model=List[ProjectData])
|
||||
def filter_projects(
|
||||
stage: Optional[InvestmentStage] = Query(
|
||||
None, description="Filter by project stage"
|
||||
@@ -273,7 +159,6 @@ def filter_projects(
|
||||
min_valuation: Optional[int] = Query(None, description="Minimum valuation"),
|
||||
max_valuation: Optional[int] = Query(None, description="Maximum valuation"),
|
||||
location: Optional[str] = Query(None, description="Location (partial match)"),
|
||||
industry: Optional[str] = Query(None, description="Industry (partial match)"),
|
||||
sector: Optional[str] = Query(None, description="Sector name (partial match)"),
|
||||
investor_name: Optional[str] = Query(
|
||||
None, description="Investor name (partial match)"
|
||||
@@ -281,11 +166,9 @@ def filter_projects(
|
||||
company_name: Optional[str] = Query(
|
||||
None, description="Company name (partial match)"
|
||||
),
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Filter projects based on various criteria (paginated)"""
|
||||
"""Filter projects based on various criteria"""
|
||||
|
||||
# Start with base query
|
||||
query = db.query(ProjectTable).options(
|
||||
@@ -307,9 +190,6 @@ def filter_projects(
|
||||
if location:
|
||||
query = query.filter(ProjectTable.location.ilike(f"%{location}%"))
|
||||
|
||||
if industry:
|
||||
query = query.filter(ProjectTable.industry.ilike(f"%{industry}%"))
|
||||
|
||||
if sector:
|
||||
query = query.join(ProjectTable.sector).filter(
|
||||
SectorTable.name.ilike(f"%{sector}%")
|
||||
@@ -325,12 +205,7 @@ def filter_projects(
|
||||
CompanyTable.name.ilike(f"%{company_name}%")
|
||||
)
|
||||
|
||||
# Get total count before pagination
|
||||
total_count = query.count()
|
||||
|
||||
# Calculate offset and apply pagination
|
||||
offset = (page - 1) * page_size
|
||||
projects = query.offset(offset).limit(page_size).all()
|
||||
projects = query.all()
|
||||
|
||||
# Transform to ProjectData format
|
||||
project_data_list = []
|
||||
@@ -343,16 +218,7 @@ def filter_projects(
|
||||
)
|
||||
project_data_list.append(project_data)
|
||||
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=project_data_list,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
return project_data_list
|
||||
|
||||
|
||||
# Association management routes
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from db.db import get_db
|
||||
from db.models import FundTable, InvestorTable, ProjectTable
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from fastapi.responses import Response
|
||||
from services.report_gen import ReportGenerator
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
router = APIRouter(tags=["Report Generation"])
|
||||
|
||||
|
||||
@router.get("/report/investor/{investor_id}")
|
||||
async def generate_investor_report(
|
||||
investor_id: int,
|
||||
project_id: Optional[int] = Query(
|
||||
None, description="Optional project ID for compatibility analysis"
|
||||
),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Generate a PDF report for an investor profile.
|
||||
|
||||
Args:
|
||||
investor_id: The ID of the investor to generate a report for
|
||||
project_id: Optional project ID to include mandate match analysis
|
||||
|
||||
Returns:
|
||||
PDF file as a downloadable response
|
||||
"""
|
||||
# Fetch investor data with all relationships
|
||||
investor = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages),
|
||||
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
|
||||
)
|
||||
.filter(InvestorTable.id == investor_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not investor:
|
||||
raise HTTPException(status_code=404, detail="Investor not found")
|
||||
|
||||
# Prepare investor data dictionary
|
||||
investor_data = {
|
||||
"name": investor.name,
|
||||
"description": investor.description,
|
||||
"website": investor.website,
|
||||
"headquarters": investor.headquarters,
|
||||
"aum": investor.aum,
|
||||
"portfolio_highlights": investor.portfolio_highlights or [],
|
||||
"investment_thesis": investor.investment_thesis or [],
|
||||
"sectors": [sector.name for sector in investor.sectors],
|
||||
"team_members": [
|
||||
{
|
||||
"name": member.name,
|
||||
"role": member.role,
|
||||
"title": member.title,
|
||||
"email": member.email,
|
||||
}
|
||||
for member in investor.team_members
|
||||
],
|
||||
"funds": [],
|
||||
}
|
||||
|
||||
# Get all funds with their data
|
||||
if investor.funds:
|
||||
for fund in investor.funds:
|
||||
fund_data = {
|
||||
"fund_name": fund.fund_name,
|
||||
"fund_size": fund.fund_size,
|
||||
"check_size_lower": fund.check_size_lower,
|
||||
"check_size_upper": fund.check_size_upper,
|
||||
"geographic_focus": fund.geographic_focus,
|
||||
"investment_stages": [stage.name for stage in fund.investment_stages],
|
||||
"sectors": [sector.name for sector in fund.sectors],
|
||||
}
|
||||
investor_data["funds"].append(fund_data)
|
||||
|
||||
# Fetch project data if project_id is provided
|
||||
project_data = None
|
||||
if project_id:
|
||||
project = (
|
||||
db.query(ProjectTable)
|
||||
.options(selectinload(ProjectTable.sector))
|
||||
.filter(ProjectTable.id == project_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not project:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
project_data = {
|
||||
"name": project.name,
|
||||
"description": project.description,
|
||||
"location": project.location,
|
||||
"valuation": project.valuation,
|
||||
"stage": project.stage.name if project.stage else None,
|
||||
"sectors": [sector.name for sector in project.sector],
|
||||
}
|
||||
|
||||
# Generate PDF report
|
||||
report_generator = ReportGenerator()
|
||||
pdf_bytes = await report_generator.generate_investor_report(
|
||||
investor_data, project_data, investor_model=investor, project_model=project
|
||||
)
|
||||
|
||||
# Return PDF as downloadable file
|
||||
filename = f"{investor.name.replace(' ', '_')}_Report.pdf"
|
||||
return Response(
|
||||
content=pdf_bytes,
|
||||
media_type="application/pdf",
|
||||
headers={"Content-Disposition": f'attachment; filename="{filename}"'},
|
||||
)
|
||||
Binary file not shown.
Binary file not shown.
@@ -1,18 +0,0 @@
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class InsightResponse(BaseModel):
|
||||
investment_pattern_analysis: str
|
||||
market_position: str
|
||||
compatibility_score: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
json_schema_extra = {
|
||||
"example": {
|
||||
"investment_pattern_analysis": "Sequoia has been increasingly active in AI/ML startups (43% increase in last 18 months). Their average investment size has grown 23% year-over-year, indicating confidence in larger rounds. Peak activity in Q2-Q3, suggesting seasonal investment patterns.",
|
||||
"market_position": "Top 3 most active VC in enterprise software deals. Strong presence in unicorn companies (47 portfolio unicorns). Consistently leads or co-leads rounds, indicating decision-making influence.",
|
||||
"compatibility_score": "0.85",
|
||||
}
|
||||
}
|
||||
@@ -30,7 +30,7 @@ class InvestorSchema(BaseModel):
|
||||
check_size_lower: int | None
|
||||
check_size_upper: int | None
|
||||
geographic_focus: str | None
|
||||
stage_focus: Optional[InvestmentStage] = None
|
||||
stage_focus: InvestmentStage
|
||||
number_of_investments: int | None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
@@ -60,7 +60,6 @@ class ProjectSchema(BaseModel):
|
||||
valuation: int | None
|
||||
stage: InvestmentStage | None
|
||||
location: str | None
|
||||
industry: str | None
|
||||
description: Optional[str]
|
||||
start_date: Optional[datetime]
|
||||
end_date: Optional[datetime]
|
||||
@@ -76,7 +75,6 @@ class ProjectCreate(BaseModel):
|
||||
valuation: Optional[int] = None
|
||||
stage: Optional[InvestmentStage] = None
|
||||
location: Optional[str] = None
|
||||
industry: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
start_date: Optional[datetime] = None
|
||||
end_date: Optional[datetime] = None
|
||||
@@ -87,7 +85,6 @@ class ProjectUpdate(BaseModel):
|
||||
valuation: Optional[int] = None
|
||||
stage: Optional[InvestmentStage] = None
|
||||
location: Optional[str] = None
|
||||
industry: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
start_date: Optional[datetime] = None
|
||||
end_date: Optional[datetime] = None
|
||||
|
||||
@@ -258,6 +258,10 @@ class InvestorSchema(BaseModel):
|
||||
default=None,
|
||||
description="Geographic investment focus. Do not return any special characters, Just locations separated by commas. Leave empty if not clearly identifiable.",
|
||||
)
|
||||
stage_focus: InvestmentStage = Field(
|
||||
default=InvestmentStage.SEED,
|
||||
description="Investment stage focus. Use SEED as default if uncertain.",
|
||||
)
|
||||
number_of_investments: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any, Generic, List, Optional, TypeVar
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Generic type for pagination
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class InvestmentStage(str, Enum):
|
||||
SEED = "SEED"
|
||||
@@ -25,39 +22,11 @@ class SectorSchema(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestmentStageSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorMemberSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
role: str | None
|
||||
email: str | None
|
||||
linkedin: str | None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class FundSchema(BaseModel):
|
||||
id: int
|
||||
fund_name: str | None
|
||||
fund_size: int | None # Changed to int for numerical filtering
|
||||
fund_size_source_url: str | None
|
||||
check_size_lower: int | None # NEW: Lower bound of check size range
|
||||
check_size_upper: int | None # NEW: Upper bound of check size range
|
||||
source_url: str | None
|
||||
source_provider: str | None
|
||||
geographic_focus: str | None # Changed from List[str] to string
|
||||
investment_stages: List[InvestmentStageSchema] | None # Changed to relationship
|
||||
sectors: List[SectorSchema] | None # Changed to relationship
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
@@ -93,20 +62,11 @@ class InvestorSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: Optional[str]
|
||||
website: Optional[str] = None
|
||||
headquarters: Optional[str] = None
|
||||
aum: int | None
|
||||
aum_as_of_date: str | None = None
|
||||
aum_source_url: str | None = None
|
||||
check_size_lower: int | None
|
||||
check_size_upper: int | None
|
||||
geographic_focus: str | None
|
||||
investment_thesis: Any = (
|
||||
None # Flexible JSON field - can be list, dict, or list of dicts
|
||||
)
|
||||
portfolio_highlights: Any = (
|
||||
None # Flexible JSON field - can be list, dict, or list of dicts
|
||||
)
|
||||
stage_focus: InvestmentStage
|
||||
number_of_investments: int | None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
@@ -116,87 +76,22 @@ class InvestorSchema(BaseModel):
|
||||
|
||||
|
||||
class InvestorData(BaseModel):
|
||||
"""Comprehensive investor data schema - used for individual investor requests"""
|
||||
"""Comprehensive investor data schema for LLM processing"""
|
||||
|
||||
investor: InvestorSchema
|
||||
portfolio_companies: List[CompanySchema]
|
||||
team_members: List[InvestorMemberSchema]
|
||||
sectors: List[SectorSchema]
|
||||
funds: List[FundSchema]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorFundData(BaseModel):
|
||||
"""Investor-Fund combined data - used for list/filter requests
|
||||
|
||||
Each row represents one investor-fund combination.
|
||||
An investor with 3 funds will appear as 3 separate entries.
|
||||
"""
|
||||
|
||||
# Investor fields
|
||||
investor_id: int
|
||||
investor_name: str
|
||||
investor_description: Optional[str]
|
||||
investor_website: Optional[str]
|
||||
investor_headquarters: Optional[str]
|
||||
aum: int | None
|
||||
aum_as_of_date: str | None
|
||||
aum_source_url: str | None
|
||||
investment_thesis: Any = None # Flexible JSON field
|
||||
portfolio_highlights: Any = None # Flexible JSON field
|
||||
number_of_investments: int | None
|
||||
|
||||
# Fund fields
|
||||
fund_id: int | None
|
||||
fund_name: str | None
|
||||
fund_size: int | None # Changed to int for numerical filtering
|
||||
fund_size_source_url: str | None
|
||||
check_size_lower: int | None # NEW: Lower bound of check size range
|
||||
check_size_upper: int | None # NEW: Upper bound of check size range
|
||||
geographic_focus: str | None # Changed from List[str] to string
|
||||
fund_investment_stages: (
|
||||
List[InvestmentStageSchema] | None
|
||||
) # Changed to relationship
|
||||
fund_sectors: List[SectorSchema] | None # Changed to relationship
|
||||
|
||||
# Related data
|
||||
portfolio_companies: List[CompanySchema]
|
||||
team_members: List[InvestorMemberSchema]
|
||||
sectors: List[SectorSchema]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorMinimal(BaseModel):
|
||||
"""Minimal investor info with just id and name"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CompanySchemaMinimal(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
industry: str | None
|
||||
location: str | None
|
||||
founded_year: Optional[int]
|
||||
website: Optional[str]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
|
||||
company: CompanySchemaMinimal
|
||||
investors: List[InvestorMinimal]
|
||||
members: List[CompanyMemberSchema] = []
|
||||
sectors: List[SectorSchema] = []
|
||||
company: CompanySchema
|
||||
sectors: List[SectorSchema]
|
||||
members: List[CompanyMemberSchema]
|
||||
investors: List[InvestorSchema]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
@@ -204,65 +99,3 @@ class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
|
||||
|
||||
class InvestorList(BaseModel):
|
||||
investors: List[InvestorData]
|
||||
|
||||
|
||||
class InvestorFundList(BaseModel):
|
||||
"""List of investor-fund combinations"""
|
||||
|
||||
investor_funds: List[InvestorFundData]
|
||||
|
||||
|
||||
class CompanyMinimal(BaseModel):
|
||||
"""Minimal company info with just id and name"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SectorMinimal(BaseModel):
|
||||
"""Minimal sector info with just id and name"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestmentResponse(BaseModel):
|
||||
"""Simplified investment response schema
|
||||
|
||||
One row per investor-fund combination with streamlined data
|
||||
"""
|
||||
|
||||
id: int # Investor ID
|
||||
name: (
|
||||
str # Combination of investor name and fund name (e.g., "Investor A - Fund A")
|
||||
)
|
||||
aum: int | None # From investor
|
||||
check_size_lower: int | None # From fund
|
||||
check_size_upper: int | None # From fund
|
||||
geographic_focus: str | None # From fund
|
||||
stage_focus: str | None # Comma-separated stages from fund
|
||||
portfolio_companies: List[CompanyMinimal] # Top 3 companies from investor
|
||||
sectors: List[SectorMinimal] # Top 3 sectors from fund
|
||||
compatibility_score: float # 0 to 1 (default 1 for now)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel, Generic[T]):
|
||||
"""Generic paginated response schema"""
|
||||
|
||||
items: List[T]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -1,228 +0,0 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from db.db import get_db
|
||||
from db.models import CompanyTable
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
from schemas.router_schemas import CompanyData, PaginatedResponse
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CompanyQueryProcessor:
|
||||
def __init__(self):
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="openai/gpt-4o-mini",
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
# Query cache for performance
|
||||
self.query_cache = {}
|
||||
|
||||
# SQL generation prompt
|
||||
self.sql_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"""You are a SQL expert. Generate a SQLite query to find company IDs based on user requirements.
|
||||
|
||||
Database Schema:
|
||||
- companies: id, name, industry, location, description, founded_year, website
|
||||
- company_sector: company_id, sector_id
|
||||
- sectors: id, name
|
||||
- investor_companies: investor_id, company_id
|
||||
- investors: id, name, aum
|
||||
- team_members: id, company_id, name, title
|
||||
|
||||
IMPORTANT RULES:
|
||||
1. ALWAYS return ONLY company IDs (companies.id) - use SELECT DISTINCT c.id
|
||||
2. For industry: Check BOTH industry field AND sectors table with synonyms
|
||||
- Use LEFT JOIN for sectors so companies without sector tags still match
|
||||
- Include related terms: 'Fintech' → c.industry LIKE '%Fintech%' OR c.industry LIKE '%Finance%' OR sec.name LIKE '%Fintech%' OR sec.name LIKE '%Financial%'
|
||||
- 'AI' → c.industry LIKE '%AI%' OR c.industry LIKE '%Artificial Intelligence%' OR c.industry LIKE '%Machine Learning%' OR sec.name LIKE '%AI%' OR sec.name LIKE '%ML%'
|
||||
3. For location: Be FLEXIBLE with variations and abbreviations
|
||||
- 'San Francisco' → c.location LIKE '%San Francisco%' OR c.location LIKE '%SF%' OR c.location LIKE '%Bay Area%'
|
||||
- 'New York' → c.location LIKE '%New York%' OR c.location LIKE '%NYC%' OR c.location LIKE '%NY%'
|
||||
- 'Europe' → c.location LIKE '%Europe%' OR c.location LIKE '%UK%' OR c.location LIKE '%London%' OR c.location LIKE '%Berlin%' OR c.location LIKE '%Paris%'
|
||||
4. For sectors: Use LEFT JOIN and include multiple synonyms
|
||||
- 'Healthcare' → sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR c.industry LIKE '%Health%'
|
||||
5. For founding year filters (include NULL to be inclusive):
|
||||
- "founded after 2020" → WHERE (founded_year >= 2020 OR founded_year IS NULL)
|
||||
- "founded before 2018" → WHERE (founded_year <= 2018 OR founded_year IS NULL)
|
||||
- "founded in 2020" → WHERE founded_year = 2020
|
||||
6. For investor-related queries: Use JOIN investor_companies
|
||||
7. Use LEFT JOIN for sectors so companies without tags still match
|
||||
8. Use DISTINCT to avoid duplicates from joins
|
||||
9. Be INCLUSIVE - use OR conditions with synonyms and variations
|
||||
10. Return a single, complete SELECT query
|
||||
|
||||
Example Queries:
|
||||
Q: "Fintech companies founded in 2020"
|
||||
A: SELECT DISTINCT c.id FROM companies c
|
||||
LEFT JOIN company_sector cs ON c.id = cs.company_id
|
||||
LEFT JOIN sectors sec ON cs.sector_id = sec.id
|
||||
WHERE (c.industry LIKE '%Fintech%' OR c.industry LIKE '%Finance%' OR c.industry LIKE '%Financial%' OR sec.name LIKE '%Fintech%' OR sec.name LIKE '%Financial Services%')
|
||||
AND c.founded_year = 2020
|
||||
|
||||
Q: "AI companies in San Francisco"
|
||||
A: SELECT DISTINCT c.id FROM companies c
|
||||
LEFT JOIN company_sector cs ON c.id = cs.company_id
|
||||
LEFT JOIN sectors sec ON cs.sector_id = sec.id
|
||||
WHERE (c.industry LIKE '%AI%' OR c.industry LIKE '%Artificial Intelligence%' OR c.industry LIKE '%Machine Learning%' OR sec.name LIKE '%AI%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%')
|
||||
AND (c.location LIKE '%San Francisco%' OR c.location LIKE '%SF%' OR c.location LIKE '%Bay Area%')
|
||||
|
||||
Q: "Healthcare companies"
|
||||
A: SELECT DISTINCT c.id FROM companies c
|
||||
LEFT JOIN company_sector cs ON c.id = cs.company_id
|
||||
LEFT JOIN sectors sec ON cs.sector_id = sec.id
|
||||
WHERE c.industry LIKE '%Healthcare%' OR c.industry LIKE '%Health%' OR c.industry LIKE '%Medical%' OR sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR sec.name LIKE '%Pharma%'
|
||||
|
||||
Q: "Companies funded by Sequoia"
|
||||
A: SELECT DISTINCT c.id FROM companies c
|
||||
JOIN investor_companies ic ON c.id = ic.company_id
|
||||
JOIN investors i ON ic.investor_id = i.id
|
||||
WHERE i.name LIKE '%Sequoia%'
|
||||
|
||||
Q: "European startups founded after 2019"
|
||||
A: SELECT DISTINCT c.id FROM companies c
|
||||
WHERE (c.location LIKE '%Europe%' OR c.location LIKE '%UK%' OR c.location LIKE '%London%' OR c.location LIKE '%Germany%' OR c.location LIKE '%Berlin%' OR c.location LIKE '%France%' OR c.location LIKE '%Paris%')
|
||||
AND (c.founded_year > 2019 OR c.founded_year IS NULL)
|
||||
|
||||
Q: "SaaS companies"
|
||||
A: SELECT DISTINCT c.id FROM companies c
|
||||
LEFT JOIN company_sector cs ON c.id = cs.company_id
|
||||
LEFT JOIN sectors sec ON cs.sector_id = sec.id
|
||||
WHERE c.industry LIKE '%SaaS%' OR c.industry LIKE '%Software%' OR c.industry LIKE '%Cloud%' OR sec.name LIKE '%SaaS%' OR sec.name LIKE '%Software%'
|
||||
|
||||
IMPORTANT:
|
||||
- Use LEFT JOIN so companies without sector tags still match via industry field
|
||||
- Use OR conditions with related keywords/synonyms to cast a wider net
|
||||
- Include NULL checks for optional filters to avoid excluding companies with missing data
|
||||
|
||||
Return ONLY the SQL query, no explanations or markdown.""",
|
||||
),
|
||||
("user", "{question}"),
|
||||
]
|
||||
)
|
||||
|
||||
def _get_cache_key(self, question: str) -> str:
|
||||
"""Generate cache key from normalized question."""
|
||||
return hashlib.md5(question.lower().strip().encode()).hexdigest()
|
||||
|
||||
# synchronous helper is provided below as `_process_query_sync` and an
|
||||
# async wrapper `process_query` runs it in a thread. This keeps the
|
||||
# FastAPI event loop non-blocking while reusing the existing sync code.
|
||||
async def process_query(self, question: str) -> PaginatedResponse[CompanyData]:
|
||||
"""Async wrapper for process_query. Runs blocking work in a thread to avoid
|
||||
blocking the event loop.
|
||||
"""
|
||||
return await asyncio.to_thread(self._process_query_sync, question)
|
||||
|
||||
def _process_query_sync(self, question: str) -> PaginatedResponse[CompanyData]:
|
||||
"""Synchronous implementation of process_query. This is run in a thread by
|
||||
the async wrapper above.
|
||||
"""
|
||||
cache_key = self._get_cache_key(question)
|
||||
|
||||
# Check cache first
|
||||
if cache_key in self.query_cache:
|
||||
sql_query = self.query_cache[cache_key]
|
||||
logger.info(f"Using cached SQL: {sql_query}")
|
||||
else:
|
||||
# Generate SQL query
|
||||
messages = self.sql_prompt.format_messages(question=question)
|
||||
response = self.llm.invoke(messages)
|
||||
sql_query = response.content.strip()
|
||||
|
||||
# Clean up SQL (remove markdown code blocks if present)
|
||||
sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
|
||||
|
||||
# Cache the query
|
||||
self.query_cache[cache_key] = sql_query
|
||||
logger.info(f"Generated SQL: {sql_query}")
|
||||
|
||||
# Execute query to get company IDs
|
||||
db_session = next(get_db())
|
||||
try:
|
||||
result = db_session.execute(text(sql_query))
|
||||
company_ids = [row[0] for row in result.fetchall()]
|
||||
logger.info(
|
||||
f"Found {len(company_ids)} company IDs: {company_ids[:10]}{'...' if len(company_ids) > 10 else ''}"
|
||||
)
|
||||
|
||||
return self._fetch_companies_by_ids(company_ids)
|
||||
except Exception as e:
|
||||
logger.error(f"SQL execution error: {e}")
|
||||
logger.error(f"Failed SQL: {sql_query}")
|
||||
# Return empty result
|
||||
return PaginatedResponse(
|
||||
items=[], total=0, page=1, page_size=10, total_pages=0
|
||||
)
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
def _fetch_companies_by_ids(
|
||||
self, company_ids: List[int]
|
||||
) -> PaginatedResponse[CompanyData]:
|
||||
"""Fetch companies with all their relationships from the database using company IDs.
|
||||
|
||||
Args:
|
||||
company_ids: List of company IDs to fetch
|
||||
"""
|
||||
if not company_ids:
|
||||
return PaginatedResponse(
|
||||
items=[],
|
||||
total=0,
|
||||
page=1,
|
||||
page_size=10,
|
||||
total_pages=0,
|
||||
)
|
||||
|
||||
# Get database session
|
||||
db_session = next(get_db())
|
||||
|
||||
try:
|
||||
# Query companies with all necessary relationships loaded
|
||||
companies = (
|
||||
db_session.query(CompanyTable)
|
||||
.options(
|
||||
selectinload(CompanyTable.investors),
|
||||
selectinload(CompanyTable.members),
|
||||
selectinload(CompanyTable.sectors),
|
||||
)
|
||||
.filter(CompanyTable.id.in_(company_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
# Transform to CompanyData format
|
||||
company_data_list = []
|
||||
for company in companies:
|
||||
company_data = CompanyData(
|
||||
company=company,
|
||||
investors=company.investors,
|
||||
members=company.members,
|
||||
sectors=company.sectors,
|
||||
)
|
||||
company_data_list.append(company_data)
|
||||
|
||||
total_count = len(company_data_list)
|
||||
total_pages = 1 if total_count > 0 else 0
|
||||
|
||||
return PaginatedResponse(
|
||||
items=company_data_list,
|
||||
total=total_count,
|
||||
page=1,
|
||||
page_size=total_count,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
finally:
|
||||
db_session.close()
|
||||
@@ -1,785 +0,0 @@
|
||||
"""
|
||||
Compatibility Score Service
|
||||
|
||||
This module calculates compatibility scores between projects and investors.
|
||||
The scoring system evaluates multiple dimensions to determine how well a project
|
||||
matches with an investor's investment criteria.
|
||||
"""
|
||||
|
||||
from difflib import SequenceMatcher
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from db.models import FundTable, InvestorTable, ProjectTable
|
||||
|
||||
|
||||
def calculate_project_investor_compatibility(
|
||||
project: ProjectTable, investor: InvestorTable, use_funds: bool = True
|
||||
) -> float:
|
||||
"""
|
||||
Calculate compatibility score between a project and an investor.
|
||||
|
||||
Args:
|
||||
project: The project to evaluate
|
||||
investor: The investor to compare against
|
||||
use_funds: If True, evaluates against investor's funds. If False, uses investor-level data.
|
||||
|
||||
Returns:
|
||||
A score between 0 and 1, where 1 is perfect match
|
||||
|
||||
Scoring breakdown (out of 100 points):
|
||||
- Investment Stage Match: 30 points
|
||||
- Sector Overlap: 30 points
|
||||
- Geographic Match: 20 points
|
||||
- Valuation/Check Size Fit: 20 points
|
||||
"""
|
||||
if use_funds and investor.funds:
|
||||
# Calculate score for each fund and return the highest
|
||||
max_score = 0.0
|
||||
for fund in investor.funds:
|
||||
fund_score = _calculate_project_fund_compatibility(project, fund)
|
||||
max_score = max(max_score, fund_score)
|
||||
return max_score
|
||||
else:
|
||||
# Use investor-level data (fallback)
|
||||
return _calculate_project_investor_direct_compatibility(project, investor)
|
||||
|
||||
|
||||
def calculate_project_investors_compatibility(
|
||||
project: ProjectTable, investors: List[InvestorTable], use_funds: bool = True
|
||||
) -> List[Tuple[InvestorTable, float]]:
|
||||
"""
|
||||
Calculate compatibility scores between a project and multiple investors.
|
||||
|
||||
Args:
|
||||
project: The project to evaluate
|
||||
investors: List of investors to compare against
|
||||
use_funds: If True, evaluates against investors' funds. If False, uses investor-level data.
|
||||
|
||||
Returns:
|
||||
List of tuples (investor, score) sorted by score descending
|
||||
"""
|
||||
scored_investors = []
|
||||
|
||||
for investor in investors:
|
||||
score = calculate_project_investor_compatibility(project, investor, use_funds)
|
||||
scored_investors.append((investor, score))
|
||||
|
||||
# Sort by score descending
|
||||
scored_investors.sort(key=lambda x: x[1], reverse=True)
|
||||
|
||||
return scored_investors
|
||||
|
||||
|
||||
def _calculate_project_fund_compatibility(
|
||||
project: ProjectTable, fund: FundTable
|
||||
) -> float:
|
||||
"""
|
||||
Calculate compatibility score between a project and a specific fund.
|
||||
|
||||
Scoring breakdown:
|
||||
- Investment Stage Match: 30 points (all or nothing if stage exists)
|
||||
- Sector Overlap: 30 points (proportional to overlap)
|
||||
- Geographic Match: 20 points (exact=20, partial=10, none=0)
|
||||
- Valuation/Check Size Fit: 20 points (proportional to fit)
|
||||
|
||||
Returns:
|
||||
A score between 0 and 1
|
||||
"""
|
||||
total_score = 0
|
||||
max_score = 100
|
||||
|
||||
# 1. Investment Stage Match (30 points)
|
||||
stage_score = 0
|
||||
if project.stage and fund.investment_stages:
|
||||
# Check if project stage matches any of the fund's investment stages
|
||||
fund_stage_names = {stage.name for stage in fund.investment_stages}
|
||||
# Convert project.stage enum to string for comparison
|
||||
project_stage_name = (
|
||||
project.stage.value
|
||||
if hasattr(project.stage, "value")
|
||||
else str(project.stage)
|
||||
)
|
||||
|
||||
# Normalize both for case-insensitive comparison
|
||||
project_stage_normalized = project_stage_name.upper().strip()
|
||||
fund_stages_normalized = {name.upper().strip() for name in fund_stage_names}
|
||||
|
||||
if project_stage_normalized in fund_stages_normalized:
|
||||
stage_score = 30
|
||||
else:
|
||||
# Partial credit for adjacent stages
|
||||
stage_score = _calculate_stage_proximity(
|
||||
project_stage_normalized, fund_stages_normalized
|
||||
)
|
||||
|
||||
total_score += stage_score
|
||||
|
||||
# 2. Sector Overlap (30 points)
|
||||
sector_score = 0
|
||||
if project.sector and fund.sectors:
|
||||
project_sectors = [s for s in project.sector if hasattr(s, "name")]
|
||||
fund_sectors = [s for s in fund.sectors if hasattr(s, "name")]
|
||||
|
||||
if project_sectors and fund_sectors:
|
||||
# Use fuzzy matching to account for similar but not identical sector names
|
||||
match_count = 0
|
||||
total_matches = 0
|
||||
|
||||
for proj_sector in project_sectors:
|
||||
best_match_score = 0
|
||||
proj_name = proj_sector.name.lower().strip()
|
||||
|
||||
for fund_sector in fund_sectors:
|
||||
fund_name = fund_sector.name.lower().strip()
|
||||
|
||||
# Exact match
|
||||
if proj_name == fund_name:
|
||||
best_match_score = 1.0
|
||||
break
|
||||
|
||||
# Fuzzy match using sequence matcher
|
||||
similarity = SequenceMatcher(None, proj_name, fund_name).ratio()
|
||||
|
||||
# Also check if one contains the other (substring match)
|
||||
if proj_name in fund_name or fund_name in proj_name:
|
||||
similarity = max(similarity, 0.8)
|
||||
|
||||
best_match_score = max(best_match_score, similarity)
|
||||
|
||||
# Count matches with threshold
|
||||
# Perfect match (1.0), strong match (>0.75), partial match (>0.6)
|
||||
if best_match_score >= 0.6:
|
||||
total_matches += best_match_score
|
||||
match_count += 1
|
||||
|
||||
if match_count > 0:
|
||||
# Calculate overlap ratio based on fuzzy matches
|
||||
overlap_ratio = total_matches / len(project_sectors)
|
||||
sector_score = int(30 * overlap_ratio)
|
||||
|
||||
total_score += sector_score
|
||||
|
||||
# 3. Geographic Match (20 points)
|
||||
geo_score = 0
|
||||
if project.location and fund.geographic_focus:
|
||||
project_location_lower = project.location.lower().strip()
|
||||
fund_geo_lower = (fund.geographic_focus or "").lower().strip()
|
||||
|
||||
# Exact match
|
||||
if project_location_lower == fund_geo_lower:
|
||||
geo_score = 20
|
||||
# Partial match (one contains the other)
|
||||
elif (
|
||||
project_location_lower in fund_geo_lower
|
||||
or fund_geo_lower in project_location_lower
|
||||
):
|
||||
geo_score = 15
|
||||
# Check for common geographic terms or regional overlap (continent/country matching)
|
||||
elif _check_geographic_overlap(project_location_lower, fund_geo_lower):
|
||||
# Give higher score for continent/country matches (e.g., Germany -> Europe)
|
||||
geo_score = 18
|
||||
|
||||
total_score += geo_score
|
||||
|
||||
# 4. Valuation/Check Size Fit (20 points)
|
||||
valuation_score = 0
|
||||
if project.valuation and fund.check_size_lower and fund.check_size_upper:
|
||||
# Check if project valuation falls within or near the check size range
|
||||
# Typically, check size is a fraction of valuation (e.g., 10-20%)
|
||||
# We'll assume check size represents potential investment amount
|
||||
|
||||
if fund.check_size_lower <= project.valuation <= fund.check_size_upper:
|
||||
# Valuation is within the check size range (might be too small)
|
||||
valuation_score = 10
|
||||
else:
|
||||
# Check if the check size is reasonable for this valuation
|
||||
# Typical investment is 10-30% of valuation
|
||||
reasonable_valuation_min = fund.check_size_lower * 3 # Investing ~33%
|
||||
reasonable_valuation_max = fund.check_size_upper * 10 # Investing ~10%
|
||||
|
||||
if (
|
||||
reasonable_valuation_min
|
||||
<= project.valuation
|
||||
<= reasonable_valuation_max
|
||||
):
|
||||
# Perfect fit
|
||||
valuation_score = 20
|
||||
elif project.valuation < reasonable_valuation_min:
|
||||
# Project might be too small
|
||||
ratio = (
|
||||
project.valuation / reasonable_valuation_min
|
||||
if reasonable_valuation_min > 0
|
||||
else 0
|
||||
)
|
||||
valuation_score = int(10 * ratio)
|
||||
else:
|
||||
# Project might be too large
|
||||
ratio = (
|
||||
reasonable_valuation_max / project.valuation
|
||||
if project.valuation > 0
|
||||
else 0
|
||||
)
|
||||
valuation_score = int(10 * ratio)
|
||||
|
||||
total_score += valuation_score
|
||||
|
||||
# Convert to 0-1 scale
|
||||
return total_score / max_score
|
||||
|
||||
|
||||
def _calculate_project_investor_direct_compatibility(
|
||||
project: ProjectTable, investor: InvestorTable
|
||||
) -> float:
|
||||
"""
|
||||
Calculate compatibility using investor-level data (fallback when no funds available).
|
||||
|
||||
Uses the same scoring system but with investor-level attributes.
|
||||
"""
|
||||
total_score = 0
|
||||
max_score = 100
|
||||
|
||||
# 1. Investment Stage - Skip this since investors don't have a direct stage field
|
||||
# We could add 30 points to other categories, but for consistency, we'll leave it as 0
|
||||
stage_score = 0
|
||||
total_score += stage_score
|
||||
|
||||
# 2. Sector Overlap (30 points)
|
||||
sector_score = 0
|
||||
if project.sector and investor.sectors:
|
||||
project_sectors = [s for s in project.sector if hasattr(s, "name")]
|
||||
investor_sectors = [s for s in investor.sectors if hasattr(s, "name")]
|
||||
|
||||
if project_sectors and investor_sectors:
|
||||
# Use fuzzy matching to account for similar but not identical sector names
|
||||
match_count = 0
|
||||
total_matches = 0
|
||||
|
||||
for proj_sector in project_sectors:
|
||||
best_match_score = 0
|
||||
proj_name = proj_sector.name.lower().strip()
|
||||
|
||||
for inv_sector in investor_sectors:
|
||||
inv_name = inv_sector.name.lower().strip()
|
||||
|
||||
# Exact match
|
||||
if proj_name == inv_name:
|
||||
best_match_score = 1.0
|
||||
break
|
||||
|
||||
# Fuzzy match using sequence matcher
|
||||
similarity = SequenceMatcher(None, proj_name, inv_name).ratio()
|
||||
|
||||
# Also check if one contains the other (substring match)
|
||||
if proj_name in inv_name or inv_name in proj_name:
|
||||
similarity = max(similarity, 0.8)
|
||||
|
||||
best_match_score = max(best_match_score, similarity)
|
||||
|
||||
# Count matches with threshold
|
||||
if best_match_score >= 0.6:
|
||||
total_matches += best_match_score
|
||||
match_count += 1
|
||||
|
||||
if match_count > 0:
|
||||
# Calculate overlap ratio based on fuzzy matches
|
||||
overlap_ratio = total_matches / len(project_sectors)
|
||||
sector_score = int(30 * overlap_ratio)
|
||||
|
||||
total_score += sector_score
|
||||
|
||||
# 3. Geographic Match (20 points)
|
||||
geo_score = 0
|
||||
if project.location and investor.geographic_focus:
|
||||
project_location_lower = project.location.lower()
|
||||
investor_geo_lower = (investor.geographic_focus or "").lower()
|
||||
|
||||
if project_location_lower == investor_geo_lower:
|
||||
geo_score = 20
|
||||
elif (
|
||||
project_location_lower in investor_geo_lower
|
||||
or investor_geo_lower in project_location_lower
|
||||
):
|
||||
geo_score = 15
|
||||
elif _check_geographic_overlap(project_location_lower, investor_geo_lower):
|
||||
# Give higher score for continent/country matches (e.g., Germany -> Europe)
|
||||
geo_score = 18
|
||||
|
||||
total_score += geo_score
|
||||
|
||||
# 4. Valuation/Check Size Fit (20 points)
|
||||
valuation_score = 0
|
||||
if project.valuation and investor.check_size_lower and investor.check_size_upper:
|
||||
reasonable_valuation_min = investor.check_size_lower * 3
|
||||
reasonable_valuation_max = investor.check_size_upper * 10
|
||||
|
||||
if reasonable_valuation_min <= project.valuation <= reasonable_valuation_max:
|
||||
valuation_score = 20
|
||||
elif project.valuation < reasonable_valuation_min:
|
||||
ratio = (
|
||||
project.valuation / reasonable_valuation_min
|
||||
if reasonable_valuation_min > 0
|
||||
else 0
|
||||
)
|
||||
valuation_score = int(10 * ratio)
|
||||
else:
|
||||
ratio = (
|
||||
reasonable_valuation_max / project.valuation
|
||||
if project.valuation > 0
|
||||
else 0
|
||||
)
|
||||
valuation_score = int(10 * ratio)
|
||||
|
||||
total_score += valuation_score
|
||||
|
||||
# Convert to 0-1 scale
|
||||
return total_score / max_score
|
||||
|
||||
|
||||
def _calculate_stage_proximity(project_stage: str, fund_stages: set) -> int:
|
||||
"""
|
||||
Calculate proximity score between project stage and fund stages.
|
||||
Awards partial credit for adjacent investment stages.
|
||||
|
||||
Stage progression: SEED -> SERIES_A -> SERIES_B -> SERIES_C -> GROWTH -> LATE_STAGE
|
||||
|
||||
Returns:
|
||||
Score from 0-15 (half credit for adjacent stages)
|
||||
"""
|
||||
stage_order = ["SEED", "SERIES_A", "SERIES_B", "SERIES_C", "GROWTH", "LATE_STAGE"]
|
||||
|
||||
# Normalize project stage for comparison
|
||||
project_stage_normalized = project_stage.upper().strip()
|
||||
|
||||
try:
|
||||
project_idx = stage_order.index(project_stage_normalized)
|
||||
except ValueError:
|
||||
return 0
|
||||
|
||||
# Check for adjacent stages
|
||||
adjacent_stages = []
|
||||
if project_idx > 0:
|
||||
adjacent_stages.append(stage_order[project_idx - 1])
|
||||
if project_idx < len(stage_order) - 1:
|
||||
adjacent_stages.append(stage_order[project_idx + 1])
|
||||
|
||||
# Normalize fund stages and check for matches
|
||||
for stage in fund_stages:
|
||||
stage_normalized = stage.upper().strip()
|
||||
if stage_normalized in adjacent_stages:
|
||||
return 15 # Half credit for adjacent stage
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def _check_geographic_overlap(location1: str, location2: str) -> bool:
|
||||
"""
|
||||
Check for common geographic terms between two locations.
|
||||
|
||||
Examples:
|
||||
- "San Francisco, CA" and "California" -> True
|
||||
- "New York" and "USA" -> True (if both contain USA/US)
|
||||
- "London, UK" and "United Kingdom" -> True
|
||||
- "Germany" and "Europe" -> True
|
||||
"""
|
||||
# Normalize inputs
|
||||
loc1 = location1.lower().strip()
|
||||
loc2 = location2.lower().strip()
|
||||
|
||||
# Common geographic groupings with broader regional mappings
|
||||
geo_groups = [
|
||||
# North America
|
||||
["usa", "us", "united states", "america", "u.s.", "u.s.a"],
|
||||
["canada", "canadian"],
|
||||
["mexico", "mexican"],
|
||||
# Europe and countries
|
||||
[
|
||||
"europe",
|
||||
"european",
|
||||
"eu",
|
||||
"germany",
|
||||
"france",
|
||||
"uk",
|
||||
"united kingdom",
|
||||
"britain",
|
||||
"spain",
|
||||
"italy",
|
||||
"netherlands",
|
||||
"belgium",
|
||||
"sweden",
|
||||
"denmark",
|
||||
"norway",
|
||||
"finland",
|
||||
"poland",
|
||||
"portugal",
|
||||
"austria",
|
||||
"switzerland",
|
||||
"ireland",
|
||||
"greece",
|
||||
"czech",
|
||||
"romania",
|
||||
],
|
||||
# UK specific
|
||||
["uk", "united kingdom", "britain", "england", "scotland", "wales", "london"],
|
||||
# US states
|
||||
["california", "ca", "san francisco", "los angeles", "silicon valley"],
|
||||
["new york", "ny", "nyc"],
|
||||
["texas", "tx"],
|
||||
["massachusetts", "ma", "boston"],
|
||||
["washington", "seattle"],
|
||||
# Asia
|
||||
[
|
||||
"asia",
|
||||
"asian",
|
||||
"china",
|
||||
"japan",
|
||||
"korea",
|
||||
"singapore",
|
||||
"hong kong",
|
||||
"india",
|
||||
"indonesia",
|
||||
"thailand",
|
||||
"vietnam",
|
||||
"malaysia",
|
||||
"philippines",
|
||||
],
|
||||
# Middle East
|
||||
["middle east", "israel", "uae", "dubai", "saudi arabia"],
|
||||
# Latin America
|
||||
["latin america", "brazil", "argentina", "chile", "colombia", "mexico"],
|
||||
# Africa
|
||||
["africa", "african", "south africa", "nigeria", "kenya", "egypt"],
|
||||
# Oceania
|
||||
["australia", "australian", "new zealand"],
|
||||
]
|
||||
|
||||
# Check if both locations match any group
|
||||
for group in geo_groups:
|
||||
found_in_1 = any(term in loc1 for term in group)
|
||||
found_in_2 = any(term in loc2 for term in group)
|
||||
if found_in_1 and found_in_2:
|
||||
return True
|
||||
|
||||
# Check for direct substring match (one contains the other)
|
||||
if loc1 in loc2 or loc2 in loc1:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_top_compatible_investors(
|
||||
project: ProjectTable,
|
||||
investors: List[InvestorTable],
|
||||
limit: int = 10,
|
||||
min_score: float = 0.0,
|
||||
use_funds: bool = True,
|
||||
) -> List[Tuple[InvestorTable, float]]:
|
||||
"""
|
||||
Get the top N most compatible investors for a project.
|
||||
|
||||
Args:
|
||||
project: The project to find investors for
|
||||
investors: List of all available investors
|
||||
limit: Maximum number of investors to return
|
||||
min_score: Minimum compatibility score threshold (0-1)
|
||||
use_funds: If True, evaluates against investors' funds
|
||||
|
||||
Returns:
|
||||
List of tuples (investor, score) sorted by score descending,
|
||||
limited to 'limit' items and filtered by min_score
|
||||
"""
|
||||
scored_investors = calculate_project_investors_compatibility(
|
||||
project, investors, use_funds
|
||||
)
|
||||
|
||||
# Filter by minimum score
|
||||
filtered_investors = [
|
||||
(investor, score) for investor, score in scored_investors if score >= min_score
|
||||
]
|
||||
|
||||
# Return top N
|
||||
return filtered_investors[:limit]
|
||||
|
||||
|
||||
def get_compatibility_score_breakdown(
|
||||
project: ProjectTable, investor: InvestorTable, fund: Optional[FundTable] = None
|
||||
) -> dict:
|
||||
"""
|
||||
Get a detailed breakdown of the compatibility score components.
|
||||
|
||||
Useful for debugging or showing users why a particular score was calculated.
|
||||
|
||||
Returns:
|
||||
Dictionary with score components and explanations
|
||||
"""
|
||||
if fund:
|
||||
total_score = 0
|
||||
|
||||
# Stage score
|
||||
stage_score = 0
|
||||
stage_match = False
|
||||
if project.stage and fund.investment_stages:
|
||||
fund_stage_names = {stage.name for stage in fund.investment_stages}
|
||||
project_stage_name = (
|
||||
project.stage.value
|
||||
if hasattr(project.stage, "value")
|
||||
else str(project.stage)
|
||||
)
|
||||
if project_stage_name in fund_stage_names:
|
||||
stage_score = 30
|
||||
stage_match = True
|
||||
else:
|
||||
stage_score = _calculate_stage_proximity(
|
||||
project_stage_name, fund_stage_names
|
||||
)
|
||||
|
||||
# Sector score
|
||||
sector_score = 0
|
||||
matching_sectors = []
|
||||
if project.sector and fund.sectors:
|
||||
project_sector_ids = {sector.id for sector in project.sector}
|
||||
fund_sector_ids = {sector.id for sector in fund.sectors}
|
||||
if project_sector_ids and fund_sector_ids:
|
||||
common_sectors = project_sector_ids.intersection(fund_sector_ids)
|
||||
matching_sectors = [
|
||||
s.name for s in fund.sectors if s.id in common_sectors
|
||||
]
|
||||
overlap_ratio = len(common_sectors) / len(project_sector_ids)
|
||||
sector_score = int(30 * overlap_ratio)
|
||||
|
||||
# Geographic score
|
||||
geo_score = 0
|
||||
geo_match_type = "none"
|
||||
if project.location and fund.geographic_focus:
|
||||
project_location_lower = project.location.lower()
|
||||
fund_geo_lower = fund.geographic_focus.lower()
|
||||
if project_location_lower == fund_geo_lower:
|
||||
geo_score = 20
|
||||
geo_match_type = "exact"
|
||||
elif (
|
||||
project_location_lower in fund_geo_lower
|
||||
or fund_geo_lower in project_location_lower
|
||||
):
|
||||
geo_score = 10
|
||||
geo_match_type = "partial"
|
||||
elif _check_geographic_overlap(project_location_lower, fund_geo_lower):
|
||||
geo_score = 5
|
||||
geo_match_type = "regional"
|
||||
|
||||
# Valuation score
|
||||
valuation_score = 0
|
||||
valuation_fit = "unknown"
|
||||
if project.valuation and fund.check_size_lower and fund.check_size_upper:
|
||||
reasonable_valuation_min = fund.check_size_lower * 3
|
||||
reasonable_valuation_max = fund.check_size_upper * 10
|
||||
if (
|
||||
reasonable_valuation_min
|
||||
<= project.valuation
|
||||
<= reasonable_valuation_max
|
||||
):
|
||||
valuation_score = 20
|
||||
valuation_fit = "perfect"
|
||||
elif project.valuation < reasonable_valuation_min:
|
||||
ratio = (
|
||||
project.valuation / reasonable_valuation_min
|
||||
if reasonable_valuation_min > 0
|
||||
else 0
|
||||
)
|
||||
valuation_score = int(10 * ratio)
|
||||
valuation_fit = "too_small"
|
||||
else:
|
||||
ratio = (
|
||||
reasonable_valuation_max / project.valuation
|
||||
if project.valuation > 0
|
||||
else 0
|
||||
)
|
||||
valuation_score = int(10 * ratio)
|
||||
valuation_fit = "too_large"
|
||||
|
||||
total_score = stage_score + sector_score + geo_score + valuation_score
|
||||
|
||||
return {
|
||||
"total_score": total_score / 100,
|
||||
"breakdown": {
|
||||
"stage": {
|
||||
"score": stage_score,
|
||||
"max_score": 30,
|
||||
"match": stage_match,
|
||||
"project_stage": project.stage.value if project.stage else None,
|
||||
"fund_stages": [s.name for s in fund.investment_stages]
|
||||
if fund.investment_stages
|
||||
else [],
|
||||
},
|
||||
"sector": {
|
||||
"score": sector_score,
|
||||
"max_score": 30,
|
||||
"matching_sectors": matching_sectors,
|
||||
"project_sectors": [s.name for s in project.sector]
|
||||
if project.sector
|
||||
else [],
|
||||
"fund_sectors": [s.name for s in fund.sectors]
|
||||
if fund.sectors
|
||||
else [],
|
||||
},
|
||||
"geography": {
|
||||
"score": geo_score,
|
||||
"max_score": 20,
|
||||
"match_type": geo_match_type,
|
||||
"project_location": project.location,
|
||||
"fund_geography": fund.geographic_focus,
|
||||
},
|
||||
"valuation": {
|
||||
"score": valuation_score,
|
||||
"max_score": 20,
|
||||
"fit": valuation_fit,
|
||||
"project_valuation": project.valuation,
|
||||
"fund_check_size_range": f"{fund.check_size_lower}-{fund.check_size_upper}"
|
||||
if fund.check_size_lower
|
||||
else None,
|
||||
},
|
||||
},
|
||||
}
|
||||
else:
|
||||
# Investor-level breakdown (simplified)
|
||||
return {
|
||||
"total_score": _calculate_project_investor_direct_compatibility(
|
||||
project, investor
|
||||
),
|
||||
"note": "Using investor-level data (no specific fund selected)",
|
||||
}
|
||||
|
||||
|
||||
def generate_compatibility_explanation(
|
||||
project: ProjectTable, investor: InvestorTable, score: float, use_funds: bool = True
|
||||
) -> str:
|
||||
"""
|
||||
Generate a detailed, natural language explanation of the compatibility score.
|
||||
|
||||
Args:
|
||||
project: The project being evaluated
|
||||
investor: The investor being compared against
|
||||
score: The calculated compatibility score (0-1)
|
||||
use_funds: Whether fund-level data was used
|
||||
|
||||
Returns:
|
||||
A formatted string with the compatibility score and detailed explanation
|
||||
"""
|
||||
score_percentage = int(score * 100)
|
||||
|
||||
# Determine match quality
|
||||
if score_percentage >= 80:
|
||||
match_level = "Excellent match"
|
||||
elif score_percentage >= 65:
|
||||
match_level = "Strong match"
|
||||
elif score_percentage >= 50:
|
||||
match_level = "Good match"
|
||||
elif score_percentage >= 35:
|
||||
match_level = "Moderate match"
|
||||
else:
|
||||
match_level = "Limited match"
|
||||
|
||||
# Collect alignment factors
|
||||
alignment_factors = []
|
||||
recommendations = []
|
||||
|
||||
# Get the best matching fund if using funds
|
||||
best_fund = None
|
||||
if use_funds and investor.funds:
|
||||
best_score = 0
|
||||
for fund in investor.funds:
|
||||
fund_score = _calculate_project_fund_compatibility(project, fund)
|
||||
if fund_score > best_score:
|
||||
best_score = fund_score
|
||||
best_fund = fund
|
||||
|
||||
# Analyze sector alignment
|
||||
if project.sector:
|
||||
project_sectors = [s.name for s in project.sector if hasattr(s, "name")]
|
||||
|
||||
if best_fund and best_fund.sectors:
|
||||
fund_sectors = {s.name for s in best_fund.sectors if hasattr(s, "name")}
|
||||
common_sectors = set(project_sectors) & fund_sectors
|
||||
|
||||
if common_sectors:
|
||||
sectors_str = ", ".join(list(common_sectors)[:2])
|
||||
alignment_factors.append(f"{sectors_str} sector focus")
|
||||
elif project_sectors:
|
||||
recommendations.append(
|
||||
f"Consider emphasizing any {project_sectors[0]} industry connections"
|
||||
)
|
||||
elif investor.sectors:
|
||||
investor_sectors = {s.name for s in investor.sectors if hasattr(s, "name")}
|
||||
common_sectors = set(project_sectors) & investor_sectors
|
||||
|
||||
if common_sectors:
|
||||
sectors_str = ", ".join(list(common_sectors)[:2])
|
||||
alignment_factors.append(f"{sectors_str} sector focus")
|
||||
|
||||
# Analyze stage alignment
|
||||
if project.stage:
|
||||
stage_name = (
|
||||
project.stage.value
|
||||
if hasattr(project.stage, "value")
|
||||
else str(project.stage)
|
||||
)
|
||||
stage_display = stage_name.replace("_", " ").title()
|
||||
|
||||
if best_fund and best_fund.investment_stages:
|
||||
fund_stage_names = {
|
||||
s.name for s in best_fund.investment_stages if hasattr(s, "name")
|
||||
}
|
||||
if stage_name in fund_stage_names:
|
||||
alignment_factors.append(f"{stage_display} stage")
|
||||
else:
|
||||
recommendations.append(
|
||||
"Investor typically focuses on different stages; highlight your traction and growth metrics"
|
||||
)
|
||||
|
||||
if not best_fund:
|
||||
alignment_factors.append(f"{stage_display} stage")
|
||||
|
||||
# Analyze geographic alignment
|
||||
if project.location:
|
||||
if best_fund and best_fund.geographic_focus:
|
||||
if (
|
||||
project.location.lower() in best_fund.geographic_focus.lower()
|
||||
or best_fund.geographic_focus.lower() in project.location.lower()
|
||||
):
|
||||
alignment_factors.append(f"{project.location} presence")
|
||||
elif investor.headquarters:
|
||||
if (
|
||||
project.location.lower() in investor.headquarters.lower()
|
||||
or investor.headquarters.lower() in project.location.lower()
|
||||
):
|
||||
alignment_factors.append(f"{project.location} market presence")
|
||||
|
||||
# Analyze valuation/check size fit
|
||||
if project.valuation:
|
||||
if best_fund and best_fund.check_size_lower and best_fund.check_size_upper:
|
||||
reasonable_min = best_fund.check_size_lower * 3
|
||||
reasonable_max = best_fund.check_size_upper * 10
|
||||
|
||||
if reasonable_min <= project.valuation <= reasonable_max:
|
||||
alignment_factors.append("appropriate funding stage")
|
||||
elif project.valuation < reasonable_min:
|
||||
recommendations.append(
|
||||
"You may be early for this investor; consider approaching at a later stage"
|
||||
)
|
||||
else:
|
||||
recommendations.append(
|
||||
"Consider highlighting your growth trajectory and market opportunity"
|
||||
)
|
||||
|
||||
# Build the explanation
|
||||
explanation_parts = [f"Based on your startup profile: {score_percentage}% match"]
|
||||
|
||||
if alignment_factors:
|
||||
alignment_text = ", ".join(alignment_factors)
|
||||
explanation_parts.append(f"{match_level}: {alignment_text}.")
|
||||
else:
|
||||
explanation_parts.append(f"{match_level}.")
|
||||
|
||||
if recommendations:
|
||||
rec_text = recommendations[0] # Show the most important recommendation
|
||||
explanation_parts.append(rec_text + ".")
|
||||
|
||||
return " ".join(explanation_parts)
|
||||
@@ -1,205 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
|
||||
import requests
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
handlers=[logging.StreamHandler()],
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FolkAPI:
|
||||
BASE_URL = "https://api.folk.app/v1"
|
||||
|
||||
def __init__(self, api_key: str):
|
||||
api_key = os.environ.get("FOLK_API_KEY", api_key)
|
||||
self.headers = {"Authorization": f"Bearer {api_key}"}
|
||||
logger.info(f"FolkAPI initialized with API key: {api_key[:4]}***")
|
||||
|
||||
def get_groups(self):
|
||||
"""Fetch all groups from Folk."""
|
||||
url = f"{self.BASE_URL}/groups"
|
||||
response = requests.get(url, headers=self.headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def create_company(
|
||||
self,
|
||||
name: str,
|
||||
group_id: str = None,
|
||||
website: str = None,
|
||||
linkedin_url: str = None,
|
||||
description: str = None,
|
||||
emails=None,
|
||||
phones=None,
|
||||
addresses=None,
|
||||
urls=None,
|
||||
custom_field_values=None,
|
||||
groups=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a company (investor) in a specific group.
|
||||
|
||||
This method builds a payload matching Folk's Create Company API:
|
||||
https://developer.folk.app/api-reference/companies/create-a-company
|
||||
|
||||
It keeps backward compatibility with the previous `group_id`,
|
||||
`website` and `linkedin_url` arguments.
|
||||
"""
|
||||
url = f"{self.BASE_URL}/companies"
|
||||
|
||||
# Build the top-level payload expected by Folk
|
||||
data = {"name": name}
|
||||
if description:
|
||||
data["description"] = description
|
||||
|
||||
# Groups: prefer explicit `groups`, else fall back to `group_id`
|
||||
if groups:
|
||||
# Accept either list of ids or list of dicts
|
||||
formatted = []
|
||||
for g in groups:
|
||||
if isinstance(g, dict) and g.get("id"):
|
||||
formatted.append({"id": g["id"]})
|
||||
else:
|
||||
formatted.append({"id": str(g)})
|
||||
data["groups"] = formatted
|
||||
elif group_id:
|
||||
data["groups"] = [{"id": group_id}]
|
||||
|
||||
# Helper to normalize single or multiple inputs into lists
|
||||
def _to_list(val):
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, (list, tuple)):
|
||||
return [v for v in val if v is not None]
|
||||
return [val]
|
||||
|
||||
# URLs: include website and linkedin_url if provided and merge with urls
|
||||
urls_list = _to_list(urls) or []
|
||||
if website:
|
||||
urls_list.append(website)
|
||||
if linkedin_url:
|
||||
urls_list.append(linkedin_url)
|
||||
if urls_list:
|
||||
data["urls"] = urls_list
|
||||
|
||||
# Emails/phones/addresses
|
||||
emails_list = _to_list(emails)
|
||||
if emails_list:
|
||||
data["emails"] = emails_list
|
||||
phones_list = _to_list(phones)
|
||||
if phones_list:
|
||||
data["phones"] = phones_list
|
||||
addresses_list = _to_list(addresses)
|
||||
if addresses_list:
|
||||
data["addresses"] = addresses_list
|
||||
|
||||
# Custom field values follow the API's structure
|
||||
if custom_field_values:
|
||||
data["customFieldValues"] = custom_field_values
|
||||
|
||||
# Allow passing any additional top-level fields via kwargs (careful)
|
||||
for k, v in kwargs.items():
|
||||
# don't overwrite keys we explicitly set
|
||||
if k not in data:
|
||||
data[k] = v
|
||||
|
||||
response = requests.post(url, headers=self.headers, json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
def create_person(
|
||||
self,
|
||||
first_name: str,
|
||||
last_name: str,
|
||||
email: str = None,
|
||||
company_id: str = None,
|
||||
group_id: str = None,
|
||||
linkedin_url: str = None,
|
||||
companies=None,
|
||||
emails=None,
|
||||
phones=None,
|
||||
addresses=None,
|
||||
urls=None,
|
||||
custom_field_values=None,
|
||||
groups=None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Create a person in the workspace.
|
||||
|
||||
Builds payload matching Folk's Create Person API: use camelCase
|
||||
keys (firstName, lastName, groups, companies, emails, etc.).
|
||||
Keeps backward compatibility with `company_id` and `group_id`.
|
||||
"""
|
||||
url = f"{self.BASE_URL}/people"
|
||||
|
||||
data = {"firstName": first_name, "lastName": last_name}
|
||||
|
||||
# Groups: explicit `groups` preferred, else fallback to `group_id`
|
||||
if groups:
|
||||
formatted = []
|
||||
for g in groups:
|
||||
if isinstance(g, dict) and g.get("id"):
|
||||
formatted.append({"id": g["id"]})
|
||||
else:
|
||||
formatted.append({"id": str(g)})
|
||||
data["groups"] = formatted
|
||||
elif group_id:
|
||||
data["groups"] = [{"id": group_id}]
|
||||
|
||||
# Companies: keep backward compatibility with company_id
|
||||
if companies:
|
||||
formatted = []
|
||||
for c in companies:
|
||||
if isinstance(c, dict):
|
||||
formatted.append(c)
|
||||
elif isinstance(c, str):
|
||||
# treat as id
|
||||
formatted.append({"id": c})
|
||||
if formatted:
|
||||
data["companies"] = formatted
|
||||
elif company_id:
|
||||
data["companies"] = [{"id": company_id}]
|
||||
|
||||
# Helper to normalize into lists
|
||||
def _to_list(val):
|
||||
if val is None:
|
||||
return None
|
||||
if isinstance(val, (list, tuple)):
|
||||
return [v for v in val if v is not None]
|
||||
return [val]
|
||||
|
||||
emails_list = _to_list(emails) or []
|
||||
if email:
|
||||
emails_list.insert(0, email)
|
||||
if emails_list:
|
||||
data["emails"] = emails_list
|
||||
|
||||
phones_list = _to_list(phones)
|
||||
if phones_list:
|
||||
data["phones"] = phones_list
|
||||
addresses_list = _to_list(addresses)
|
||||
if addresses_list:
|
||||
data["addresses"] = addresses_list
|
||||
urls_list = _to_list(urls) or []
|
||||
if linkedin_url:
|
||||
urls_list.append(linkedin_url)
|
||||
if urls_list:
|
||||
data["urls"] = urls_list
|
||||
|
||||
if custom_field_values:
|
||||
data["customFieldValues"] = custom_field_values
|
||||
|
||||
# Allow passthrough of other top-level fields in kwargs
|
||||
for k, v in kwargs.items():
|
||||
if k not in data:
|
||||
data[k] = v
|
||||
|
||||
response = requests.post(url, headers=self.headers, json=data)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@@ -1,181 +0,0 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
|
||||
from crawl4ai import AsyncWebCrawler
|
||||
from ddgs import DDGS
|
||||
from dotenv import load_dotenv
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from schemas.insight_schema import InsightResponse
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
|
||||
)
|
||||
logger = logging.getLogger("web_search_agent")
|
||||
|
||||
load_dotenv()
|
||||
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
|
||||
|
||||
if not OPENROUTER_API_KEY:
|
||||
logger.warning("OPENROUTER_API_KEY not set. LLM calls will fail if invoked.")
|
||||
|
||||
|
||||
class QueryProcessor:
|
||||
def __init__(self):
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=OPENROUTER_API_KEY,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="openai/gpt-4o-mini",
|
||||
temperature=0,
|
||||
)
|
||||
self.agent = create_react_agent(
|
||||
model=self.llm,
|
||||
tools=[self.web_search],
|
||||
response_format=InsightResponse,
|
||||
)
|
||||
|
||||
self.ddg_search = DDGS()
|
||||
|
||||
async def crawl(self, url: str):
|
||||
"""Tool to search the web using a web crawler. given the url"""
|
||||
|
||||
logger.info(f"\nCrawl tool called with url: {url}")
|
||||
async with AsyncWebCrawler() as crawler:
|
||||
results = await crawler.arun(url)
|
||||
return results.markdown
|
||||
|
||||
def web_search(self, query: str):
|
||||
"""Tool to search the web using google, provide the relevant query to get the information"""
|
||||
logger.info(f"\nWeb Search Tool Called with query: {query}")
|
||||
if query:
|
||||
result = self.ddg_search.text(query, max_results=10)
|
||||
return result
|
||||
return "No query provided."
|
||||
|
||||
async def get_investor_insights(
|
||||
self,
|
||||
investor_name: str,
|
||||
investor_website: str = None,
|
||||
investor_description: str = None,
|
||||
investor_headquarters: str = None,
|
||||
investment_thesis: list = None,
|
||||
portfolio_highlights: list = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Get investment pattern analysis and market position for an investor.
|
||||
|
||||
Args:
|
||||
investor_name: Name of the investor/VC firm
|
||||
investor_website: Website URL of the investor
|
||||
investor_description: Description of the investor
|
||||
investor_headquarters: Headquarters location
|
||||
investment_thesis: List of investment thesis statements
|
||||
portfolio_highlights: List of notable portfolio companies
|
||||
|
||||
Returns:
|
||||
Dictionary with investment_pattern_analysis and market_position
|
||||
"""
|
||||
logger.info(f"Getting insights for investor: {investor_name}")
|
||||
|
||||
# Build context information
|
||||
context_parts = [f'Investment Firm: "{investor_name}"']
|
||||
|
||||
if investor_website:
|
||||
context_parts.append(f"Website: {investor_website}")
|
||||
if investor_headquarters:
|
||||
context_parts.append(f"Location: {investor_headquarters}")
|
||||
if investor_description:
|
||||
context_parts.append(f"Description: {investor_description}")
|
||||
if investment_thesis and isinstance(investment_thesis, list):
|
||||
thesis_str = ", ".join(
|
||||
str(item) for item in investment_thesis[:3]
|
||||
) # Limit to first 3
|
||||
context_parts.append(f"Investment Focus: {thesis_str}")
|
||||
if portfolio_highlights and isinstance(portfolio_highlights, list):
|
||||
portfolio_str = ", ".join(
|
||||
str(item) for item in portfolio_highlights[:5]
|
||||
) # Limit to first 5
|
||||
context_parts.append(f"Notable Portfolio Companies: {portfolio_str}")
|
||||
|
||||
context = "\n".join(context_parts)
|
||||
|
||||
prompt = f"""
|
||||
Research and analyze the following investment firm:
|
||||
|
||||
{context}
|
||||
|
||||
CRITICAL INSTRUCTIONS:
|
||||
- You MUST provide concrete, data-driven insights with specific numbers and percentages
|
||||
- Use the web_search tool to find recent news, press releases, and investment databases (Crunchbase, PitchBook, etc.)
|
||||
- If you cannot find sufficient data after searching, make reasonable inferences based on available information
|
||||
- DO NOT state that data is unavailable or ambiguous - provide the best analysis possible with what you find
|
||||
- Focus on ACTIONABLE insights, not disclaimers
|
||||
- Only call the tool twice at most, be strategic in your searches
|
||||
- Summarize your findings concisely and clearly
|
||||
|
||||
Provide insights in the InsightResponse schema format:
|
||||
|
||||
1. investment_pattern_analysis (MAX 3 SENTENCES):
|
||||
- Recent investment activity and trends in the last 12-18 months
|
||||
- Investment size ranges, deal frequency, and sector preferences
|
||||
- Notable patterns (e.g., "increased AI investments by 40%", "average check size $5-10M")
|
||||
- If specific numbers aren't available, provide reasonable estimates based on portfolio and market position
|
||||
|
||||
2. market_position (MAX 3 SENTENCES):
|
||||
- Standing in the venture capital market
|
||||
- Activity level in specific sectors and notable unicorn investments
|
||||
- Deal leadership roles (lead vs co-lead) and market influence
|
||||
- Regional or global market presence and competitive positioning
|
||||
|
||||
Use the web_search tool strategically. Search for:
|
||||
- "{investor_name}" recent investments 2024 2025
|
||||
- "{investor_name}" portfolio Crunchbase
|
||||
- "{investor_name}" funding rounds news
|
||||
- Specific portfolio companies if mentioned above
|
||||
"""
|
||||
|
||||
try:
|
||||
result = await self.agent.ainvoke({"messages": [("user", prompt)]})
|
||||
# The agent with response_format=InsightResponse returns structured output
|
||||
logger.info(f"Raw agent result keys: {result.keys()}")
|
||||
|
||||
# Check if structured_response exists and is an InsightResponse object
|
||||
if "structured_response" in result:
|
||||
structured = result["structured_response"]
|
||||
logger.info(f"Structured response type: {type(structured)}")
|
||||
|
||||
# If it's already an InsightResponse object, convert to dict
|
||||
if isinstance(structured, InsightResponse):
|
||||
return structured.model_dump()
|
||||
# If it's already a dict, return it
|
||||
elif isinstance(structured, dict):
|
||||
return structured
|
||||
|
||||
# Fallback: shouldn't reach here, but handle it gracefully
|
||||
logger.warning("No structured_response found in result, using fallback")
|
||||
return {
|
||||
"investment_pattern_analysis": "Unable to retrieve investment pattern analysis at this time.",
|
||||
"market_position": "Unable to retrieve market position at this time.",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting insights for {investor_name}: {e}")
|
||||
logger.exception("Full exception details:")
|
||||
return {
|
||||
"investment_pattern_analysis": "Unable to retrieve investment pattern analysis at this time.",
|
||||
"market_position": "Unable to retrieve market position at this time.",
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
qp = QueryProcessor()
|
||||
result = await qp.agent.ainvoke(
|
||||
{"messages": [("user", "Can you tell me about 3T Finance investment company")]}
|
||||
)
|
||||
final_message = result["messages"][-1].content
|
||||
print(final_message)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
+83
-751
@@ -1,7 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
@@ -9,35 +7,15 @@ from db.db import get_db_session
|
||||
from db.models import (
|
||||
CompanyMember,
|
||||
CompanyTable,
|
||||
FundTable,
|
||||
InvestmentStageTable,
|
||||
InvestorMember,
|
||||
InvestorTable,
|
||||
SectorTable,
|
||||
)
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import BaseModel
|
||||
from schemas.py_schemas import CompanyData, InvestorData
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class CurrencyConversion(BaseModel):
|
||||
"""Schema for LLM currency conversion responses"""
|
||||
|
||||
amount_usd: int = 0
|
||||
confidence: str = "high" # high, medium, low
|
||||
notes: str = ""
|
||||
|
||||
|
||||
class CheckSizeRange(BaseModel):
|
||||
"""Schema for LLM check size range parsing from estimated investment size"""
|
||||
|
||||
lower_bound_usd: int = 0
|
||||
upper_bound_usd: int = 0
|
||||
confidence: str = "high" # high, medium, low
|
||||
notes: str = ""
|
||||
|
||||
|
||||
class InvestorProcessor:
|
||||
def __init__(self):
|
||||
self.llm = ChatOpenAI(
|
||||
@@ -47,534 +25,9 @@ class InvestorProcessor:
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
# Structured LLMs for specific parsing tasks
|
||||
self.currency_converter_llm = self.llm.with_structured_output(
|
||||
CurrencyConversion
|
||||
)
|
||||
self.check_size_parser_llm = self.llm.with_structured_output(CheckSizeRange)
|
||||
|
||||
# Keep legacy structured LLMs for backward compatibility
|
||||
self.investor_structured_llm = self.llm.with_structured_output(InvestorData)
|
||||
self.company_structured_llm = self.llm.with_structured_output(CompanyData)
|
||||
|
||||
async def convert_to_usd(self, amount_str: str) -> Optional[int]:
|
||||
"""
|
||||
Use LLM to convert currency amounts to USD integers.
|
||||
Handles formats like:
|
||||
- "EUR 850,000,000"
|
||||
- "$5M"
|
||||
- "GBP 10-20 million"
|
||||
- "Approximately EUR 100 million"
|
||||
"""
|
||||
if not amount_str or amount_str == "Not Available" or amount_str == "0":
|
||||
return None
|
||||
|
||||
try:
|
||||
prompt = f"""Convert this amount to USD as an integer (whole number, no decimals).
|
||||
If it's a range, use the midpoint. If already in USD, just extract the number.
|
||||
Remove all commas and convert millions/billions to actual numbers.
|
||||
|
||||
Amount: {amount_str}
|
||||
|
||||
Examples:
|
||||
- "EUR 850,000,000" -> 935000000 (assuming EUR to USD rate ~1.10)
|
||||
- "$5M" -> 5000000
|
||||
- "GBP 10-20 million" -> 18000000 (midpoint 15M * 1.20 rate)
|
||||
- "Approximately EUR 100 million" -> 110000000
|
||||
|
||||
Return only the USD integer amount with current exchange rates."""
|
||||
|
||||
result = await self.currency_converter_llm.ainvoke(prompt)
|
||||
return result.amount_usd if result.amount_usd > 0 else None
|
||||
except Exception as e:
|
||||
print(f"Error converting currency '{amount_str}': {e}")
|
||||
return None
|
||||
|
||||
async def parse_check_size_range(
|
||||
self, estimated_investment_str: str
|
||||
) -> tuple[Optional[int], Optional[int]]:
|
||||
"""
|
||||
Use LLM to parse check size range from estimated investment size string.
|
||||
Returns tuple of (lower_bound_usd, upper_bound_usd).
|
||||
|
||||
Handles formats like:
|
||||
- "EUR 1,000 to 2,000"
|
||||
- "$100K-$500K"
|
||||
- "Between $1M and $5M"
|
||||
- "Up to EUR 10 million"
|
||||
- "$2M typical"
|
||||
"""
|
||||
if (
|
||||
not estimated_investment_str
|
||||
or estimated_investment_str == "Not Available"
|
||||
or estimated_investment_str == "0"
|
||||
):
|
||||
return None, None
|
||||
|
||||
try:
|
||||
prompt = f"""Parse this check size/investment range into lower and upper bounds in USD as integers.
|
||||
|
||||
Input: {estimated_investment_str}
|
||||
|
||||
Instructions:
|
||||
- If it's a range (e.g., "EUR 1M to 5M"), extract both bounds
|
||||
- If it's a single amount (e.g., "$2M typical"), use it as both lower and upper
|
||||
- If it says "up to X", use 0 as lower and X as upper
|
||||
- Convert all currencies to USD using current exchange rates
|
||||
- Return integers (whole numbers, no decimals)
|
||||
|
||||
Examples:
|
||||
- "EUR 1,000 to 2,000" -> lower: 1100, upper: 2200
|
||||
- "$100K-$500K" -> lower: 100000, upper: 500000
|
||||
- "Between $1M and $5M" -> lower: 1000000, upper: 5000000
|
||||
- "Up to EUR 10 million" -> lower: 0, upper: 11000000
|
||||
- "$2M typical" -> lower: 2000000, upper: 2000000
|
||||
- "GBP 500K-2M" -> lower: 600000, upper: 2400000
|
||||
|
||||
Return the lower and upper bounds in USD."""
|
||||
|
||||
result = await self.check_size_parser_llm.ainvoke(prompt)
|
||||
lower = result.lower_bound_usd if result.lower_bound_usd > 0 else None
|
||||
upper = result.upper_bound_usd if result.upper_bound_usd > 0 else None
|
||||
return lower, upper
|
||||
except Exception as e:
|
||||
print(f"Error parsing check size range '{estimated_investment_str}': {e}")
|
||||
return None, None
|
||||
|
||||
def parse_json_profile(self, json_str: str) -> Optional[dict]:
|
||||
"""
|
||||
Manually parse the JSON profile from the CSV.
|
||||
Returns a cleaned dictionary with the investor profile data.
|
||||
Handles JSON wrapped in markdown code blocks (```json ... ```).
|
||||
Handles trailing quotes and extra data after JSON.
|
||||
"""
|
||||
if not json_str or pd.isna(json_str):
|
||||
return None
|
||||
|
||||
try:
|
||||
# Clean the JSON string
|
||||
cleaned_json = json_str.strip()
|
||||
|
||||
# Check if it's plain text (no JSON structure)
|
||||
if not cleaned_json.startswith(("{", "```", "'")):
|
||||
print(" ⚠️ No JSON structure found - skipping")
|
||||
return None
|
||||
|
||||
# Remove markdown code block markers if present
|
||||
if cleaned_json.startswith("```"):
|
||||
# Remove opening marker (```json or ```Json or ```)
|
||||
lines = cleaned_json.split("\n")
|
||||
if lines[0].startswith("```"):
|
||||
lines = lines[1:] # Remove first line
|
||||
# Remove closing marker (``` or ```')
|
||||
if lines and lines[-1].strip() in ("```", "```'", '```"'):
|
||||
lines = lines[:-1] # Remove last line
|
||||
cleaned_json = "\n".join(lines).strip()
|
||||
|
||||
# Remove trailing quotes that might be left over
|
||||
if cleaned_json.endswith(("'", '"')):
|
||||
cleaned_json = cleaned_json[:-1].strip()
|
||||
|
||||
# Try to find JSON boundaries if there's extra data
|
||||
# Look for the first { and the last }
|
||||
start_idx = cleaned_json.find("{")
|
||||
if start_idx == -1:
|
||||
print(" ⚠️ No opening brace found - not valid JSON")
|
||||
return None
|
||||
|
||||
# Find the matching closing brace
|
||||
# We need to count braces to find the actual end
|
||||
brace_count = 0
|
||||
end_idx = -1
|
||||
for i in range(start_idx, len(cleaned_json)):
|
||||
if cleaned_json[i] == "{":
|
||||
brace_count += 1
|
||||
elif cleaned_json[i] == "}":
|
||||
brace_count -= 1
|
||||
if brace_count == 0:
|
||||
end_idx = i + 1
|
||||
break
|
||||
|
||||
if end_idx == -1:
|
||||
print(" ⚠️ No matching closing brace found")
|
||||
return None
|
||||
|
||||
# Extract just the JSON part
|
||||
cleaned_json = cleaned_json[start_idx:end_idx]
|
||||
|
||||
# Parse JSON string
|
||||
profile = json.loads(cleaned_json)
|
||||
return profile
|
||||
except json.JSONDecodeError as e:
|
||||
print(f" ❌ JSON parsing error: {e}")
|
||||
# Print first 200 chars for debugging
|
||||
preview = json_str[:200] if len(json_str) > 200 else json_str
|
||||
print(f" Preview: {preview}...")
|
||||
return None
|
||||
except Exception as e:
|
||||
print(f" ❌ Unexpected error: {e}")
|
||||
return None
|
||||
|
||||
async def process_investor_profile(
|
||||
self, name: str, website: str, profile_json: str
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Process investor profile from CSV data.
|
||||
Manually extracts fields and uses LLM only for currency conversion.
|
||||
"""
|
||||
profile = self.parse_json_profile(profile_json)
|
||||
if not profile:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Extract basic info
|
||||
investor_data = {
|
||||
"name": name.strip() if name else None,
|
||||
"website": website.strip() if website else None,
|
||||
"headquarters": profile.get("headquarters"),
|
||||
"description": profile.get("investorDescription"),
|
||||
"aum": None,
|
||||
"aum_as_of_date": None,
|
||||
"aum_source_url": None,
|
||||
"investment_thesis": profile.get("investmentThesisFocus", []),
|
||||
"portfolio_highlights": profile.get("portfolioHighlights", []),
|
||||
"linked_documents": profile.get("linkedDocuments", []),
|
||||
"researcher_notes": profile.get("researcherNotes"),
|
||||
"missing_important_fields": profile.get("missingImportantFields", []),
|
||||
"sources": profile.get("sources", {}),
|
||||
"team_members": [],
|
||||
"funds": [],
|
||||
}
|
||||
|
||||
# Process AUM
|
||||
aum_data = profile.get("overallAssetsUnderManagement", {})
|
||||
if aum_data and isinstance(aum_data, dict):
|
||||
aum_amount = aum_data.get("aumAmount")
|
||||
if aum_amount and aum_amount != "Not Available":
|
||||
# Convert AUM to USD integer
|
||||
aum_usd = await self.convert_to_usd(aum_amount)
|
||||
investor_data["aum"] = aum_usd
|
||||
investor_data["aum_as_of_date"] = aum_data.get("asOfDate")
|
||||
investor_data["aum_source_url"] = aum_data.get("sourceUrl")
|
||||
|
||||
# Process senior leadership
|
||||
senior_leadership = profile.get("seniorLeadership", [])
|
||||
for member in senior_leadership:
|
||||
if isinstance(member, dict) and member.get("name"):
|
||||
investor_data["team_members"].append(
|
||||
{
|
||||
"name": member.get("name"),
|
||||
"title": member.get("title"),
|
||||
"role": member.get("title"), # Use title as role
|
||||
"email": None,
|
||||
"source_url": member.get("sourceUrl"),
|
||||
}
|
||||
)
|
||||
|
||||
# Process funds
|
||||
funds = profile.get("funds", [])
|
||||
for fund in funds:
|
||||
if isinstance(fund, dict):
|
||||
fund_data = {
|
||||
"fund_name": fund.get("fundName"),
|
||||
"fund_size": None,
|
||||
"fund_size_source_url": fund.get("fundSizeSourceUrl"),
|
||||
"check_size_lower": None,
|
||||
"check_size_upper": None,
|
||||
"source_url": fund.get("sourceUrl"),
|
||||
"source_provider": fund.get("sourceProvider"),
|
||||
"geographic_focus": None, # Will be converted to string
|
||||
"investment_stage_names": fund.get("investmentStageFocus", []),
|
||||
"sector_names": fund.get("sectorFocus", []),
|
||||
}
|
||||
|
||||
# Convert geographic focus from array to comma-separated string
|
||||
geo_focus = fund.get("geographicFocus", [])
|
||||
if geo_focus and isinstance(geo_focus, list):
|
||||
fund_data["geographic_focus"] = ", ".join(geo_focus)
|
||||
|
||||
# Convert fund size to USD integer
|
||||
fund_size_str = fund.get("fundSize")
|
||||
if fund_size_str and fund_size_str != "Not Available":
|
||||
fund_size_usd = await self.convert_to_usd(fund_size_str)
|
||||
if fund_size_usd:
|
||||
fund_data["fund_size"] = fund_size_usd # Store as integer
|
||||
|
||||
# Parse check size range from estimated investment size
|
||||
est_size_str = fund.get("estimatedInvestmentSize")
|
||||
if est_size_str and est_size_str != "Not Available":
|
||||
check_lower, check_upper = await self.parse_check_size_range(
|
||||
est_size_str
|
||||
)
|
||||
if check_lower is not None:
|
||||
fund_data["check_size_lower"] = check_lower
|
||||
if check_upper is not None:
|
||||
fund_data["check_size_upper"] = check_upper
|
||||
|
||||
investor_data["funds"].append(fund_data)
|
||||
|
||||
return investor_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing investor profile for {name}: {e}")
|
||||
return None
|
||||
|
||||
async def process_company_profile(
|
||||
self, name: str, website: str, profile_json: str, investor_names: str = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Process company profile from CSV data.
|
||||
Only extracts founded_year and key_executives - rest is in base database.
|
||||
"""
|
||||
profile = self.parse_json_profile(profile_json)
|
||||
if not profile:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Only extract founded_year and key_executives
|
||||
company_data = {
|
||||
"name": name.strip() if name else None,
|
||||
"founded_year": None,
|
||||
"key_executives": [],
|
||||
}
|
||||
|
||||
# Process key executives/leadership
|
||||
key_executives = profile.get("keyExecutives", [])
|
||||
if not key_executives:
|
||||
# Try alternative field names
|
||||
key_executives = profile.get("seniorLeadership", [])
|
||||
|
||||
for exec_member in key_executives:
|
||||
if isinstance(exec_member, dict) and exec_member.get("name"):
|
||||
company_data["key_executives"].append(
|
||||
{
|
||||
"name": exec_member.get("name"),
|
||||
"title": exec_member.get("title"),
|
||||
"source_url": exec_member.get("sourceUrl"),
|
||||
}
|
||||
)
|
||||
|
||||
# Try to extract founding year from description
|
||||
description = profile.get("companyDescription", "")
|
||||
if description:
|
||||
# Look for patterns like "founded in 2020", "Gegründet 2020", "founded 2020"
|
||||
year_patterns = [
|
||||
r"founded in (\d{4})",
|
||||
r"founded (\d{4})",
|
||||
r"Gegründet (\d{4})",
|
||||
r"established in (\d{4})",
|
||||
r"since (\d{4})",
|
||||
r"\((\d{4})\)", # Year in parentheses
|
||||
]
|
||||
for pattern in year_patterns:
|
||||
match = re.search(pattern, description, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
year = int(match.group(1))
|
||||
if 1900 <= year <= 2025: # Sanity check
|
||||
company_data["founded_year"] = year
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return company_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing company profile for {name}: {e}")
|
||||
return None
|
||||
|
||||
def _save_parsed_company_to_db(
|
||||
self, db: Session, company_data: dict
|
||||
) -> Optional[CompanyTable]:
|
||||
"""Save manually parsed company data to database - only updates founded_year and key_executives"""
|
||||
try:
|
||||
# Check if company already exists (should exist in base database)
|
||||
existing_company = (
|
||||
db.query(CompanyTable).filter_by(name=company_data["name"]).first()
|
||||
)
|
||||
|
||||
if existing_company:
|
||||
# Update only founded_year on existing company
|
||||
company = existing_company
|
||||
updated_fields = []
|
||||
|
||||
if company_data.get("founded_year"):
|
||||
company.founded_year = company_data["founded_year"]
|
||||
updated_fields.append(
|
||||
f"founded_year: {company_data['founded_year']}"
|
||||
)
|
||||
|
||||
# Add/update company members (key executives)
|
||||
# First, remove existing members if updating
|
||||
db.query(CompanyMember).filter_by(company_id=company.id).delete()
|
||||
|
||||
exec_count = 0
|
||||
for exec_data in company_data.get("key_executives", []):
|
||||
member = CompanyMember(
|
||||
name=exec_data.get("name"),
|
||||
role=exec_data.get("title"),
|
||||
linkedin=exec_data.get(
|
||||
"source_url"
|
||||
), # Store source URL in linkedin field
|
||||
company_id=company.id,
|
||||
)
|
||||
db.add(member)
|
||||
exec_count += 1
|
||||
|
||||
if exec_count > 0:
|
||||
updated_fields.append(f"{exec_count} executives")
|
||||
|
||||
if updated_fields:
|
||||
print(f" 📝 Updated: {', '.join(updated_fields)}")
|
||||
|
||||
return company
|
||||
else:
|
||||
# Company not found in base database, skip
|
||||
print(" ⚠️ Not in database - skipping")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f" ❌ Error saving: {e}")
|
||||
db.rollback()
|
||||
return None
|
||||
|
||||
def _save_parsed_investor_to_db(
|
||||
self, db: Session, investor_data: dict
|
||||
) -> Optional[InvestorTable]:
|
||||
"""Save manually parsed investor data to database"""
|
||||
try:
|
||||
# Check if investor already exists
|
||||
existing_investor = (
|
||||
db.query(InvestorTable).filter_by(name=investor_data["name"]).first()
|
||||
)
|
||||
|
||||
if existing_investor:
|
||||
# Update existing investor
|
||||
investor = existing_investor
|
||||
investor.website = investor_data.get("website") or investor.website
|
||||
investor.headquarters = (
|
||||
investor_data.get("headquarters") or investor.headquarters
|
||||
)
|
||||
investor.description = (
|
||||
investor_data.get("description") or investor.description
|
||||
)
|
||||
investor.aum = investor_data.get("aum") or investor.aum
|
||||
investor.aum_as_of_date = (
|
||||
investor_data.get("aum_as_of_date") or investor.aum_as_of_date
|
||||
)
|
||||
investor.aum_source_url = (
|
||||
investor_data.get("aum_source_url") or investor.aum_source_url
|
||||
)
|
||||
investor.investment_thesis = (
|
||||
investor_data.get("investment_thesis") or investor.investment_thesis
|
||||
)
|
||||
investor.portfolio_highlights = (
|
||||
investor_data.get("portfolio_highlights")
|
||||
or investor.portfolio_highlights
|
||||
)
|
||||
investor.linked_documents = (
|
||||
investor_data.get("linked_documents") or investor.linked_documents
|
||||
)
|
||||
investor.researcher_notes = (
|
||||
investor_data.get("researcher_notes") or investor.researcher_notes
|
||||
)
|
||||
investor.missing_important_fields = (
|
||||
investor_data.get("missing_important_fields")
|
||||
or investor.missing_important_fields
|
||||
)
|
||||
investor.sources = investor_data.get("sources") or investor.sources
|
||||
else:
|
||||
# Create new investor
|
||||
investor = InvestorTable(
|
||||
name=investor_data["name"],
|
||||
website=investor_data.get("website"),
|
||||
headquarters=investor_data.get("headquarters"),
|
||||
description=investor_data.get("description"),
|
||||
aum=investor_data.get("aum"),
|
||||
aum_as_of_date=investor_data.get("aum_as_of_date"),
|
||||
aum_source_url=investor_data.get("aum_source_url"),
|
||||
investment_thesis=investor_data.get("investment_thesis"),
|
||||
portfolio_highlights=investor_data.get("portfolio_highlights"),
|
||||
linked_documents=investor_data.get("linked_documents"),
|
||||
researcher_notes=investor_data.get("researcher_notes"),
|
||||
missing_important_fields=investor_data.get(
|
||||
"missing_important_fields"
|
||||
),
|
||||
sources=investor_data.get("sources"),
|
||||
)
|
||||
db.add(investor)
|
||||
db.flush()
|
||||
|
||||
# Add/update team members
|
||||
# First, remove existing team members if updating
|
||||
if existing_investor:
|
||||
db.query(InvestorMember).filter_by(investor_id=investor.id).delete()
|
||||
|
||||
for member_data in investor_data.get("team_members", []):
|
||||
member = InvestorMember(
|
||||
name=member_data.get("name"),
|
||||
role=member_data.get("role"),
|
||||
title=member_data.get("title"),
|
||||
email=member_data.get("email"),
|
||||
source_url=member_data.get("source_url"),
|
||||
investor_id=investor.id,
|
||||
)
|
||||
db.add(member)
|
||||
|
||||
# Add/update funds
|
||||
# First, remove existing funds if updating
|
||||
if existing_investor:
|
||||
db.query(FundTable).filter_by(investor_id=investor.id).delete()
|
||||
|
||||
for fund_data in investor_data.get("funds", []):
|
||||
fund = FundTable(
|
||||
investor_id=investor.id,
|
||||
fund_name=fund_data.get("fund_name"),
|
||||
fund_size=fund_data.get("fund_size"), # Now an integer
|
||||
fund_size_source_url=fund_data.get("fund_size_source_url"),
|
||||
check_size_lower=fund_data.get("check_size_lower"),
|
||||
check_size_upper=fund_data.get("check_size_upper"),
|
||||
source_url=fund_data.get("source_url"),
|
||||
source_provider=fund_data.get("source_provider"),
|
||||
geographic_focus=fund_data.get("geographic_focus"), # Now a string
|
||||
)
|
||||
db.add(fund)
|
||||
db.flush() # Get the fund ID
|
||||
|
||||
# Add investment stages (many-to-many)
|
||||
for stage_name in fund_data.get("investment_stage_names", []):
|
||||
stage = self._get_or_create_investment_stage(db, stage_name)
|
||||
fund.investment_stages.append(stage)
|
||||
|
||||
# Add sectors (many-to-many)
|
||||
for sector_name in fund_data.get("sector_names", []):
|
||||
sector = self._get_or_create_sector(db, sector_name)
|
||||
fund.sectors.append(sector)
|
||||
|
||||
return investor
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error saving investor to database: {e}")
|
||||
db.rollback()
|
||||
return None
|
||||
|
||||
def _get_or_create_investment_stage(
|
||||
self, db: Session, stage_name: str
|
||||
) -> InvestmentStageTable:
|
||||
"""Get existing investment stage or create new one"""
|
||||
from db.models import InvestmentStageTable
|
||||
|
||||
stage = (
|
||||
db.query(InvestmentStageTable)
|
||||
.filter(InvestmentStageTable.name == stage_name)
|
||||
.first()
|
||||
)
|
||||
if not stage:
|
||||
stage = InvestmentStageTable(name=stage_name)
|
||||
db.add(stage)
|
||||
db.flush() # Get the ID without committing
|
||||
return stage
|
||||
|
||||
def _get_or_create_sector(self, db: Session, sector_name: str) -> SectorTable:
|
||||
"""Get existing sector or create new one"""
|
||||
sector = db.query(SectorTable).filter(SectorTable.name == sector_name).first()
|
||||
@@ -596,6 +49,7 @@ Return the lower and upper bounds in USD."""
|
||||
check_size_lower=investor_data.investor.check_size_lower,
|
||||
check_size_upper=investor_data.investor.check_size_upper,
|
||||
geographic_focus=investor_data.investor.geographic_focus,
|
||||
stage_focus=investor_data.investor.stage_focus,
|
||||
number_of_investments=investor_data.investor.number_of_investments,
|
||||
)
|
||||
db.add(investor)
|
||||
@@ -719,263 +173,141 @@ Return the lower and upper bounds in USD."""
|
||||
print(f"Error processing row {row_idx + 1}: {e}")
|
||||
return None
|
||||
|
||||
async def _process_single_investor(
|
||||
self, idx: int, row: pd.Series, total_rows: int
|
||||
) -> Optional[dict]:
|
||||
"""Process a single investor row"""
|
||||
try:
|
||||
name = row.get("Name", "").strip() if pd.notna(row.get("Name")) else None
|
||||
website = (
|
||||
row.get("Website", "").strip() if pd.notna(row.get("Website")) else None
|
||||
)
|
||||
profile_json = (
|
||||
row.get("Final Investor Profile", "")
|
||||
if pd.notna(row.get("Final Investor Profile"))
|
||||
else None
|
||||
)
|
||||
|
||||
if not name or not profile_json:
|
||||
print(f"⚠️ Row {idx + 1}: Skipping - missing name or profile")
|
||||
return None
|
||||
|
||||
print(f"📊 Processing {idx + 1}/{total_rows}: {name}")
|
||||
|
||||
# Process the investor profile
|
||||
investor_data = await self.process_investor_profile(
|
||||
name, website, profile_json
|
||||
)
|
||||
|
||||
if investor_data:
|
||||
print(f" ✓ {name} parsed successfully")
|
||||
return investor_data
|
||||
else:
|
||||
print(f" ⚠️ {name} failed to process")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing row {idx + 1}: {e}")
|
||||
return None
|
||||
|
||||
async def parse_investors(
|
||||
self, df: pd.DataFrame, save_to_db: bool = True, batch_size: int = 10
|
||||
):
|
||||
"""
|
||||
Parse investors from DataFrame using manual JSON parsing and LLM for currency conversion.
|
||||
Processes multiple investors concurrently for better performance.
|
||||
Expected CSV columns: Name, Website, Final Investor Profile, Final Profile sourcing
|
||||
|
||||
Args:
|
||||
df: DataFrame with investor data
|
||||
save_to_db: Whether to save to database
|
||||
batch_size: Number of investors to process concurrently (default: 10)
|
||||
"""
|
||||
results = []
|
||||
async def parse_investors(self, df, save_to_db: bool = True):
|
||||
"""Parse investors from DataFrame and optionally save to database"""
|
||||
investors = []
|
||||
df = df[20:]
|
||||
db = None
|
||||
if save_to_db:
|
||||
db = get_db_session()
|
||||
|
||||
try:
|
||||
total_rows = len(df)
|
||||
print(
|
||||
f"\n🚀 Starting to process {total_rows} investors with batch size {batch_size}..."
|
||||
)
|
||||
# Process rows in batches asynchronously
|
||||
batch_size = 20 # Adjust batch size as needed
|
||||
rows = [(idx, row) for idx, row in df.iterrows()]
|
||||
|
||||
# Process in batches
|
||||
for batch_start in range(0, total_rows, batch_size):
|
||||
batch_end = min(batch_start + batch_size, total_rows)
|
||||
print(
|
||||
f"\n🔄 Processing batch {batch_start + 1}-{batch_end} of {total_rows}..."
|
||||
)
|
||||
for i in range(0, len(rows), batch_size):
|
||||
batch = rows[i : i + batch_size]
|
||||
|
||||
# Create tasks for concurrent processing
|
||||
tasks = []
|
||||
for idx in range(batch_start, batch_end):
|
||||
row = df.iloc[idx]
|
||||
task = self._process_single_investor(idx, row, total_rows)
|
||||
tasks.append(task)
|
||||
# Process batch asynchronously
|
||||
tasks = [
|
||||
self._process_row(row, idx, is_investor=True) for idx, row in batch
|
||||
]
|
||||
|
||||
# Process batch concurrently
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Filter out None results and exceptions, then save to database
|
||||
for investor_data in batch_results:
|
||||
if investor_data and not isinstance(investor_data, Exception):
|
||||
results.append(investor_data)
|
||||
# Handle results from batch
|
||||
for (idx, row), result in zip(batch, batch_results):
|
||||
if isinstance(result, Exception):
|
||||
print(f"Error processing row {idx}: {result}")
|
||||
if db:
|
||||
db.rollback()
|
||||
continue
|
||||
|
||||
# Save to database
|
||||
if result:
|
||||
# Convert dict to InvestorData if needed
|
||||
if isinstance(result, dict):
|
||||
investor_data = InvestorData(**result)
|
||||
else:
|
||||
investor_data = result
|
||||
|
||||
investors.append(investor_data)
|
||||
|
||||
# Save to database if requested
|
||||
if save_to_db and db:
|
||||
try:
|
||||
saved_investor = self._save_parsed_investor_to_db(
|
||||
saved_investor = self._save_investor_to_db(
|
||||
db, investor_data
|
||||
)
|
||||
if saved_investor:
|
||||
print(
|
||||
f" ✅ Saved {investor_data['name']} to database (ID: {saved_investor.id})"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" ❌ Failed to save {investor_data['name']} to database"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(
|
||||
f" ❌ Database error for {investor_data['name']}: {e}"
|
||||
)
|
||||
elif isinstance(investor_data, Exception):
|
||||
print(f" ❌ Exception occurred: {investor_data}")
|
||||
|
||||
# Commit batch to database
|
||||
if save_to_db and db:
|
||||
try:
|
||||
db.commit()
|
||||
print(f"💾 Committed batch {batch_start + 1}-{batch_end}")
|
||||
print(
|
||||
f"✅ Saved investor '{saved_investor.name}' to database"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f"❌ Failed to commit batch: {e}")
|
||||
print(f"❌ Failed to save investor to database: {e}")
|
||||
|
||||
print(
|
||||
f"Completed batch {i // batch_size + 1} of {(len(rows) + batch_size - 1) // batch_size}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Fatal error in parse_investors: {e}")
|
||||
print(f"Error in batch processing: {e}")
|
||||
if db:
|
||||
db.rollback()
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
print(f"\n🎉 Completed! Processed {len(results)}/{total_rows} investors")
|
||||
return results
|
||||
return investors
|
||||
|
||||
async def _process_single_company(
|
||||
self, idx: int, row: pd.Series, total_rows: int
|
||||
) -> Optional[dict]:
|
||||
"""Process a single company row"""
|
||||
try:
|
||||
name = row.get("Name", "").strip() if pd.notna(row.get("Name")) else None
|
||||
website = (
|
||||
row.get("Website", "").strip() if pd.notna(row.get("Website")) else None
|
||||
)
|
||||
investor_names = (
|
||||
row.get("Investor", "").strip()
|
||||
if pd.notna(row.get("Investor"))
|
||||
else None
|
||||
)
|
||||
# Try both column names for flexibility
|
||||
profile_json = (
|
||||
row.get("Perplexity Gap Output", "")
|
||||
if pd.notna(row.get("Perplexity Gap Output"))
|
||||
else row.get("Final Investor Profile", "")
|
||||
if pd.notna(row.get("Final Investor Profile"))
|
||||
else None
|
||||
)
|
||||
|
||||
if not name or not profile_json:
|
||||
print(f"⚠️ Row {idx + 1}: Skipping - missing name or profile")
|
||||
return None
|
||||
|
||||
print(f"📊 Processing {idx + 1}/{total_rows}: {name}")
|
||||
|
||||
# Process the company profile
|
||||
company_data = await self.process_company_profile(
|
||||
name, website, profile_json, investor_names
|
||||
)
|
||||
|
||||
if company_data:
|
||||
print(f" ✓ {name} parsed successfully")
|
||||
return company_data
|
||||
else:
|
||||
print(f" ⚠️ {name} failed to process")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing row {idx + 1}: {e}")
|
||||
return None
|
||||
|
||||
async def parse_companies(
|
||||
self, df: pd.DataFrame, save_to_db: bool = True, batch_size: int = 10
|
||||
):
|
||||
"""
|
||||
Parse companies from DataFrame using manual JSON parsing.
|
||||
Processes multiple companies concurrently for better performance.
|
||||
Expected CSV columns: Name, Website, Investor, Final Investor Profile (actually company profile)
|
||||
|
||||
Args:
|
||||
df: DataFrame with company data
|
||||
save_to_db: Whether to save to database
|
||||
batch_size: Number of companies to process concurrently (default: 10)
|
||||
"""
|
||||
results = []
|
||||
async def parse_companies(self, df, save_to_db: bool = True):
|
||||
"""Parse companies from DataFrame and optionally save to database"""
|
||||
companies = []
|
||||
df = df[20:]
|
||||
db = None
|
||||
if save_to_db:
|
||||
db = get_db_session()
|
||||
|
||||
try:
|
||||
total_rows = len(df)
|
||||
print(
|
||||
f"\n🚀 Starting to process {total_rows} companies with batch size {batch_size}..."
|
||||
)
|
||||
# Process rows in batches asynchronously
|
||||
batch_size = 20 # Adjust batch size as needed
|
||||
rows = [(idx, row) for idx, row in df.iterrows()]
|
||||
|
||||
# Process in batches
|
||||
for batch_start in range(0, total_rows, batch_size):
|
||||
batch_end = min(batch_start + batch_size, total_rows)
|
||||
print(
|
||||
f"\n🔄 Processing batch {batch_start + 1}-{batch_end} of {total_rows}..."
|
||||
)
|
||||
for i in range(0, len(rows), batch_size):
|
||||
batch = rows[i : i + batch_size]
|
||||
|
||||
# Create tasks for concurrent processing
|
||||
tasks = []
|
||||
for idx in range(batch_start, batch_end):
|
||||
row = df.iloc[idx]
|
||||
task = self._process_single_company(idx, row, total_rows)
|
||||
tasks.append(task)
|
||||
# Process batch asynchronously
|
||||
tasks = [
|
||||
self._process_row(row, idx, is_investor=False) for idx, row in batch
|
||||
]
|
||||
|
||||
# Process batch concurrently
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Filter out None results and exceptions, then save to database
|
||||
for company_data in batch_results:
|
||||
if company_data and not isinstance(company_data, Exception):
|
||||
results.append(company_data)
|
||||
# Handle results from batch
|
||||
for (idx, row), result in zip(batch, batch_results):
|
||||
if isinstance(result, Exception):
|
||||
print(f"Error processing row {idx}: {result}")
|
||||
if db:
|
||||
db.rollback()
|
||||
continue
|
||||
|
||||
# Save to database
|
||||
if result:
|
||||
# Convert dict to CompanyData if needed
|
||||
if isinstance(result, dict):
|
||||
company_data = CompanyData(**result)
|
||||
else:
|
||||
company_data = result
|
||||
|
||||
companies.append(company_data)
|
||||
|
||||
# Save to database if requested
|
||||
if save_to_db and db:
|
||||
try:
|
||||
saved_company = self._save_parsed_company_to_db(
|
||||
saved_company = self._save_company_to_db(
|
||||
db, company_data
|
||||
)
|
||||
if saved_company:
|
||||
print(
|
||||
f" ✅ Saved {company_data['name']} to database (ID: {saved_company.id})"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" ❌ Failed to save {company_data['name']} to database"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(
|
||||
f" ❌ Database error for {company_data['name']}: {e}"
|
||||
)
|
||||
elif isinstance(company_data, Exception):
|
||||
print(f" ❌ Exception occurred: {company_data}")
|
||||
|
||||
# Commit batch to database
|
||||
if save_to_db and db:
|
||||
try:
|
||||
db.commit()
|
||||
print(f"💾 Committed batch {batch_start + 1}-{batch_end}")
|
||||
print(
|
||||
f"✅ Saved company '{saved_company.name}' to database"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f"❌ Failed to commit batch: {e}")
|
||||
print(f"❌ Failed to save company to database: {e}")
|
||||
|
||||
print(
|
||||
f"Completed batch {i // batch_size + 1} of {(len(rows) + batch_size - 1) // batch_size}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Fatal error in parse_companies: {e}")
|
||||
print(f"Error processing row {idx}: {e}")
|
||||
if db:
|
||||
db.rollback()
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
print(f"\n🎉 Completed! Processed {len(results)}/{total_rows} companies")
|
||||
return results
|
||||
return companies
|
||||
|
||||
|
||||
# async def main():
|
||||
|
||||
+76
-253
@@ -1,25 +1,19 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
from typing import List, Optional
|
||||
from typing import List
|
||||
|
||||
from db.db import get_db
|
||||
from db.models import FundTable, InvestorTable, ProjectTable
|
||||
from langchain_core.prompts import ChatPromptTemplate
|
||||
from db.db import DATABASE_URL, get_db
|
||||
from db.models import InvestorTable
|
||||
from langchain import hub
|
||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
from langchain_openai import ChatOpenAI
|
||||
from schemas.router_schemas import (
|
||||
CompanyMinimal,
|
||||
InvestmentResponse,
|
||||
PaginatedResponse,
|
||||
SectorMinimal,
|
||||
)
|
||||
from sqlalchemy import text
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from schemas.py_schemas import InvestorData, InvestorList
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from services.compatibility_score import calculate_project_investor_compatibility
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
# Connect to SQLite
|
||||
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
|
||||
db = SQLDatabase.from_uri(DATABASE_URL)
|
||||
|
||||
|
||||
class QueryProcessor:
|
||||
@@ -30,266 +24,95 @@ class QueryProcessor:
|
||||
model="openai/gpt-4o-mini",
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
# Query cache for performance
|
||||
self.query_cache = {}
|
||||
|
||||
# SQL generation prompt
|
||||
self.sql_prompt = ChatPromptTemplate.from_messages(
|
||||
[
|
||||
(
|
||||
"system",
|
||||
"""You are a SQL expert. Generate a SQLite query to find fund IDs based on user requirements.
|
||||
|
||||
Database Schema:
|
||||
- funds: id, fund_name, investor_id, check_size_lower, check_size_upper, geographic_focus
|
||||
- fund_sectors: fund_id, sector_id
|
||||
- fund_investment_stages: fund_id, stage_id
|
||||
- sectors: id, name
|
||||
- investment_stages: id, name
|
||||
- investors: id, name, aum
|
||||
|
||||
IMPORTANT RULES:
|
||||
1. ALWAYS return ONLY fund IDs (funds.id) - use SELECT DISTINCT f.id
|
||||
2. For geography: Be FLEXIBLE - use OR with variations and partial matches
|
||||
- 'Europe' → WHERE geographic_focus LIKE '%Europe%' OR geographic_focus LIKE '%European%'
|
||||
- 'America' → WHERE geographic_focus LIKE '%America%' OR geographic_focus LIKE '%US%' OR geographic_focus LIKE '%United States%'
|
||||
- 'Asia' → WHERE geographic_focus LIKE '%Asia%' OR geographic_focus LIKE '%Asian%'
|
||||
- If no geography specified, DON'T filter by geography
|
||||
3. For stages: Use LEFT JOIN and LIKE for flexible matching with synonyms
|
||||
- 'Seed' → s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%'
|
||||
- 'Series A' → s.name LIKE '%Series A%' OR s.name LIKE '%A%'
|
||||
- 'Growth' → s.name LIKE '%Growth%' OR s.name LIKE '%Late%' OR s.name LIKE '%Expansion%'
|
||||
- If stage not specified, include ALL funds
|
||||
4. For sectors: Use LEFT JOIN and include related terms with OR
|
||||
- 'Fintech' → sec.name LIKE '%Fintech%' OR sec.name LIKE '%Finance%' OR sec.name LIKE '%Financial%'
|
||||
- 'AI' → sec.name LIKE '%AI%' OR sec.name LIKE '%Artificial Intelligence%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%'
|
||||
- 'Healthcare' → sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%'
|
||||
5. For check size filters (be flexible with ranges):
|
||||
- "under X" → WHERE (check_size_upper <= X OR check_size_upper IS NULL)
|
||||
- "over X" → WHERE (check_size_lower >= X OR check_size_lower IS NULL)
|
||||
- "between X and Y" → WHERE check_size_lower >= X AND check_size_upper <= Y
|
||||
6. Use LEFT JOIN for stages and sectors so funds without tags still match
|
||||
7. Use DISTINCT to avoid duplicates from joins
|
||||
8. Be INCLUSIVE - use OR conditions to cast a wider net
|
||||
9. If query is very simple (e.g., just "seed stage"), don't add unnecessary filters
|
||||
10. Return a single, complete SELECT query
|
||||
|
||||
Example Queries:
|
||||
Q: "Seed stage investors in Europe"
|
||||
A: SELECT DISTINCT f.id FROM funds f
|
||||
LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id
|
||||
LEFT JOIN investment_stages s ON fis.stage_id = s.id
|
||||
WHERE (s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%' OR s.id IS NULL)
|
||||
AND (f.geographic_focus LIKE '%Europe%' OR f.geographic_focus LIKE '%European%')
|
||||
|
||||
Q: "Fintech investors with check size under 5 million"
|
||||
A: SELECT DISTINCT f.id FROM funds f
|
||||
LEFT JOIN fund_sectors fs ON f.id = fs.fund_id
|
||||
LEFT JOIN sectors sec ON fs.sector_id = sec.id
|
||||
WHERE (sec.name LIKE '%Fintech%' OR sec.name LIKE '%Finance%' OR sec.name LIKE '%Financial%' OR sec.id IS NULL)
|
||||
AND (f.check_size_upper <= 5000000 OR f.check_size_upper IS NULL)
|
||||
|
||||
Q: "Seed stage investors"
|
||||
A: SELECT DISTINCT f.id FROM funds f
|
||||
LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id
|
||||
LEFT JOIN investment_stages s ON fis.stage_id = s.id
|
||||
WHERE s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%'
|
||||
|
||||
Q: "Growth stage investors"
|
||||
A: SELECT DISTINCT f.id FROM funds f
|
||||
LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id
|
||||
LEFT JOIN investment_stages s ON fis.stage_id = s.id
|
||||
WHERE s.name LIKE '%Growth%' OR s.name LIKE '%Late%' OR s.name LIKE '%Expansion%' OR s.name LIKE '%Series C%' OR s.name LIKE '%Series D%'
|
||||
|
||||
Q: "AI investors in America"
|
||||
A: SELECT DISTINCT f.id FROM funds f
|
||||
LEFT JOIN fund_sectors fs ON f.id = fs.fund_id
|
||||
LEFT JOIN sectors sec ON fs.sector_id = sec.id
|
||||
WHERE (sec.name LIKE '%AI%' OR sec.name LIKE '%Artificial Intelligence%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%')
|
||||
AND (f.geographic_focus LIKE '%America%' OR f.geographic_focus LIKE '%US%' OR f.geographic_focus LIKE '%United States%' OR f.geographic_focus LIKE '%USA%')
|
||||
|
||||
Q: "Healthcare investors"
|
||||
A: SELECT DISTINCT f.id FROM funds f
|
||||
LEFT JOIN fund_sectors fs ON f.id = fs.fund_id
|
||||
LEFT JOIN sectors sec ON fs.sector_id = sec.id
|
||||
WHERE sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR sec.name LIKE '%Pharma%'
|
||||
|
||||
IMPORTANT: Use LEFT JOIN so funds without sector/stage tags can still match. Include synonym terms with OR for better recall.
|
||||
|
||||
Return ONLY the SQL query, no explanations or markdown.""",
|
||||
),
|
||||
("user", "{question}"),
|
||||
]
|
||||
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
||||
# Update system message to specifically request only investor IDs
|
||||
system_message_updated = (
|
||||
prompt_template.format(dialect="SQLite", top_k=5)
|
||||
+ "\n\nIMPORTANT: You must ONLY return the investor IDs (id field) that match the user's criteria. "
|
||||
+ "Do NOT return any other information, explanations, or data. "
|
||||
+ "Your response should be ONLY a comma-separated list of numbers representing the investor IDs. "
|
||||
+ "Example format: 1, 5, 12, 23"
|
||||
)
|
||||
self.agent = create_react_agent(
|
||||
model=self.llm,
|
||||
tools=self.toolkit.get_tools(),
|
||||
prompt=system_message_updated,
|
||||
)
|
||||
|
||||
def _get_cache_key(self, question: str) -> str:
|
||||
"""Generate cache key from normalized question."""
|
||||
return hashlib.md5(question.lower().strip().encode()).hexdigest()
|
||||
def process_query(self, question: str) -> InvestorList:
|
||||
"""Process a query using the LLM and return investor data."""
|
||||
# Let the LLM handle all database interactions and filtering to get IDs
|
||||
response = self.agent.invoke(
|
||||
{"messages": [("user", question)]},
|
||||
)
|
||||
|
||||
async def process_query(
|
||||
self, question: str, project_id: Optional[int] = None
|
||||
) -> PaginatedResponse[InvestmentResponse]:
|
||||
"""Async wrapper for process_query. Runs blocking work in a thread to avoid
|
||||
blocking the event loop.
|
||||
"""
|
||||
return await asyncio.to_thread(self._process_query_sync, question, project_id)
|
||||
# Extract the actual message content
|
||||
ai_response = (
|
||||
response["messages"][-1].content if response.get("messages") else ""
|
||||
)
|
||||
|
||||
def _process_query_sync(
|
||||
self, question: str, project_id: Optional[int] = None
|
||||
) -> PaginatedResponse[InvestmentResponse]:
|
||||
"""Synchronous implementation of process_query. This is run in a thread by
|
||||
the async wrapper above.
|
||||
"""
|
||||
cache_key = self._get_cache_key(question)
|
||||
# Extract investor IDs from the AI response
|
||||
investor_ids = self._extract_investor_ids_from_response(ai_response)
|
||||
|
||||
# Check cache first
|
||||
if cache_key in self.query_cache:
|
||||
sql_query = self.query_cache[cache_key]
|
||||
logger.info(f"Using cached SQL: {sql_query}")
|
||||
else:
|
||||
# Generate SQL query
|
||||
messages = self.sql_prompt.format_messages(question=question)
|
||||
response = self.llm.invoke(messages)
|
||||
sql_query = response.content.strip()
|
||||
# Fetch full investor data using the IDs
|
||||
return self._fetch_investors_by_ids(investor_ids)
|
||||
|
||||
# Clean up SQL (remove markdown code blocks if present)
|
||||
sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
|
||||
def _extract_investor_ids_from_response(self, ai_response: str) -> List[int]:
|
||||
"""Extract investor IDs from AI response."""
|
||||
import re
|
||||
|
||||
# Cache the query
|
||||
self.query_cache[cache_key] = sql_query
|
||||
logger.info(f"Generated SQL: {sql_query}")
|
||||
|
||||
# Execute query to get fund IDs
|
||||
db_session = next(get_db())
|
||||
investor_ids = []
|
||||
try:
|
||||
result = db_session.execute(text(sql_query))
|
||||
fund_ids = [row[0] for row in result.fetchall()]
|
||||
logger.info(
|
||||
f"Found {len(fund_ids)} fund IDs: {fund_ids[:10]}{'...' if len(fund_ids) > 10 else ''}"
|
||||
)
|
||||
# Try multiple patterns to extract IDs from the response
|
||||
# Pattern 1: Simple numbers (assuming they are IDs)
|
||||
numbers = re.findall(r"\b\d+\b", ai_response)
|
||||
investor_ids = [int(num) for num in numbers]
|
||||
|
||||
# Pattern 2: If response contains explicit ID references
|
||||
id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower())
|
||||
if id_matches:
|
||||
investor_ids = [int(id_str) for id_str in id_matches]
|
||||
|
||||
return self._fetch_funds_by_ids(fund_ids, project_id)
|
||||
except Exception as e:
|
||||
logger.error(f"SQL execution error: {e}")
|
||||
logger.error(f"Failed SQL: {sql_query}")
|
||||
# Return empty result
|
||||
return PaginatedResponse(
|
||||
items=[], total=0, page=1, page_size=10, total_pages=0
|
||||
)
|
||||
finally:
|
||||
db_session.close()
|
||||
print(f"Error extracting IDs from response: {e}")
|
||||
return []
|
||||
|
||||
def _fetch_funds_by_ids(
|
||||
self, fund_ids: List[int], project_id: Optional[int] = None
|
||||
) -> PaginatedResponse[InvestmentResponse]:
|
||||
"""Fetch funds with all their relationships from the database using fund IDs.
|
||||
Constructs response similar to read_investors but starting from funds.
|
||||
return investor_ids
|
||||
|
||||
Args:
|
||||
fund_ids: List of fund IDs to fetch
|
||||
project_id: Optional project ID for compatibility scoring
|
||||
"""
|
||||
if not fund_ids:
|
||||
return PaginatedResponse(
|
||||
items=[],
|
||||
total=0,
|
||||
page=1,
|
||||
page_size=len(fund_ids) if fund_ids else 10,
|
||||
total_pages=0,
|
||||
)
|
||||
def _fetch_investors_by_ids(self, investor_ids: List[int]) -> InvestorList:
|
||||
"""Fetch investors with all their relationships from the database using IDs."""
|
||||
if not investor_ids:
|
||||
return InvestorList(investors=[])
|
||||
|
||||
# Get database session
|
||||
db_session = next(get_db())
|
||||
|
||||
try:
|
||||
# Load project if project_id provided
|
||||
project = None
|
||||
if project_id is not None:
|
||||
project = (
|
||||
db_session.query(ProjectTable)
|
||||
.options(selectinload(ProjectTable.sector))
|
||||
.filter(ProjectTable.id == project_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
# Query funds with all necessary relationships loaded
|
||||
funds = (
|
||||
db_session.query(FundTable)
|
||||
# Build query with all relationships loaded
|
||||
query = (
|
||||
db_session.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(FundTable.investor).selectinload(
|
||||
InvestorTable.portfolio_companies
|
||||
),
|
||||
selectinload(FundTable.investor).selectinload(
|
||||
InvestorTable.team_members
|
||||
),
|
||||
selectinload(FundTable.investor).selectinload(
|
||||
InvestorTable.sectors
|
||||
),
|
||||
selectinload(FundTable.investment_stages),
|
||||
selectinload(FundTable.sectors),
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
)
|
||||
.filter(FundTable.id.in_(fund_ids))
|
||||
.all()
|
||||
.filter(InvestorTable.id.in_(investor_ids))
|
||||
)
|
||||
|
||||
# Transform to InvestmentResponse format (one row per fund)
|
||||
investment_responses = []
|
||||
for fund in funds:
|
||||
investor = fund.investor
|
||||
investors = query.all()
|
||||
|
||||
# Calculate compatibility score if project provided
|
||||
compatibility_score = 1.0
|
||||
if project is not None:
|
||||
compatibility_score = calculate_project_investor_compatibility(
|
||||
project=project, investor=investor, use_funds=True
|
||||
# Transform to InvestorData format
|
||||
investor_data_list = []
|
||||
for investor in investors:
|
||||
investor_data = InvestorData(
|
||||
investor=investor,
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
)
|
||||
investor_data_list.append(investor_data)
|
||||
|
||||
# Get top 3 portfolio companies (id and name only)
|
||||
portfolio_companies = [
|
||||
CompanyMinimal(id=company.id, name=company.name)
|
||||
for company in investor.portfolio_companies[:3]
|
||||
]
|
||||
|
||||
# Get stage focus as comma-separated string
|
||||
stage_focus = (
|
||||
", ".join([stage.name for stage in fund.investment_stages])
|
||||
if fund.investment_stages
|
||||
else None
|
||||
)
|
||||
|
||||
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
|
||||
fund_sectors = [
|
||||
SectorMinimal(id=sector.id, name=sector.name)
|
||||
for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name)
|
||||
]
|
||||
|
||||
investment_response = InvestmentResponse(
|
||||
id=investor.id,
|
||||
name=f"{investor.name} - {fund.fund_name}"
|
||||
if fund.fund_name
|
||||
else investor.name,
|
||||
aum=investor.aum,
|
||||
check_size_lower=fund.check_size_lower,
|
||||
check_size_upper=fund.check_size_upper,
|
||||
geographic_focus=fund.geographic_focus,
|
||||
stage_focus=stage_focus,
|
||||
portfolio_companies=portfolio_companies,
|
||||
sectors=fund_sectors,
|
||||
compatibility_score=compatibility_score,
|
||||
)
|
||||
investment_responses.append(investment_response)
|
||||
|
||||
total_count = len(investment_responses)
|
||||
total_pages = 1 if total_count > 0 else 0
|
||||
|
||||
return PaginatedResponse(
|
||||
items=investment_responses,
|
||||
total=total_count,
|
||||
page=1,
|
||||
page_size=total_count,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
return InvestorList(investors=investor_data_list)
|
||||
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
@@ -1,340 +0,0 @@
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
# Import database models and compatibility score service
|
||||
from db.models import InvestorTable, ProjectTable
|
||||
from jinja2 import Environment, FileSystemLoader
|
||||
from playwright.async_api import async_playwright
|
||||
|
||||
from services.compatibility_score import calculate_project_investor_compatibility
|
||||
|
||||
|
||||
class ReportGenerator:
|
||||
"""Service for generating PDF reports from HTML templates"""
|
||||
|
||||
def __init__(self):
|
||||
# Set up Jinja2 environment
|
||||
template_dir = Path(__file__).parent.parent / "templates"
|
||||
self.env = Environment(loader=FileSystemLoader(str(template_dir)))
|
||||
|
||||
async def generate_investor_report(
|
||||
self,
|
||||
investor_data: Dict[str, Any],
|
||||
project_data: Optional[Dict[str, Any]] = None,
|
||||
investor_model: Optional[InvestorTable] = None,
|
||||
project_model: Optional[ProjectTable] = None,
|
||||
) -> bytes:
|
||||
"""
|
||||
Generate a PDF report for an investor profile.
|
||||
|
||||
Args:
|
||||
investor_data: Dictionary containing investor information
|
||||
project_data: Optional dictionary containing project information for compatibility analysis
|
||||
investor_model: Optional database model for investor (used for compatibility scoring)
|
||||
project_model: Optional database model for project (used for compatibility scoring)
|
||||
|
||||
Returns:
|
||||
bytes: PDF file content
|
||||
"""
|
||||
# Prepare template context
|
||||
context = self._prepare_context(
|
||||
investor_data, project_data, investor_model, project_model
|
||||
)
|
||||
|
||||
# Render HTML from template
|
||||
template = self.env.get_template("report.html")
|
||||
html_content = template.render(**context)
|
||||
# Convert HTML to PDF using Playwright
|
||||
pdf_bytes = await self._html_to_pdf(html_content)
|
||||
|
||||
return pdf_bytes
|
||||
|
||||
def _prepare_context(
|
||||
self,
|
||||
investor_data: Dict[str, Any],
|
||||
project_data: Optional[Dict[str, Any]] = None,
|
||||
investor_model: Optional[InvestorTable] = None,
|
||||
project_model: Optional[ProjectTable] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Prepare the context dictionary for template rendering"""
|
||||
context = {
|
||||
"investor": investor_data,
|
||||
"project": project_data,
|
||||
"compatibility_score": 0,
|
||||
"match_criteria": [],
|
||||
"recommendation": None,
|
||||
}
|
||||
|
||||
# If project data is provided, calculate compatibility
|
||||
if project_data:
|
||||
# Use the compatibility_score service if models are provided
|
||||
if investor_model and project_model:
|
||||
# Calculate using the standardized compatibility score service
|
||||
# Returns score between 0 and 1, convert to percentage (0-100)
|
||||
score_decimal = calculate_project_investor_compatibility(
|
||||
project=project_model, investor=investor_model, use_funds=True
|
||||
)
|
||||
context["compatibility_score"] = int(score_decimal * 100)
|
||||
else:
|
||||
# Fallback to old calculation method if models not provided
|
||||
context["compatibility_score"] = self._calculate_compatibility_score(
|
||||
investor_data, project_data
|
||||
)
|
||||
|
||||
context["match_criteria"] = self._generate_match_criteria(
|
||||
investor_data, project_data
|
||||
)
|
||||
context["recommendation"] = self._generate_recommendation(
|
||||
context["compatibility_score"], context["match_criteria"]
|
||||
)
|
||||
|
||||
return context
|
||||
|
||||
def _calculate_compatibility_score(
|
||||
self, investor_data: Dict[str, Any], project_data: Dict[str, Any]
|
||||
) -> int:
|
||||
"""Calculate overall compatibility score between investor and project"""
|
||||
score = 0
|
||||
weights = {
|
||||
"sector": 30,
|
||||
"stage": 30,
|
||||
"geography": 20,
|
||||
"check_size": 20,
|
||||
}
|
||||
|
||||
# Aggregate data from all funds
|
||||
all_sectors = set(investor_data.get("sectors", []))
|
||||
all_stages = set()
|
||||
all_geographies = []
|
||||
check_ranges = []
|
||||
|
||||
for fund in investor_data.get("funds", []):
|
||||
all_sectors.update(fund.get("sectors", []))
|
||||
all_stages.update(fund.get("investment_stages", []))
|
||||
if fund.get("geographic_focus"):
|
||||
all_geographies.append(fund["geographic_focus"])
|
||||
if fund.get("check_size_lower") and fund.get("check_size_upper"):
|
||||
check_ranges.append(
|
||||
{
|
||||
"lower": fund["check_size_lower"],
|
||||
"upper": fund["check_size_upper"],
|
||||
}
|
||||
)
|
||||
|
||||
# Sector match
|
||||
project_sectors = set(project_data.get("sectors", []))
|
||||
if all_sectors and project_sectors:
|
||||
if all_sectors & project_sectors:
|
||||
score += weights["sector"]
|
||||
|
||||
# Stage match - case insensitive comparison
|
||||
project_stage = project_data.get("stage")
|
||||
if project_stage and all_stages:
|
||||
# Normalize stage names for comparison (case-insensitive)
|
||||
normalized_stages = {
|
||||
stage.lower().replace("_", " ") for stage in all_stages
|
||||
}
|
||||
project_stage_normalized = project_stage.lower().replace("_", " ")
|
||||
if project_stage_normalized in normalized_stages:
|
||||
score += weights["stage"]
|
||||
|
||||
# Geography match - check if any fund matches
|
||||
project_geo = (project_data.get("location") or "").lower()
|
||||
geo_match = False
|
||||
if all_geographies:
|
||||
for geo in all_geographies:
|
||||
if geo:
|
||||
geo_lower = geo.lower()
|
||||
# Match if investor geography is "global" or if there's a location overlap
|
||||
if "global" in geo_lower or "worldwide" in geo_lower:
|
||||
geo_match = True
|
||||
break
|
||||
if project_geo and (
|
||||
geo_lower in project_geo or project_geo in geo_lower
|
||||
):
|
||||
geo_match = True
|
||||
break
|
||||
if geo_match:
|
||||
score += weights["geography"]
|
||||
|
||||
# Check size match - check if any fund's range matches
|
||||
project_valuation = project_data.get("valuation", 0)
|
||||
check_match = False
|
||||
if project_valuation and check_ranges:
|
||||
for check_range in check_ranges:
|
||||
if check_range["lower"] <= project_valuation <= check_range["upper"]:
|
||||
check_match = True
|
||||
break
|
||||
if check_match:
|
||||
score += weights["check_size"]
|
||||
|
||||
return min(score, 100)
|
||||
|
||||
def _generate_match_criteria(
|
||||
self, investor_data: Dict[str, Any], project_data: Dict[str, Any]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Generate detailed match criteria table"""
|
||||
criteria = []
|
||||
|
||||
# Aggregate data from all funds
|
||||
all_sectors = set(investor_data.get("sectors", []))
|
||||
all_stages = set()
|
||||
all_geographies = []
|
||||
check_ranges = []
|
||||
|
||||
for fund in investor_data.get("funds", []):
|
||||
all_sectors.update(fund.get("sectors", []))
|
||||
all_stages.update(fund.get("investment_stages", []))
|
||||
if fund.get("geographic_focus"):
|
||||
all_geographies.append(fund["geographic_focus"])
|
||||
if fund.get("check_size_lower") and fund.get("check_size_upper"):
|
||||
check_ranges.append(
|
||||
{
|
||||
"lower": fund["check_size_lower"],
|
||||
"upper": fund["check_size_upper"],
|
||||
"fund_name": fund.get("fund_name", "Unnamed Fund"),
|
||||
}
|
||||
)
|
||||
|
||||
# Sector criterion
|
||||
project_sectors = project_data.get("sectors", [])
|
||||
sector_match = "Perfect" if all_sectors & set(project_sectors) else "Mismatch"
|
||||
criteria.append(
|
||||
{
|
||||
"name": "Sector",
|
||||
"requirement": ", ".join(project_sectors) if project_sectors else "N/A",
|
||||
"evidence": ", ".join(list(all_sectors)[:3]) if all_sectors else "N/A",
|
||||
"match": sector_match,
|
||||
"weight": "30%",
|
||||
}
|
||||
)
|
||||
|
||||
# Stage criterion - case insensitive comparison
|
||||
project_stage = project_data.get("stage", "N/A")
|
||||
stage_match = "Mismatch"
|
||||
if project_stage != "N/A" and all_stages:
|
||||
# Normalize stage names for comparison
|
||||
normalized_stages = {
|
||||
stage.lower().replace("_", " ") for stage in all_stages
|
||||
}
|
||||
project_stage_normalized = project_stage.lower().replace("_", " ")
|
||||
stage_match = (
|
||||
"Perfect"
|
||||
if project_stage_normalized in normalized_stages
|
||||
else "Mismatch"
|
||||
)
|
||||
elif project_stage == "N/A":
|
||||
stage_match = "N/A"
|
||||
|
||||
criteria.append(
|
||||
{
|
||||
"name": "Stage",
|
||||
"requirement": str(project_stage),
|
||||
"evidence": ", ".join(all_stages) if all_stages else "N/A",
|
||||
"match": stage_match,
|
||||
"weight": "30%",
|
||||
}
|
||||
)
|
||||
|
||||
# Geography criterion
|
||||
project_geo = project_data.get("location") or "N/A"
|
||||
investor_geo_display = ", ".join(all_geographies) if all_geographies else "N/A"
|
||||
|
||||
# Safe comparison handling None values and "Global" matches
|
||||
geo_match = "Mismatch"
|
||||
if project_geo != "N/A" and all_geographies:
|
||||
for geo in all_geographies:
|
||||
if geo:
|
||||
geo_lower = geo.lower()
|
||||
# Match if investor geography is "global" or if there's a location overlap
|
||||
if "global" in geo_lower or "worldwide" in geo_lower:
|
||||
geo_match = "Perfect"
|
||||
break
|
||||
if (
|
||||
geo_lower in project_geo.lower()
|
||||
or project_geo.lower() in geo_lower
|
||||
):
|
||||
geo_match = "Strong"
|
||||
break
|
||||
elif not all_geographies and project_geo == "N/A":
|
||||
geo_match = "N/A"
|
||||
|
||||
criteria.append(
|
||||
{
|
||||
"name": "Geography",
|
||||
"requirement": project_geo,
|
||||
"evidence": investor_geo_display,
|
||||
"match": geo_match,
|
||||
"weight": "20%",
|
||||
}
|
||||
)
|
||||
|
||||
# Check Size criterion
|
||||
project_val = project_data.get("valuation", 0)
|
||||
|
||||
# Build evidence string from all fund ranges
|
||||
check_evidence = "N/A"
|
||||
if check_ranges:
|
||||
evidence_parts = []
|
||||
for cr in check_ranges[:3]: # Show up to 3 funds
|
||||
range_str = (
|
||||
f"€{cr['lower'] / 1000000:.0f}M - €{cr['upper'] / 1000000:.0f}M"
|
||||
)
|
||||
if cr["fund_name"]:
|
||||
evidence_parts.append(f"{cr['fund_name']}: {range_str}")
|
||||
else:
|
||||
evidence_parts.append(range_str)
|
||||
check_evidence = "; ".join(evidence_parts)
|
||||
|
||||
# Check if project valuation matches any fund
|
||||
check_match = "N/A"
|
||||
if project_val > 0 and check_ranges:
|
||||
match_found = any(
|
||||
cr["lower"] <= project_val <= cr["upper"] for cr in check_ranges
|
||||
)
|
||||
check_match = "Perfect" if match_found else "Mismatch"
|
||||
|
||||
criteria.append(
|
||||
{
|
||||
"name": "Check Size",
|
||||
"requirement": f"€{project_val / 1000000:.0f}M"
|
||||
if project_val
|
||||
else "N/A",
|
||||
"evidence": check_evidence,
|
||||
"match": check_match,
|
||||
"weight": "20%",
|
||||
}
|
||||
)
|
||||
|
||||
return criteria
|
||||
|
||||
def _generate_recommendation(
|
||||
self, score: int, criteria: List[Dict[str, str]]
|
||||
) -> str:
|
||||
"""Generate recommendation text based on score and criteria"""
|
||||
if score >= 85:
|
||||
return "High Priority. A strong target due to exceptional alignment on the most heavily-weighted criteria: Sector and Stage. The strong geographic fit further solidifies this recommendation."
|
||||
elif score >= 70:
|
||||
return "Medium Priority. Good alignment on key criteria with some areas of strong fit. The geographic fit in the target region supports this recommendation."
|
||||
else:
|
||||
return "Low Priority. Limited alignment on key investment criteria. Consider for future evaluation if circumstances change."
|
||||
|
||||
async def _html_to_pdf(self, html_content: str) -> bytes:
|
||||
"""Convert HTML content to PDF using Playwright"""
|
||||
async with async_playwright() as p:
|
||||
browser = await p.chromium.launch()
|
||||
page = await browser.new_page()
|
||||
|
||||
# Set content and wait for any dynamic content to load
|
||||
await page.set_content(html_content, wait_until="networkidle")
|
||||
|
||||
# Generate PDF with proper settings
|
||||
pdf_bytes = await page.pdf(
|
||||
format="A4",
|
||||
print_background=True,
|
||||
margin={"top": "0", "right": "0", "bottom": "0", "left": "0"},
|
||||
)
|
||||
|
||||
await browser.close()
|
||||
|
||||
return pdf_bytes
|
||||
@@ -1,329 +0,0 @@
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<title>Investor Profile Report</title>
|
||||
<script src="https://cdn.tailwindcss.com"></script>
|
||||
<style>
|
||||
@page {
|
||||
size: A4;
|
||||
margin: 0;
|
||||
}
|
||||
|
||||
html,
|
||||
body {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
height: 100%;
|
||||
background: white;
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI",
|
||||
Roboto, sans-serif;
|
||||
}
|
||||
|
||||
/* Each page is exactly one A4 sheet */
|
||||
.page {
|
||||
width: 210mm;
|
||||
height: 297mm;
|
||||
position: relative;
|
||||
background: white;
|
||||
overflow: hidden;
|
||||
}
|
||||
|
||||
/* Adds a break between pages (for print/PDF) */
|
||||
.page-with-break {
|
||||
page-break-after: always;
|
||||
}
|
||||
|
||||
/* Inner content wrapper for consistent padding */
|
||||
.page-content {
|
||||
box-sizing: border-box;
|
||||
padding: 48px; /* equivalent to Tailwind p-12 */
|
||||
height: 100%;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
}
|
||||
|
||||
.tag {
|
||||
display: inline-block;
|
||||
padding: 4px 12px;
|
||||
background: #f3f4f6;
|
||||
border-radius: 4px;
|
||||
font-size: 12px;
|
||||
margin: 4px;
|
||||
}
|
||||
|
||||
/* Ensure the footer text stays inside page bounds */
|
||||
.page-footer {
|
||||
position: absolute;
|
||||
bottom: 48px;
|
||||
right: 48px;
|
||||
font-size: 10px;
|
||||
color: #9ca3af; /* Tailwind gray-400 */
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<!-- Page 1 -->
|
||||
<div class="page page-with-break">
|
||||
<div class="page-content">
|
||||
<div class="flex justify-between items-start mb-8">
|
||||
<div>
|
||||
<p class="text-sm text-gray-600 mb-2">Investor Profile</p>
|
||||
<h1 class="text-4xl font-bold text-gray-900">
|
||||
{{ investor.name }}
|
||||
</h1>
|
||||
</div>
|
||||
<a
|
||||
href="{{ investor.website }}"
|
||||
target="_blank"
|
||||
class="bg-gray-200 text-gray-700 px-4 py-2 rounded text-sm no-underline"
|
||||
>Visit Website →</a
|
||||
>
|
||||
</div>
|
||||
|
||||
<div class="grid grid-cols-2 gap-8 flex-grow">
|
||||
<!-- Left Column -->
|
||||
<div>
|
||||
<div class="mb-4">
|
||||
<h2 class="text-sm font-bold text-gray-900 uppercase mb-4">
|
||||
Investor Description
|
||||
</h2>
|
||||
<p class="text-sm text-gray-700 leading-relaxed">
|
||||
{{ investor.description or 'No description available.' }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div class="mb-4">
|
||||
<h2 class="text-sm font-bold text-gray-900 uppercase mb-4">
|
||||
Portfolio Highlights
|
||||
</h2>
|
||||
<div class="flex flex-wrap gap-2">
|
||||
{% if investor.portfolio_highlights %}
|
||||
{% for company in investor.portfolio_highlights[:5] %}
|
||||
<span class="tag">{{ company }}</span>
|
||||
{% endfor %}
|
||||
{% else %}
|
||||
<p class="text-sm text-gray-500">
|
||||
No portfolio highlights available
|
||||
</p>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="mb-4">
|
||||
<h2 class="text-sm font-bold text-gray-900 uppercase mb-4">
|
||||
Senior Leadership
|
||||
</h2>
|
||||
{% if investor.team_members %}
|
||||
{% for member in investor.team_members[:2] %}
|
||||
<div class="mb-3">
|
||||
<p class="text-sm font-semibold text-gray-900">
|
||||
{{ member.name }}
|
||||
</p>
|
||||
<p class="text-sm text-gray-600">
|
||||
{{ member.role or member.title or 'Team Member' }}
|
||||
</p>
|
||||
{% if member.email %}
|
||||
<p class="text-xs text-blue-600">
|
||||
{{ member.email }}
|
||||
</p>
|
||||
{% endif %}
|
||||
</div>
|
||||
{% endfor %}
|
||||
{% else %}
|
||||
<p class="text-sm text-gray-500">No team information available</p>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Right Column -->
|
||||
<div class="bg-gray-50 p-6 rounded-lg">
|
||||
<h2 class="text-sm font-bold text-gray-900 uppercase mb-4">
|
||||
Key Data
|
||||
</h2>
|
||||
<div class="space-y-3 text-sm">
|
||||
<div>
|
||||
<p class="text-xs text-gray-600">Headquarters:</p>
|
||||
<p class="font-semibold text-gray-900">
|
||||
{{ investor.headquarters or 'N/A' }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<p class="text-xs text-gray-600">Sectors:</p>
|
||||
<p class="font-semibold text-gray-900">
|
||||
{% if investor.sectors %}
|
||||
{{ investor.sectors | join(', ') }}
|
||||
{% else %}
|
||||
N/A
|
||||
{% endif %}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<p class="text-xs text-gray-600">AUM (EUR million):</p>
|
||||
<p class="font-semibold text-gray-900">
|
||||
{% if investor.aum %}
|
||||
€{{ '{:,.0f}'.format(investor.aum / 1000000) }}M
|
||||
{% else %}
|
||||
N/A
|
||||
{% endif %}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<p class="text-xs text-gray-600 mb-1">Number of Funds:</p>
|
||||
<p class="font-semibold text-gray-900">
|
||||
{{ investor.funds | length if investor.funds else 'N/A' }}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="mt-4">
|
||||
<h3 class="text-xs font-bold text-gray-900 uppercase mb-2">
|
||||
Fund Details
|
||||
</h3>
|
||||
{% if investor.funds %}
|
||||
{% for fund in investor.funds %}
|
||||
<div class="mb-3 pb-3 border-b border-gray-200">
|
||||
<p class="text-sm font-semibold text-gray-900 mb-1">
|
||||
{{ fund.fund_name or 'Fund ' + loop.index|string }}
|
||||
</p>
|
||||
<div class="text-xs text-gray-700 space-y-1">
|
||||
{% if fund.fund_size %}
|
||||
<p>Fund Size: €{{ '{:,.0f}'.format(fund.fund_size / 1000000) }}M</p>
|
||||
{% endif %}
|
||||
{% if fund.check_size_lower and fund.check_size_upper %}
|
||||
<p>Check Size: €{{ '{:,.0f}'.format(fund.check_size_lower / 1000000) }}M - €{{ '{:,.0f}'.format(fund.check_size_upper / 1000000) }}M</p>
|
||||
{% endif %}
|
||||
{% if fund.geographic_focus %}
|
||||
<p>Geography: {{ fund.geographic_focus }}</p>
|
||||
{% endif %}
|
||||
{% if fund.investment_stages %}
|
||||
<p>Stages: {{ fund.investment_stages | join(', ') }}</p>
|
||||
{% endif %}
|
||||
{% if fund.sectors %}
|
||||
<p>Sectors: {{ fund.sectors[:3] | join(', ') }}</p>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
{% endfor %}
|
||||
{% else %}
|
||||
<p class="text-xs text-gray-500">No fund information available</p>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="page-footer">Page 1</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Page 2 -->
|
||||
{% if project %}
|
||||
<div class="page">
|
||||
<div class="page-content">
|
||||
<h1 class="text-3xl font-bold text-gray-900 mb-8">
|
||||
{{ investor.name }}: Mandate Match Analysis
|
||||
</h1>
|
||||
|
||||
<!-- Overall Match Circle -->
|
||||
<div class="flex justify-center mb-12">
|
||||
<div class="text-center">
|
||||
<p class="text-sm font-bold text-gray-700 uppercase mb-4">
|
||||
Overall Mandate Match
|
||||
</p>
|
||||
<div
|
||||
class="w-48 h-48 rounded-full border-8 border-green-400 flex items-center justify-center bg-green-50 mx-auto"
|
||||
>
|
||||
<span class="text-5xl font-bold text-green-600"
|
||||
>{{ compatibility_score }}%</span
|
||||
>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<!-- Mandate Alignment Analysis Table -->
|
||||
<div class="mb-12">
|
||||
<h2 class="text-xl font-bold text-gray-900 mb-6">
|
||||
Mandate Alignment Analysis
|
||||
</h2>
|
||||
<table class="w-full border-collapse">
|
||||
<thead>
|
||||
<tr class="border-b-2 border-gray-300">
|
||||
<th
|
||||
class="text-left py-3 px-4 text-sm font-bold text-gray-700"
|
||||
>
|
||||
Criterion
|
||||
</th>
|
||||
<th
|
||||
class="text-left py-3 px-4 text-sm font-bold text-gray-700"
|
||||
>
|
||||
Mandate Requirement
|
||||
</th>
|
||||
<th
|
||||
class="text-left py-3 px-4 text-sm font-bold text-gray-700"
|
||||
>
|
||||
Investor Evidence (from Database)
|
||||
</th>
|
||||
<th
|
||||
class="text-left py-3 px-4 text-sm font-bold text-gray-700"
|
||||
>
|
||||
Match Score
|
||||
</th>
|
||||
<th
|
||||
class="text-left py-3 px-4 text-sm font-bold text-gray-700"
|
||||
>
|
||||
Weight
|
||||
</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody>
|
||||
{% for criterion in match_criteria %}
|
||||
<tr class="border-b border-gray-200">
|
||||
<td class="py-4 px-4 text-sm text-gray-900">
|
||||
{{ criterion.name }}
|
||||
</td>
|
||||
<td class="py-4 px-4 text-sm text-gray-700">
|
||||
{{ criterion.requirement }}
|
||||
</td>
|
||||
<td class="py-4 px-4 text-sm text-gray-700">
|
||||
{{ criterion.evidence }}
|
||||
</td>
|
||||
<td class="py-4 px-4 text-sm">
|
||||
<span
|
||||
class="{% if criterion.match == 'Perfect' %}text-green-600{% elif criterion.match == 'Strong' %}text-blue-600{% else %}text-yellow-600{% endif %} font-semibold"
|
||||
>
|
||||
{{ criterion.match }}
|
||||
</span>
|
||||
</td>
|
||||
<td class="py-4 px-4 text-sm text-gray-700">
|
||||
{{ criterion.weight }}
|
||||
</td>
|
||||
</tr>
|
||||
{% endfor %}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<!-- Final Recommendation -->
|
||||
<div class="bg-blue-50 border-l-4 border-blue-500 p-6 rounded">
|
||||
<h3 class="text-lg font-bold text-gray-900 mb-3">
|
||||
Final Recommendation & Rationale
|
||||
</h3>
|
||||
<p class="text-sm text-gray-700 leading-relaxed">
|
||||
{{ recommendation or "High Priority. A strong target due to
|
||||
exceptional alignment on the most heavily-weighted criteria:
|
||||
Sector and Stage. The strong geographic fit in the DACH
|
||||
region further solidifies this recommendation." }}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<div class="absolute bottom-12 right-12 text-xs text-gray-400">
|
||||
Page 2
|
||||
</div>
|
||||
</div>
|
||||
{% endif %}
|
||||
</body>
|
||||
</html>
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Binary file not shown.
@@ -1,315 +0,0 @@
|
||||
import logging
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
import pandas as pd
|
||||
from models import CompanyTable, InvestorTable, SectorTable, engine, init_database
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import the schema
|
||||
init_database()
|
||||
|
||||
|
||||
# ===================== Ingesting Original Data =====================#
|
||||
def parse_investor_names(investor_names_str):
|
||||
"""Parse comma-separated investor names and return a list"""
|
||||
if pd.isna(investor_names_str) or investor_names_str == "":
|
||||
return []
|
||||
|
||||
# Split by comma and clean whitespace
|
||||
# investors = [name.strip() for name in str(investor_names_str).split(",")]
|
||||
investors = [
|
||||
clean_name(name.strip()) for name in str(investor_names_str).split(",")
|
||||
]
|
||||
return [investor for investor in investors if investor]
|
||||
|
||||
|
||||
def parse_industries(industries_str):
|
||||
"""Parse comma-separated industries and return a list"""
|
||||
if pd.isna(industries_str) or industries_str == "":
|
||||
return []
|
||||
|
||||
# Split by comma and clean whitespace
|
||||
industries = [industry.strip() for industry in str(industries_str).split(",")]
|
||||
return [industry for industry in industries if industry]
|
||||
|
||||
|
||||
def clean_special_characters(text):
|
||||
"""Clean special characters from text, converting to ASCII equivalents"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# First remove ellipses and other problematic patterns
|
||||
text = str(text).replace("...", "").replace("..", "")
|
||||
|
||||
# Normalize unicode characters to their closest ASCII equivalents
|
||||
normalized = unicodedata.normalize("NFKD", text)
|
||||
|
||||
# Remove accents and convert to ASCII
|
||||
ascii_text = normalized.encode("ascii", "ignore").decode("ascii")
|
||||
|
||||
# Remove any remaining non-alphanumeric characters except spaces, hyphens, and periods
|
||||
cleaned = re.sub(r"[^a-zA-Z0-9\s\-\.]", "", ascii_text)
|
||||
|
||||
# Clean up multiple spaces
|
||||
cleaned = re.sub(r"\s+", " ", cleaned).strip()
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def clean_string(value):
|
||||
"""Clean string values, converting empty/null/nan/0 to None and removing special characters"""
|
||||
if (
|
||||
pd.isna(value)
|
||||
or value == ""
|
||||
or str(value).lower() in ["nan", "null", "none", "0", "0.0"]
|
||||
):
|
||||
return None
|
||||
|
||||
# First clean special characters
|
||||
cleaned = clean_special_characters(str(value).strip())
|
||||
|
||||
# Check if result is just "0" after cleaning
|
||||
if cleaned in ["0", "0.0", "null", "nan", "none"]:
|
||||
return None
|
||||
|
||||
return cleaned if cleaned else None
|
||||
|
||||
|
||||
def clean_name(value):
|
||||
"""Clean names (companies, investors) with special character handling"""
|
||||
if (
|
||||
pd.isna(value)
|
||||
or value == ""
|
||||
or str(value).lower() in ["nan", "null", "none", "0", "0.0"]
|
||||
):
|
||||
return None
|
||||
|
||||
# Clean special characters but be more permissive for names
|
||||
text = str(value).strip()
|
||||
# First remove ellipses and other problematic patterns
|
||||
# text = text.replace("...", "").replace("..", "")
|
||||
|
||||
# Normalize unicode characters
|
||||
normalized = unicodedata.normalize("NFKD", text)
|
||||
|
||||
# Convert to ASCII but keep more characters for business names
|
||||
ascii_text = normalized.encode("ascii", "ignore").decode("ascii")
|
||||
|
||||
# Allow alphanumeric, spaces, hyphens, periods, parentheses, and ampersands
|
||||
cleaned = re.sub(r"[^a-zA-Z0-9\s\-\.\(\)&]", "", ascii_text)
|
||||
|
||||
# Clean up multiple spaces
|
||||
cleaned = re.sub(r"\s+", " ", cleaned).strip()
|
||||
|
||||
# Remove any trailing or leading periods
|
||||
cleaned = cleaned.strip(".")
|
||||
|
||||
cleaned = cleaned.replace("..", "").replace("...", "")
|
||||
# Check if result is just "0" after cleaning
|
||||
if cleaned in ["0", "0.0", "null", "nan", "none"]:
|
||||
return None
|
||||
|
||||
return cleaned if cleaned else None
|
||||
|
||||
|
||||
def clean_integer(value):
|
||||
"""Clean integer values, converting empty/null/nan/0 to None"""
|
||||
if pd.isna(value) or str(value).lower() in ["nan", "null", "none", "", "0", "0.0"]:
|
||||
return None
|
||||
try:
|
||||
cleaned_val = int(float(value))
|
||||
return cleaned_val if cleaned_val > 0 else None
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def parse_website(website_str: str):
|
||||
try:
|
||||
_, end = website_str.split(":")
|
||||
|
||||
if end == "0":
|
||||
return None
|
||||
return "https:" + end
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def ingest_data():
|
||||
# Create database engine and session
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = Session()
|
||||
|
||||
# Load CSV files
|
||||
print("Loading CSV files...")
|
||||
companies_df = pd.read_csv("companies.csv")
|
||||
investors_df = pd.read_csv("investors.csv")
|
||||
|
||||
print(f"📊 Companies CSV: {len(companies_df)} rows")
|
||||
print(f"📊 Investors CSV: {len(investors_df)} rows")
|
||||
|
||||
# Step 1: Ingest Investors
|
||||
print("\n🔄 Step 1: Ingesting Investors...")
|
||||
investors_processed = 0
|
||||
|
||||
for index, row in investors_df.iterrows():
|
||||
try:
|
||||
investor_name = clean_name(row.get("Filtered investor names", ""))
|
||||
|
||||
if investor_name:
|
||||
# Check if investor already exists
|
||||
existing_investor = (
|
||||
session.query(InvestorTable).filter_by(name=investor_name).first()
|
||||
)
|
||||
if not existing_investor:
|
||||
investor = InvestorTable(
|
||||
name=investor_name,
|
||||
description=clean_string(row.get("Business model", "")),
|
||||
headquarters=clean_string(row.get("HQ", "")),
|
||||
website=parse_website(str(row.get("Website", "")).strip()),
|
||||
number_of_investments=clean_integer(
|
||||
row.get("Number of investments")
|
||||
),
|
||||
)
|
||||
session.add(investor)
|
||||
investors_processed += 1
|
||||
|
||||
if investors_processed % 1000 == 0:
|
||||
session.commit()
|
||||
print(f" Committed {investors_processed} investors")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing investor {index}: {e}")
|
||||
continue
|
||||
|
||||
session.commit()
|
||||
print(f"✅ Investors completed: {investors_processed} processed")
|
||||
|
||||
# Step 2: Ingest Companies and Rounds
|
||||
print("\n🔄 Step 2: Ingesting Companies and Sectors...")
|
||||
companies_processed = 0
|
||||
sectors_created = set()
|
||||
|
||||
for index, row in companies_df.iterrows():
|
||||
try:
|
||||
# Process company
|
||||
company_name = clean_name(row.get("Organization Name", ""))
|
||||
if not company_name:
|
||||
continue
|
||||
|
||||
# Check if company already exists
|
||||
existing_company = (
|
||||
session.query(CompanyTable).filter_by(name=company_name).first()
|
||||
)
|
||||
if existing_company:
|
||||
company = existing_company
|
||||
else:
|
||||
# Create company
|
||||
company = CompanyTable(
|
||||
name=company_name,
|
||||
description=clean_string(row.get("Organization Description", "")),
|
||||
location=clean_string(row.get("Organization Location", "")),
|
||||
industry=clean_string(row.get("Organization Industries", "")),
|
||||
website=clean_string(row.get("Organization Website", "")),
|
||||
)
|
||||
session.add(company)
|
||||
session.flush() # Get the company ID
|
||||
companies_processed += 1
|
||||
|
||||
# Process investor relationships
|
||||
investor_names_str = row.get("Investor Names", "")
|
||||
if pd.notna(investor_names_str) and investor_names_str:
|
||||
investor_names = parse_investor_names(investor_names_str)
|
||||
|
||||
for investor_name in investor_names:
|
||||
# Find investor in database
|
||||
investor = (
|
||||
session.query(InvestorTable)
|
||||
.filter_by(name=investor_name.strip())
|
||||
.first()
|
||||
)
|
||||
|
||||
if investor:
|
||||
# Add investor-company relationship
|
||||
if company not in investor.portfolio_companies:
|
||||
investor.portfolio_companies.append(company)
|
||||
else:
|
||||
print("This company has an investor not in DB:", investor_name)
|
||||
|
||||
# Process sectors/industries
|
||||
industries_str = row.get("Organization Industries", "")
|
||||
if pd.notna(industries_str) and industries_str:
|
||||
industries = parse_industries(industries_str)
|
||||
|
||||
for industry_name in industries:
|
||||
industry_name = industry_name.strip()
|
||||
if industry_name:
|
||||
# Check if sector exists
|
||||
sector = (
|
||||
session.query(SectorTable)
|
||||
.filter_by(name=industry_name)
|
||||
.first()
|
||||
)
|
||||
if not sector:
|
||||
sector = SectorTable(name=industry_name)
|
||||
session.add(sector)
|
||||
session.flush()
|
||||
sectors_created.add(industry_name)
|
||||
|
||||
# Add company-sector relationship
|
||||
if sector not in company.sectors:
|
||||
company.sectors.append(sector)
|
||||
|
||||
# Commit every 100 companies
|
||||
if companies_processed % 100 == 0 and companies_processed > 0:
|
||||
session.commit()
|
||||
print(f" Processed {companies_processed} companies...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing company {index}: {e}")
|
||||
session.rollback()
|
||||
continue
|
||||
|
||||
# Step 3: Link investors to sectors based on portfolio companies
|
||||
print("\n🔄 Step 3: Linking Investors to Sectors...")
|
||||
investors_linked_to_sectors = 0
|
||||
all_investors = session.query(InvestorTable).all()
|
||||
for investor in all_investors:
|
||||
sectors = set()
|
||||
for company in investor.portfolio_companies:
|
||||
for sector in company.sectors:
|
||||
sectors.add(sector)
|
||||
# Add sectors to investor if not already present
|
||||
for sector in sectors:
|
||||
if sector not in investor.sectors:
|
||||
investor.sectors.append(sector)
|
||||
if sectors:
|
||||
investors_linked_to_sectors += 1
|
||||
session.commit()
|
||||
print(f"✅ Linked {investors_linked_to_sectors} investors to sectors")
|
||||
|
||||
# Final commit
|
||||
session.commit()
|
||||
|
||||
# Final counts
|
||||
final_investors = session.query(InvestorTable).count()
|
||||
final_companies = session.query(CompanyTable).count()
|
||||
final_sectors = session.query(SectorTable).count()
|
||||
|
||||
print("\n🎉 Ingestion Complete!")
|
||||
print(f" Investors: {final_investors}")
|
||||
print(f" Companies: {final_companies}")
|
||||
print(f" Sectors: {final_sectors}")
|
||||
|
||||
session.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ingest_data()
|
||||
# print(clean_name("A... Energi"))
|
||||
# print(clean_name("B.. Tech"))
|
||||
# print(clean_name("A... Energi"))
|
||||
@@ -1,381 +0,0 @@
|
||||
import enum
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
Table,
|
||||
Text,
|
||||
create_engine,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session, declarative_mixin, relationship, sessionmaker
|
||||
from sqlalchemy.types import JSON, Enum
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
# Database configuration
|
||||
# DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./investors.db")
|
||||
|
||||
# Create engine
|
||||
engine = create_engine("sqlite:///./investors.db", echo=False)
|
||||
|
||||
# Create session factory
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
db_dependency = Annotated[Session, Depends(get_db)]
|
||||
|
||||
|
||||
def init_database():
|
||||
"""Initialize the database by creating all tables"""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
def get_session_sync() -> Session:
|
||||
"""Get a database session for synchronous operations"""
|
||||
return SessionLocal()
|
||||
|
||||
|
||||
def get_db_session():
|
||||
"""Get a database session for direct use."""
|
||||
return SessionLocal()
|
||||
|
||||
|
||||
@declarative_mixin
|
||||
class TimestampMixin:
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
|
||||
class InvestmentStage(enum.Enum):
|
||||
SEED = "SEED"
|
||||
SERIES_A = "SERIES_A"
|
||||
SERIES_B = "SERIES_B"
|
||||
SERIES_C = "SERIES_C"
|
||||
GROWTH = "GROWTH"
|
||||
LATE_STAGE = "LATE_STAGE"
|
||||
|
||||
|
||||
# Association table for many-to-many relationship between investors and companies
|
||||
investor_company_association = Table(
|
||||
"investor_companies",
|
||||
Base.metadata,
|
||||
Column("investor_id", Integer, ForeignKey("investors.id")),
|
||||
Column("company_id", Integer, ForeignKey("companies.id")),
|
||||
)
|
||||
|
||||
|
||||
# Association table for investor-sector many-to-many
|
||||
investor_sector_association = Table(
|
||||
"investor_sectors",
|
||||
Base.metadata,
|
||||
Column("investor_id", Integer, ForeignKey("investors.id")),
|
||||
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
||||
)
|
||||
|
||||
|
||||
company_sector_association = Table(
|
||||
"company_sector",
|
||||
Base.metadata,
|
||||
Column("company_id", Integer, ForeignKey("companies.id")),
|
||||
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
||||
)
|
||||
|
||||
project_sector_association = Table(
|
||||
"project_sector",
|
||||
Base.metadata,
|
||||
Column("project_id", Integer, ForeignKey("projects.id")),
|
||||
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
||||
)
|
||||
|
||||
project_investor_association = Table(
|
||||
"project_investors",
|
||||
Base.metadata,
|
||||
Column("project_id", Integer, ForeignKey("projects.id")),
|
||||
Column("investor_id", Integer, ForeignKey("investors.id")),
|
||||
)
|
||||
|
||||
project_company_association = Table(
|
||||
"project_companies",
|
||||
Base.metadata,
|
||||
Column("project_id", Integer, ForeignKey("projects.id")),
|
||||
Column("company_id", Integer, ForeignKey("companies.id")),
|
||||
)
|
||||
|
||||
# Association table for investor-stage many-to-many
|
||||
investor_stage_association = Table(
|
||||
"investor_stages",
|
||||
Base.metadata,
|
||||
Column("investor_id", Integer, ForeignKey("investors.id")),
|
||||
Column("stage_id", Integer, ForeignKey("investment_stages.id")),
|
||||
)
|
||||
|
||||
# Association table for fund-stage many-to-many
|
||||
fund_investment_stages_association = Table(
|
||||
"fund_investment_stages",
|
||||
Base.metadata,
|
||||
Column("fund_id", Integer, ForeignKey("funds.id")),
|
||||
Column("stage_id", Integer, ForeignKey("investment_stages.id")),
|
||||
)
|
||||
|
||||
# Association table for fund-sector many-to-many
|
||||
fund_sectors_association = Table(
|
||||
"fund_sectors",
|
||||
Base.metadata,
|
||||
Column("fund_id", Integer, ForeignKey("funds.id")),
|
||||
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
||||
)
|
||||
|
||||
|
||||
class InvestorTable(Base, TimestampMixin):
|
||||
__tablename__ = "investors"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Basic investor info
|
||||
website = Column(String, nullable=True)
|
||||
headquarters = Column(String, nullable=True)
|
||||
|
||||
# AUM fields
|
||||
aum = Column(Integer, nullable=True) # Store as integer for numerical filtering
|
||||
aum_as_of_date = Column(String, nullable=True)
|
||||
aum_source_url = Column(String, nullable=True)
|
||||
|
||||
# Check size (deprecated in favor of fund-level data, but keeping for backward compatibility)
|
||||
check_size_lower = Column(Integer, nullable=True)
|
||||
check_size_upper = Column(Integer, nullable=True)
|
||||
|
||||
# Geographic focus (deprecated in favor of fund-level, but keeping for backward compatibility)
|
||||
geographic_focus = Column(String, nullable=True)
|
||||
|
||||
# Investment thesis and portfolio
|
||||
investment_thesis = Column(JSON, nullable=True) # Array of thesis statements
|
||||
portfolio_highlights = Column(
|
||||
JSON, nullable=True
|
||||
) # Array of portfolio company names
|
||||
linked_documents = Column(JSON, nullable=True) # Array of document URLs
|
||||
|
||||
# Research metadata
|
||||
researcher_notes = Column(Text, nullable=True)
|
||||
missing_important_fields = Column(
|
||||
JSON, nullable=True
|
||||
) # Array of missing field names
|
||||
sources = Column(JSON, nullable=True) # JSON object with source URLs
|
||||
|
||||
# Portfolio info
|
||||
number_of_investments = Column(Integer, nullable=True)
|
||||
|
||||
# Relationships
|
||||
team_members = relationship(
|
||||
"InvestorMember", back_populates="investor", cascade="all, delete-orphan"
|
||||
)
|
||||
funds = relationship(
|
||||
"FundTable", back_populates="investor", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Many-to-many relationship with investment stages
|
||||
investment_stages = relationship(
|
||||
"InvestmentStageTable",
|
||||
secondary=investor_stage_association,
|
||||
back_populates="investors",
|
||||
)
|
||||
|
||||
# Relationship to portfolio companies
|
||||
portfolio_companies = relationship(
|
||||
"CompanyTable",
|
||||
secondary=investor_company_association,
|
||||
back_populates="investors",
|
||||
)
|
||||
|
||||
sectors = relationship(
|
||||
"SectorTable",
|
||||
secondary=investor_sector_association,
|
||||
back_populates="investors",
|
||||
)
|
||||
|
||||
projects = relationship(
|
||||
"ProjectTable",
|
||||
secondary=project_investor_association,
|
||||
back_populates="investors",
|
||||
)
|
||||
|
||||
|
||||
class InvestorMember(Base, TimestampMixin):
|
||||
__tablename__ = "investor_members"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
role = Column(String, nullable=True)
|
||||
title = Column(String, nullable=True) # Alternative to role
|
||||
email = Column(String, nullable=True)
|
||||
source_url = Column(String, nullable=True) # URL where member info was found
|
||||
|
||||
investor_id = Column(Integer, ForeignKey("investors.id"))
|
||||
investor = relationship("InvestorTable", back_populates="team_members")
|
||||
|
||||
|
||||
class FundTable(Base, TimestampMixin):
|
||||
__tablename__ = "funds"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
investor_id = Column(Integer, ForeignKey("investors.id"), nullable=False)
|
||||
|
||||
# Fund details
|
||||
fund_name = Column(String, nullable=True)
|
||||
fund_size = Column(
|
||||
Integer, nullable=True
|
||||
) # Store as integer for numerical filtering
|
||||
fund_size_source_url = Column(String, nullable=True)
|
||||
|
||||
# Check size range (parsed from estimated_investment_size by LLM)
|
||||
check_size_lower = Column(Integer, nullable=True)
|
||||
check_size_upper = Column(Integer, nullable=True)
|
||||
|
||||
source_url = Column(String, nullable=True)
|
||||
source_provider = Column(String, nullable=True) # e.g., "Perplexity"
|
||||
|
||||
# Geographic focus as simple string
|
||||
geographic_focus = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
investor = relationship("InvestorTable", back_populates="funds")
|
||||
investment_stages = relationship(
|
||||
"InvestmentStageTable",
|
||||
secondary=fund_investment_stages_association,
|
||||
back_populates="funds",
|
||||
)
|
||||
sectors = relationship(
|
||||
"SectorTable",
|
||||
secondary=fund_sectors_association,
|
||||
back_populates="funds",
|
||||
)
|
||||
|
||||
|
||||
class InvestmentStageTable(Base, TimestampMixin):
|
||||
__tablename__ = "investment_stages"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
|
||||
# Relationships
|
||||
investors = relationship(
|
||||
"InvestorTable",
|
||||
secondary=investor_stage_association,
|
||||
back_populates="investment_stages",
|
||||
)
|
||||
funds = relationship(
|
||||
"FundTable",
|
||||
secondary=fund_investment_stages_association,
|
||||
back_populates="investment_stages",
|
||||
)
|
||||
|
||||
|
||||
class CompanyTable(Base, TimestampMixin):
|
||||
__tablename__ = "companies"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
industry = Column(String, nullable=True)
|
||||
location = Column(String, nullable=True)
|
||||
description = Column(String, nullable=True)
|
||||
founded_year = Column(Integer, nullable=True)
|
||||
website = Column(String, nullable=True)
|
||||
|
||||
members = relationship(
|
||||
"CompanyMember", back_populates="company", cascade="all, delete-orphan"
|
||||
)
|
||||
# Relationship back to investors
|
||||
investors = relationship(
|
||||
"InvestorTable",
|
||||
secondary=investor_company_association,
|
||||
back_populates="portfolio_companies",
|
||||
)
|
||||
|
||||
sectors = relationship(
|
||||
"SectorTable", secondary=company_sector_association, back_populates="companies"
|
||||
)
|
||||
|
||||
projects = relationship(
|
||||
"ProjectTable",
|
||||
secondary=project_company_association,
|
||||
back_populates="companies",
|
||||
)
|
||||
|
||||
|
||||
class CompanyMember(Base, TimestampMixin):
|
||||
__tablename__ = "company_members"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String)
|
||||
linkedin = Column(String, nullable=True)
|
||||
role = Column(String, nullable=True)
|
||||
company_id = Column(Integer, ForeignKey("companies.id"), nullable=False)
|
||||
|
||||
company = relationship("CompanyTable", back_populates="members")
|
||||
|
||||
|
||||
class SectorTable(Base, TimestampMixin):
|
||||
__tablename__ = "sectors"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
|
||||
# Relationships
|
||||
investors = relationship(
|
||||
"InvestorTable",
|
||||
secondary=investor_sector_association,
|
||||
back_populates="sectors",
|
||||
)
|
||||
companies = relationship(
|
||||
"CompanyTable", secondary=company_sector_association, back_populates="sectors"
|
||||
)
|
||||
projects = relationship(
|
||||
"ProjectTable", secondary=project_sector_association, back_populates="sector"
|
||||
)
|
||||
funds = relationship(
|
||||
"FundTable",
|
||||
secondary=fund_sectors_association,
|
||||
back_populates="sectors",
|
||||
)
|
||||
|
||||
|
||||
class ProjectTable(Base, TimestampMixin):
|
||||
__tablename__ = "projects"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
valuation = Column(Integer, nullable=True)
|
||||
|
||||
stage = Column(Enum(InvestmentStage), nullable=True)
|
||||
location = Column(String, nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
start_date = Column(DateTime, nullable=True)
|
||||
end_date = Column(DateTime, nullable=True)
|
||||
|
||||
sector = relationship(
|
||||
"SectorTable", secondary=project_sector_association, back_populates="projects"
|
||||
)
|
||||
investors = relationship(
|
||||
"InvestorTable",
|
||||
secondary=project_investor_association,
|
||||
back_populates="projects",
|
||||
)
|
||||
companies = relationship(
|
||||
"CompanyTable", secondary=project_company_association, back_populates="projects"
|
||||
)
|
||||
Binary file not shown.
@@ -1,117 +0,0 @@
|
||||
"""
|
||||
Migration: Add fields from feedback fixes
|
||||
Date: 2025-01-07
|
||||
|
||||
Adds the following fields:
|
||||
- projects.is_archived (INTEGER, default 0)
|
||||
- companies.product_service (TEXT, nullable)
|
||||
- companies.clients (TEXT, nullable - stored as JSON string)
|
||||
- investor_members.linkedin (VARCHAR, nullable)
|
||||
"""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path to import app modules
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from app.db.db import engine
|
||||
|
||||
|
||||
def check_column_exists(conn, table_name, column_name):
|
||||
"""Check if a column exists in a table"""
|
||||
result = conn.execute(text(f"PRAGMA table_info({table_name})"))
|
||||
columns = [row[1] for row in result]
|
||||
return column_name in columns
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""Add new columns to tables"""
|
||||
print("Running migration: Add feedback fixes fields")
|
||||
print("=" * 60)
|
||||
|
||||
with engine.begin() as conn: # Use begin() for transaction management
|
||||
# 1. Add is_archived to projects table
|
||||
print("\n1. Adding 'is_archived' column to projects table...")
|
||||
if check_column_exists(conn, "projects", "is_archived"):
|
||||
print(" ✓ Column 'is_archived' already exists. Skipping.")
|
||||
else:
|
||||
conn.execute(
|
||||
text(
|
||||
"ALTER TABLE projects ADD COLUMN is_archived INTEGER DEFAULT 0 NOT NULL"
|
||||
)
|
||||
)
|
||||
# Set default value for existing rows
|
||||
conn.execute(
|
||||
text("UPDATE projects SET is_archived = 0 WHERE is_archived IS NULL")
|
||||
)
|
||||
print(" ✓ Successfully added 'is_archived' column to projects table")
|
||||
|
||||
# 2. Add product_service to companies table
|
||||
print("\n2. Adding 'product_service' column to companies table...")
|
||||
if check_column_exists(conn, "companies", "product_service"):
|
||||
print(" ✓ Column 'product_service' already exists. Skipping.")
|
||||
else:
|
||||
conn.execute(text("ALTER TABLE companies ADD COLUMN product_service TEXT"))
|
||||
print(" ✓ Successfully added 'product_service' column to companies table")
|
||||
|
||||
# 3. Add clients to companies table
|
||||
print("\n3. Adding 'clients' column to companies table...")
|
||||
if check_column_exists(conn, "companies", "clients"):
|
||||
print(" ✓ Column 'clients' already exists. Skipping.")
|
||||
else:
|
||||
conn.execute(text("ALTER TABLE companies ADD COLUMN clients TEXT"))
|
||||
print(" ✓ Successfully added 'clients' column to companies table")
|
||||
|
||||
# 4. Add linkedin to investor_members table
|
||||
print("\n4. Adding 'linkedin' column to investor_members table...")
|
||||
if check_column_exists(conn, "investor_members", "linkedin"):
|
||||
print(" ✓ Column 'linkedin' already exists. Skipping.")
|
||||
else:
|
||||
conn.execute(
|
||||
text("ALTER TABLE investor_members ADD COLUMN linkedin VARCHAR")
|
||||
)
|
||||
print(" ✓ Successfully added 'linkedin' column to investor_members table")
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("Migration completed successfully!")
|
||||
|
||||
|
||||
def downgrade():
|
||||
"""Remove added columns from tables"""
|
||||
print("Running downgrade: Remove feedback fixes fields")
|
||||
print("=" * 60)
|
||||
|
||||
# Note: SQLite doesn't support DROP COLUMN directly
|
||||
print("\nWarning: SQLite doesn't support DROP COLUMN directly.")
|
||||
print("To remove these columns, you would need to:")
|
||||
print("1. Create new tables without the columns")
|
||||
print("2. Copy data from old tables to new tables")
|
||||
print("3. Drop old tables and rename new tables")
|
||||
print("\nColumns to remove:")
|
||||
print(" - projects.is_archived")
|
||||
print(" - companies.product_service")
|
||||
print(" - companies.clients")
|
||||
print(" - investor_members.linkedin")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run database migration")
|
||||
parser.add_argument(
|
||||
"direction",
|
||||
choices=["upgrade", "downgrade"],
|
||||
default="upgrade",
|
||||
nargs="?",
|
||||
help="Migration direction (default: upgrade)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.direction == "upgrade":
|
||||
upgrade()
|
||||
else:
|
||||
downgrade()
|
||||
@@ -1,67 +0,0 @@
|
||||
"""
|
||||
Migration: Add industry column to projects table
|
||||
Date: 2025-10-23
|
||||
"""
|
||||
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path to import app modules
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from sqlalchemy import create_engine, text
|
||||
from app.db.db import DATABASE_URL, engine
|
||||
|
||||
|
||||
def upgrade():
|
||||
"""Add industry column to projects table"""
|
||||
print("Running migration: Add industry column to projects table")
|
||||
|
||||
with engine.connect() as conn:
|
||||
# Check if column already exists
|
||||
result = conn.execute(text("PRAGMA table_info(projects)"))
|
||||
columns = [row[1] for row in result]
|
||||
|
||||
if 'industry' in columns:
|
||||
print("Column 'industry' already exists in projects table. Skipping migration.")
|
||||
return
|
||||
|
||||
# Add the industry column
|
||||
conn.execute(text("ALTER TABLE projects ADD COLUMN industry VARCHAR"))
|
||||
conn.commit()
|
||||
|
||||
print("Successfully added 'industry' column to projects table")
|
||||
|
||||
|
||||
def downgrade():
|
||||
"""Remove industry column from projects table"""
|
||||
print("Running downgrade: Remove industry column from projects table")
|
||||
|
||||
# Note: SQLite doesn't support DROP COLUMN directly
|
||||
# This is a simplified version - in production you'd need to recreate the table
|
||||
print("Warning: SQLite doesn't support DROP COLUMN.")
|
||||
print("To remove the column, you would need to:")
|
||||
print("1. Create a new table without the industry column")
|
||||
print("2. Copy data from old table to new table")
|
||||
print("3. Drop old table and rename new table")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="Run database migration")
|
||||
parser.add_argument(
|
||||
"direction",
|
||||
choices=["upgrade", "downgrade"],
|
||||
default="upgrade",
|
||||
nargs="?",
|
||||
help="Migration direction (default: upgrade)"
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.direction == "upgrade":
|
||||
upgrade()
|
||||
else:
|
||||
downgrade()
|
||||
@@ -1,40 +1,26 @@
|
||||
aiofiles==24.1.0
|
||||
aiohappyeyeballs==2.6.1
|
||||
aiohttp==3.12.15
|
||||
aiosignal==1.4.0
|
||||
aiosqlite==0.21.0
|
||||
alphashape==1.3.1
|
||||
annotated-types==0.7.0
|
||||
anyio==4.10.0
|
||||
attrs==25.3.0
|
||||
backoff==2.2.1
|
||||
bcrypt==4.3.0
|
||||
beautifulsoup4==4.14.2
|
||||
brotli==1.1.0
|
||||
build==1.3.0
|
||||
cachetools==5.5.2
|
||||
certifi==2025.8.3
|
||||
cffi==2.0.0
|
||||
chardet==5.2.0
|
||||
charset-normalizer==3.4.3
|
||||
chromadb==1.0.20
|
||||
click==8.2.1
|
||||
click-log==0.4.0
|
||||
coloredlogs==15.0.1
|
||||
crawl4ai==0.7.4
|
||||
cryptography==46.0.1
|
||||
dataclasses-json==0.6.7
|
||||
ddgs==9.5.2
|
||||
distro==1.9.0
|
||||
dnspython==2.7.0
|
||||
durationpy==0.10
|
||||
email-validator==2.3.0
|
||||
fake-http-header==0.3.5
|
||||
fake-useragent==2.2.0
|
||||
fastapi==0.116.1
|
||||
fastapi-cli==0.0.8
|
||||
fastapi-cloud-cli==0.1.5
|
||||
fastuuid==0.13.5
|
||||
filelock==3.19.1
|
||||
flatbuffers==25.2.10
|
||||
frozenlist==1.7.0
|
||||
@@ -44,24 +30,19 @@ googleapis-common-protos==1.70.0
|
||||
greenlet==3.2.4
|
||||
grpcio==1.74.0
|
||||
h11==0.16.0
|
||||
h2==4.3.0
|
||||
hf-xet==1.1.8
|
||||
hpack==4.1.0
|
||||
httpcore==1.0.9
|
||||
httptools==0.6.4
|
||||
httpx==0.28.1
|
||||
httpx-sse==0.4.1
|
||||
huggingface-hub==0.34.4
|
||||
humanfriendly==10.0
|
||||
humanize==4.13.0
|
||||
hyperframe==6.1.0
|
||||
idna==3.10
|
||||
importlib-metadata==8.7.0
|
||||
importlib-resources==6.5.2
|
||||
itsdangerous==2.2.0
|
||||
jinja2==3.1.6
|
||||
jiter==0.10.0
|
||||
joblib==1.5.2
|
||||
jsonpatch==1.33
|
||||
jsonpointer==3.0.0
|
||||
jsonschema==4.25.1
|
||||
@@ -77,9 +58,6 @@ langgraph-checkpoint==2.1.1
|
||||
langgraph-prebuilt==0.6.4
|
||||
langgraph-sdk==0.2.4
|
||||
langsmith==0.4.20
|
||||
lark==1.3.0
|
||||
litellm==1.77.5
|
||||
lxml==5.4.0
|
||||
markdown-it-py==4.0.0
|
||||
markupsafe==3.0.2
|
||||
marshmallow==3.26.1
|
||||
@@ -88,8 +66,6 @@ mmh3==5.2.0
|
||||
mpmath==1.3.0
|
||||
multidict==6.6.4
|
||||
mypy-extensions==1.1.0
|
||||
networkx==3.5
|
||||
nltk==3.9.1
|
||||
numpy==2.3.2
|
||||
oauthlib==3.3.1
|
||||
onnxruntime==1.22.1
|
||||
@@ -105,26 +81,18 @@ ormsgpack==1.10.0
|
||||
overrides==7.7.0
|
||||
packaging==25.0
|
||||
pandas==2.3.2
|
||||
patchright==1.55.2
|
||||
pillow==11.3.0
|
||||
pip==25.2
|
||||
playwright==1.55.0
|
||||
posthog==5.4.0
|
||||
primp==0.15.0
|
||||
propcache==0.3.2
|
||||
protobuf==6.32.0
|
||||
psutil==7.1.0
|
||||
pyasn1==0.6.1
|
||||
pyasn1-modules==0.4.2
|
||||
pybase64==1.4.2
|
||||
pycparser==2.23
|
||||
pydantic==2.11.7
|
||||
pydantic-core==2.33.2
|
||||
pydantic-extra-types==2.10.5
|
||||
pydantic-settings==2.10.1
|
||||
pyee==13.0.0
|
||||
pygments==2.19.2
|
||||
pyopenssl==25.3.0
|
||||
pypika==0.48.9
|
||||
pyproject-hooks==1.2.0
|
||||
python-dateutil==2.9.0.post0
|
||||
@@ -132,7 +100,6 @@ python-dotenv==1.1.1
|
||||
python-multipart==0.0.20
|
||||
pytz==2025.2
|
||||
pyyaml==6.0.2
|
||||
rank-bm25==0.2.2
|
||||
referencing==0.36.2
|
||||
regex==2025.7.34
|
||||
requests==2.32.5
|
||||
@@ -143,24 +110,17 @@ rich-toolkit==0.15.0
|
||||
rignore==0.6.4
|
||||
rpds-py==0.27.1
|
||||
rsa==4.9.1
|
||||
rtree==1.4.1
|
||||
scipy==1.16.2
|
||||
sentry-sdk==2.35.1
|
||||
shapely==2.1.2
|
||||
shellingham==1.5.4
|
||||
six==1.17.0
|
||||
sniffio==1.3.1
|
||||
snowballstemmer==2.2.0
|
||||
soupsieve==2.8
|
||||
sqlalchemy==2.0.43
|
||||
starlette==0.47.3
|
||||
sympy==1.14.0
|
||||
tenacity==9.1.2
|
||||
tf-playwright-stealth==1.2.0
|
||||
tiktoken==0.11.0
|
||||
tokenizers==0.21.4
|
||||
tqdm==4.67.1
|
||||
trimesh==4.8.3
|
||||
typer==0.16.1
|
||||
typing-extensions==4.15.0
|
||||
typing-inspect==0.9.0
|
||||
|
||||
@@ -1,68 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Server management script for app/main.py
|
||||
# Usage: ./server_manager.sh start|stop|restart
|
||||
|
||||
PID_FILE="server.pid"
|
||||
LOG_FILE="server.log"
|
||||
|
||||
start() {
|
||||
if [ -f "$PID_FILE" ] && kill -0 $(cat "$PID_FILE") 2>/dev/null; then
|
||||
echo "Server is already running (PID: $(cat "$PID_FILE"))"
|
||||
return 1
|
||||
fi
|
||||
echo "Starting server..."
|
||||
nohup uv run app/main.py > "$LOG_FILE" 2>&1 &
|
||||
echo $! > "$PID_FILE"
|
||||
echo "Server started (PID: $(cat "$PID_FILE"))"
|
||||
}
|
||||
|
||||
stop() {
|
||||
if [ ! -f "$PID_FILE" ]; then
|
||||
echo "Server is not running (no PID file found)"
|
||||
return 1
|
||||
fi
|
||||
PID=$(cat "$PID_FILE")
|
||||
if ! kill -0 "$PID" 2>/dev/null; then
|
||||
echo "Server is not running (PID $PID not found)"
|
||||
rm -f "$PID_FILE"
|
||||
return 1
|
||||
fi
|
||||
echo "Stopping server (PID: $PID)..."
|
||||
kill "$PID"
|
||||
# Wait for process to stop
|
||||
for i in {1..10}; do
|
||||
if ! kill -0 "$PID" 2>/dev/null; then
|
||||
break
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
if kill -0 "$PID" 2>/dev/null; then
|
||||
echo "Force killing server..."
|
||||
kill -9 "$PID"
|
||||
fi
|
||||
rm -f "$PID_FILE"
|
||||
echo "Server stopped"
|
||||
}
|
||||
|
||||
restart() {
|
||||
stop
|
||||
sleep 2
|
||||
start
|
||||
}
|
||||
|
||||
case "$1" in
|
||||
start)
|
||||
start
|
||||
;;
|
||||
stop)
|
||||
stop
|
||||
;;
|
||||
restart)
|
||||
restart
|
||||
;;
|
||||
*)
|
||||
echo "Usage: $0 {start|stop|restart}"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
@@ -1,310 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Update Investor Members LinkedIn Profiles Script
|
||||
|
||||
This script finds and updates LinkedIn profile URLs for investor members in the database.
|
||||
Uses crawl4ai to efficiently scrape team pages and extract LinkedIn URLs.
|
||||
|
||||
Usage:
|
||||
python update_linkedin_profiles.py [--test] [--limit N] [--skip-existing]
|
||||
|
||||
Options:
|
||||
--test Test mode: process only 10 records and don't update database
|
||||
--limit N Process only N records (default: all)
|
||||
--skip-existing Skip members that already have LinkedIn URLs
|
||||
--start-from N Start from record N (for resuming)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from datetime import datetime
|
||||
|
||||
# Add app to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "app"))
|
||||
|
||||
from db.db import get_db_session
|
||||
from db.models import InvestorMember, InvestorTable
|
||||
from linkedin_scraper import LinkedInProfileScraper, format_linkedin_url
|
||||
|
||||
|
||||
def progress_callback(current, total, result):
|
||||
"""Print progress updates"""
|
||||
percent = (current / total) * 100
|
||||
status = "✓" if result["linkedin_url"] else "✗"
|
||||
print(f"[{current}/{total} - {percent:.1f}%] {status} {result['member_name']}")
|
||||
if result["linkedin_url"]:
|
||||
print(
|
||||
f" → {result['linkedin_url']} (confidence: {result['confidence']}%, method: {result['method']})"
|
||||
)
|
||||
|
||||
|
||||
def create_db_callback(test_mode=False):
|
||||
"""
|
||||
Create a callback function that saves LinkedIn profiles to the database immediately.
|
||||
This allows stopping and resuming without losing progress.
|
||||
"""
|
||||
saved_count = {"count": 0} # Use dict to allow modification in closure
|
||||
|
||||
def db_callback(member_id: int, linkedin_url: str) -> bool:
|
||||
"""Save LinkedIn URL to database immediately"""
|
||||
if test_mode:
|
||||
print(f" [TEST] Would save to DB: member {member_id}")
|
||||
saved_count["count"] += 1
|
||||
return True
|
||||
|
||||
try:
|
||||
db = get_db_session()
|
||||
member = db.query(InvestorMember).filter_by(id=member_id).first()
|
||||
if member:
|
||||
member.linkedin = format_linkedin_url(linkedin_url)
|
||||
db.commit()
|
||||
saved_count["count"] += 1
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f" ⚠️ DB Error for member {member_id}: {e}")
|
||||
try:
|
||||
db.rollback()
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
finally:
|
||||
try:
|
||||
db.close()
|
||||
except Exception:
|
||||
pass
|
||||
return False
|
||||
|
||||
return db_callback, saved_count
|
||||
|
||||
|
||||
def update_database(members_data, test_mode=False):
|
||||
"""Update database with found LinkedIn profiles"""
|
||||
db = get_db_session()
|
||||
|
||||
try:
|
||||
updated_count = 0
|
||||
for data in members_data:
|
||||
if data["linkedin_url"] and data["member_id"]:
|
||||
if not test_mode:
|
||||
member = (
|
||||
db.query(InvestorMember).filter_by(id=data["member_id"]).first()
|
||||
)
|
||||
if member:
|
||||
member.linkedin = format_linkedin_url(data["linkedin_url"])
|
||||
updated_count += 1
|
||||
else:
|
||||
print(
|
||||
f" [TEST MODE] Would update member {data['member_id']}: {data['linkedin_url']}"
|
||||
)
|
||||
updated_count += 1
|
||||
|
||||
if not test_mode:
|
||||
db.commit()
|
||||
print(f"\n✓ Successfully updated {updated_count} records in database")
|
||||
else:
|
||||
print(f"\n[TEST MODE] Would have updated {updated_count} records")
|
||||
|
||||
return updated_count
|
||||
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f"\n✗ Error updating database: {e}")
|
||||
raise
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
def save_results(results, filename="linkedin_scraping_results.json"):
|
||||
"""Save results to JSON file for backup/analysis"""
|
||||
output = {
|
||||
"timestamp": datetime.now().isoformat(),
|
||||
"total_processed": len(results),
|
||||
"found_count": sum(1 for r in results if r["linkedin_url"]),
|
||||
"results": results,
|
||||
}
|
||||
|
||||
with open(filename, "w") as f:
|
||||
json.dump(output, f, indent=2)
|
||||
|
||||
print(f"\n✓ Results saved to {filename}")
|
||||
|
||||
|
||||
def print_summary(results):
|
||||
"""Print summary statistics"""
|
||||
total = len(results)
|
||||
found = sum(1 for r in results if r["linkedin_url"])
|
||||
not_found = total - found
|
||||
|
||||
# Count by method
|
||||
methods = {}
|
||||
for r in results:
|
||||
if r["linkedin_url"]:
|
||||
method = r["method"]
|
||||
methods[method] = methods.get(method, 0) + 1
|
||||
|
||||
# Average confidence for found profiles
|
||||
avg_confidence = (
|
||||
sum(r["confidence"] for r in results if r["linkedin_url"]) / found
|
||||
if found > 0
|
||||
else 0
|
||||
)
|
||||
|
||||
print("\n" + "=" * 60)
|
||||
print("SUMMARY")
|
||||
print("=" * 60)
|
||||
print(f"Total processed: {total}")
|
||||
print(f"LinkedIn found: {found} ({found / total * 100:.1f}%)")
|
||||
print(f"Not found: {not_found} ({not_found / total * 100:.1f}%)")
|
||||
print(f"\nAverage confidence: {avg_confidence:.1f}%")
|
||||
print("\nMethods used:")
|
||||
for method, count in sorted(methods.items(), key=lambda x: x[1], reverse=True):
|
||||
print(f" {method:20s} {count:5d} ({count / found * 100:.1f}%)")
|
||||
print("=" * 60)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Update LinkedIn profiles for investor members"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--test",
|
||||
action="store_true",
|
||||
help="Test mode: process only 10 records without updating database",
|
||||
)
|
||||
parser.add_argument("--limit", type=int, help="Limit number of records to process")
|
||||
parser.add_argument(
|
||||
"--skip-existing",
|
||||
action="store_true",
|
||||
help="Skip members that already have LinkedIn URLs",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--start-from",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Start from record N (for resuming interrupted runs)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--rate-limit",
|
||||
type=float,
|
||||
default=0.5,
|
||||
help="Delay between URL crawls in seconds (default: 0.5)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Test mode overrides limit
|
||||
if args.test and not args.limit:
|
||||
args.limit = 10
|
||||
|
||||
print("=" * 60)
|
||||
print("LinkedIn Profile Scraper for Investor Members (crawl4ai)")
|
||||
print("=" * 60)
|
||||
|
||||
if args.test:
|
||||
print("\n⚠️ TEST MODE - No database changes will be made")
|
||||
|
||||
# Initialize database and scraper
|
||||
db = get_db_session()
|
||||
|
||||
try:
|
||||
# Build query
|
||||
query = db.query(InvestorMember, InvestorTable).join(
|
||||
InvestorTable, InvestorMember.investor_id == InvestorTable.id
|
||||
)
|
||||
|
||||
# Filter existing if requested
|
||||
if args.skip_existing:
|
||||
query = query.filter(
|
||||
(InvestorMember.linkedin.is_(None)) | (InvestorMember.linkedin == "")
|
||||
)
|
||||
print("\n✓ Filtering to members without LinkedIn profiles")
|
||||
|
||||
# Get total count
|
||||
total_available = query.count()
|
||||
print(f"\n✓ Found {total_available} members to process")
|
||||
|
||||
# Apply offset and limit
|
||||
if args.start_from > 0:
|
||||
query = query.offset(args.start_from)
|
||||
print(f"✓ Starting from record {args.start_from}")
|
||||
|
||||
if args.limit:
|
||||
query = query.limit(args.limit)
|
||||
print(f"✓ Processing {args.limit} records")
|
||||
|
||||
# Fetch members
|
||||
members_data = []
|
||||
for member, investor in query.all():
|
||||
members_data.append(
|
||||
{
|
||||
"id": member.id,
|
||||
"name": member.name,
|
||||
"company": investor.name,
|
||||
"role": member.role,
|
||||
"source_url": member.source_url,
|
||||
}
|
||||
)
|
||||
|
||||
if not members_data:
|
||||
print("\n⚠️ No members to process")
|
||||
return
|
||||
|
||||
# Count unique source URLs
|
||||
unique_urls = len(set(m["source_url"] for m in members_data if m["source_url"]))
|
||||
with_urls = sum(1 for m in members_data if m["source_url"])
|
||||
|
||||
print(f"\n✓ Loaded {len(members_data)} members")
|
||||
print(
|
||||
f"✓ {with_urls} members have source URLs ({unique_urls} unique pages to crawl)"
|
||||
)
|
||||
print(f"✓ {len(members_data) - with_urls} members without source URLs")
|
||||
print(f"✓ Rate limit: {args.rate_limit}s between page crawls")
|
||||
print("\nStarting LinkedIn profile search using crawl4ai...\n")
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
# Initialize scraper
|
||||
scraper = LinkedInProfileScraper(rate_limit_delay=args.rate_limit, use_cache=True)
|
||||
|
||||
print("ℹ️ Using crawl4ai to scrape team pages and extract LinkedIn URLs")
|
||||
print(
|
||||
"ℹ️ Profiles are saved to database IMMEDIATELY when found - safe to stop anytime!\n"
|
||||
)
|
||||
|
||||
# Create database callback for real-time saving
|
||||
db_callback, saved_count = create_db_callback(test_mode=args.test)
|
||||
|
||||
# Process members asynchronously with real-time DB saving
|
||||
results = asyncio.run(
|
||||
scraper.batch_find_profiles(
|
||||
members_data, progress_callback=progress_callback, db_callback=db_callback
|
||||
)
|
||||
)
|
||||
|
||||
# Print summary
|
||||
print_summary(results)
|
||||
|
||||
# Save results
|
||||
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
||||
results_file = f"linkedin_results_{timestamp}.json"
|
||||
save_results(results, results_file)
|
||||
|
||||
# Show database update summary
|
||||
if not args.test:
|
||||
print(
|
||||
f"\n✓ Database updated in real-time: {saved_count['count']} profiles saved"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"\n[TEST MODE] Would have saved {saved_count['count']} profiles to database"
|
||||
)
|
||||
|
||||
print("\n✓ Done! You can resume anytime with --skip-existing")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user