22 Commits

Author SHA1 Message Date
bolade 25d83f24b7 completed investors linkedin 2025-11-28 07:19:58 +01:00
michael 3bc8a24c0c feat: Add LinkedIn URL support for investor synchronization and update schemas 2025-11-28 06:18:04 +00:00
bolade 495f8a0ff6 added linkedin profiles 2025-11-27 16:44:22 +01:00
michael 100e0b2b0c made improvements 2025-11-26 08:04:11 +00:00
bolade b92feaa13a refactor: Clean up migration script and improve readability by removing unnecessary imports and formatting 2025-11-11 20:28:20 +01:00
bolade 215fec2895 made corrections based on feedback 2025-11-11 20:27:55 +01:00
bolade 5e83734acf feat: Enhance data models and sorting logic for investors and projects 2025-11-11 13:10:28 +01:00
michael 0e4763bf4f updated db 2025-11-11 12:07:01 +00:00
michael 8a25e892ad Merge branch 'version_three' of http://23.29.118.76:3000/bolade/Anton_wireframe into version_three 2025-10-28 23:31:13 +00:00
bolade 6b9fd86ab7 refactor: Improve report generation logic and adjust scoring weights 2025-10-29 00:27:39 +01:00
michael db2addb835 Merge branch 'version_three' of http://23.29.118.76:3000/bolade/Anton_wireframe into version_three 2025-10-28 22:16:29 +00:00
michael 7048847a42 db update 2025-10-28 22:16:06 +00:00
bolade 45e1f099b8 fixed insight 2025-10-28 23:14:57 +01:00
bolade e19c8f96eb feat: Add server management script with start, stop, and restart functionality 2025-10-28 22:03:32 +01:00
bolade 3ab2592c22 Added logging to main 2025-10-28 21:34:35 +01:00
michael f63672bdac added db 2025-10-28 20:13:49 +00:00
michael c53455cc06 feat: Enhance compatibility scoring and report generation with new methods and models 2025-10-28 20:13:45 +00:00
bolade 02c8bb816f made querying async 2025-10-28 21:09:47 +01:00
bolade bb03f6ade4 fixed querying 2025-10-28 20:54:15 +01:00
bolade ff0010019e feat: Implement company querying functionality with natural language processing and logging 2025-10-27 20:13:24 +01:00
michael 1ac755b2d7 feat: Add industry column to ProjectTable and update related schemas and query filters 2025-10-23 12:52:52 +00:00
bolade 483c2cc114 feat: Update investor report generation and HTML template to include fund details and improve data handling 2025-10-21 10:48:58 +01:00
33 changed files with 2833 additions and 395 deletions
+5 -1
View File
@@ -13,4 +13,8 @@
*.cypython
nohup.out
nohup.out
server.log
server.pid
Binary file not shown.
Binary file not shown.
Binary file not shown.
-1
View File
@@ -1,5 +1,4 @@
import os
from pathlib import Path
from typing import Annotated
from fastapi import Depends
+5
View File
@@ -162,6 +162,7 @@ class InvestorMember(Base, TimestampMixin):
role = Column(String, nullable=True)
title = Column(String, nullable=True) # Alternative to role
email = Column(String, nullable=True)
linkedin = Column(String, nullable=True) # LinkedIn profile URL
source_url = Column(String, nullable=True) # URL where member info was found
investor_id = Column(Integer, ForeignKey("investors.id"))
@@ -215,6 +216,8 @@ class CompanyTable(Base, TimestampMixin):
description = Column(String, nullable=True)
founded_year = Column(Integer, nullable=True)
website = Column(String, nullable=True)
product_service = Column(Text, nullable=True) # Product/service description
clients = Column(JSON, nullable=True) # List of client names or client information
members = relationship(
"CompanyMember", back_populates="company", cascade="all, delete-orphan"
@@ -296,9 +299,11 @@ class ProjectTable(Base, TimestampMixin):
stage = Column(Enum(InvestmentStage), nullable=True)
location = Column(String, nullable=True)
industry = Column(String, nullable=True)
description = Column(Text, nullable=True)
start_date = Column(DateTime, nullable=True)
end_date = Column(DateTime, nullable=True)
is_archived = Column(Integer, default=0, nullable=False) # 0 = active, 1 = archived
sector = relationship(
"SectorTable", secondary=project_sector_association, back_populates="projects"
+730
View File
@@ -0,0 +1,730 @@
"""
LinkedIn Profile Scraper for Investor Members
This module uses crawl4ai to scrape team pages and find LinkedIn profiles.
Strategies:
1. Crawl the source_url (team pages) to extract LinkedIn profile links
2. Use LLM-powered web search to find LinkedIn profiles by name
Key advantages of crawl4ai:
- Handles JavaScript-rendered pages
- Better at extracting content from modern websites
- More reliable than simple requests
"""
import asyncio
import logging
import os
import re
from typing import Dict, List, Optional
from crawl4ai import AsyncWebCrawler
from ddgs import DDGS
from dotenv import load_dotenv
from langchain_openai import ChatOpenAI
# Setup logging
logging.basicConfig(
level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s"
)
logger = logging.getLogger("linkedin_scraper")
load_dotenv()
OPENROUTER_API_KEY = os.getenv("OPENROUTER_API_KEY")
class LinkedInProfileScraper:
"""
LinkedIn profile finder using crawl4ai and LLM-powered web search.
Strategies:
1. Crawl source URLs (team pages) to extract LinkedIn links
2. Use LLM-powered web search to find profiles by name
"""
def __init__(
self,
rate_limit_delay: float = 0.5,
use_cache: bool = True,
use_llm_search: bool = True,
):
"""
Initialize the scraper
Args:
rate_limit_delay: Delay between requests in seconds
use_cache: Whether to cache crawled pages
use_llm_search: Whether to use LLM-powered web search as fallback
"""
self.rate_limit_delay = rate_limit_delay
self.use_cache = use_cache
self.use_llm_search = use_llm_search and OPENROUTER_API_KEY
self.page_cache: Dict[str, str] = {} # Cache crawled pages by URL
self.html_cache: Dict[str, str] = {} # Cache HTML separately
self.profile_cache: Dict[str, Dict] = {} # Cache results by member
# Initialize LLM agent if API key available
if self.use_llm_search:
self._init_llm_agent()
else:
self.llm = None
self.agent = None
self.ddg_search = None
logger.info("LLM search disabled (no OPENROUTER_API_KEY)")
def _init_llm_agent(self):
"""Initialize LLM agent for web search"""
try:
self.llm = ChatOpenAI(
api_key=OPENROUTER_API_KEY,
base_url="https://openrouter.ai/api/v1",
model="x-ai/grok-4.1-fast:free",
temperature=0,
)
self.ddg_search = DDGS()
logger.info("LLM search agent initialized")
except Exception as e:
logger.error(f"Failed to initialize LLM agent: {e}")
self.llm = None
self.ddg_search = None
def web_search(self, query: str) -> List[Dict]:
"""Tool to search the web using DuckDuckGo"""
if not self.ddg_search:
return []
try:
results = list(self.ddg_search.text(query, max_results=10))
return results
except Exception as e:
logger.error(f"Web search error: {e}")
return []
async def crawl_page(self, url: str) -> Optional[str]:
"""
Crawl a webpage and return its content.
Args:
url: URL to crawl
Returns:
Page content as markdown/text, or None if failed
"""
if not url:
return None
# Check cache first
if self.use_cache and url in self.page_cache:
logger.debug(f"Using cached page for {url}")
return self.page_cache[url]
try:
logger.info(f"Crawling: {url}")
async with AsyncWebCrawler() as crawler:
result = await crawler.arun(url)
if result and result.markdown:
content = result.markdown
# Also get HTML for better link extraction
html_content = result.html if hasattr(result, "html") else ""
# Cache the results
if self.use_cache:
self.page_cache[url] = content
self.html_cache[url] = html_content
return content
except Exception as e:
logger.error(f"Error crawling {url}: {e}")
return None
def extract_linkedin_urls_from_content(self, content: str) -> List[Dict[str, str]]:
"""
Extract all LinkedIn profile URLs from content (HTML or markdown).
Returns:
List of dicts with 'url', 'context', and 'username'
"""
linkedin_links = []
# Pattern for LinkedIn profile URLs (handles country-specific domains)
linkedin_pattern = (
r"https?://(?:www\.)?(?:[a-z]{2,3}\.)?linkedin\.com/in/([a-zA-Z0-9_-]+)/?"
)
# Find all LinkedIn URLs
matches = list(re.finditer(linkedin_pattern, content, re.IGNORECASE))
for match in matches:
url = match.group(0).rstrip("/")
# Normalize URL
url = self._normalize_linkedin_url(url)
# Get surrounding context (200 chars before and after)
start = max(0, match.start() - 200)
end = min(len(content), match.end() + 200)
context = content[start:end]
# Clean up context (remove HTML tags for readability)
context = re.sub(r"<[^>]+>", " ", context)
context = " ".join(context.split()) # Normalize whitespace
linkedin_links.append(
{"url": url, "context": context, "username": match.group(1)}
)
# Remove duplicates while preserving order
seen_urls = set()
unique_links = []
for link in linkedin_links:
if link["url"] not in seen_urls:
seen_urls.add(link["url"])
unique_links.append(link)
return unique_links
def _normalize_linkedin_url(self, url: str) -> str:
"""Normalize LinkedIn URL to standard format"""
# Remove trailing slashes
url = url.rstrip("/")
# Convert country-specific to www
url = re.sub(
r"https?://[a-z]{2,3}\.linkedin\.com", "https://www.linkedin.com", url
)
# Ensure https
if url.startswith("http://"):
url = url.replace("http://", "https://")
return url
def _name_matches_context(self, name: str, context: str) -> float:
"""
Check if a person's name appears in the context around a LinkedIn URL.
Returns:
Confidence score 0-100
"""
if not name or not context:
return 0
context_lower = context.lower()
name_lower = name.lower()
# Split name into parts (handle multiple spaces, titles like "Dr.", etc.)
name_parts = [p for p in name_lower.replace(".", " ").split() if len(p) > 1]
# Check for full name match
if name_lower in context_lower:
return 95
# Check for name parts in context
matches = sum(
1 for part in name_parts if part in context_lower and len(part) > 2
)
if len(name_parts) > 0:
if matches == len(name_parts):
return 90 # All name parts found
elif matches >= 2:
return 75 # At least 2 parts found (first + last typically)
elif matches == 1 and len(name_parts) <= 2:
return 50 # Only one part found but name is short
elif matches == 1:
return 35 # Only one part found
return 0
def _name_matches_username(self, name: str, username: str) -> float:
"""
Check if LinkedIn username contains parts of the name.
Returns:
Confidence score 0-100
"""
if not name or not username:
return 0
name_lower = name.lower()
username_lower = username.lower().replace("-", " ").replace("_", " ")
name_parts = [p for p in name_lower.replace(".", " ").split() if len(p) > 2]
matches = sum(1 for part in name_parts if part in username_lower)
if len(name_parts) > 0:
if matches == len(name_parts) and len(name_parts) >= 2:
return 85 # Full name in username
elif matches >= 2:
return 70 # Multiple parts match
elif matches == 1:
return 35 # Only one part matches
return 0
async def find_linkedin_from_source(
self, name: str, source_url: str, role: Optional[str] = None
) -> Dict:
"""
Find LinkedIn profile by crawling the source URL (team page).
Args:
name: Person's name
source_url: URL of the team/about page
role: Person's role (for additional context matching)
Returns:
Dict with linkedin_url, confidence, method, notes
"""
if not source_url:
return {
"linkedin_url": None,
"confidence": 0,
"method": "source_crawl",
"notes": "No source URL provided",
}
# Crawl the page
content = await self.crawl_page(source_url)
if not content:
return {
"linkedin_url": None,
"confidence": 0,
"method": "source_crawl",
"notes": f"Failed to crawl {source_url}",
}
# Get HTML for better link extraction
html = self.html_cache.get(source_url, content)
# Extract all LinkedIn URLs from both HTML and markdown
linkedin_links = self.extract_linkedin_urls_from_content(html)
if not linkedin_links:
linkedin_links = self.extract_linkedin_urls_from_content(content)
if not linkedin_links:
return {
"linkedin_url": None,
"confidence": 0,
"method": "source_crawl",
"notes": f"No LinkedIn URLs found on {source_url}",
}
# Score each LinkedIn URL based on name matching
best_match = None
best_score = 0
for link in linkedin_links:
# Score based on context matching
context_score = self._name_matches_context(name, link["context"])
# Score based on username matching
username_score = self._name_matches_username(name, link["username"])
# Also check if role appears in context
role_bonus = 0
if role and role.lower() in link["context"].lower():
role_bonus = 10
# Combined score (take best of context or username, plus role bonus)
total_score = max(context_score, username_score) + role_bonus
logger.debug(
f" {name} -> {link['url']}: context={context_score}, username={username_score}, role={role_bonus}, total={total_score}"
)
if total_score > best_score:
best_score = total_score
best_match = link
if best_match and best_score >= 30: # Minimum threshold
return {
"linkedin_url": best_match["url"],
"confidence": min(best_score, 100),
"method": "source_crawl",
"notes": f"Found on {source_url}",
}
return {
"linkedin_url": None,
"confidence": 0,
"method": "source_crawl",
"notes": f'No matching LinkedIn profile found for "{name}" on {source_url}',
}
async def find_linkedin_via_search(
self, name: str, company: str, role: Optional[str] = None
) -> Dict:
"""
Find LinkedIn profile using web search.
Args:
name: Person's name
company: Company/investor name
role: Person's role (optional)
Returns:
Dict with linkedin_url, confidence, method, notes
"""
if not self.ddg_search:
return {
"linkedin_url": None,
"confidence": 0,
"method": "web_search",
"notes": "Web search not available",
}
try:
# Build search query - search for LinkedIn profile
query = f"{name} {company} site:linkedin.com/in"
if role:
query = f"{name} {role} {company} site:linkedin.com/in"
logger.debug(f"Searching: {query}")
results = self.web_search(query)
if results:
# Look for LinkedIn profile URLs in results
linkedin_pattern = r"https?://(?:www\.)?(?:[a-z]{2,3}\.)?linkedin\.com/in/([a-zA-Z0-9_-]+)"
for result in results:
url = result.get("href") or result.get("link") or ""
title = result.get("title", "").lower()
body = result.get("body", "").lower()
match = re.search(linkedin_pattern, url, re.IGNORECASE)
if match:
linkedin_url = self._normalize_linkedin_url(match.group(0))
username = match.group(1)
# Score based on name matching in title/body and username
context = f"{title} {body}"
context_score = self._name_matches_context(name, context)
username_score = self._name_matches_username(name, username)
total_score = max(context_score, username_score)
if total_score >= 30:
return {
"linkedin_url": linkedin_url,
"confidence": min(
total_score, 90
), # Cap at 90 for search results
"method": "web_search",
"notes": "Found via web search",
}
return {
"linkedin_url": None,
"confidence": 0,
"method": "web_search",
"notes": "No matching profile found in search results",
}
except Exception as e:
logger.error(f"Web search error for {name}: {e}")
return {
"linkedin_url": None,
"confidence": 0,
"method": "web_search",
"notes": f"Search error: {str(e)}",
}
async def find_linkedin_profile(
self,
name: str,
company: str,
role: Optional[str] = None,
source_url: Optional[str] = None,
) -> Dict:
"""
Find LinkedIn profile for a person.
Primary strategy: Crawl source URL to find LinkedIn links.
Args:
name: Person's name
company: Company/investor name
role: Person's role/title (optional)
source_url: URL where person info was found (optional)
Returns:
Dict with:
- linkedin_url: Found LinkedIn URL or None
- confidence: Confidence score (0-100)
- method: Method used to find the profile
- notes: Additional information
"""
cache_key = f"{name}|{company}"
# Check cache
if self.use_cache and cache_key in self.profile_cache:
logger.debug(f"Using cached result for {name}")
return self.profile_cache[cache_key]
result = {"linkedin_url": None, "confidence": 0, "method": "none", "notes": ""}
# Primary strategy: Crawl source URL
if source_url:
result = await self.find_linkedin_from_source(name, source_url, role)
if result["linkedin_url"]:
if self.use_cache:
self.profile_cache[cache_key] = result
return result
# Fallback strategy: Web search (if enabled and no result from source crawl)
if self.use_llm_search and not result.get("linkedin_url"):
search_result = await self.find_linkedin_via_search(name, company, role)
if search_result["linkedin_url"]:
if self.use_cache:
self.profile_cache[cache_key] = search_result
return search_result
# If no source URL or no match found
if not result["linkedin_url"]:
result = {
"linkedin_url": None,
"confidence": 0,
"method": "none",
"notes": "No source URL available"
if not source_url
else result.get("notes", "Not found"),
}
if self.use_cache:
self.profile_cache[cache_key] = result
return result
async def batch_find_profiles(
self, members: List[Dict], progress_callback=None, db_callback=None
) -> List[Dict]:
"""
Find LinkedIn profiles for multiple members efficiently.
Groups members by source_url to minimize crawling the same page multiple times.
Args:
members: List of dicts with 'name', 'company', 'role', 'source_url', 'id'
progress_callback: Optional callback function(current, total, result)
db_callback: Optional callback to save to database immediately when profile found
Signature: db_callback(member_id, linkedin_url) -> bool
Returns:
List of results for each member
"""
results = []
total = len(members)
# Group members by source_url for efficient crawling
url_groups: Dict[str, List[Dict]] = {}
no_url_members = []
for member in members:
url = member.get("source_url")
if url:
if url not in url_groups:
url_groups[url] = []
url_groups[url].append(member)
else:
no_url_members.append(member)
logger.info(
f"Processing {len(url_groups)} unique source URLs for {total} members"
)
logger.info(f"Members with source URLs: {total - len(no_url_members)}")
logger.info(f"Members without source URLs: {len(no_url_members)}")
if self.use_llm_search:
logger.info("Web search fallback: ENABLED")
else:
logger.info("Web search fallback: DISABLED")
processed = 0
# Process members grouped by URL (efficient - one crawl per page)
for url, group_members in url_groups.items():
# Crawl the page once
content = await self.crawl_page(url)
html = self.html_cache.get(url, content or "")
# Extract all LinkedIn URLs from this page
linkedin_links = []
if content:
linkedin_links = self.extract_linkedin_urls_from_content(html)
if not linkedin_links:
linkedin_links = self.extract_linkedin_urls_from_content(content)
# Match each member in this group
for member in group_members:
processed += 1
result = None
found_linkedin = False
if linkedin_links:
# Find best matching LinkedIn for this member
best_match = None
best_score = 0
for link in linkedin_links:
context_score = self._name_matches_context(
member["name"], link["context"]
)
username_score = self._name_matches_username(
member["name"], link["username"]
)
role_bonus = (
10
if member.get("role")
and member["role"].lower() in link["context"].lower()
else 0
)
total_score = max(context_score, username_score) + role_bonus
if total_score > best_score:
best_score = total_score
best_match = link
if best_match and best_score >= 30:
result = {
"linkedin_url": best_match["url"],
"confidence": min(best_score, 100),
"method": "source_crawl",
"notes": f"Found on {url}",
"member_id": member.get("id"),
"member_name": member["name"],
}
found_linkedin = True
# Save to database immediately if callback provided
if db_callback and member.get("id"):
db_callback(member["id"], best_match["url"])
# If no result from source crawl, try web search IMMEDIATELY
if not found_linkedin and self.use_llm_search:
search_result = await self.find_linkedin_via_search(
member["name"], member["company"], member.get("role")
)
if search_result["linkedin_url"]:
result = {
"linkedin_url": search_result["linkedin_url"],
"confidence": search_result["confidence"],
"method": "web_search",
"notes": search_result.get("notes", "Found via web search"),
"member_id": member.get("id"),
"member_name": member["name"],
}
found_linkedin = True
# Save to database immediately
if db_callback and member.get("id"):
db_callback(member["id"], search_result["linkedin_url"])
# If still no result, record as not found
if not found_linkedin:
result = {
"linkedin_url": None,
"confidence": 0,
"method": "source_crawl" if content else "none",
"notes": f"No match on {url}"
if linkedin_links
else (
f"No LinkedIn URLs on {url}"
if content
else f"Failed to crawl {url}"
),
"member_id": member.get("id"),
"member_name": member["name"],
}
results.append(result)
if progress_callback:
progress_callback(processed, total, result)
# Small delay between different URLs
await asyncio.sleep(self.rate_limit_delay)
# Process members without source URLs - do web search immediately for each
for member in no_url_members:
processed += 1
result = None
# Try web search immediately
if self.use_llm_search:
search_result = await self.find_linkedin_via_search(
member["name"], member["company"], member.get("role")
)
if search_result["linkedin_url"]:
result = {
"linkedin_url": search_result["linkedin_url"],
"confidence": search_result["confidence"],
"method": "web_search",
"notes": search_result.get("notes", "Found via web search"),
"member_id": member.get("id"),
"member_name": member["name"],
}
# Save to database immediately
if db_callback and member.get("id"):
db_callback(member["id"], search_result["linkedin_url"])
# If no result from search
if not result:
result = {
"linkedin_url": None,
"confidence": 0,
"method": "web_search" if self.use_llm_search else "none",
"notes": "No LinkedIn profile found"
if self.use_llm_search
else "No source URL available",
"member_id": member.get("id"),
"member_name": member["name"],
}
results.append(result)
if progress_callback:
progress_callback(processed, total, result)
# Rate limit between searches
await asyncio.sleep(self.rate_limit_delay)
return results
def format_linkedin_url(url: str) -> str:
"""Normalize LinkedIn URL format"""
if not url:
return url
# Remove trailing slashes
url = url.rstrip("/")
# Ensure https and normalize to www
url = re.sub(r"https?://[a-z]{2,3}\.linkedin\.com", "https://www.linkedin.com", url)
if url.startswith("http://"):
url = url.replace("http://", "https://")
return url
# Async wrapper for sync contexts
def run_batch_scraper(
members: List[Dict], rate_limit: float = 0.5, progress_callback=None
) -> List[Dict]:
"""
Synchronous wrapper for batch_find_profiles.
Args:
members: List of member dicts
rate_limit: Delay between URL crawls
progress_callback: Optional progress callback
Returns:
List of results
"""
scraper = LinkedInProfileScraper(rate_limit_delay=rate_limit)
return asyncio.run(scraper.batch_find_profiles(members, progress_callback))
+56 -18
View File
@@ -1,11 +1,14 @@
import io
import logging
import pandas as pd
from db.db import Base, db_dependency, engine
from dotenv import load_dotenv
from fastapi import FastAPI, File, Form, UploadFile
from fastapi import FastAPI, File, Form, HTTPException, UploadFile
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from routers import (
addition,
companies,
folk_crm,
insight_route,
@@ -13,7 +16,8 @@ from routers import (
projects,
report_route,
)
from schemas.router_schemas import InvestmentResponse, PaginatedResponse
from schemas.router_schemas import CompanyData, InvestmentResponse, PaginatedResponse
from services.company_querying import CompanyQueryProcessor
from services.llm_parser import InvestorProcessor
from services.querying import QueryProcessor
@@ -25,10 +29,21 @@ def init_database():
Base.metadata.create_all(bind=engine)
logger = logging.getLogger(__name__)
init_database()
app = FastAPI()
# Add CORS middleware to allow frontend requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, replace with specific origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Request models
class QueryRequest(BaseModel):
@@ -42,6 +57,17 @@ class QueryRequest(BaseModel):
}
class CompanyQueryRequest(BaseModel):
question: str
class Config:
json_schema_extra = {
"example": {
"question": "Find me companies in the fintech sector located in San Francisco."
}
}
@app.get("/")
def health():
return {"Hello": "World"}
@@ -61,16 +87,18 @@ async def parse_csv(
- Handles AUM, fund sizes, and check sizes as integers
**For companies:**
- Expected columns: Name, Website, Investor, Final Investor Profile (company profile)
- Expected columns: Name, Website, Perplexity Gap Output (or Final Investor Profile)
- 100% manual JSON parsing - no LLM needed
- Extracts company details, executives, investors, and client categories
- Automatically links companies to investors in database
- **Only extracts:** founded_year and key_executives
- **Only updates companies already in the database** (syncs with existing records)
- Skips companies not found in the database
**Benefits:**
- Fast processing (5-10s per record)
- Low cost (minimal or no LLM usage)
- Accurate data extraction
- Automatic database persistence
- Safe: won't create duplicate companies
"""
# Read uploaded CSV with pandas
content = await file.read()
@@ -95,21 +123,30 @@ async def parse_csv(
"/query", response_model=PaginatedResponse[InvestmentResponse], tags=["Querying"]
)
async def query_investors(request: QueryRequest):
"""
Query investors using natural language.
"""Query investors/funds using natural language"""
try:
processor = QueryProcessor()
result = await processor.process_query(request.question)
logger.info(f"Query completed successfully with {result.total} results")
return result
except Exception as e:
logger.error(f"Error in query_investors: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
Returns fund-level matches (one row per fund) with investor details.
This ensures only relevant funds are included in the response.
Supports queries like:
- "Show me seed stage investors"
- "Find fintech investors in Silicon Valley"
- "Growth stage investors with $5M+ check sizes"
- "Healthcare investors in Europe"
"""
processor = QueryProcessor()
results = processor.process_query(request.question)
return results
@app.post(
"/query-companies", response_model=PaginatedResponse[CompanyData], tags=["Querying"]
)
async def query_companies(request: CompanyQueryRequest):
"""Query companies using natural language"""
try:
processor = CompanyQueryProcessor()
result = await processor.process_query(request.question)
logger.info(f"Company query completed successfully with {result.total} results")
return result
except Exception as e:
logger.error(f"Error in query_companies: {e}", exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
app.include_router(investors.router)
@@ -118,6 +155,7 @@ app.include_router(projects.router)
app.include_router(folk_crm.router)
app.include_router(insight_route.router)
app.include_router(report_route.router)
app.include_router(addition.router)
if __name__ == "__main__":
import uvicorn
Binary file not shown.
+370
View File
@@ -0,0 +1,370 @@
from typing import Optional
from db.db import get_db
from db.models import FundTable, InvestorTable, SectorTable
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from sqlalchemy.orm import Session
router = APIRouter(tags=["Additional Routes"])
# Response schemas
class SectorsResponse(BaseModel):
sectors: list[str]
total: int
class CountryInfo(BaseModel):
name: str
class ContinentInfo(BaseModel):
name: str
countries: list[str]
class GeographyResponse(BaseModel):
continents: list[ContinentInfo]
total_continents: int
total_countries: int
# Mapping of countries to continents
COUNTRY_TO_CONTINENT = {
# Africa
"Algeria": "Africa",
"Angola": "Africa",
"Benin": "Africa",
"Botswana": "Africa",
"Burkina Faso": "Africa",
"Burundi": "Africa",
"Cameroon": "Africa",
"Cape Verde": "Africa",
"Central African Republic": "Africa",
"Chad": "Africa",
"Comoros": "Africa",
"Congo": "Africa",
"Democratic Republic of the Congo": "Africa",
"Djibouti": "Africa",
"Egypt": "Africa",
"Equatorial Guinea": "Africa",
"Eritrea": "Africa",
"Eswatini": "Africa",
"Ethiopia": "Africa",
"Gabon": "Africa",
"Gambia": "Africa",
"Ghana": "Africa",
"Guinea": "Africa",
"Guinea-Bissau": "Africa",
"Ivory Coast": "Africa",
"Kenya": "Africa",
"Lesotho": "Africa",
"Liberia": "Africa",
"Libya": "Africa",
"Madagascar": "Africa",
"Malawi": "Africa",
"Mali": "Africa",
"Mauritania": "Africa",
"Mauritius": "Africa",
"Morocco": "Africa",
"Mozambique": "Africa",
"Namibia": "Africa",
"Niger": "Africa",
"Nigeria": "Africa",
"Rwanda": "Africa",
"Sao Tome and Principe": "Africa",
"Senegal": "Africa",
"Seychelles": "Africa",
"Sierra Leone": "Africa",
"Somalia": "Africa",
"South Africa": "Africa",
"South Sudan": "Africa",
"Sudan": "Africa",
"Tanzania": "Africa",
"Togo": "Africa",
"Tunisia": "Africa",
"Uganda": "Africa",
"Zambia": "Africa",
"Zimbabwe": "Africa",
# Asia
"Afghanistan": "Asia",
"Armenia": "Asia",
"Azerbaijan": "Asia",
"Bahrain": "Asia",
"Bangladesh": "Asia",
"Bhutan": "Asia",
"Brunei": "Asia",
"Cambodia": "Asia",
"China": "Asia",
"Cyprus": "Asia",
"Georgia": "Asia",
"Hong Kong": "Asia",
"India": "Asia",
"Indonesia": "Asia",
"Iran": "Asia",
"Iraq": "Asia",
"Israel": "Asia",
"Japan": "Asia",
"Jordan": "Asia",
"Kazakhstan": "Asia",
"Kuwait": "Asia",
"Kyrgyzstan": "Asia",
"Laos": "Asia",
"Lebanon": "Asia",
"Malaysia": "Asia",
"Maldives": "Asia",
"Mongolia": "Asia",
"Myanmar": "Asia",
"Nepal": "Asia",
"North Korea": "Asia",
"Oman": "Asia",
"Pakistan": "Asia",
"Palestine": "Asia",
"Philippines": "Asia",
"Qatar": "Asia",
"Saudi Arabia": "Asia",
"Singapore": "Asia",
"South Korea": "Asia",
"Sri Lanka": "Asia",
"Syria": "Asia",
"Taiwan": "Asia",
"Tajikistan": "Asia",
"Thailand": "Asia",
"Timor-Leste": "Asia",
"Turkey": "Asia",
"Turkmenistan": "Asia",
"United Arab Emirates": "Asia",
"UAE": "Asia",
"Uzbekistan": "Asia",
"Vietnam": "Asia",
"Yemen": "Asia",
# Europe
"Albania": "Europe",
"Andorra": "Europe",
"Austria": "Europe",
"Belarus": "Europe",
"Belgium": "Europe",
"Bosnia and Herzegovina": "Europe",
"Bulgaria": "Europe",
"Croatia": "Europe",
"Czech Republic": "Europe",
"Czechia": "Europe",
"Denmark": "Europe",
"Estonia": "Europe",
"Finland": "Europe",
"France": "Europe",
"Germany": "Europe",
"Greece": "Europe",
"Hungary": "Europe",
"Iceland": "Europe",
"Ireland": "Europe",
"Italy": "Europe",
"Kosovo": "Europe",
"Latvia": "Europe",
"Liechtenstein": "Europe",
"Lithuania": "Europe",
"Luxembourg": "Europe",
"Malta": "Europe",
"Moldova": "Europe",
"Monaco": "Europe",
"Montenegro": "Europe",
"Netherlands": "Europe",
"North Macedonia": "Europe",
"Norway": "Europe",
"Poland": "Europe",
"Portugal": "Europe",
"Romania": "Europe",
"Russia": "Europe",
"San Marino": "Europe",
"Serbia": "Europe",
"Slovakia": "Europe",
"Slovenia": "Europe",
"Spain": "Europe",
"Sweden": "Europe",
"Switzerland": "Europe",
"Ukraine": "Europe",
"United Kingdom": "Europe",
"UK": "Europe",
"Vatican City": "Europe",
# North America
"Antigua and Barbuda": "North America",
"Bahamas": "North America",
"Barbados": "North America",
"Belize": "North America",
"Canada": "North America",
"Costa Rica": "North America",
"Cuba": "North America",
"Dominica": "North America",
"Dominican Republic": "North America",
"El Salvador": "North America",
"Grenada": "North America",
"Guatemala": "North America",
"Haiti": "North America",
"Honduras": "North America",
"Jamaica": "North America",
"Mexico": "North America",
"Nicaragua": "North America",
"Panama": "North America",
"Saint Kitts and Nevis": "North America",
"Saint Lucia": "North America",
"Saint Vincent and the Grenadines": "North America",
"Trinidad and Tobago": "North America",
"United States": "North America",
"USA": "North America",
"US": "North America",
# South America
"Argentina": "South America",
"Bolivia": "South America",
"Brazil": "South America",
"Chile": "South America",
"Colombia": "South America",
"Ecuador": "South America",
"Guyana": "South America",
"Paraguay": "South America",
"Peru": "South America",
"Suriname": "South America",
"Uruguay": "South America",
"Venezuela": "South America",
# Oceania
"Australia": "Oceania",
"Fiji": "Oceania",
"Kiribati": "Oceania",
"Marshall Islands": "Oceania",
"Micronesia": "Oceania",
"Nauru": "Oceania",
"New Zealand": "Oceania",
"Palau": "Oceania",
"Papua New Guinea": "Oceania",
"Samoa": "Oceania",
"Solomon Islands": "Oceania",
"Tonga": "Oceania",
"Tuvalu": "Oceania",
"Vanuatu": "Oceania",
}
# Valid continent names for direct matching
VALID_CONTINENTS = {
"Africa",
"Asia",
"Europe",
"North America",
"South America",
"Oceania",
"Antarctica",
}
def extract_countries_from_geographic_focus(geographic_focus: str) -> set[str]:
"""
Extract country names from a geographic_focus string.
Handles comma-separated values, slashes, and various formats.
"""
if not geographic_focus:
return set()
countries = set()
# Split by common delimiters
parts = geographic_focus.replace("/", ",").replace(";", ",").split(",")
for part in parts:
cleaned = part.strip()
if cleaned:
# Check if it's a known country
if cleaned in COUNTRY_TO_CONTINENT:
countries.add(cleaned)
# Check for partial matches (e.g., "United States of America" -> "United States")
else:
for country in COUNTRY_TO_CONTINENT.keys():
if country.lower() in cleaned.lower() or cleaned.lower() in country.lower():
countries.add(country)
break
return countries
def organize_geography(geographic_data: list[str]) -> dict[str, set[str]]:
"""
Organize geographic data into continents and their countries.
Returns a dict with continent names as keys and sets of countries as values.
"""
continent_countries: dict[str, set[str]] = {}
for geo_focus in geographic_data:
if not geo_focus:
continue
# Extract countries from the geographic focus string
countries = extract_countries_from_geographic_focus(geo_focus)
for country in countries:
continent = COUNTRY_TO_CONTINENT.get(country)
if continent:
if continent not in continent_countries:
continent_countries[continent] = set()
continent_countries[continent].add(country)
# Also check if the geographic focus itself is a continent
cleaned_geo = geo_focus.strip()
if cleaned_geo in VALID_CONTINENTS:
if cleaned_geo not in continent_countries:
continent_countries[cleaned_geo] = set()
return continent_countries
@router.get("/sectors", response_model=SectorsResponse)
def get_unique_sectors(db: Session = Depends(get_db)):
"""
Get all unique sectors from the database.
Returns a list of sector names sorted alphabetically.
"""
sectors = db.query(SectorTable.name).distinct().order_by(SectorTable.name).all()
sector_names = [s[0] for s in sectors if s[0]]
return SectorsResponse(sectors=sector_names, total=len(sector_names))
@router.get("/geography", response_model=GeographyResponse)
def get_arranged_geography(db: Session = Depends(get_db)):
"""
Get all unique geographic locations arranged by continent and countries.
Extracts geography from both investors and funds tables.
Returns continents with their associated countries.
"""
# Collect all geographic focus data from investors
investor_geo = (
db.query(InvestorTable.geographic_focus)
.filter(InvestorTable.geographic_focus.isnot(None))
.distinct()
.all()
)
# Collect all geographic focus data from funds
fund_geo = (
db.query(FundTable.geographic_focus)
.filter(FundTable.geographic_focus.isnot(None))
.distinct()
.all()
)
# Combine all geographic data
all_geo_data = [g[0] for g in investor_geo] + [g[0] for g in fund_geo]
# Organize into continents and countries
continent_countries = organize_geography(all_geo_data)
# Build response
continents = []
total_countries = 0
for continent_name in sorted(continent_countries.keys()):
countries = sorted(continent_countries[continent_name])
total_countries += len(countries)
continents.append(ContinentInfo(name=continent_name, countries=countries))
return GeographyResponse(
continents=continents,
total_continents=len(continents),
total_countries=total_countries,
)
+14 -4
View File
@@ -63,11 +63,13 @@ def read_companies(
# Transform CompanyTable objects to CompanyData format
company_data_list = []
for company in companies:
# Sort sectors alphabetically
sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else []
company_data = CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=company.sectors,
sectors=sorted_sectors,
)
company_data_list.append(company_data)
@@ -147,11 +149,13 @@ def filter_companies(
# Transform to CompanyData format
company_data_list = []
for company in companies:
# Sort sectors alphabetically
sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else []
company_data = CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=company.sectors,
sectors=sorted_sectors,
)
company_data_list.append(company_data)
@@ -184,12 +188,15 @@ def read_company(company_id: int, db: Session = Depends(get_db)):
if not company:
raise HTTPException(status_code=404, detail="Company not found")
# Sort sectors alphabetically
sorted_sectors = sorted(company.sectors, key=lambda s: s.name) if company.sectors else []
# Transform to CompanyData format
return CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=company.sectors,
sectors=sorted_sectors,
)
@@ -250,12 +257,15 @@ def update_company(
.first()
)
# Sort sectors alphabetically
sorted_sectors = sorted(company_with_relations.sectors, key=lambda s: s.name) if company_with_relations.sectors else []
# Transform to CompanyData format
return CompanyData(
company=company_with_relations,
investors=company_with_relations.investors,
members=company_with_relations.members,
sectors=company_with_relations.sectors,
sectors=sorted_sectors,
)
+15 -1
View File
@@ -1,15 +1,21 @@
import os
from typing import List
from db.db import get_db
from db.models import InvestorTable
from fastapi import APIRouter, Depends, HTTPException
from pydantic import BaseModel
from services.crm import folk
from services.crm import FolkAPI
from sqlalchemy.orm import Session, selectinload
router = APIRouter(prefix="/folk", tags=["Folk CRM"])
def get_folk_client():
"""Get Folk API client with loaded environment variables"""
return FolkAPI(api_key=os.environ.get("FOLK_API_KEY", ""))
class GroupResponse(BaseModel):
id: str
name: str
@@ -44,6 +50,7 @@ def get_folk_groups():
to sync investors to Folk.
"""
try:
folk = get_folk_client()
groups_data = folk.get_groups()
items = groups_data.get("data", {}).get("items", [])
@@ -71,6 +78,7 @@ def sync_investors_to_folk(
Returns:
Summary of sync operation including successes and errors
"""
folk = get_folk_client()
# Fetch investors with their team members
investors = (
db.query(InvestorTable)
@@ -128,6 +136,11 @@ def sync_investors_to_folk(
if hasattr(member, "source_url") and member.source_url:
urls_list = [member.source_url]
# Get LinkedIn URL if available
linkedin_url = None
if hasattr(member, "linkedin") and member.linkedin:
linkedin_url = member.linkedin
# Build job title from title or role
job_title = None
if hasattr(member, "title") and member.title:
@@ -141,6 +154,7 @@ def sync_investors_to_folk(
email=member.email,
company_id=company_id,
group_id=request.group_id,
linkedin_url=linkedin_url,
urls=urls_list,
jobTitle=job_title,
)
+87 -31
View File
@@ -12,7 +12,10 @@ from schemas.router_schemas import (
PaginatedResponse,
SectorMinimal,
)
from services.compatibility_score import calculate_project_investor_compatibility
from services.compatibility_score import (
_calculate_project_fund_compatibility,
_calculate_project_investor_direct_compatibility,
)
from sqlalchemy.orm import Session, selectinload
router = APIRouter(tags=["Investor Routes"])
@@ -77,31 +80,46 @@ def read_investors(
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Get paginated results
investors = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
# When project_id is provided, we need to get all investors first to sort by compatibility score
# Otherwise, we can paginate at the database level
if project is not None:
# Get all investors (we'll sort by compatibility score, then paginate)
all_investors = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds).selectinload(
FundTable.investment_stages
),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
)
.all()
)
# We'll paginate after sorting by compatibility score
investors = all_investors
else:
# Get paginated results (no sorting needed)
investors = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds).selectinload(
FundTable.investment_stages
),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
)
.offset(offset)
.limit(page_size)
.all()
)
.offset(offset)
.limit(page_size)
.all()
)
# Transform to InvestmentResponse format (one row per investor-fund combination)
investment_responses = []
for investor in investors:
# Calculate compatibility score if project provided
compatibility_score = 1.0
if project is not None:
compatibility_score = calculate_project_investor_compatibility(
project=project, investor=investor, use_funds=True
)
# Get top 3 portfolio companies (id and name only)
portfolio_companies = [
CompanyMinimal(id=company.id, name=company.name)
@@ -111,6 +129,13 @@ def read_investors(
# If investor has funds, create one entry per fund
if investor.funds:
for fund in investor.funds:
# Calculate compatibility score for this specific fund
compatibility_score = 1.0
if project is not None:
compatibility_score = _calculate_project_fund_compatibility(
project=project, fund=fund
)
# Get stage focus as comma-separated string
stage_focus = (
", ".join([stage.name for stage in fund.investment_stages])
@@ -118,10 +143,12 @@ def read_investors(
else None
)
# Get top 3 sectors from fund (id and name only)
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in (fund.sectors[:3] if fund.sectors else [])
for sector in sorted(
fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name
)
]
investment_response = InvestmentResponse(
@@ -141,6 +168,13 @@ def read_investors(
investment_responses.append(investment_response)
else:
# If no funds, create one entry with null fund fields
# Calculate compatibility using investor-level data
compatibility_score = 1.0
if project is not None:
compatibility_score = _calculate_project_investor_direct_compatibility(
project=project, investor=investor
)
investment_response = InvestmentResponse(
id=investor.id,
name=investor.name,
@@ -155,6 +189,12 @@ def read_investors(
)
investment_responses.append(investment_response)
# Sort by compatibility score (descending) when project_id is provided
if project is not None:
investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True)
# Apply pagination after sorting
investment_responses = investment_responses[offset : offset + page_size]
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
@@ -246,20 +286,27 @@ def filter_investors(
# Get total count before pagination
total_count = query.count()
# Calculate offset and apply pagination
offset = (page - 1) * page_size
funds = query.offset(offset).limit(page_size).all()
# When project_id is provided, we need to get all funds first to sort by compatibility score
# Otherwise, we can paginate at the database level
if project is not None:
# Get all funds (we'll sort by compatibility score, then paginate)
all_funds = query.all()
funds = all_funds
else:
# Calculate offset and apply pagination (no sorting needed)
offset = (page - 1) * page_size
funds = query.offset(offset).limit(page_size).all()
# Transform to InvestmentResponse format (one row per fund)
investment_responses = []
for fund in funds:
investor = fund.investor
# Calculate compatibility score if project provided
# Calculate compatibility score for this specific fund
compatibility_score = 1.0
if project is not None:
compatibility_score = calculate_project_investor_compatibility(
project=project, investor=investor, use_funds=True
compatibility_score = _calculate_project_fund_compatibility(
project=project, fund=fund
)
# Get top 3 portfolio companies (id and name only)
@@ -275,10 +322,12 @@ def filter_investors(
else None
)
# Get top 3 sectors from fund (id and name only)
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in (fund.sectors[:3] if fund.sectors else [])
for sector in sorted(
fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name
)
]
investment_response = InvestmentResponse(
@@ -297,6 +346,13 @@ def filter_investors(
)
investment_responses.append(investment_response)
# Sort by compatibility score (descending) when project_id is provided
if project is not None:
investment_responses.sort(key=lambda x: x.compatibility_score, reverse=True)
# Apply pagination after sorting
offset = (page - 1) * page_size
investment_responses = investment_responses[offset : offset + page_size]
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
+100 -5
View File
@@ -24,19 +24,29 @@ router = APIRouter(tags=["Project Routes"])
def read_projects(
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
include_archived: bool = Query(False, description="Include archived projects"),
db: Session = Depends(get_db),
):
"""Get all projects with their related data (paginated)"""
"""Get all projects with their related data (paginated)
By default, archived projects are excluded. Set include_archived=True to include them.
"""
# Calculate offset
offset = (page - 1) * page_size
# Start with base query
query = db.query(ProjectTable)
# Filter out archived projects by default
if not include_archived:
query = query.filter(ProjectTable.is_archived == 0)
# Get total count
total_count = db.query(ProjectTable).count()
total_count = query.count()
# Get paginated results
projects = (
db.query(ProjectTable)
.options(
query.options(
selectinload(ProjectTable.sector),
selectinload(ProjectTable.investors),
selectinload(ProjectTable.companies),
@@ -162,7 +172,7 @@ def update_project(
@router.delete("/projects/{project_id}")
def delete_project(project_id: int, db: Session = Depends(get_db)):
"""Delete a project"""
"""Delete a project permanently"""
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not db_project:
@@ -174,6 +184,87 @@ def delete_project(project_id: int, db: Session = Depends(get_db)):
return {"message": "Project deleted successfully"}
@router.post("/projects/{project_id}/archive")
def archive_project(project_id: int, db: Session = Depends(get_db)):
"""Archive a project (soft delete)"""
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not db_project:
raise HTTPException(status_code=404, detail="Project not found")
db_project.is_archived = 1
db.commit()
db.refresh(db_project)
return {"message": "Project archived successfully", "project_id": project_id}
@router.post("/projects/{project_id}/unarchive")
def unarchive_project(project_id: int, db: Session = Depends(get_db)):
"""Unarchive a project (restore from archive)"""
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not db_project:
raise HTTPException(status_code=404, detail="Project not found")
db_project.is_archived = 0
db.commit()
db.refresh(db_project)
return {"message": "Project unarchived successfully", "project_id": project_id}
@router.get("/projects/archived", response_model=PaginatedResponse[ProjectData])
def read_archived_projects(
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
db: Session = Depends(get_db),
):
"""Get all archived projects (paginated)"""
# Calculate offset
offset = (page - 1) * page_size
# Query only archived projects
query = db.query(ProjectTable).filter(ProjectTable.is_archived == 1)
# Get total count
total_count = query.count()
# Get paginated results
projects = (
query.options(
selectinload(ProjectTable.sector),
selectinload(ProjectTable.investors),
selectinload(ProjectTable.companies),
)
.offset(offset)
.limit(page_size)
.all()
)
# Transform ProjectTable objects to ProjectData format
project_data_list = []
for project in projects:
project_data = ProjectData(
project=project,
sector=project.sector,
investors=project.investors,
companies=project.companies,
)
project_data_list.append(project_data)
# Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=project_data_list,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
@router.get("/projects/filter", response_model=PaginatedResponse[ProjectData])
def filter_projects(
stage: Optional[InvestmentStage] = Query(
@@ -182,6 +273,7 @@ def filter_projects(
min_valuation: Optional[int] = Query(None, description="Minimum valuation"),
max_valuation: Optional[int] = Query(None, description="Maximum valuation"),
location: Optional[str] = Query(None, description="Location (partial match)"),
industry: Optional[str] = Query(None, description="Industry (partial match)"),
sector: Optional[str] = Query(None, description="Sector name (partial match)"),
investor_name: Optional[str] = Query(
None, description="Investor name (partial match)"
@@ -215,6 +307,9 @@ def filter_projects(
if location:
query = query.filter(ProjectTable.location.ilike(f"%{location}%"))
if industry:
query = query.filter(ProjectTable.industry.ilike(f"%{industry}%"))
if sector:
query = query.join(ProjectTable.sector).filter(
SectorTable.name.ilike(f"%{sector}%")
+13 -16
View File
@@ -52,7 +52,6 @@ async def generate_investor_report(
"website": investor.website,
"headquarters": investor.headquarters,
"aum": investor.aum,
"geographic_focus": investor.geographic_focus,
"portfolio_highlights": investor.portfolio_highlights or [],
"investment_thesis": investor.investment_thesis or [],
"sectors": [sector.name for sector in investor.sectors],
@@ -65,24 +64,22 @@ async def generate_investor_report(
}
for member in investor.team_members
],
"check_size_lower": None,
"check_size_upper": None,
"investment_stages": [],
"funds": [],
}
# Get check sizes and stages from funds
# Get all funds with their data
if investor.funds:
# Use the first fund's data or aggregate
fund = investor.funds[0]
investor_data["check_size_lower"] = fund.check_size_lower
investor_data["check_size_upper"] = fund.check_size_upper
# Aggregate all investment stages from all funds
stages = set()
for fund in investor.funds:
for stage in fund.investment_stages:
stages.add(stage.name)
investor_data["investment_stages"] = list(stages)
fund_data = {
"fund_name": fund.fund_name,
"fund_size": fund.fund_size,
"check_size_lower": fund.check_size_lower,
"check_size_upper": fund.check_size_upper,
"geographic_focus": fund.geographic_focus,
"investment_stages": [stage.name for stage in fund.investment_stages],
"sectors": [sector.name for sector in fund.sectors],
}
investor_data["funds"].append(fund_data)
# Fetch project data if project_id is provided
project_data = None
@@ -109,7 +106,7 @@ async def generate_investor_report(
# Generate PDF report
report_generator = ReportGenerator()
pdf_bytes = await report_generator.generate_investor_report(
investor_data, project_data
investor_data, project_data, investor_model=investor, project_model=project
)
# Return PDF as downloadable file
+3
View File
@@ -60,6 +60,7 @@ class ProjectSchema(BaseModel):
valuation: int | None
stage: InvestmentStage | None
location: str | None
industry: str | None
description: Optional[str]
start_date: Optional[datetime]
end_date: Optional[datetime]
@@ -75,6 +76,7 @@ class ProjectCreate(BaseModel):
valuation: Optional[int] = None
stage: Optional[InvestmentStage] = None
location: Optional[str] = None
industry: Optional[str] = None
description: Optional[str] = None
start_date: Optional[datetime] = None
end_date: Optional[datetime] = None
@@ -85,6 +87,7 @@ class ProjectUpdate(BaseModel):
valuation: Optional[int] = None
stage: Optional[InvestmentStage] = None
location: Optional[str] = None
industry: Optional[str] = None
description: Optional[str] = None
start_date: Optional[datetime] = None
end_date: Optional[datetime] = None
+6
View File
@@ -38,6 +38,7 @@ class InvestorMemberSchema(BaseModel):
name: str
role: str | None
email: str | None
linkedin: str | None
class Config:
from_attributes = True
@@ -168,6 +169,7 @@ class InvestorFundData(BaseModel):
class Config:
from_attributes = True
class InvestorMinimal(BaseModel):
"""Minimal investor info with just id and name"""
@@ -177,6 +179,7 @@ class InvestorMinimal(BaseModel):
class Config:
from_attributes = True
class CompanySchemaMinimal(BaseModel):
id: int
name: str
@@ -188,9 +191,12 @@ class CompanySchemaMinimal(BaseModel):
class Config:
from_attributes = True
class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
company: CompanySchemaMinimal
investors: List[InvestorMinimal]
members: List[CompanyMemberSchema] = []
sectors: List[SectorSchema] = []
class Config:
from_attributes = True
Binary file not shown.
Binary file not shown.
+228
View File
@@ -0,0 +1,228 @@
import asyncio
import hashlib
import logging
import os
from typing import List
from db.db import get_db
from db.models import CompanyTable
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from schemas.router_schemas import CompanyData, PaginatedResponse
from sqlalchemy import text
from sqlalchemy.orm import selectinload
logger = logging.getLogger(__name__)
class CompanyQueryProcessor:
def __init__(self):
self.llm = ChatOpenAI(
api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1",
model="openai/gpt-4o-mini",
temperature=0,
)
# Query cache for performance
self.query_cache = {}
# SQL generation prompt
self.sql_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a SQL expert. Generate a SQLite query to find company IDs based on user requirements.
Database Schema:
- companies: id, name, industry, location, description, founded_year, website
- company_sector: company_id, sector_id
- sectors: id, name
- investor_companies: investor_id, company_id
- investors: id, name, aum
- team_members: id, company_id, name, title
IMPORTANT RULES:
1. ALWAYS return ONLY company IDs (companies.id) - use SELECT DISTINCT c.id
2. For industry: Check BOTH industry field AND sectors table with synonyms
- Use LEFT JOIN for sectors so companies without sector tags still match
- Include related terms: 'Fintech' → c.industry LIKE '%Fintech%' OR c.industry LIKE '%Finance%' OR sec.name LIKE '%Fintech%' OR sec.name LIKE '%Financial%'
- 'AI' → c.industry LIKE '%AI%' OR c.industry LIKE '%Artificial Intelligence%' OR c.industry LIKE '%Machine Learning%' OR sec.name LIKE '%AI%' OR sec.name LIKE '%ML%'
3. For location: Be FLEXIBLE with variations and abbreviations
- 'San Francisco' → c.location LIKE '%San Francisco%' OR c.location LIKE '%SF%' OR c.location LIKE '%Bay Area%'
- 'New York' → c.location LIKE '%New York%' OR c.location LIKE '%NYC%' OR c.location LIKE '%NY%'
- 'Europe' → c.location LIKE '%Europe%' OR c.location LIKE '%UK%' OR c.location LIKE '%London%' OR c.location LIKE '%Berlin%' OR c.location LIKE '%Paris%'
4. For sectors: Use LEFT JOIN and include multiple synonyms
- 'Healthcare' → sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR c.industry LIKE '%Health%'
5. For founding year filters (include NULL to be inclusive):
- "founded after 2020" → WHERE (founded_year >= 2020 OR founded_year IS NULL)
- "founded before 2018" → WHERE (founded_year <= 2018 OR founded_year IS NULL)
- "founded in 2020" → WHERE founded_year = 2020
6. For investor-related queries: Use JOIN investor_companies
7. Use LEFT JOIN for sectors so companies without tags still match
8. Use DISTINCT to avoid duplicates from joins
9. Be INCLUSIVE - use OR conditions with synonyms and variations
10. Return a single, complete SELECT query
Example Queries:
Q: "Fintech companies founded in 2020"
A: SELECT DISTINCT c.id FROM companies c
LEFT JOIN company_sector cs ON c.id = cs.company_id
LEFT JOIN sectors sec ON cs.sector_id = sec.id
WHERE (c.industry LIKE '%Fintech%' OR c.industry LIKE '%Finance%' OR c.industry LIKE '%Financial%' OR sec.name LIKE '%Fintech%' OR sec.name LIKE '%Financial Services%')
AND c.founded_year = 2020
Q: "AI companies in San Francisco"
A: SELECT DISTINCT c.id FROM companies c
LEFT JOIN company_sector cs ON c.id = cs.company_id
LEFT JOIN sectors sec ON cs.sector_id = sec.id
WHERE (c.industry LIKE '%AI%' OR c.industry LIKE '%Artificial Intelligence%' OR c.industry LIKE '%Machine Learning%' OR sec.name LIKE '%AI%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%')
AND (c.location LIKE '%San Francisco%' OR c.location LIKE '%SF%' OR c.location LIKE '%Bay Area%')
Q: "Healthcare companies"
A: SELECT DISTINCT c.id FROM companies c
LEFT JOIN company_sector cs ON c.id = cs.company_id
LEFT JOIN sectors sec ON cs.sector_id = sec.id
WHERE c.industry LIKE '%Healthcare%' OR c.industry LIKE '%Health%' OR c.industry LIKE '%Medical%' OR sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR sec.name LIKE '%Pharma%'
Q: "Companies funded by Sequoia"
A: SELECT DISTINCT c.id FROM companies c
JOIN investor_companies ic ON c.id = ic.company_id
JOIN investors i ON ic.investor_id = i.id
WHERE i.name LIKE '%Sequoia%'
Q: "European startups founded after 2019"
A: SELECT DISTINCT c.id FROM companies c
WHERE (c.location LIKE '%Europe%' OR c.location LIKE '%UK%' OR c.location LIKE '%London%' OR c.location LIKE '%Germany%' OR c.location LIKE '%Berlin%' OR c.location LIKE '%France%' OR c.location LIKE '%Paris%')
AND (c.founded_year > 2019 OR c.founded_year IS NULL)
Q: "SaaS companies"
A: SELECT DISTINCT c.id FROM companies c
LEFT JOIN company_sector cs ON c.id = cs.company_id
LEFT JOIN sectors sec ON cs.sector_id = sec.id
WHERE c.industry LIKE '%SaaS%' OR c.industry LIKE '%Software%' OR c.industry LIKE '%Cloud%' OR sec.name LIKE '%SaaS%' OR sec.name LIKE '%Software%'
IMPORTANT:
- Use LEFT JOIN so companies without sector tags still match via industry field
- Use OR conditions with related keywords/synonyms to cast a wider net
- Include NULL checks for optional filters to avoid excluding companies with missing data
Return ONLY the SQL query, no explanations or markdown.""",
),
("user", "{question}"),
]
)
def _get_cache_key(self, question: str) -> str:
"""Generate cache key from normalized question."""
return hashlib.md5(question.lower().strip().encode()).hexdigest()
# synchronous helper is provided below as `_process_query_sync` and an
# async wrapper `process_query` runs it in a thread. This keeps the
# FastAPI event loop non-blocking while reusing the existing sync code.
async def process_query(self, question: str) -> PaginatedResponse[CompanyData]:
"""Async wrapper for process_query. Runs blocking work in a thread to avoid
blocking the event loop.
"""
return await asyncio.to_thread(self._process_query_sync, question)
def _process_query_sync(self, question: str) -> PaginatedResponse[CompanyData]:
"""Synchronous implementation of process_query. This is run in a thread by
the async wrapper above.
"""
cache_key = self._get_cache_key(question)
# Check cache first
if cache_key in self.query_cache:
sql_query = self.query_cache[cache_key]
logger.info(f"Using cached SQL: {sql_query}")
else:
# Generate SQL query
messages = self.sql_prompt.format_messages(question=question)
response = self.llm.invoke(messages)
sql_query = response.content.strip()
# Clean up SQL (remove markdown code blocks if present)
sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
# Cache the query
self.query_cache[cache_key] = sql_query
logger.info(f"Generated SQL: {sql_query}")
# Execute query to get company IDs
db_session = next(get_db())
try:
result = db_session.execute(text(sql_query))
company_ids = [row[0] for row in result.fetchall()]
logger.info(
f"Found {len(company_ids)} company IDs: {company_ids[:10]}{'...' if len(company_ids) > 10 else ''}"
)
return self._fetch_companies_by_ids(company_ids)
except Exception as e:
logger.error(f"SQL execution error: {e}")
logger.error(f"Failed SQL: {sql_query}")
# Return empty result
return PaginatedResponse(
items=[], total=0, page=1, page_size=10, total_pages=0
)
finally:
db_session.close()
def _fetch_companies_by_ids(
self, company_ids: List[int]
) -> PaginatedResponse[CompanyData]:
"""Fetch companies with all their relationships from the database using company IDs.
Args:
company_ids: List of company IDs to fetch
"""
if not company_ids:
return PaginatedResponse(
items=[],
total=0,
page=1,
page_size=10,
total_pages=0,
)
# Get database session
db_session = next(get_db())
try:
# Query companies with all necessary relationships loaded
companies = (
db_session.query(CompanyTable)
.options(
selectinload(CompanyTable.investors),
selectinload(CompanyTable.members),
selectinload(CompanyTable.sectors),
)
.filter(CompanyTable.id.in_(company_ids))
.all()
)
# Transform to CompanyData format
company_data_list = []
for company in companies:
company_data = CompanyData(
company=company,
investors=company.investors,
members=company.members,
sectors=company.sectors,
)
company_data_list.append(company_data)
total_count = len(company_data_list)
total_pages = 1 if total_count > 0 else 0
return PaginatedResponse(
items=company_data_list,
total=total_count,
page=1,
page_size=total_count,
total_pages=total_pages,
)
finally:
db_session.close()
+173 -34
View File
@@ -6,6 +6,7 @@ The scoring system evaluates multiple dimensions to determine how well a project
matches with an investor's investment criteria.
"""
from difflib import SequenceMatcher
from typing import List, Optional, Tuple
from db.models import FundTable, InvestorTable, ProjectTable
@@ -99,12 +100,16 @@ def _calculate_project_fund_compatibility(
else str(project.stage)
)
if project_stage_name in fund_stage_names:
# Normalize both for case-insensitive comparison
project_stage_normalized = project_stage_name.upper().strip()
fund_stages_normalized = {name.upper().strip() for name in fund_stage_names}
if project_stage_normalized in fund_stages_normalized:
stage_score = 30
else:
# Partial credit for adjacent stages
stage_score = _calculate_stage_proximity(
project_stage_name, fund_stage_names
project_stage_normalized, fund_stages_normalized
)
total_score += stage_score
@@ -112,22 +117,53 @@ def _calculate_project_fund_compatibility(
# 2. Sector Overlap (30 points)
sector_score = 0
if project.sector and fund.sectors:
project_sector_ids = {sector.id for sector in project.sector}
fund_sector_ids = {sector.id for sector in fund.sectors}
project_sectors = [s for s in project.sector if hasattr(s, "name")]
fund_sectors = [s for s in fund.sectors if hasattr(s, "name")]
if project_sector_ids and fund_sector_ids:
common_sectors = project_sector_ids.intersection(fund_sector_ids)
# Score based on what percentage of project sectors are covered by fund
overlap_ratio = len(common_sectors) / len(project_sector_ids)
sector_score = int(30 * overlap_ratio)
if project_sectors and fund_sectors:
# Use fuzzy matching to account for similar but not identical sector names
match_count = 0
total_matches = 0
for proj_sector in project_sectors:
best_match_score = 0
proj_name = proj_sector.name.lower().strip()
for fund_sector in fund_sectors:
fund_name = fund_sector.name.lower().strip()
# Exact match
if proj_name == fund_name:
best_match_score = 1.0
break
# Fuzzy match using sequence matcher
similarity = SequenceMatcher(None, proj_name, fund_name).ratio()
# Also check if one contains the other (substring match)
if proj_name in fund_name or fund_name in proj_name:
similarity = max(similarity, 0.8)
best_match_score = max(best_match_score, similarity)
# Count matches with threshold
# Perfect match (1.0), strong match (>0.75), partial match (>0.6)
if best_match_score >= 0.6:
total_matches += best_match_score
match_count += 1
if match_count > 0:
# Calculate overlap ratio based on fuzzy matches
overlap_ratio = total_matches / len(project_sectors)
sector_score = int(30 * overlap_ratio)
total_score += sector_score
# 3. Geographic Match (20 points)
geo_score = 0
if project.location and fund.geographic_focus:
project_location_lower = project.location.lower()
fund_geo_lower = (fund.geographic_focus or "").lower()
project_location_lower = project.location.lower().strip()
fund_geo_lower = (fund.geographic_focus or "").lower().strip()
# Exact match
if project_location_lower == fund_geo_lower:
@@ -137,10 +173,11 @@ def _calculate_project_fund_compatibility(
project_location_lower in fund_geo_lower
or fund_geo_lower in project_location_lower
):
geo_score = 10
# Check for common geographic terms
geo_score = 15
# Check for common geographic terms or regional overlap (continent/country matching)
elif _check_geographic_overlap(project_location_lower, fund_geo_lower):
geo_score = 5
# Give higher score for continent/country matches (e.g., Germany -> Europe)
geo_score = 18
total_score += geo_score
@@ -209,13 +246,44 @@ def _calculate_project_investor_direct_compatibility(
# 2. Sector Overlap (30 points)
sector_score = 0
if project.sector and investor.sectors:
project_sector_ids = {sector.id for sector in project.sector}
investor_sector_ids = {sector.id for sector in investor.sectors}
project_sectors = [s for s in project.sector if hasattr(s, "name")]
investor_sectors = [s for s in investor.sectors if hasattr(s, "name")]
if project_sector_ids and investor_sector_ids:
common_sectors = project_sector_ids.intersection(investor_sector_ids)
overlap_ratio = len(common_sectors) / len(project_sector_ids)
sector_score = int(30 * overlap_ratio)
if project_sectors and investor_sectors:
# Use fuzzy matching to account for similar but not identical sector names
match_count = 0
total_matches = 0
for proj_sector in project_sectors:
best_match_score = 0
proj_name = proj_sector.name.lower().strip()
for inv_sector in investor_sectors:
inv_name = inv_sector.name.lower().strip()
# Exact match
if proj_name == inv_name:
best_match_score = 1.0
break
# Fuzzy match using sequence matcher
similarity = SequenceMatcher(None, proj_name, inv_name).ratio()
# Also check if one contains the other (substring match)
if proj_name in inv_name or inv_name in proj_name:
similarity = max(similarity, 0.8)
best_match_score = max(best_match_score, similarity)
# Count matches with threshold
if best_match_score >= 0.6:
total_matches += best_match_score
match_count += 1
if match_count > 0:
# Calculate overlap ratio based on fuzzy matches
overlap_ratio = total_matches / len(project_sectors)
sector_score = int(30 * overlap_ratio)
total_score += sector_score
@@ -231,9 +299,10 @@ def _calculate_project_investor_direct_compatibility(
project_location_lower in investor_geo_lower
or investor_geo_lower in project_location_lower
):
geo_score = 10
geo_score = 15
elif _check_geographic_overlap(project_location_lower, investor_geo_lower):
geo_score = 5
# Give higher score for continent/country matches (e.g., Germany -> Europe)
geo_score = 18
total_score += geo_score
@@ -278,8 +347,11 @@ def _calculate_stage_proximity(project_stage: str, fund_stages: set) -> int:
"""
stage_order = ["SEED", "SERIES_A", "SERIES_B", "SERIES_C", "GROWTH", "LATE_STAGE"]
# Normalize project stage for comparison
project_stage_normalized = project_stage.upper().strip()
try:
project_idx = stage_order.index(project_stage)
project_idx = stage_order.index(project_stage_normalized)
except ValueError:
return 0
@@ -290,8 +362,10 @@ def _calculate_stage_proximity(project_stage: str, fund_stages: set) -> int:
if project_idx < len(stage_order) - 1:
adjacent_stages.append(stage_order[project_idx + 1])
# Normalize fund stages and check for matches
for stage in fund_stages:
if stage in adjacent_stages:
stage_normalized = stage.upper().strip()
if stage_normalized in adjacent_stages:
return 15 # Half credit for adjacent stage
return 0
@@ -305,25 +379,90 @@ def _check_geographic_overlap(location1: str, location2: str) -> bool:
- "San Francisco, CA" and "California" -> True
- "New York" and "USA" -> True (if both contain USA/US)
- "London, UK" and "United Kingdom" -> True
- "Germany" and "Europe" -> True
"""
# Common geographic groupings
# Normalize inputs
loc1 = location1.lower().strip()
loc2 = location2.lower().strip()
# Common geographic groupings with broader regional mappings
geo_groups = [
["usa", "us", "united states", "america"],
["uk", "united kingdom", "britain"],
["california", "ca"],
["new york", "ny"],
# North America
["usa", "us", "united states", "america", "u.s.", "u.s.a"],
["canada", "canadian"],
["mexico", "mexican"],
# Europe and countries
[
"europe",
"european",
"eu",
"germany",
"france",
"uk",
"united kingdom",
"britain",
"spain",
"italy",
"netherlands",
"belgium",
"sweden",
"denmark",
"norway",
"finland",
"poland",
"portugal",
"austria",
"switzerland",
"ireland",
"greece",
"czech",
"romania",
],
# UK specific
["uk", "united kingdom", "britain", "england", "scotland", "wales", "london"],
# US states
["california", "ca", "san francisco", "los angeles", "silicon valley"],
["new york", "ny", "nyc"],
["texas", "tx"],
["europe", "eu"],
["asia", "asian"],
["africa", "african"],
["massachusetts", "ma", "boston"],
["washington", "seattle"],
# Asia
[
"asia",
"asian",
"china",
"japan",
"korea",
"singapore",
"hong kong",
"india",
"indonesia",
"thailand",
"vietnam",
"malaysia",
"philippines",
],
# Middle East
["middle east", "israel", "uae", "dubai", "saudi arabia"],
# Latin America
["latin america", "brazil", "argentina", "chile", "colombia", "mexico"],
# Africa
["africa", "african", "south africa", "nigeria", "kenya", "egypt"],
# Oceania
["australia", "australian", "new zealand"],
]
# Check if both locations match any group
for group in geo_groups:
found_in_1 = any(term in location1 for term in group)
found_in_2 = any(term in location2 for term in group)
found_in_1 = any(term in loc1 for term in group)
found_in_2 = any(term in loc2 for term in group)
if found_in_1 and found_in_2:
return True
# Check for direct substring match (one contains the other)
if loc1 in loc2 or loc2 in loc1:
return True
return False
+14 -69
View File
@@ -1,14 +1,24 @@
import logging
import os
import sys
import requests
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
handlers=[logging.StreamHandler()],
)
logger = logging.getLogger(__name__)
class FolkAPI:
BASE_URL = "https://api.folk.app/v1"
def __init__(self, api_key: str):
api_key = os.environ.get("FOLK_API_KEY", api_key)
self.headers = {"Authorization": f"Bearer {api_key}"}
logger.info(f"FolkAPI initialized with API key: {api_key[:4]}***")
def get_groups(self):
"""Fetch all groups from Folk."""
@@ -109,6 +119,7 @@ class FolkAPI:
email: str = None,
company_id: str = None,
group_id: str = None,
linkedin_url: str = None,
companies=None,
emails=None,
phones=None,
@@ -174,7 +185,9 @@ class FolkAPI:
addresses_list = _to_list(addresses)
if addresses_list:
data["addresses"] = addresses_list
urls_list = _to_list(urls)
urls_list = _to_list(urls) or []
if linkedin_url:
urls_list.append(linkedin_url)
if urls_list:
data["urls"] = urls_list
@@ -190,71 +203,3 @@ class FolkAPI:
response.raise_for_status()
return response.json()
# Prefer getting the API key from the environment. If not set, fall back to the
# existing (hard-coded) key so behavior is unchanged for now.
DEFAULT_API_KEY = "FOLKfIGXuv74ML9EAajxyiUR39ePaNrZ"
api_key = os.environ.get("FOLK_API_KEY", DEFAULT_API_KEY)
folk = FolkAPI(api_key=api_key)
def example_flow():
# Step 1: Get groups
groups = folk.get_groups()
print(groups)
# Safely dig into the returned structure. The API returns groups under
# groups['data']['items'] (not groups['data'][0]). Handle missing/empty.
items = groups.get("data", {}).get("items", [])
if not items:
print("No groups returned by Folk API.")
sys.exit(1)
# Choose the first group as an example
group_id = items[0].get("id")
if not group_id:
print("No id found for the first group item.")
sys.exit(1)
# Step 2: Choose a group_id and create a company
company = folk.create_company(
name="2050 Investment Partners",
group_id=group_id,
website="https://2050.com",
linkedin_url="https://linkedin.com/company/2050-investments",
)
# Step 3: Add a person to the same group or company
person = folk.create_person(
first_name="John",
last_name="Doe",
email="john@2050.com",
company_id=company.get("data", {}).get("id"),
group_id=group_id,
)
print("Created company:", company)
print("Created person:", person)
if __name__ == "__main__":
try:
example_flow()
except requests.HTTPError as e:
# Try to include response body for easier debugging if available
resp = getattr(e, "response", None)
if resp is not None:
try:
body = resp.text
except Exception:
body = "<unreadable response body>"
print("HTTP error while talking to Folk API:", e)
print("Response status:", resp.status_code)
print("Response body:", body)
else:
print("HTTP error while talking to Folk API:", e)
sys.exit(1)
except Exception as e: # pragma: no cover - top-level safety
print("Unexpected error:", e)
sys.exit(1)
+9 -5
View File
@@ -49,7 +49,7 @@ class QueryProcessor:
"""Tool to search the web using google, provide the relevant query to get the information"""
logger.info(f"\nWeb Search Tool Called with query: {query}")
if query:
result = self.ddg_search.text(query, max_results=10, backend="google")
result = self.ddg_search.text(query, max_results=10)
return result
return "No query provided."
@@ -87,11 +87,15 @@ class QueryProcessor:
context_parts.append(f"Location: {investor_headquarters}")
if investor_description:
context_parts.append(f"Description: {investor_description}")
if investment_thesis:
thesis_str = ", ".join(investment_thesis[:3]) # Limit to first 3
if investment_thesis and isinstance(investment_thesis, list):
thesis_str = ", ".join(
str(item) for item in investment_thesis[:3]
) # Limit to first 3
context_parts.append(f"Investment Focus: {thesis_str}")
if portfolio_highlights:
portfolio_str = ", ".join(portfolio_highlights[:5]) # Limit to first 5
if portfolio_highlights and isinstance(portfolio_highlights, list):
portfolio_str = ", ".join(
str(item) for item in portfolio_highlights[:5]
) # Limit to first 5
context_parts.append(f"Notable Portfolio Companies: {portfolio_str}")
context = "\n".join(context_parts)
+97 -25
View File
@@ -145,16 +145,74 @@ Return the lower and upper bounds in USD."""
"""
Manually parse the JSON profile from the CSV.
Returns a cleaned dictionary with the investor profile data.
Handles JSON wrapped in markdown code blocks (```json ... ```).
Handles trailing quotes and extra data after JSON.
"""
if not json_str or pd.isna(json_str):
return None
try:
# Clean the JSON string
cleaned_json = json_str.strip()
# Check if it's plain text (no JSON structure)
if not cleaned_json.startswith(("{", "```", "'")):
print(" ⚠️ No JSON structure found - skipping")
return None
# Remove markdown code block markers if present
if cleaned_json.startswith("```"):
# Remove opening marker (```json or ```Json or ```)
lines = cleaned_json.split("\n")
if lines[0].startswith("```"):
lines = lines[1:] # Remove first line
# Remove closing marker (``` or ```')
if lines and lines[-1].strip() in ("```", "```'", '```"'):
lines = lines[:-1] # Remove last line
cleaned_json = "\n".join(lines).strip()
# Remove trailing quotes that might be left over
if cleaned_json.endswith(("'", '"')):
cleaned_json = cleaned_json[:-1].strip()
# Try to find JSON boundaries if there's extra data
# Look for the first { and the last }
start_idx = cleaned_json.find("{")
if start_idx == -1:
print(" ⚠️ No opening brace found - not valid JSON")
return None
# Find the matching closing brace
# We need to count braces to find the actual end
brace_count = 0
end_idx = -1
for i in range(start_idx, len(cleaned_json)):
if cleaned_json[i] == "{":
brace_count += 1
elif cleaned_json[i] == "}":
brace_count -= 1
if brace_count == 0:
end_idx = i + 1
break
if end_idx == -1:
print(" ⚠️ No matching closing brace found")
return None
# Extract just the JSON part
cleaned_json = cleaned_json[start_idx:end_idx]
# Parse JSON string
profile = json.loads(json_str)
profile = json.loads(cleaned_json)
return profile
except json.JSONDecodeError as e:
print(f"Error parsing JSON: {e}")
print(f" ❌ JSON parsing error: {e}")
# Print first 200 chars for debugging
preview = json_str[:200] if len(json_str) > 200 else json_str
print(f" Preview: {preview}...")
return None
except Exception as e:
print(f" ❌ Unexpected error: {e}")
return None
async def process_investor_profile(
@@ -338,34 +396,45 @@ Return the lower and upper bounds in USD."""
if existing_company:
# Update only founded_year on existing company
company = existing_company
updated_fields = []
if company_data.get("founded_year"):
company.founded_year = company_data["founded_year"]
updated_fields.append(
f"founded_year: {company_data['founded_year']}"
)
# Add/update company members (key executives)
# First, remove existing members if updating
db.query(CompanyMember).filter_by(company_id=company.id).delete()
exec_count = 0
for exec_data in company_data.get("key_executives", []):
member = CompanyMember(
name=exec_data.get("name"),
role=exec_data.get("title"),
linkedin=exec_data.get(
"source_url"
), # Store source URL in linkedin field
company_id=company.id,
)
db.add(member)
exec_count += 1
if exec_count > 0:
updated_fields.append(f"{exec_count} executives")
if updated_fields:
print(f" 📝 Updated: {', '.join(updated_fields)}")
return company
else:
# Company should already be in base database, but if not found, skip
print(
f"⚠️ Company '{company_data['name']}' not found in base database - skipping"
)
# Company not found in base database, skip
print(" ⚠️ Not in database - skipping")
return None
# Add/update company members (key executives)
# First, remove existing members if updating
db.query(CompanyMember).filter_by(company_id=company.id).delete()
for exec_data in company_data.get("key_executives", []):
member = CompanyMember(
name=exec_data.get("name"),
role=exec_data.get("title"),
linkedin=exec_data.get(
"source_url"
), # Store source URL in linkedin field
company_id=company.id,
)
db.add(member)
return company
except Exception as e:
print(f"Error saving company to database: {e}")
print(f"Error saving: {e}")
db.rollback()
return None
@@ -789,8 +858,11 @@ Return the lower and upper bounds in USD."""
if pd.notna(row.get("Investor"))
else None
)
# Try both column names for flexibility
profile_json = (
row.get("Final Investor Profile", "")
row.get("Perplexity Gap Output", "")
if pd.notna(row.get("Perplexity Gap Output"))
else row.get("Final Investor Profile", "")
if pd.notna(row.get("Final Investor Profile"))
else None
)
+144 -71
View File
@@ -1,29 +1,25 @@
import json
import asyncio
import hashlib
import logging
import os
from typing import List, Optional
from db.db import DATABASE_URL, get_db
from db.db import get_db
from db.models import FundTable, InvestorTable, ProjectTable
from langchain import hub
from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
from schemas.router_schemas import (
CompanyMinimal,
InvestmentResponse,
PaginatedResponse,
SectorMinimal,
)
from sqlalchemy import text
from sqlalchemy.orm import selectinload
from services.compatibility_score import calculate_project_investor_compatibility
logger = logging.getLogger(__name__)
# Connect to SQLite
prompt_template = hub.pull("langchain-ai/sql-agent-system-prompt")
db = SQLDatabase.from_uri(DATABASE_URL)
class QueryProcessor:
@@ -34,78 +30,155 @@ class QueryProcessor:
model="openai/gpt-4o-mini",
temperature=0,
)
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
# Update system message to specifically request only fund IDs
system_message_updated = (
prompt_template.format(dialect="SQLite", top_k=5)
+ "\n\n=== IMPORTANT TERMINOLOGY ==="
+ "\n- When users say 'investors' or 'find me investors', they mean FUNDS"
+ "\n- Always query the 'funds' table for investment opportunities"
+ "\n- The 'investors' table is for parent company information only"
+ "\n- Relationship: investors (1) -> (many) funds"
+ "\n\n=== YOUR TASK ==="
+ "\nReturn ONLY fund IDs (funds.id) that match the user's criteria."
+ "\nFormat: comma-separated numbers only (e.g., 1, 5, 12, 23)"
+ "\nNo explanations, no other data."
+ "\n\n=== QUERY GUIDELINES ==="
+ "\n1. For geographic searches: use funds.geographic_focus"
+ "\n2. For sector searches: JOIN with fund_sectors table"
+ "\n3. For stage searches: JOIN with fund_investment_stages table"
+ "\n4. If no results: respond with 'NO_RESULTS'"
+ "\n5. Never repeat the same failed query"
)
self.agent = create_react_agent(
model=self.llm,
tools=self.toolkit.get_tools(),
prompt=system_message_updated,
# Query cache for performance
self.query_cache = {}
# SQL generation prompt
self.sql_prompt = ChatPromptTemplate.from_messages(
[
(
"system",
"""You are a SQL expert. Generate a SQLite query to find fund IDs based on user requirements.
Database Schema:
- funds: id, fund_name, investor_id, check_size_lower, check_size_upper, geographic_focus
- fund_sectors: fund_id, sector_id
- fund_investment_stages: fund_id, stage_id
- sectors: id, name
- investment_stages: id, name
- investors: id, name, aum
IMPORTANT RULES:
1. ALWAYS return ONLY fund IDs (funds.id) - use SELECT DISTINCT f.id
2. For geography: Be FLEXIBLE - use OR with variations and partial matches
- 'Europe' → WHERE geographic_focus LIKE '%Europe%' OR geographic_focus LIKE '%European%'
- 'America' → WHERE geographic_focus LIKE '%America%' OR geographic_focus LIKE '%US%' OR geographic_focus LIKE '%United States%'
- 'Asia' → WHERE geographic_focus LIKE '%Asia%' OR geographic_focus LIKE '%Asian%'
- If no geography specified, DON'T filter by geography
3. For stages: Use LEFT JOIN and LIKE for flexible matching with synonyms
- 'Seed' → s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%'
- 'Series A' → s.name LIKE '%Series A%' OR s.name LIKE '%A%'
- 'Growth' → s.name LIKE '%Growth%' OR s.name LIKE '%Late%' OR s.name LIKE '%Expansion%'
- If stage not specified, include ALL funds
4. For sectors: Use LEFT JOIN and include related terms with OR
- 'Fintech' → sec.name LIKE '%Fintech%' OR sec.name LIKE '%Finance%' OR sec.name LIKE '%Financial%'
- 'AI' → sec.name LIKE '%AI%' OR sec.name LIKE '%Artificial Intelligence%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%'
- 'Healthcare' → sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%'
5. For check size filters (be flexible with ranges):
- "under X" → WHERE (check_size_upper <= X OR check_size_upper IS NULL)
- "over X" → WHERE (check_size_lower >= X OR check_size_lower IS NULL)
- "between X and Y" → WHERE check_size_lower >= X AND check_size_upper <= Y
6. Use LEFT JOIN for stages and sectors so funds without tags still match
7. Use DISTINCT to avoid duplicates from joins
8. Be INCLUSIVE - use OR conditions to cast a wider net
9. If query is very simple (e.g., just "seed stage"), don't add unnecessary filters
10. Return a single, complete SELECT query
Example Queries:
Q: "Seed stage investors in Europe"
A: SELECT DISTINCT f.id FROM funds f
LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id
LEFT JOIN investment_stages s ON fis.stage_id = s.id
WHERE (s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%' OR s.id IS NULL)
AND (f.geographic_focus LIKE '%Europe%' OR f.geographic_focus LIKE '%European%')
Q: "Fintech investors with check size under 5 million"
A: SELECT DISTINCT f.id FROM funds f
LEFT JOIN fund_sectors fs ON f.id = fs.fund_id
LEFT JOIN sectors sec ON fs.sector_id = sec.id
WHERE (sec.name LIKE '%Fintech%' OR sec.name LIKE '%Finance%' OR sec.name LIKE '%Financial%' OR sec.id IS NULL)
AND (f.check_size_upper <= 5000000 OR f.check_size_upper IS NULL)
Q: "Seed stage investors"
A: SELECT DISTINCT f.id FROM funds f
LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id
LEFT JOIN investment_stages s ON fis.stage_id = s.id
WHERE s.name LIKE '%Seed%' OR s.name LIKE '%Pre-Seed%' OR s.name LIKE '%Early%'
Q: "Growth stage investors"
A: SELECT DISTINCT f.id FROM funds f
LEFT JOIN fund_investment_stages fis ON f.id = fis.fund_id
LEFT JOIN investment_stages s ON fis.stage_id = s.id
WHERE s.name LIKE '%Growth%' OR s.name LIKE '%Late%' OR s.name LIKE '%Expansion%' OR s.name LIKE '%Series C%' OR s.name LIKE '%Series D%'
Q: "AI investors in America"
A: SELECT DISTINCT f.id FROM funds f
LEFT JOIN fund_sectors fs ON f.id = fs.fund_id
LEFT JOIN sectors sec ON fs.sector_id = sec.id
WHERE (sec.name LIKE '%AI%' OR sec.name LIKE '%Artificial Intelligence%' OR sec.name LIKE '%Machine Learning%' OR sec.name LIKE '%ML%')
AND (f.geographic_focus LIKE '%America%' OR f.geographic_focus LIKE '%US%' OR f.geographic_focus LIKE '%United States%' OR f.geographic_focus LIKE '%USA%')
Q: "Healthcare investors"
A: SELECT DISTINCT f.id FROM funds f
LEFT JOIN fund_sectors fs ON f.id = fs.fund_id
LEFT JOIN sectors sec ON fs.sector_id = sec.id
WHERE sec.name LIKE '%Healthcare%' OR sec.name LIKE '%Health%' OR sec.name LIKE '%Medical%' OR sec.name LIKE '%Biotech%' OR sec.name LIKE '%Pharma%'
IMPORTANT: Use LEFT JOIN so funds without sector/stage tags can still match. Include synonym terms with OR for better recall.
Return ONLY the SQL query, no explanations or markdown.""",
),
("user", "{question}"),
]
)
def process_query(
def _get_cache_key(self, question: str) -> str:
"""Generate cache key from normalized question."""
return hashlib.md5(question.lower().strip().encode()).hexdigest()
async def process_query(
self, question: str, project_id: Optional[int] = None
) -> PaginatedResponse[InvestmentResponse]:
"""Process a query using the LLM and return investment response data.
Args:
question: The natural language query to process
project_id: Optional project ID for compatibility scoring
"""Async wrapper for process_query. Runs blocking work in a thread to avoid
blocking the event loop.
"""
# Let the LLM handle all database interactions and filtering to get fund IDs
response = self.agent.invoke(
{"messages": [("user", question)]},
config={"recursion_limit": 50},
)
return await asyncio.to_thread(self._process_query_sync, question, project_id)
# Extract the actual message content
logger.info(f"{response}")
final_message_content = response["messages"][-1].content
logger.info(f"AI Response: \n{final_message_content}")
# Extract fund IDs from the AI response
fund_ids = self._extract_fund_ids_from_response(final_message_content)
def _process_query_sync(
self, question: str, project_id: Optional[int] = None
) -> PaginatedResponse[InvestmentResponse]:
"""Synchronous implementation of process_query. This is run in a thread by
the async wrapper above.
"""
cache_key = self._get_cache_key(question)
# Fetch full fund data with investor relationships using the IDs
return self._fetch_funds_by_ids(fund_ids, project_id)
# Check cache first
if cache_key in self.query_cache:
sql_query = self.query_cache[cache_key]
logger.info(f"Using cached SQL: {sql_query}")
else:
# Generate SQL query
messages = self.sql_prompt.format_messages(question=question)
response = self.llm.invoke(messages)
sql_query = response.content.strip()
def _extract_fund_ids_from_response(self, ai_response: str) -> List[int]:
"""Extract fund IDs from AI response."""
import re
# Clean up SQL (remove markdown code blocks if present)
sql_query = sql_query.replace("```sql", "").replace("```", "").strip()
fund_ids = []
# Cache the query
self.query_cache[cache_key] = sql_query
logger.info(f"Generated SQL: {sql_query}")
# Execute query to get fund IDs
db_session = next(get_db())
try:
# Try multiple patterns to extract IDs from the response
# Pattern 1: Simple numbers (assuming they are IDs)
numbers = re.findall(r"\b\d+\b", ai_response)
fund_ids = [int(num) for num in numbers]
# Pattern 2: If response contains explicit ID references
id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower())
if id_matches:
fund_ids = [int(id_str) for id_str in id_matches]
result = db_session.execute(text(sql_query))
fund_ids = [row[0] for row in result.fetchall()]
logger.info(
f"Found {len(fund_ids)} fund IDs: {fund_ids[:10]}{'...' if len(fund_ids) > 10 else ''}"
)
return self._fetch_funds_by_ids(fund_ids, project_id)
except Exception as e:
print(f"Error extracting IDs from response: {e}")
return []
return fund_ids
logger.error(f"SQL execution error: {e}")
logger.error(f"Failed SQL: {sql_query}")
# Return empty result
return PaginatedResponse(
items=[], total=0, page=1, page_size=10, total_pages=0
)
finally:
db_session.close()
def _fetch_funds_by_ids(
self, fund_ids: List[int], project_id: Optional[int] = None
@@ -185,10 +258,10 @@ class QueryProcessor:
else None
)
# Get top 3 sectors from fund (id and name only)
# Get top 3 sectors from fund (id and name only) - sorted alphabetically
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in (fund.sectors[:3] if fund.sectors else [])
for sector in sorted(fund.sectors[:3] if fund.sectors else [], key=lambda s: s.name)
]
investment_response = InvestmentResponse(
+164 -83
View File
@@ -1,9 +1,13 @@
from pathlib import Path
from typing import Any, Dict, List, Optional
# Import database models and compatibility score service
from db.models import InvestorTable, ProjectTable
from jinja2 import Environment, FileSystemLoader
from playwright.async_api import async_playwright
from services.compatibility_score import calculate_project_investor_compatibility
class ReportGenerator:
"""Service for generating PDF reports from HTML templates"""
@@ -17,6 +21,8 @@ class ReportGenerator:
self,
investor_data: Dict[str, Any],
project_data: Optional[Dict[str, Any]] = None,
investor_model: Optional[InvestorTable] = None,
project_model: Optional[ProjectTable] = None,
) -> bytes:
"""
Generate a PDF report for an investor profile.
@@ -24,12 +30,16 @@ class ReportGenerator:
Args:
investor_data: Dictionary containing investor information
project_data: Optional dictionary containing project information for compatibility analysis
investor_model: Optional database model for investor (used for compatibility scoring)
project_model: Optional database model for project (used for compatibility scoring)
Returns:
bytes: PDF file content
"""
# Prepare template context
context = self._prepare_context(investor_data, project_data)
context = self._prepare_context(
investor_data, project_data, investor_model, project_model
)
# Render HTML from template
template = self.env.get_template("report.html")
@@ -43,6 +53,8 @@ class ReportGenerator:
self,
investor_data: Dict[str, Any],
project_data: Optional[Dict[str, Any]] = None,
investor_model: Optional[InvestorTable] = None,
project_model: Optional[ProjectTable] = None,
) -> Dict[str, Any]:
"""Prepare the context dictionary for template rendering"""
context = {
@@ -55,9 +67,20 @@ class ReportGenerator:
# If project data is provided, calculate compatibility
if project_data:
context["compatibility_score"] = self._calculate_compatibility_score(
investor_data, project_data
)
# Use the compatibility_score service if models are provided
if investor_model and project_model:
# Calculate using the standardized compatibility score service
# Returns score between 0 and 1, convert to percentage (0-100)
score_decimal = calculate_project_investor_compatibility(
project=project_model, investor=investor_model, use_funds=True
)
context["compatibility_score"] = int(score_decimal * 100)
else:
# Fallback to old calculation method if models not provided
context["compatibility_score"] = self._calculate_compatibility_score(
investor_data, project_data
)
context["match_criteria"] = self._generate_match_criteria(
investor_data, project_data
)
@@ -76,43 +99,75 @@ class ReportGenerator:
"sector": 30,
"stage": 30,
"geography": 20,
"check_size": 15,
"thesis": 5,
"check_size": 20,
}
# Aggregate data from all funds
all_sectors = set(investor_data.get("sectors", []))
all_stages = set()
all_geographies = []
check_ranges = []
for fund in investor_data.get("funds", []):
all_sectors.update(fund.get("sectors", []))
all_stages.update(fund.get("investment_stages", []))
if fund.get("geographic_focus"):
all_geographies.append(fund["geographic_focus"])
if fund.get("check_size_lower") and fund.get("check_size_upper"):
check_ranges.append(
{
"lower": fund["check_size_lower"],
"upper": fund["check_size_upper"],
}
)
# Sector match
investor_sectors = set(investor_data.get("sectors", []))
project_sectors = set(project_data.get("sectors", []))
if investor_sectors and project_sectors:
if investor_sectors & project_sectors:
if all_sectors and project_sectors:
if all_sectors & project_sectors:
score += weights["sector"]
# Stage match
investor_stages = set(investor_data.get("investment_stages", []))
# Stage match - case insensitive comparison
project_stage = project_data.get("stage")
if project_stage and project_stage in investor_stages:
score += weights["stage"]
if project_stage and all_stages:
# Normalize stage names for comparison (case-insensitive)
normalized_stages = {
stage.lower().replace("_", " ") for stage in all_stages
}
project_stage_normalized = project_stage.lower().replace("_", " ")
if project_stage_normalized in normalized_stages:
score += weights["stage"]
# Geography match
investor_geo = (investor_data.get("geographic_focus") or "").lower()
# Geography match - check if any fund matches
project_geo = (project_data.get("location") or "").lower()
if investor_geo and project_geo and investor_geo in project_geo:
geo_match = False
if all_geographies:
for geo in all_geographies:
if geo:
geo_lower = geo.lower()
# Match if investor geography is "global" or if there's a location overlap
if "global" in geo_lower or "worldwide" in geo_lower:
geo_match = True
break
if project_geo and (
geo_lower in project_geo or project_geo in geo_lower
):
geo_match = True
break
if geo_match:
score += weights["geography"]
# Check size match
# Check size match - check if any fund's range matches
project_valuation = project_data.get("valuation", 0)
check_lower = investor_data.get("check_size_lower") or 0
check_upper = investor_data.get("check_size_upper") or float("inf")
if (
check_lower
and check_upper
and check_lower <= project_valuation <= check_upper
):
check_match = False
if project_valuation and check_ranges:
for check_range in check_ranges:
if check_range["lower"] <= project_valuation <= check_range["upper"]:
check_match = True
break
if check_match:
score += weights["check_size"]
# Thesis alignment (simplified)
score += weights["thesis"]
return min(score, 100)
def _generate_match_criteria(
@@ -121,86 +176,124 @@ class ReportGenerator:
"""Generate detailed match criteria table"""
criteria = []
# Aggregate data from all funds
all_sectors = set(investor_data.get("sectors", []))
all_stages = set()
all_geographies = []
check_ranges = []
for fund in investor_data.get("funds", []):
all_sectors.update(fund.get("sectors", []))
all_stages.update(fund.get("investment_stages", []))
if fund.get("geographic_focus"):
all_geographies.append(fund["geographic_focus"])
if fund.get("check_size_lower") and fund.get("check_size_upper"):
check_ranges.append(
{
"lower": fund["check_size_lower"],
"upper": fund["check_size_upper"],
"fund_name": fund.get("fund_name", "Unnamed Fund"),
}
)
# Sector criterion
investor_sectors = investor_data.get("sectors", [])
project_sectors = project_data.get("sectors", [])
sector_match = (
"Perfect" if set(investor_sectors) & set(project_sectors) else "Mismatch"
)
sector_match = "Perfect" if all_sectors & set(project_sectors) else "Mismatch"
criteria.append(
{
"name": "Sector",
"requirement": "Cybersecurity, B2B SaaS" if project_sectors else "N/A",
"evidence": ", ".join(investor_sectors[:3])
if investor_sectors
else "N/A",
"requirement": ", ".join(project_sectors) if project_sectors else "N/A",
"evidence": ", ".join(list(all_sectors)[:3]) if all_sectors else "N/A",
"match": sector_match,
"weight": "30%",
}
)
# Stage criterion
investor_stages = investor_data.get("investment_stages", [])
# Stage criterion - case insensitive comparison
project_stage = project_data.get("stage", "N/A")
stage_match = "Perfect" if project_stage in investor_stages else "Mismatch"
stage_match = "Mismatch"
if project_stage != "N/A" and all_stages:
# Normalize stage names for comparison
normalized_stages = {
stage.lower().replace("_", " ") for stage in all_stages
}
project_stage_normalized = project_stage.lower().replace("_", " ")
stage_match = (
"Perfect"
if project_stage_normalized in normalized_stages
else "Mismatch"
)
elif project_stage == "N/A":
stage_match = "N/A"
criteria.append(
{
"name": "Stage",
"requirement": str(project_stage),
"evidence": ", ".join(investor_stages) if investor_stages else "N/A",
"evidence": ", ".join(all_stages) if all_stages else "N/A",
"match": stage_match,
"weight": "30%",
}
)
# Geography criterion
investor_geo = investor_data.get("geographic_focus") or "N/A"
project_geo = project_data.get("location") or "N/A"
investor_geo_display = ", ".join(all_geographies) if all_geographies else "N/A"
# Safe comparison handling None values and "Global" matches
geo_match = "Mismatch"
if project_geo != "N/A" and all_geographies:
for geo in all_geographies:
if geo:
geo_lower = geo.lower()
# Match if investor geography is "global" or if there's a location overlap
if "global" in geo_lower or "worldwide" in geo_lower:
geo_match = "Perfect"
break
if (
geo_lower in project_geo.lower()
or project_geo.lower() in geo_lower
):
geo_match = "Strong"
break
elif not all_geographies and project_geo == "N/A":
geo_match = "N/A"
# Safe comparison handling None values
if investor_geo == "N/A" or project_geo == "N/A":
geo_match = (
"N/A" if investor_geo == "N/A" and project_geo == "N/A" else "Mismatch"
)
else:
investor_geo_lower = investor_geo.lower()
project_geo_lower = project_geo.lower()
geo_match = (
"Strong"
if investor_geo_lower in project_geo_lower
or project_geo_lower in investor_geo_lower
else "Mismatch"
)
criteria.append(
{
"name": "Geography",
"requirement": project_geo,
"evidence": investor_geo,
"evidence": investor_geo_display,
"match": geo_match,
"weight": "20%",
}
)
# Check Size criterion
check_lower = investor_data.get("check_size_lower") or 0
check_upper = investor_data.get("check_size_upper") or 0
project_val = project_data.get("valuation", 0)
# Build evidence string from all fund ranges
check_evidence = "N/A"
if check_lower and check_upper:
check_evidence = (
f"{check_lower / 1000000:.0f}M - €{check_upper / 1000000:.0f}M"
)
elif check_lower:
check_evidence = f"{check_lower / 1000000:.0f}M+"
if check_ranges:
evidence_parts = []
for cr in check_ranges[:3]: # Show up to 3 funds
range_str = (
f"{cr['lower'] / 1000000:.0f}M - €{cr['upper'] / 1000000:.0f}M"
)
if cr["fund_name"]:
evidence_parts.append(f"{cr['fund_name']}: {range_str}")
else:
evidence_parts.append(range_str)
check_evidence = "; ".join(evidence_parts)
# Check if project valuation matches any fund
check_match = "N/A"
if project_val > 0 and check_ranges:
match_found = any(
cr["lower"] <= project_val <= cr["upper"] for cr in check_ranges
)
check_match = "Perfect" if match_found else "Mismatch"
check_match = (
"Perfect"
if check_lower and check_upper and check_lower <= project_val <= check_upper
else "Strong"
if project_val > 0
else "N/A"
)
criteria.append(
{
"name": "Check Size",
@@ -209,19 +302,7 @@ class ReportGenerator:
else "N/A",
"evidence": check_evidence,
"match": check_match,
"weight": "15%",
}
)
# Thesis criterion
thesis = investor_data.get("investment_thesis", [])
criteria.append(
{
"name": "Thesis",
"requirement": "Founder-led, ESG focus",
"evidence": ", ".join(thesis[:2]) if thesis else "Entrepreneur-led",
"match": "Strong",
"weight": "5%",
"weight": "20%",
}
)
+38 -31
View File
@@ -161,13 +161,6 @@
</p>
</div>
<div>
<p class="text-xs text-gray-600">DACH Region:</p>
<p class="font-semibold text-gray-900">
{{ investor.geographic_focus or 'N/A' }}
</p>
</div>
<div>
<p class="text-xs text-gray-600">AUM (EUR million):</p>
<p class="font-semibold text-gray-900">
@@ -179,33 +172,47 @@
</p>
</div>
<div class="mb-4">
<p class="text-xs text-gray-600 mb-1">
Investment Stage:
</p>
<p class="text-sm font-semibold text-gray-900">
{% if investor.investment_stages %} {{
investor.investment_stages | join(', ') }} {% else
%} N/A {% endif %}
</p>
</div>
<div class="mb-4">
<p class="text-xs text-gray-600 mb-1">
Est. Investment Size:
</p>
<p class="text-sm font-semibold text-gray-900">
{% if investor.check_size_lower and
investor.check_size_upper %} €{{
'{:,.0f}'.format(investor.check_size_lower /
1000000) }}M - €{{
'{:,.0f}'.format(investor.check_size_upper /
1000000) }}M {% elif investor.check_size_lower %}
€{{ '{:,.0f}'.format(investor.check_size_lower /
1000000) }}M+ {% else %} N/A {% endif %}
<div>
<p class="text-xs text-gray-600 mb-1">Number of Funds:</p>
<p class="font-semibold text-gray-900">
{{ investor.funds | length if investor.funds else 'N/A' }}
</p>
</div>
</div>
<div class="mt-4">
<h3 class="text-xs font-bold text-gray-900 uppercase mb-2">
Fund Details
</h3>
{% if investor.funds %}
{% for fund in investor.funds %}
<div class="mb-3 pb-3 border-b border-gray-200">
<p class="text-sm font-semibold text-gray-900 mb-1">
{{ fund.fund_name or 'Fund ' + loop.index|string }}
</p>
<div class="text-xs text-gray-700 space-y-1">
{% if fund.fund_size %}
<p>Fund Size: €{{ '{:,.0f}'.format(fund.fund_size / 1000000) }}M</p>
{% endif %}
{% if fund.check_size_lower and fund.check_size_upper %}
<p>Check Size: €{{ '{:,.0f}'.format(fund.check_size_lower / 1000000) }}M - €{{ '{:,.0f}'.format(fund.check_size_upper / 1000000) }}M</p>
{% endif %}
{% if fund.geographic_focus %}
<p>Geography: {{ fund.geographic_focus }}</p>
{% endif %}
{% if fund.investment_stages %}
<p>Stages: {{ fund.investment_stages | join(', ') }}</p>
{% endif %}
{% if fund.sectors %}
<p>Sectors: {{ fund.sectors[:3] | join(', ') }}</p>
{% endif %}
</div>
</div>
{% endfor %}
{% else %}
<p class="text-xs text-gray-500">No fund information available</p>
{% endif %}
</div>
</div>
</div>
BIN
View File
Binary file not shown.
+117
View File
@@ -0,0 +1,117 @@
"""
Migration: Add fields from feedback fixes
Date: 2025-01-07
Adds the following fields:
- projects.is_archived (INTEGER, default 0)
- companies.product_service (TEXT, nullable)
- companies.clients (TEXT, nullable - stored as JSON string)
- investor_members.linkedin (VARCHAR, nullable)
"""
import sys
from pathlib import Path
# Add parent directory to path to import app modules
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import text
from app.db.db import engine
def check_column_exists(conn, table_name, column_name):
"""Check if a column exists in a table"""
result = conn.execute(text(f"PRAGMA table_info({table_name})"))
columns = [row[1] for row in result]
return column_name in columns
def upgrade():
"""Add new columns to tables"""
print("Running migration: Add feedback fixes fields")
print("=" * 60)
with engine.begin() as conn: # Use begin() for transaction management
# 1. Add is_archived to projects table
print("\n1. Adding 'is_archived' column to projects table...")
if check_column_exists(conn, "projects", "is_archived"):
print(" ✓ Column 'is_archived' already exists. Skipping.")
else:
conn.execute(
text(
"ALTER TABLE projects ADD COLUMN is_archived INTEGER DEFAULT 0 NOT NULL"
)
)
# Set default value for existing rows
conn.execute(
text("UPDATE projects SET is_archived = 0 WHERE is_archived IS NULL")
)
print(" ✓ Successfully added 'is_archived' column to projects table")
# 2. Add product_service to companies table
print("\n2. Adding 'product_service' column to companies table...")
if check_column_exists(conn, "companies", "product_service"):
print(" ✓ Column 'product_service' already exists. Skipping.")
else:
conn.execute(text("ALTER TABLE companies ADD COLUMN product_service TEXT"))
print(" ✓ Successfully added 'product_service' column to companies table")
# 3. Add clients to companies table
print("\n3. Adding 'clients' column to companies table...")
if check_column_exists(conn, "companies", "clients"):
print(" ✓ Column 'clients' already exists. Skipping.")
else:
conn.execute(text("ALTER TABLE companies ADD COLUMN clients TEXT"))
print(" ✓ Successfully added 'clients' column to companies table")
# 4. Add linkedin to investor_members table
print("\n4. Adding 'linkedin' column to investor_members table...")
if check_column_exists(conn, "investor_members", "linkedin"):
print(" ✓ Column 'linkedin' already exists. Skipping.")
else:
conn.execute(
text("ALTER TABLE investor_members ADD COLUMN linkedin VARCHAR")
)
print(" ✓ Successfully added 'linkedin' column to investor_members table")
print("\n" + "=" * 60)
print("Migration completed successfully!")
def downgrade():
"""Remove added columns from tables"""
print("Running downgrade: Remove feedback fixes fields")
print("=" * 60)
# Note: SQLite doesn't support DROP COLUMN directly
print("\nWarning: SQLite doesn't support DROP COLUMN directly.")
print("To remove these columns, you would need to:")
print("1. Create new tables without the columns")
print("2. Copy data from old tables to new tables")
print("3. Drop old tables and rename new tables")
print("\nColumns to remove:")
print(" - projects.is_archived")
print(" - companies.product_service")
print(" - companies.clients")
print(" - investor_members.linkedin")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run database migration")
parser.add_argument(
"direction",
choices=["upgrade", "downgrade"],
default="upgrade",
nargs="?",
help="Migration direction (default: upgrade)",
)
args = parser.parse_args()
if args.direction == "upgrade":
upgrade()
else:
downgrade()
+67
View File
@@ -0,0 +1,67 @@
"""
Migration: Add industry column to projects table
Date: 2025-10-23
"""
import os
import sys
from pathlib import Path
# Add parent directory to path to import app modules
sys.path.insert(0, str(Path(__file__).parent.parent))
from sqlalchemy import create_engine, text
from app.db.db import DATABASE_URL, engine
def upgrade():
"""Add industry column to projects table"""
print("Running migration: Add industry column to projects table")
with engine.connect() as conn:
# Check if column already exists
result = conn.execute(text("PRAGMA table_info(projects)"))
columns = [row[1] for row in result]
if 'industry' in columns:
print("Column 'industry' already exists in projects table. Skipping migration.")
return
# Add the industry column
conn.execute(text("ALTER TABLE projects ADD COLUMN industry VARCHAR"))
conn.commit()
print("Successfully added 'industry' column to projects table")
def downgrade():
"""Remove industry column from projects table"""
print("Running downgrade: Remove industry column from projects table")
# Note: SQLite doesn't support DROP COLUMN directly
# This is a simplified version - in production you'd need to recreate the table
print("Warning: SQLite doesn't support DROP COLUMN.")
print("To remove the column, you would need to:")
print("1. Create a new table without the industry column")
print("2. Copy data from old table to new table")
print("3. Drop old table and rename new table")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Run database migration")
parser.add_argument(
"direction",
choices=["upgrade", "downgrade"],
default="upgrade",
nargs="?",
help="Migration direction (default: upgrade)"
)
args = parser.parse_args()
if args.direction == "upgrade":
upgrade()
else:
downgrade()
+68
View File
@@ -0,0 +1,68 @@
#!/bin/bash
# Server management script for app/main.py
# Usage: ./server_manager.sh start|stop|restart
PID_FILE="server.pid"
LOG_FILE="server.log"
start() {
if [ -f "$PID_FILE" ] && kill -0 $(cat "$PID_FILE") 2>/dev/null; then
echo "Server is already running (PID: $(cat "$PID_FILE"))"
return 1
fi
echo "Starting server..."
nohup uv run app/main.py > "$LOG_FILE" 2>&1 &
echo $! > "$PID_FILE"
echo "Server started (PID: $(cat "$PID_FILE"))"
}
stop() {
if [ ! -f "$PID_FILE" ]; then
echo "Server is not running (no PID file found)"
return 1
fi
PID=$(cat "$PID_FILE")
if ! kill -0 "$PID" 2>/dev/null; then
echo "Server is not running (PID $PID not found)"
rm -f "$PID_FILE"
return 1
fi
echo "Stopping server (PID: $PID)..."
kill "$PID"
# Wait for process to stop
for i in {1..10}; do
if ! kill -0 "$PID" 2>/dev/null; then
break
fi
sleep 1
done
if kill -0 "$PID" 2>/dev/null; then
echo "Force killing server..."
kill -9 "$PID"
fi
rm -f "$PID_FILE"
echo "Server stopped"
}
restart() {
stop
sleep 2
start
}
case "$1" in
start)
start
;;
stop)
stop
;;
restart)
restart
;;
*)
echo "Usage: $0 {start|stop|restart}"
exit 1
;;
esac
+310
View File
@@ -0,0 +1,310 @@
#!/usr/bin/env python3
"""
Update Investor Members LinkedIn Profiles Script
This script finds and updates LinkedIn profile URLs for investor members in the database.
Uses crawl4ai to efficiently scrape team pages and extract LinkedIn URLs.
Usage:
python update_linkedin_profiles.py [--test] [--limit N] [--skip-existing]
Options:
--test Test mode: process only 10 records and don't update database
--limit N Process only N records (default: all)
--skip-existing Skip members that already have LinkedIn URLs
--start-from N Start from record N (for resuming)
"""
import argparse
import asyncio
import json
import os
import sys
from datetime import datetime
# Add app to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "app"))
from db.db import get_db_session
from db.models import InvestorMember, InvestorTable
from linkedin_scraper import LinkedInProfileScraper, format_linkedin_url
def progress_callback(current, total, result):
"""Print progress updates"""
percent = (current / total) * 100
status = "" if result["linkedin_url"] else ""
print(f"[{current}/{total} - {percent:.1f}%] {status} {result['member_name']}")
if result["linkedin_url"]:
print(
f"{result['linkedin_url']} (confidence: {result['confidence']}%, method: {result['method']})"
)
def create_db_callback(test_mode=False):
"""
Create a callback function that saves LinkedIn profiles to the database immediately.
This allows stopping and resuming without losing progress.
"""
saved_count = {"count": 0} # Use dict to allow modification in closure
def db_callback(member_id: int, linkedin_url: str) -> bool:
"""Save LinkedIn URL to database immediately"""
if test_mode:
print(f" [TEST] Would save to DB: member {member_id}")
saved_count["count"] += 1
return True
try:
db = get_db_session()
member = db.query(InvestorMember).filter_by(id=member_id).first()
if member:
member.linkedin = format_linkedin_url(linkedin_url)
db.commit()
saved_count["count"] += 1
return True
except Exception as e:
print(f" ⚠️ DB Error for member {member_id}: {e}")
try:
db.rollback()
except Exception:
pass
return False
finally:
try:
db.close()
except Exception:
pass
return False
return db_callback, saved_count
def update_database(members_data, test_mode=False):
"""Update database with found LinkedIn profiles"""
db = get_db_session()
try:
updated_count = 0
for data in members_data:
if data["linkedin_url"] and data["member_id"]:
if not test_mode:
member = (
db.query(InvestorMember).filter_by(id=data["member_id"]).first()
)
if member:
member.linkedin = format_linkedin_url(data["linkedin_url"])
updated_count += 1
else:
print(
f" [TEST MODE] Would update member {data['member_id']}: {data['linkedin_url']}"
)
updated_count += 1
if not test_mode:
db.commit()
print(f"\n✓ Successfully updated {updated_count} records in database")
else:
print(f"\n[TEST MODE] Would have updated {updated_count} records")
return updated_count
except Exception as e:
db.rollback()
print(f"\n✗ Error updating database: {e}")
raise
finally:
db.close()
def save_results(results, filename="linkedin_scraping_results.json"):
"""Save results to JSON file for backup/analysis"""
output = {
"timestamp": datetime.now().isoformat(),
"total_processed": len(results),
"found_count": sum(1 for r in results if r["linkedin_url"]),
"results": results,
}
with open(filename, "w") as f:
json.dump(output, f, indent=2)
print(f"\n✓ Results saved to {filename}")
def print_summary(results):
"""Print summary statistics"""
total = len(results)
found = sum(1 for r in results if r["linkedin_url"])
not_found = total - found
# Count by method
methods = {}
for r in results:
if r["linkedin_url"]:
method = r["method"]
methods[method] = methods.get(method, 0) + 1
# Average confidence for found profiles
avg_confidence = (
sum(r["confidence"] for r in results if r["linkedin_url"]) / found
if found > 0
else 0
)
print("\n" + "=" * 60)
print("SUMMARY")
print("=" * 60)
print(f"Total processed: {total}")
print(f"LinkedIn found: {found} ({found / total * 100:.1f}%)")
print(f"Not found: {not_found} ({not_found / total * 100:.1f}%)")
print(f"\nAverage confidence: {avg_confidence:.1f}%")
print("\nMethods used:")
for method, count in sorted(methods.items(), key=lambda x: x[1], reverse=True):
print(f" {method:20s} {count:5d} ({count / found * 100:.1f}%)")
print("=" * 60)
def main():
parser = argparse.ArgumentParser(
description="Update LinkedIn profiles for investor members"
)
parser.add_argument(
"--test",
action="store_true",
help="Test mode: process only 10 records without updating database",
)
parser.add_argument("--limit", type=int, help="Limit number of records to process")
parser.add_argument(
"--skip-existing",
action="store_true",
help="Skip members that already have LinkedIn URLs",
)
parser.add_argument(
"--start-from",
type=int,
default=0,
help="Start from record N (for resuming interrupted runs)",
)
parser.add_argument(
"--rate-limit",
type=float,
default=0.5,
help="Delay between URL crawls in seconds (default: 0.5)",
)
args = parser.parse_args()
# Test mode overrides limit
if args.test and not args.limit:
args.limit = 10
print("=" * 60)
print("LinkedIn Profile Scraper for Investor Members (crawl4ai)")
print("=" * 60)
if args.test:
print("\n⚠️ TEST MODE - No database changes will be made")
# Initialize database and scraper
db = get_db_session()
try:
# Build query
query = db.query(InvestorMember, InvestorTable).join(
InvestorTable, InvestorMember.investor_id == InvestorTable.id
)
# Filter existing if requested
if args.skip_existing:
query = query.filter(
(InvestorMember.linkedin.is_(None)) | (InvestorMember.linkedin == "")
)
print("\n✓ Filtering to members without LinkedIn profiles")
# Get total count
total_available = query.count()
print(f"\n✓ Found {total_available} members to process")
# Apply offset and limit
if args.start_from > 0:
query = query.offset(args.start_from)
print(f"✓ Starting from record {args.start_from}")
if args.limit:
query = query.limit(args.limit)
print(f"✓ Processing {args.limit} records")
# Fetch members
members_data = []
for member, investor in query.all():
members_data.append(
{
"id": member.id,
"name": member.name,
"company": investor.name,
"role": member.role,
"source_url": member.source_url,
}
)
if not members_data:
print("\n⚠️ No members to process")
return
# Count unique source URLs
unique_urls = len(set(m["source_url"] for m in members_data if m["source_url"]))
with_urls = sum(1 for m in members_data if m["source_url"])
print(f"\n✓ Loaded {len(members_data)} members")
print(
f"{with_urls} members have source URLs ({unique_urls} unique pages to crawl)"
)
print(f"{len(members_data) - with_urls} members without source URLs")
print(f"✓ Rate limit: {args.rate_limit}s between page crawls")
print("\nStarting LinkedIn profile search using crawl4ai...\n")
finally:
db.close()
# Initialize scraper
scraper = LinkedInProfileScraper(rate_limit_delay=args.rate_limit, use_cache=True)
print("️ Using crawl4ai to scrape team pages and extract LinkedIn URLs")
print(
"️ Profiles are saved to database IMMEDIATELY when found - safe to stop anytime!\n"
)
# Create database callback for real-time saving
db_callback, saved_count = create_db_callback(test_mode=args.test)
# Process members asynchronously with real-time DB saving
results = asyncio.run(
scraper.batch_find_profiles(
members_data, progress_callback=progress_callback, db_callback=db_callback
)
)
# Print summary
print_summary(results)
# Save results
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
results_file = f"linkedin_results_{timestamp}.json"
save_results(results, results_file)
# Show database update summary
if not args.test:
print(
f"\n✓ Database updated in real-time: {saved_count['count']} profiles saved"
)
else:
print(
f"\n[TEST MODE] Would have saved {saved_count['count']} profiles to database"
)
print("\n✓ Done! You can resume anytime with --skip-existing")
if __name__ == "__main__":
main()