Compare commits

6 Commits

20 changed files with 768 additions and 534 deletions
Binary file not shown.
Binary file not shown.
+1 -1
View File
@@ -14,7 +14,7 @@ Base = declarative_base()
# Get absolute path to the preprocessor database # Get absolute path to the preprocessor database
# APP_DIR = Path(__file__).parent.parent # APP_DIR = Path(__file__).parent.parent
# PREPROCESSOR_DB = APP_DIR.parent / "preprocessor" / "version_two.db" # PREPROCESSOR_DB = APP_DIR.parent / "preprocessor" / "version_two.db"
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./version_two.db") DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./investors.db")
# Create engine # Create engine
engine = create_engine(DATABASE_URL, echo=False) engine = create_engine(DATABASE_URL, echo=False)
+7 -2
View File
@@ -6,7 +6,7 @@ from dotenv import load_dotenv
from fastapi import FastAPI, File, Form, UploadFile from fastapi import FastAPI, File, Form, UploadFile
from pydantic import BaseModel from pydantic import BaseModel
from routers import companies, investors, projects from routers import companies, investors, projects
from schemas.router_schemas import InvestorList from schemas.router_schemas import InvestmentResponse, PaginatedResponse
from services.llm_parser import InvestorProcessor from services.llm_parser import InvestorProcessor
from services.querying import QueryProcessor from services.querying import QueryProcessor
@@ -84,11 +84,16 @@ async def parse_csv(
return results return results
@app.post("/query", response_model=InvestorList, tags=["Querying"]) @app.post(
"/query", response_model=PaginatedResponse[InvestmentResponse], tags=["Querying"]
)
async def query_investors(request: QueryRequest): async def query_investors(request: QueryRequest):
""" """
Query investors using natural language. Query investors using natural language.
Returns fund-level matches (one row per fund) with investor details.
This ensures only relevant funds are included in the response.
Supports queries like: Supports queries like:
- "Show me seed stage investors" - "Show me seed stage investors"
- "Find fintech investors in Silicon Valley" - "Find fintech investors in Silicon Valley"
Binary file not shown.
Binary file not shown.
+53 -14
View File
@@ -1,10 +1,10 @@
from typing import List, Optional from typing import Optional
from db.db import get_db from db.db import get_db
from db.models import CompanyTable, InvestorTable from db.models import CompanyTable, InvestorTable
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
from schemas.router_schemas import CompanyData from schemas.router_schemas import CompanyData, PaginatedResponse
from sqlalchemy.orm import Session, selectinload from sqlalchemy.orm import Session, selectinload
router = APIRouter(tags=["Company Routes"]) router = APIRouter(tags=["Company Routes"])
@@ -29,20 +29,34 @@ class CompanyUpdate(BaseModel):
website: Optional[str] = None website: Optional[str] = None
@router.get("/companies", response_model=List[CompanyData]) @router.get("/companies", response_model=PaginatedResponse[CompanyData])
def read_companies(db: Session = Depends(get_db)): def read_companies(
"""Get all companies with their investor relationships""" page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
db: Session = Depends(get_db),
):
"""Get all companies with their investor relationships (paginated)"""
# Calculate offset
offset = (page - 1) * page_size
# Get total count
total_count = (
db.query(CompanyTable)
.filter(CompanyTable.name.isnot(None), CompanyTable.description.isnot(None))
.count()
)
# Get paginated results
companies = ( companies = (
db.query(CompanyTable) db.query(CompanyTable)
.filter( .filter(CompanyTable.name.isnot(None), CompanyTable.description.isnot(None))
CompanyTable.name.isnot(None),
CompanyTable.description.isnot(None)
)
.options( .options(
selectinload(CompanyTable.investors), selectinload(CompanyTable.investors),
selectinload(CompanyTable.members), selectinload(CompanyTable.members),
selectinload(CompanyTable.sectors), selectinload(CompanyTable.sectors),
) )
.offset(offset)
.limit(page_size)
.all() .all()
) )
@@ -57,10 +71,19 @@ def read_companies(db: Session = Depends(get_db)):
) )
company_data_list.append(company_data) company_data_list.append(company_data)
return company_data_list # Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=company_data_list,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
@router.get("/companies/filter", response_model=List[CompanyData]) @router.get("/companies/filter", response_model=PaginatedResponse[CompanyData])
def filter_companies( def filter_companies(
industry: Optional[str] = Query( industry: Optional[str] = Query(
None, description="Filter by industry (partial match)" None, description="Filter by industry (partial match)"
@@ -76,9 +99,11 @@ def filter_companies(
investor_name: Optional[str] = Query( investor_name: Optional[str] = Query(
None, description="Filter by investor name (partial match)" None, description="Filter by investor name (partial match)"
), ),
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Filter companies based on various criteria""" """Filter companies based on various criteria (paginated)"""
# Start with base query # Start with base query
query = db.query(CompanyTable).options( query = db.query(CompanyTable).options(
@@ -112,7 +137,12 @@ def filter_companies(
InvestorTable.name.ilike(f"%{investor_name}%") InvestorTable.name.ilike(f"%{investor_name}%")
) )
companies = query.all() # Get total count before pagination
total_count = query.count()
# Calculate offset and apply pagination
offset = (page - 1) * page_size
companies = query.offset(offset).limit(page_size).all()
# Transform to CompanyData format # Transform to CompanyData format
company_data_list = [] company_data_list = []
@@ -125,7 +155,16 @@ def filter_companies(
) )
company_data_list.append(company_data) company_data_list.append(company_data)
return company_data_list # Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=company_data_list,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
@router.get("/companies/{company_id}", response_model=CompanyData) @router.get("/companies/{company_id}", response_model=CompanyData)
+272 -262
View File
@@ -1,13 +1,16 @@
from typing import List, Optional from typing import Optional
from db.db import get_db from db.db import get_db
from db.models import InvestorTable, SectorTable from db.models import FundTable, InvestorTable, SectorTable
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from pydantic import BaseModel from pydantic import BaseModel
from schemas.router_schemas import ( from schemas.router_schemas import (
CompanyMinimal,
InvestmentResponse,
InvestmentStage, InvestmentStage,
InvestorData, InvestorData,
InvestorFundData, PaginatedResponse,
SectorMinimal,
) )
from sqlalchemy.orm import Session, selectinload from sqlalchemy.orm import Session, selectinload
@@ -18,32 +21,45 @@ router = APIRouter(tags=["Investor Routes"])
class InvestorCreate(BaseModel): class InvestorCreate(BaseModel):
name: str name: str
description: Optional[str] = None description: Optional[str] = None
website: Optional[str] = None
headquarters: Optional[str] = None
aum: int aum: int
check_size_lower: int check_size_lower: int
check_size_upper: int check_size_upper: int
geographic_focus: str geographic_focus: str
stage_focus: InvestmentStage
number_of_investments: int = 0 number_of_investments: int = 0
class InvestorUpdate(BaseModel): class InvestorUpdate(BaseModel):
name: Optional[str] = None name: Optional[str] = None
description: Optional[str] = None description: Optional[str] = None
website: Optional[str] = None
headquarters: Optional[str] = None
aum: Optional[int] = None aum: Optional[int] = None
check_size_lower: Optional[int] = None check_size_lower: Optional[int] = None
check_size_upper: Optional[int] = None check_size_upper: Optional[int] = None
geographic_focus: Optional[str] = None geographic_focus: Optional[str] = None
stage_focus: Optional[InvestmentStage] = None
number_of_investments: Optional[int] = None number_of_investments: Optional[int] = None
@router.get("/investors", response_model=List[InvestorFundData]) @router.get("/investors", response_model=PaginatedResponse[InvestmentResponse])
def read_investors(db: Session = Depends(get_db)): def read_investors(
"""Get all investors with their funds as separate entries page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
db: Session = Depends(get_db),
):
"""Get all investors with their funds as separate entries (paginated)
Each investor-fund combination is returned as a separate row. Each investor-fund combination is returned as a separate row.
An investor with 3 funds will appear as 3 entries. An investor with 3 funds will appear as 3 entries.
""" """
# Calculate offset
offset = (page - 1) * page_size
# Get total count
total_count = db.query(InvestorTable).count()
# Get paginated results
investors = ( investors = (
db.query(InvestorTable) db.query(InvestorTable)
.options( .options(
@@ -52,80 +68,80 @@ def read_investors(db: Session = Depends(get_db)):
selectinload(InvestorTable.sectors), selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds), selectinload(InvestorTable.funds),
) )
.offset(offset)
.limit(page_size)
.all() .all()
) )
# Transform to InvestorFundData format (one row per investor-fund combination) # Transform to InvestmentResponse format (one row per investor-fund combination)
investor_fund_list = [] investment_responses = []
for investor in investors: for investor in investors:
# Get top 3 portfolio companies (id and name only)
portfolio_companies = [
CompanyMinimal(id=company.id, name=company.name)
for company in investor.portfolio_companies[:3]
]
# If investor has funds, create one entry per fund # If investor has funds, create one entry per fund
if investor.funds: if investor.funds:
for fund in investor.funds: for fund in investor.funds:
investor_fund_data = InvestorFundData( # Get stage focus as comma-separated string
# Investor fields stage_focus = (
investor_id=investor.id, ", ".join([stage.name for stage in fund.investment_stages])
investor_name=investor.name, if fund.investment_stages
investor_description=investor.description, else None
investor_website=investor.website, )
investor_headquarters=investor.headquarters,
# Get top 3 sectors from fund (id and name only)
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in (fund.sectors[:3] if fund.sectors else [])
]
investment_response = InvestmentResponse(
id=investor.id,
name=f"{investor.name} - {fund.fund_name}"
if fund.fund_name
else investor.name,
aum=investor.aum, aum=investor.aum,
aum_as_of_date=investor.aum_as_of_date,
aum_source_url=investor.aum_source_url,
investment_thesis=investor.investment_thesis,
portfolio_highlights=investor.portfolio_highlights,
number_of_investments=investor.number_of_investments,
# Fund fields
fund_id=fund.id,
fund_name=fund.fund_name,
fund_size=fund.fund_size,
fund_size_source_url=fund.fund_size_source_url,
check_size_lower=fund.check_size_lower, check_size_lower=fund.check_size_lower,
check_size_upper=fund.check_size_upper, check_size_upper=fund.check_size_upper,
geographic_focus=fund.geographic_focus, geographic_focus=fund.geographic_focus,
fund_investment_stages=fund.investment_stages, # Now a relationship stage_focus=stage_focus,
fund_sectors=fund.sectors, # Now a relationship portfolio_companies=portfolio_companies,
# Related data (same for all funds of this investor) sectors=fund_sectors,
portfolio_companies=investor.portfolio_companies, compatibility_score=1.0,
team_members=investor.team_members,
sectors=investor.sectors,
) )
investor_fund_list.append(investor_fund_data) investment_responses.append(investment_response)
else: else:
# If no funds, create one entry with null fund fields # If no funds, create one entry with null fund fields
investor_fund_data = InvestorFundData( investment_response = InvestmentResponse(
# Investor fields id=investor.id,
investor_id=investor.id, name=investor.name,
investor_name=investor.name,
investor_description=investor.description,
investor_website=investor.website,
investor_headquarters=investor.headquarters,
aum=investor.aum, aum=investor.aum,
aum_as_of_date=investor.aum_as_of_date,
aum_source_url=investor.aum_source_url,
investment_thesis=investor.investment_thesis,
portfolio_highlights=investor.portfolio_highlights,
number_of_investments=investor.number_of_investments,
# Fund fields (null)
fund_id=None,
fund_name=None,
fund_size=None,
fund_size_source_url=None,
check_size_lower=None, check_size_lower=None,
check_size_upper=None, check_size_upper=None,
geographic_focus=None, geographic_focus=None,
fund_investment_stages=None, stage_focus=None,
fund_sectors=None, portfolio_companies=portfolio_companies,
# Related data sectors=[],
portfolio_companies=investor.portfolio_companies, compatibility_score=1.0,
team_members=investor.team_members,
sectors=investor.sectors,
) )
investor_fund_list.append(investor_fund_data) investment_responses.append(investment_response)
return investor_fund_list # Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=investment_responses,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
@router.get("/investors/filter", response_model=List[InvestorFundData]) @router.get("/investors/filter", response_model=PaginatedResponse[InvestmentResponse])
def filter_investors( def filter_investors(
stage: Optional[InvestmentStage] = Query( stage: Optional[InvestmentStage] = Query(
None, description="Filter by investment stage" None, description="Filter by investment stage"
@@ -138,117 +154,109 @@ def filter_investors(
sector: Optional[str] = Query(None, description="Sector name (partial match)"), sector: Optional[str] = Query(None, description="Sector name (partial match)"),
min_aum: Optional[int] = Query(None, description="Minimum AUM"), min_aum: Optional[int] = Query(None, description="Minimum AUM"),
max_aum: Optional[int] = Query(None, description="Maximum AUM"), max_aum: Optional[int] = Query(None, description="Maximum AUM"),
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Filter investors based on various criteria """Filter investors based on various criteria (paginated)
Returns investor-fund combinations as separate rows. Returns investor-fund combinations as separate rows.
An investor with 3 funds will appear as 3 entries. Queries the funds table to find matching funds.
""" """
# Start with base query # Start with base query on funds table
query = db.query(InvestorTable).options( query = db.query(FundTable).options(
selectinload(InvestorTable.portfolio_companies), selectinload(FundTable.investor).selectinload(
selectinload(InvestorTable.team_members), InvestorTable.portfolio_companies
selectinload(InvestorTable.sectors), ),
selectinload(InvestorTable.funds), selectinload(FundTable.investor).selectinload(InvestorTable.team_members),
selectinload(FundTable.investor).selectinload(InvestorTable.sectors),
selectinload(FundTable.investment_stages),
selectinload(FundTable.sectors),
) )
# Apply filters # Apply filters at fund level
if stage:
query = query.filter(InvestorTable.stage_focus == stage)
if min_check_size is not None: if min_check_size is not None:
query = query.filter(InvestorTable.check_size_lower >= min_check_size) query = query.filter(FundTable.check_size_lower >= min_check_size)
if max_check_size is not None: if max_check_size is not None:
query = query.filter(InvestorTable.check_size_upper <= max_check_size) query = query.filter(FundTable.check_size_upper <= max_check_size)
if geography: if geography:
query = query.filter(InvestorTable.geographic_focus.ilike(f"%{geography}%")) query = query.filter(FundTable.geographic_focus.ilike(f"%{geography}%"))
# Apply filters at investor level (through relationship)
if min_aum is not None: if min_aum is not None:
query = query.filter(InvestorTable.aum >= min_aum) query = query.join(FundTable.investor).filter(InvestorTable.aum >= min_aum)
if max_aum is not None: if max_aum is not None:
if min_aum is None: # Only join if not already joined
query = query.join(FundTable.investor)
query = query.filter(InvestorTable.aum <= max_aum) query = query.filter(InvestorTable.aum <= max_aum)
# Filter by sector if provided # Filter by sector if provided (at fund level)
if sector: if sector:
query = query.join(InvestorTable.sectors).filter( query = query.join(FundTable.sectors).filter(
SectorTable.name.ilike(f"%{sector}%") SectorTable.name.ilike(f"%{sector}%")
) )
investors = query.all() # Get total count before pagination
total_count = query.count()
# Transform to InvestorFundData format (one row per investor-fund combination) # Calculate offset and apply pagination
investor_fund_list = [] offset = (page - 1) * page_size
for investor in investors: funds = query.offset(offset).limit(page_size).all()
# If investor has funds, create one entry per fund
if investor.funds: # Transform to InvestmentResponse format (one row per fund)
for fund in investor.funds: investment_responses = []
investor_fund_data = InvestorFundData( for fund in funds:
# Investor fields investor = fund.investor
investor_id=investor.id,
investor_name=investor.name, # Get top 3 portfolio companies (id and name only)
investor_description=investor.description, portfolio_companies = [
investor_website=investor.website, CompanyMinimal(id=company.id, name=company.name)
investor_headquarters=investor.headquarters, for company in investor.portfolio_companies[:3]
]
# Get stage focus as comma-separated string
stage_focus = (
", ".join([stage.name for stage in fund.investment_stages])
if fund.investment_stages
else None
)
# Get top 3 sectors from fund (id and name only)
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in (fund.sectors[:3] if fund.sectors else [])
]
investment_response = InvestmentResponse(
id=investor.id,
name=f"{investor.name} - {fund.fund_name}"
if fund.fund_name
else investor.name,
aum=investor.aum, aum=investor.aum,
aum_as_of_date=investor.aum_as_of_date,
aum_source_url=investor.aum_source_url,
investment_thesis=investor.investment_thesis,
portfolio_highlights=investor.portfolio_highlights,
number_of_investments=investor.number_of_investments,
# Fund fields
fund_id=fund.id,
fund_name=fund.fund_name,
fund_size=fund.fund_size,
fund_size_source_url=fund.fund_size_source_url,
check_size_lower=fund.check_size_lower, check_size_lower=fund.check_size_lower,
check_size_upper=fund.check_size_upper, check_size_upper=fund.check_size_upper,
geographic_focus=fund.geographic_focus, geographic_focus=fund.geographic_focus,
fund_investment_stages=fund.investment_stages, # Now a relationship stage_focus=stage_focus,
fund_sectors=fund.sectors, # Now a relationship portfolio_companies=portfolio_companies,
# Related data sectors=fund_sectors,
portfolio_companies=investor.portfolio_companies, compatibility_score=1.0,
team_members=investor.team_members,
sectors=investor.sectors,
) )
investor_fund_list.append(investor_fund_data) investment_responses.append(investment_response)
else:
# If no funds, create one entry with null fund fields
investor_fund_data = InvestorFundData(
# Investor fields
investor_id=investor.id,
investor_name=investor.name,
investor_description=investor.description,
investor_website=investor.website,
investor_headquarters=investor.headquarters,
aum=investor.aum,
aum_as_of_date=investor.aum_as_of_date,
aum_source_url=investor.aum_source_url,
investment_thesis=investor.investment_thesis,
portfolio_highlights=investor.portfolio_highlights,
number_of_investments=investor.number_of_investments,
# Fund fields (null)
fund_id=None,
fund_name=None,
fund_size=None,
fund_size_source_url=None,
check_size_lower=None,
check_size_upper=None,
geographic_focus=None,
fund_investment_stages=None,
fund_sectors=None,
# Related data
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members,
sectors=investor.sectors,
)
investor_fund_list.append(investor_fund_data)
return investor_fund_list # Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=investment_responses,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
@router.get("/investors/{investor_id}", response_model=InvestorData) @router.get("/investors/{investor_id}", response_model=InvestorData)
@@ -365,25 +373,32 @@ def delete_investor(investor_id: int, db: Session = Depends(get_db)):
return {"message": "Investor deleted successfully"} return {"message": "Investor deleted successfully"}
@router.get("/investors/{investor_id}/similar", response_model=List[InvestorFundData]) @router.get(
"/investors/{investor_id}/similar",
response_model=PaginatedResponse[InvestmentResponse],
)
def find_similar_investors( def find_similar_investors(
investor_id: int, investor_id: int,
limit: int = Query(10, description="Maximum number of similar investors to return"), limit: int = Query(10, description="Maximum number of similar investors to return"),
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Find investors similar to a given investor based on characteristics """Find investors similar to a given investor based on characteristics (paginated)
Returns investor-fund combinations as separate rows. Returns investor-fund combinations as separate rows.
Queries the funds table to find matching funds.
""" """
# Get the target investor # Get the target investor to get their funds for comparison
target_investor = ( target_investor = (
db.query(InvestorTable) db.query(InvestorTable)
.options( .options(
selectinload(InvestorTable.portfolio_companies), selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members), selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors), selectinload(InvestorTable.sectors),
selectinload(InvestorTable.funds), selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages),
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
) )
.filter(InvestorTable.id == investor_id) .filter(InvestorTable.id == investor_id)
.first() .first()
@@ -392,154 +407,149 @@ def find_similar_investors(
if not target_investor: if not target_investor:
raise HTTPException(status_code=404, detail="Investor not found") raise HTTPException(status_code=404, detail="Investor not found")
# Get target investor's sector IDs for comparison # Get target investor's sector IDs for comparison (from their funds)
target_sector_ids = {sector.id for sector in target_investor.sectors} target_sector_ids = set()
target_stage_ids = set()
target_check_ranges = []
target_geographies = []
# Query all other investors with their relationships for fund in target_investor.funds:
candidates = ( if fund.sectors:
db.query(InvestorTable) target_sector_ids.update({sector.id for sector in fund.sectors})
if fund.investment_stages:
target_stage_ids.update({stage.id for stage in fund.investment_stages})
if fund.check_size_lower and fund.check_size_upper:
target_check_ranges.append((fund.check_size_lower, fund.check_size_upper))
if fund.geographic_focus:
target_geographies.append(fund.geographic_focus.lower())
# Query all funds from other investors
candidate_funds = (
db.query(FundTable)
.options( .options(
selectinload(InvestorTable.portfolio_companies), selectinload(FundTable.investor).selectinload(
selectinload(InvestorTable.team_members), InvestorTable.portfolio_companies
selectinload(InvestorTable.sectors), ),
selectinload(InvestorTable.funds), selectinload(FundTable.investor).selectinload(InvestorTable.team_members),
selectinload(FundTable.investor).selectinload(InvestorTable.sectors),
selectinload(FundTable.investment_stages),
selectinload(FundTable.sectors),
) )
.join(FundTable.investor)
.filter(InvestorTable.id != investor_id) .filter(InvestorTable.id != investor_id)
.all() .all()
) )
# Calculate similarity scores # Calculate similarity scores for each fund
scored_investors = [] scored_funds = []
for candidate in candidates: for fund in candidate_funds:
score = 0 score = 0
# Stage focus match (30 points)
if candidate.stage_focus == target_investor.stage_focus:
score += 30
# Geographic focus match (20 points for exact, 10 for partial) # Geographic focus match (20 points for exact, 10 for partial)
if candidate.geographic_focus and target_investor.geographic_focus: if fund.geographic_focus and target_geographies:
if ( fund_geo_lower = fund.geographic_focus.lower()
candidate.geographic_focus.lower() for target_geo in target_geographies:
== target_investor.geographic_focus.lower() if fund_geo_lower == target_geo:
):
score += 20 score += 20
elif ( break
candidate.geographic_focus.lower() elif fund_geo_lower in target_geo or target_geo in fund_geo_lower:
in target_investor.geographic_focus.lower()
or target_investor.geographic_focus.lower()
in candidate.geographic_focus.lower()
):
score += 10 score += 10
break
# Check size overlap (20 points max) # Check size overlap (20 points max)
if ( if fund.check_size_lower and fund.check_size_upper and target_check_ranges:
candidate.check_size_lower max_overlap_score = 0
and candidate.check_size_upper for target_lower, target_upper in target_check_ranges:
and target_investor.check_size_lower overlap_start = max(fund.check_size_lower, target_lower)
and target_investor.check_size_upper overlap_end = min(fund.check_size_upper, target_upper)
):
# Calculate overlap percentage
overlap_start = max(
candidate.check_size_lower, target_investor.check_size_lower
)
overlap_end = min(
candidate.check_size_upper, target_investor.check_size_upper
)
if overlap_end > overlap_start: if overlap_end > overlap_start:
overlap = overlap_end - overlap_start overlap = overlap_end - overlap_start
target_range = ( target_range = target_upper - target_lower
target_investor.check_size_upper - target_investor.check_size_lower
)
overlap_ratio = overlap / target_range if target_range > 0 else 0 overlap_ratio = overlap / target_range if target_range > 0 else 0
score += int(20 * overlap_ratio) max_overlap_score = max(max_overlap_score, int(20 * overlap_ratio))
score += max_overlap_score
# AUM similarity (15 points max) # AUM similarity (15 points max)
if candidate.aum and target_investor.aum: if fund.investor.aum and target_investor.aum:
aum_diff = abs(candidate.aum - target_investor.aum) aum_diff = abs(fund.investor.aum - target_investor.aum)
max_aum = max(candidate.aum, target_investor.aum) max_aum = max(fund.investor.aum, target_investor.aum)
similarity_ratio = 1 - (aum_diff / max_aum) if max_aum > 0 else 0 similarity_ratio = 1 - (aum_diff / max_aum) if max_aum > 0 else 0
score += int(15 * similarity_ratio) score += int(15 * similarity_ratio)
# Sector overlap (30 points max) # Sector overlap (30 points max)
candidate_sector_ids = {sector.id for sector in candidate.sectors} if fund.sectors and target_sector_ids:
if target_sector_ids and candidate_sector_ids: fund_sector_ids = {sector.id for sector in fund.sectors}
common_sectors = target_sector_ids.intersection(candidate_sector_ids) common_sectors = target_sector_ids.intersection(fund_sector_ids)
overlap_ratio = len(common_sectors) / len(target_sector_ids) overlap_ratio = len(common_sectors) / len(target_sector_ids)
score += int(30 * overlap_ratio) score += int(30 * overlap_ratio)
if score > 0: # Only include investors with some similarity # Investment stage match (15 points max)
scored_investors.append((score, candidate)) if fund.investment_stages and target_stage_ids:
fund_stage_ids = {stage.id for stage in fund.investment_stages}
common_stages = target_stage_ids.intersection(fund_stage_ids)
overlap_ratio = len(common_stages) / len(target_stage_ids)
score += int(15 * overlap_ratio)
# Sort by score (descending) and take top N if score > 0: # Only include funds with some similarity
scored_investors.sort(key=lambda x: x[0], reverse=True) scored_funds.append((score, fund))
similar_investors = [inv for score, inv in scored_investors[:limit]]
# Transform to InvestorFundData format (one row per investor-fund combination) # Sort by score (descending) and take top N based on limit
investor_fund_list = [] scored_funds.sort(key=lambda x: x[0], reverse=True)
for investor in similar_investors: top_similar = scored_funds[:limit]
# If investor has funds, create one entry per fund
if investor.funds: # Apply pagination to the top similar funds
for fund in investor.funds: total_count = len(top_similar)
investor_fund_data = InvestorFundData( offset = (page - 1) * page_size
# Investor fields paginated_similar = top_similar[offset : offset + page_size]
investor_id=investor.id, similar_funds = [fund for score, fund in paginated_similar]
investor_name=investor.name,
investor_description=investor.description, # Transform to InvestmentResponse format (one row per fund)
investor_website=investor.website, investment_responses = []
investor_headquarters=investor.headquarters, for fund in similar_funds:
investor = fund.investor
# Get top 3 portfolio companies (id and name only)
portfolio_companies = [
CompanyMinimal(id=company.id, name=company.name)
for company in investor.portfolio_companies[:3]
]
# Get stage focus as comma-separated string
stage_focus = (
", ".join([stage.name for stage in fund.investment_stages])
if fund.investment_stages
else None
)
# Get top 3 sectors from fund (id and name only)
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in (fund.sectors[:3] if fund.sectors else [])
]
investment_response = InvestmentResponse(
id=investor.id,
name=f"{investor.name} - {fund.fund_name}"
if fund.fund_name
else investor.name,
aum=investor.aum, aum=investor.aum,
aum_as_of_date=investor.aum_as_of_date,
aum_source_url=investor.aum_source_url,
investment_thesis=investor.investment_thesis,
portfolio_highlights=investor.portfolio_highlights,
number_of_investments=investor.number_of_investments,
# Fund fields
fund_id=fund.id,
fund_name=fund.fund_name,
fund_size=fund.fund_size,
fund_size_source_url=fund.fund_size_source_url,
check_size_lower=fund.check_size_lower, check_size_lower=fund.check_size_lower,
check_size_upper=fund.check_size_upper, check_size_upper=fund.check_size_upper,
geographic_focus=fund.geographic_focus, geographic_focus=fund.geographic_focus,
fund_investment_stages=fund.investment_stages, # Now a relationship stage_focus=stage_focus,
fund_sectors=fund.sectors, # Now a relationship portfolio_companies=portfolio_companies,
# Related data sectors=fund_sectors,
portfolio_companies=investor.portfolio_companies, compatibility_score=1.0,
team_members=investor.team_members,
sectors=investor.sectors,
) )
investor_fund_list.append(investor_fund_data) investment_responses.append(investment_response)
else:
# If no funds, create one entry with null fund fields
investor_fund_data = InvestorFundData(
# Investor fields
investor_id=investor.id,
investor_name=investor.name,
investor_description=investor.description,
investor_website=investor.website,
investor_headquarters=investor.headquarters,
aum=investor.aum,
aum_as_of_date=investor.aum_as_of_date,
aum_source_url=investor.aum_source_url,
investment_thesis=investor.investment_thesis,
portfolio_highlights=investor.portfolio_highlights,
number_of_investments=investor.number_of_investments,
# Fund fields (null)
fund_id=None,
fund_name=None,
fund_size=None,
fund_size_source_url=None,
check_size_lower=None,
check_size_upper=None,
geographic_focus=None,
fund_investment_stages=None,
fund_sectors=None,
# Related data
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members,
sectors=investor.sectors,
)
investor_fund_list.append(investor_fund_data)
return investor_fund_list # Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=investment_responses,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
+47 -8
View File
@@ -14,14 +14,26 @@ from schemas.project_schemas import (
ProjectData, ProjectData,
ProjectUpdate, ProjectUpdate,
) )
from schemas.router_schemas import PaginatedResponse
from sqlalchemy.orm import Session, selectinload from sqlalchemy.orm import Session, selectinload
router = APIRouter(tags=["Project Routes"]) router = APIRouter(tags=["Project Routes"])
@router.get("/projects", response_model=List[ProjectData]) @router.get("/projects", response_model=PaginatedResponse[ProjectData])
def read_projects(db: Session = Depends(get_db)): def read_projects(
"""Get all projects with their related data""" page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
db: Session = Depends(get_db),
):
"""Get all projects with their related data (paginated)"""
# Calculate offset
offset = (page - 1) * page_size
# Get total count
total_count = db.query(ProjectTable).count()
# Get paginated results
projects = ( projects = (
db.query(ProjectTable) db.query(ProjectTable)
.options( .options(
@@ -29,6 +41,8 @@ def read_projects(db: Session = Depends(get_db)):
selectinload(ProjectTable.investors), selectinload(ProjectTable.investors),
selectinload(ProjectTable.companies), selectinload(ProjectTable.companies),
) )
.offset(offset)
.limit(page_size)
.all() .all()
) )
@@ -43,7 +57,16 @@ def read_projects(db: Session = Depends(get_db)):
) )
project_data_list.append(project_data) project_data_list.append(project_data)
return project_data_list # Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=project_data_list,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
@router.get("/projects/{project_id}", response_model=ProjectData) @router.get("/projects/{project_id}", response_model=ProjectData)
@@ -151,7 +174,7 @@ def delete_project(project_id: int, db: Session = Depends(get_db)):
return {"message": "Project deleted successfully"} return {"message": "Project deleted successfully"}
@router.get("/projects/filter", response_model=List[ProjectData]) @router.get("/projects/filter", response_model=PaginatedResponse[ProjectData])
def filter_projects( def filter_projects(
stage: Optional[InvestmentStage] = Query( stage: Optional[InvestmentStage] = Query(
None, description="Filter by project stage" None, description="Filter by project stage"
@@ -166,9 +189,11 @@ def filter_projects(
company_name: Optional[str] = Query( company_name: Optional[str] = Query(
None, description="Company name (partial match)" None, description="Company name (partial match)"
), ),
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
db: Session = Depends(get_db), db: Session = Depends(get_db),
): ):
"""Filter projects based on various criteria""" """Filter projects based on various criteria (paginated)"""
# Start with base query # Start with base query
query = db.query(ProjectTable).options( query = db.query(ProjectTable).options(
@@ -205,7 +230,12 @@ def filter_projects(
CompanyTable.name.ilike(f"%{company_name}%") CompanyTable.name.ilike(f"%{company_name}%")
) )
projects = query.all() # Get total count before pagination
total_count = query.count()
# Calculate offset and apply pagination
offset = (page - 1) * page_size
projects = query.offset(offset).limit(page_size).all()
# Transform to ProjectData format # Transform to ProjectData format
project_data_list = [] project_data_list = []
@@ -218,7 +248,16 @@ def filter_projects(
) )
project_data_list.append(project_data) project_data_list.append(project_data)
return project_data_list # Calculate total pages
total_pages = (total_count + page_size - 1) // page_size
return PaginatedResponse(
items=project_data_list,
total=total_count,
page=page,
page_size=page_size,
total_pages=total_pages,
)
# Association management routes # Association management routes
+93 -8
View File
@@ -1,9 +1,12 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import List, Optional from typing import Any, Generic, List, Optional, TypeVar
from pydantic import BaseModel from pydantic import BaseModel
# Generic type for pagination
T = TypeVar("T")
class InvestmentStage(str, Enum): class InvestmentStage(str, Enum):
SEED = "SEED" SEED = "SEED"
@@ -89,11 +92,20 @@ class InvestorSchema(BaseModel):
id: int id: int
name: str name: str
description: Optional[str] description: Optional[str]
website: Optional[str] = None
headquarters: Optional[str] = None
aum: int | None aum: int | None
aum_as_of_date: str | None = None
aum_source_url: str | None = None
check_size_lower: int | None check_size_lower: int | None
check_size_upper: int | None check_size_upper: int | None
geographic_focus: str | None geographic_focus: str | None
stage_focus: InvestmentStage investment_thesis: Any = (
None # Flexible JSON field - can be list, dict, or list of dicts
)
portfolio_highlights: Any = (
None # Flexible JSON field - can be list, dict, or list of dicts
)
number_of_investments: int | None number_of_investments: int | None
created_at: Optional[datetime] = None created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None updated_at: Optional[datetime] = None
@@ -131,8 +143,8 @@ class InvestorFundData(BaseModel):
aum: int | None aum: int | None
aum_as_of_date: str | None aum_as_of_date: str | None
aum_source_url: str | None aum_source_url: str | None
investment_thesis: List[str] | None investment_thesis: Any = None # Flexible JSON field
portfolio_highlights: List[str] | None portfolio_highlights: Any = None # Flexible JSON field
number_of_investments: int | None number_of_investments: int | None
# Fund fields # Fund fields
@@ -156,12 +168,29 @@ class InvestorFundData(BaseModel):
class Config: class Config:
from_attributes = True from_attributes = True
class InvestorMinimal(BaseModel):
"""Minimal investor info with just id and name"""
id: int
name: str
class Config:
from_attributes = True
class CompanySchemaMinimal(BaseModel):
id: int
name: str
industry: str | None
location: str | None
founded_year: Optional[int]
website: Optional[str]
class Config:
from_attributes = True
class CompanyData(BaseModel): # Renamed from CompaniesData for consistency class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
company: CompanySchema company: CompanySchemaMinimal
sectors: List[SectorSchema] investors: List[InvestorMinimal]
members: List[CompanyMemberSchema]
investors: List[InvestorSchema]
class Config: class Config:
from_attributes = True from_attributes = True
@@ -175,3 +204,59 @@ class InvestorFundList(BaseModel):
"""List of investor-fund combinations""" """List of investor-fund combinations"""
investor_funds: List[InvestorFundData] investor_funds: List[InvestorFundData]
class CompanyMinimal(BaseModel):
"""Minimal company info with just id and name"""
id: int
name: str
class Config:
from_attributes = True
class SectorMinimal(BaseModel):
"""Minimal sector info with just id and name"""
id: int
name: str
class Config:
from_attributes = True
class InvestmentResponse(BaseModel):
"""Simplified investment response schema
One row per investor-fund combination with streamlined data
"""
id: int # Investor ID
name: (
str # Combination of investor name and fund name (e.g., "Investor A - Fund A")
)
aum: int | None # From investor
check_size_lower: int | None # From fund
check_size_upper: int | None # From fund
geographic_focus: str | None # From fund
stage_focus: str | None # Comma-separated stages from fund
portfolio_companies: List[CompanyMinimal] # Top 3 companies from investor
sectors: List[SectorMinimal] # Top 3 sectors from fund
compatibility_score: float # 0 to 1 (default 1 for now)
class Config:
from_attributes = True
class PaginatedResponse(BaseModel, Generic[T]):
"""Generic paginated response schema"""
items: List[T]
total: int
page: int
page_size: int
total_pages: int
class Config:
from_attributes = True
Binary file not shown.
Binary file not shown.
View File
View File
View File
+166 -167
View File
@@ -1,3 +1,4 @@
import asyncio
import json import json
import os import os
import re import re
@@ -265,37 +266,20 @@ Return the lower and upper bounds in USD."""
) -> Optional[dict]: ) -> Optional[dict]:
""" """
Process company profile from CSV data. Process company profile from CSV data.
Manually extracts fields without using LLM. Only extracts founded_year and key_executives - rest is in base database.
""" """
profile = self.parse_json_profile(profile_json) profile = self.parse_json_profile(profile_json)
if not profile: if not profile:
return None return None
try: try:
# Extract basic info # Only extract founded_year and key_executives
company_data = { company_data = {
"name": name.strip() if name else None, "name": name.strip() if name else None,
"website": website.strip() if website else None, "founded_year": None,
"description": profile.get("companyDescription"),
"location": profile.get("geographicFocus"),
"industry": profile.get("sectorDescription"),
"founded_year": None, # Not typically in the company JSON
"key_executives": [], "key_executives": [],
"client_categories": profile.get("clientCategories", []),
"product_description": profile.get("productDescription"),
"linked_documents": profile.get("linkedDocuments", []),
"researcher_notes": profile.get("researcherNotes"),
"missing_important_fields": profile.get("missingImportantFields", []),
"sources": profile.get("sources", {}),
"investor_names": [],
} }
# Parse investor names from the Investor column
if investor_names and pd.notna(investor_names):
# Split by comma and clean
investors = [inv.strip() for inv in str(investor_names).split(",")]
company_data["investor_names"] = [inv for inv in investors if inv]
# Process key executives/leadership # Process key executives/leadership
key_executives = profile.get("keyExecutives", []) key_executives = profile.get("keyExecutives", [])
if not key_executives: if not key_executives:
@@ -313,7 +297,7 @@ Return the lower and upper bounds in USD."""
) )
# Try to extract founding year from description # Try to extract founding year from description
description = company_data.get("description", "") description = profile.get("companyDescription", "")
if description: if description:
# Look for patterns like "founded in 2020", "Gegründet 2020", "founded 2020" # Look for patterns like "founded in 2020", "Gegründet 2020", "founded 2020"
year_patterns = [ year_patterns = [
@@ -344,40 +328,27 @@ Return the lower and upper bounds in USD."""
def _save_parsed_company_to_db( def _save_parsed_company_to_db(
self, db: Session, company_data: dict self, db: Session, company_data: dict
) -> Optional[CompanyTable]: ) -> Optional[CompanyTable]:
"""Save manually parsed company data to database""" """Save manually parsed company data to database - only updates founded_year and key_executives"""
try: try:
# Check if company already exists # Check if company already exists (should exist in base database)
existing_company = ( existing_company = (
db.query(CompanyTable).filter_by(name=company_data["name"]).first() db.query(CompanyTable).filter_by(name=company_data["name"]).first()
) )
if existing_company: if existing_company:
# Update existing company # Update only founded_year on existing company
company = existing_company company = existing_company
company.website = company_data.get("website") or company.website
company.location = company_data.get("location") or company.location
company.description = (
company_data.get("description") or company.description
)
company.industry = company_data.get("industry") or company.industry
if company_data.get("founded_year"): if company_data.get("founded_year"):
company.founded_year = company_data["founded_year"] company.founded_year = company_data["founded_year"]
else: else:
# Create new company # Company should already be in base database, but if not found, skip
company = CompanyTable( print(
name=company_data["name"], f"⚠️ Company '{company_data['name']}' not found in base database - skipping"
website=company_data.get("website"),
location=company_data.get("location"),
description=company_data.get("description"),
industry=company_data.get("industry"),
founded_year=company_data.get("founded_year"),
) )
db.add(company) return None
db.flush()
# Add/update company members (key executives) # Add/update company members (key executives)
# First, remove existing members if updating # First, remove existing members if updating
if existing_company:
db.query(CompanyMember).filter_by(company_id=company.id).delete() db.query(CompanyMember).filter_by(company_id=company.id).delete()
for exec_data in company_data.get("key_executives", []): for exec_data in company_data.get("key_executives", []):
@@ -391,19 +362,6 @@ Return the lower and upper bounds in USD."""
) )
db.add(member) db.add(member)
# Link to investors if provided
for investor_name in company_data.get("investor_names", []):
# Find investor in database
investor = (
db.query(InvestorTable)
.filter_by(name=investor_name.strip())
.first()
)
if investor:
# Add company to investor's portfolio if not already there
if company not in investor.portfolio_companies:
investor.portfolio_companies.append(company)
return company return company
except Exception as e: except Exception as e:
@@ -692,31 +650,14 @@ Return the lower and upper bounds in USD."""
print(f"Error processing row {row_idx + 1}: {e}") print(f"Error processing row {row_idx + 1}: {e}")
return None return None
async def parse_investors(self, df: pd.DataFrame, save_to_db: bool = True): async def _process_single_investor(
""" self, idx: int, row: pd.Series, total_rows: int
Parse investors from DataFrame using manual JSON parsing and LLM for currency conversion. ) -> Optional[dict]:
Expected CSV columns: Name, Website, Final Investor Profile, Final Profile sourcing """Process a single investor row"""
"""
results = []
db = None
if save_to_db:
db = get_db_session()
try: try:
total_rows = len(df) name = row.get("Name", "").strip() if pd.notna(row.get("Name")) else None
print(f"\n🚀 Starting to process {total_rows} investors...")
for idx, row in df.iterrows():
try:
name = (
row.get("Name", "").strip()
if pd.notna(row.get("Name"))
else None
)
website = ( website = (
row.get("Website", "").strip() row.get("Website", "").strip() if pd.notna(row.get("Website")) else None
if pd.notna(row.get("Website"))
else None
) )
profile_json = ( profile_json = (
row.get("Final Investor Profile", "") row.get("Final Investor Profile", "")
@@ -726,9 +667,9 @@ Return the lower and upper bounds in USD."""
if not name or not profile_json: if not name or not profile_json:
print(f"⚠️ Row {idx + 1}: Skipping - missing name or profile") print(f"⚠️ Row {idx + 1}: Skipping - missing name or profile")
continue return None
print(f"\n📊 Processing {idx + 1}/{total_rows}: {name}") print(f"📊 Processing {idx + 1}/{total_rows}: {name}")
# Process the investor profile # Process the investor profile
investor_data = await self.process_investor_profile( investor_data = await self.process_investor_profile(
@@ -736,18 +677,61 @@ Return the lower and upper bounds in USD."""
) )
if investor_data: if investor_data:
print(f"{name} parsed successfully")
return investor_data
else:
print(f" ⚠️ {name} failed to process")
return None
except Exception as e:
print(f"❌ Error processing row {idx + 1}: {e}")
return None
async def parse_investors(
self, df: pd.DataFrame, save_to_db: bool = True, batch_size: int = 10
):
"""
Parse investors from DataFrame using manual JSON parsing and LLM for currency conversion.
Processes multiple investors concurrently for better performance.
Expected CSV columns: Name, Website, Final Investor Profile, Final Profile sourcing
Args:
df: DataFrame with investor data
save_to_db: Whether to save to database
batch_size: Number of investors to process concurrently (default: 10)
"""
results = []
db = None
if save_to_db:
db = get_db_session()
try:
total_rows = len(df)
print(
f"\n🚀 Starting to process {total_rows} investors with batch size {batch_size}..."
)
# Process in batches
for batch_start in range(0, total_rows, batch_size):
batch_end = min(batch_start + batch_size, total_rows)
print(
f"\n🔄 Processing batch {batch_start + 1}-{batch_end} of {total_rows}..."
)
# Create tasks for concurrent processing
tasks = []
for idx in range(batch_start, batch_end):
row = df.iloc[idx]
task = self._process_single_investor(idx, row, total_rows)
tasks.append(task)
# Process batch concurrently
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
# Filter out None results and exceptions, then save to database
for investor_data in batch_results:
if investor_data and not isinstance(investor_data, Exception):
results.append(investor_data) results.append(investor_data)
print(" ✓ Parsed successfully")
print(f" - HQ: {investor_data.get('headquarters')}")
print(
f" - AUM: ${investor_data.get('aum'):,}"
if investor_data.get("aum")
else " - AUM: Not Available"
)
print(f" - Funds: {len(investor_data.get('funds', []))}")
print(
f" - Team: {len(investor_data.get('team_members', []))}"
)
# Save to database # Save to database
if save_to_db and db: if save_to_db and db:
@@ -756,33 +740,29 @@ Return the lower and upper bounds in USD."""
db, investor_data db, investor_data
) )
if saved_investor: if saved_investor:
db.commit()
print( print(
f" ✅ Saved to database (ID: {saved_investor.id})" f" ✅ Saved {investor_data['name']} to database (ID: {saved_investor.id})"
) )
else: else:
print(" ❌ Failed to save to database") print(
f" ❌ Failed to save {investor_data['name']} to database"
)
except Exception as e: except Exception as e:
db.rollback() db.rollback()
print(f" ❌ Database error: {e}") print(
else: f" ❌ Database error for {investor_data['name']}: {e}"
print(" ⚠️ Failed to process profile") )
elif isinstance(investor_data, Exception):
print(f" ❌ Exception occurred: {investor_data}")
# Commit every 10 investors to avoid memory issues # Commit batch to database
if save_to_db and db and (idx + 1) % 10 == 0:
db.commit()
print(f"\n💾 Committed batch at row {idx + 1}")
except Exception as e:
print(f"❌ Error processing row {idx + 1}: {e}")
if db:
db.rollback()
continue
# Final commit
if save_to_db and db: if save_to_db and db:
try:
db.commit() db.commit()
print("\n✅ Final commit completed") print(f"💾 Committed batch {batch_start + 1}-{batch_end}")
except Exception as e:
db.rollback()
print(f"❌ Failed to commit batch: {e}")
except Exception as e: except Exception as e:
print(f"❌ Fatal error in parse_investors: {e}") print(f"❌ Fatal error in parse_investors: {e}")
@@ -795,31 +775,14 @@ Return the lower and upper bounds in USD."""
print(f"\n🎉 Completed! Processed {len(results)}/{total_rows} investors") print(f"\n🎉 Completed! Processed {len(results)}/{total_rows} investors")
return results return results
async def parse_companies(self, df: pd.DataFrame, save_to_db: bool = True): async def _process_single_company(
""" self, idx: int, row: pd.Series, total_rows: int
Parse companies from DataFrame using manual JSON parsing. ) -> Optional[dict]:
Expected CSV columns: Name, Website, Investor, Final Investor Profile (actually company profile) """Process a single company row"""
"""
results = []
db = None
if save_to_db:
db = get_db_session()
try: try:
total_rows = len(df) name = row.get("Name", "").strip() if pd.notna(row.get("Name")) else None
print(f"\n🚀 Starting to process {total_rows} companies...")
for idx, row in df.iterrows():
try:
name = (
row.get("Name", "").strip()
if pd.notna(row.get("Name"))
else None
)
website = ( website = (
row.get("Website", "").strip() row.get("Website", "").strip() if pd.notna(row.get("Website")) else None
if pd.notna(row.get("Website"))
else None
) )
investor_names = ( investor_names = (
row.get("Investor", "").strip() row.get("Investor", "").strip()
@@ -834,9 +797,9 @@ Return the lower and upper bounds in USD."""
if not name or not profile_json: if not name or not profile_json:
print(f"⚠️ Row {idx + 1}: Skipping - missing name or profile") print(f"⚠️ Row {idx + 1}: Skipping - missing name or profile")
continue return None
print(f"\n📊 Processing {idx + 1}/{total_rows}: {name}") print(f"📊 Processing {idx + 1}/{total_rows}: {name}")
# Process the company profile # Process the company profile
company_data = await self.process_company_profile( company_data = await self.process_company_profile(
@@ -844,21 +807,61 @@ Return the lower and upper bounds in USD."""
) )
if company_data: if company_data:
print(f"{name} parsed successfully")
return company_data
else:
print(f" ⚠️ {name} failed to process")
return None
except Exception as e:
print(f"❌ Error processing row {idx + 1}: {e}")
return None
async def parse_companies(
self, df: pd.DataFrame, save_to_db: bool = True, batch_size: int = 10
):
"""
Parse companies from DataFrame using manual JSON parsing.
Processes multiple companies concurrently for better performance.
Expected CSV columns: Name, Website, Investor, Final Investor Profile (actually company profile)
Args:
df: DataFrame with company data
save_to_db: Whether to save to database
batch_size: Number of companies to process concurrently (default: 10)
"""
results = []
db = None
if save_to_db:
db = get_db_session()
try:
total_rows = len(df)
print(
f"\n🚀 Starting to process {total_rows} companies with batch size {batch_size}..."
)
# Process in batches
for batch_start in range(0, total_rows, batch_size):
batch_end = min(batch_start + batch_size, total_rows)
print(
f"\n🔄 Processing batch {batch_start + 1}-{batch_end} of {total_rows}..."
)
# Create tasks for concurrent processing
tasks = []
for idx in range(batch_start, batch_end):
row = df.iloc[idx]
task = self._process_single_company(idx, row, total_rows)
tasks.append(task)
# Process batch concurrently
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
# Filter out None results and exceptions, then save to database
for company_data in batch_results:
if company_data and not isinstance(company_data, Exception):
results.append(company_data) results.append(company_data)
print(" ✓ Parsed successfully")
print(f" - Location: {company_data.get('location')}")
print(f" - Industry: {company_data.get('industry')}")
print(
f" - Founded: {company_data.get('founded_year')}"
if company_data.get("founded_year")
else " - Founded: Unknown"
)
print(
f" - Executives: {len(company_data.get('key_executives', []))}"
)
print(
f" - Investors: {len(company_data.get('investor_names', []))}"
)
# Save to database # Save to database
if save_to_db and db: if save_to_db and db:
@@ -867,33 +870,29 @@ Return the lower and upper bounds in USD."""
db, company_data db, company_data
) )
if saved_company: if saved_company:
db.commit()
print( print(
f" ✅ Saved to database (ID: {saved_company.id})" f" ✅ Saved {company_data['name']} to database (ID: {saved_company.id})"
) )
else: else:
print(" ❌ Failed to save to database") print(
f" ❌ Failed to save {company_data['name']} to database"
)
except Exception as e: except Exception as e:
db.rollback() db.rollback()
print(f" ❌ Database error: {e}") print(
else: f" ❌ Database error for {company_data['name']}: {e}"
print(" ⚠️ Failed to process profile") )
elif isinstance(company_data, Exception):
print(f" ❌ Exception occurred: {company_data}")
# Commit every 10 companies to avoid memory issues # Commit batch to database
if save_to_db and db and (idx + 1) % 10 == 0:
db.commit()
print(f"\n💾 Committed batch at row {idx + 1}")
except Exception as e:
print(f"❌ Error processing row {idx + 1}: {e}")
if db:
db.rollback()
continue
# Final commit
if save_to_db and db: if save_to_db and db:
try:
db.commit() db.commit()
print("\n✅ Final commit completed") print(f"💾 Committed batch {batch_start + 1}-{batch_end}")
except Exception as e:
db.rollback()
print(f"❌ Failed to commit batch: {e}")
except Exception as e: except Exception as e:
print(f"❌ Fatal error in parse_companies: {e}") print(f"❌ Fatal error in parse_companies: {e}")
+100 -43
View File
@@ -2,13 +2,18 @@ import os
from typing import List from typing import List
from db.db import DATABASE_URL, get_db from db.db import DATABASE_URL, get_db
from db.models import InvestorTable from db.models import FundTable, InvestorTable
from langchain import hub from langchain import hub
from langchain_community.agent_toolkits import SQLDatabaseToolkit from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent from langgraph.prebuilt import create_react_agent
from schemas.py_schemas import InvestorData, InvestorList from schemas.router_schemas import (
CompanyMinimal,
InvestmentResponse,
PaginatedResponse,
SectorMinimal,
)
from sqlalchemy.orm import selectinload from sqlalchemy.orm import selectinload
# Connect to SQLite # Connect to SQLite
@@ -21,16 +26,16 @@ class QueryProcessor:
self.llm = ChatOpenAI( self.llm = ChatOpenAI(
api_key=os.getenv("OPENROUTER_API_KEY"), api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1", base_url="https://openrouter.ai/api/v1",
model="openai/gpt-4o-mini", model="x-ai/grok-4-fast",
temperature=0, temperature=0,
) )
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
# Update system message to specifically request only investor IDs # Update system message to specifically request only fund IDs
system_message_updated = ( system_message_updated = (
prompt_template.format(dialect="SQLite", top_k=5) prompt_template.format(dialect="SQLite", top_k=5)
+ "\n\nIMPORTANT: You must ONLY return the investor IDs (id field) that match the user's criteria. " + "\n\nIMPORTANT: You must ONLY return the fund IDs (id field from the funds table) that match the user's criteria. "
+ "Do NOT return any other information, explanations, or data. " + "Do NOT return any other information, explanations, or data. "
+ "Your response should be ONLY a comma-separated list of numbers representing the investor IDs. " + "Your response should be ONLY a comma-separated list of numbers representing the fund IDs. "
+ "Example format: 1, 5, 12, 23" + "Example format: 1, 5, 12, 23"
) )
self.agent = create_react_agent( self.agent = create_react_agent(
@@ -39,9 +44,9 @@ class QueryProcessor:
prompt=system_message_updated, prompt=system_message_updated,
) )
def process_query(self, question: str) -> InvestorList: def process_query(self, question: str) -> PaginatedResponse[InvestmentResponse]:
"""Process a query using the LLM and return investor data.""" """Process a query using the LLM and return investment response data."""
# Let the LLM handle all database interactions and filtering to get IDs # Let the LLM handle all database interactions and filtering to get fund IDs
response = self.agent.invoke( response = self.agent.invoke(
{"messages": [("user", question)]}, {"messages": [("user", question)]},
) )
@@ -51,70 +56,122 @@ class QueryProcessor:
response["messages"][-1].content if response.get("messages") else "" response["messages"][-1].content if response.get("messages") else ""
) )
# Extract investor IDs from the AI response # Extract fund IDs from the AI response
investor_ids = self._extract_investor_ids_from_response(ai_response) fund_ids = self._extract_fund_ids_from_response(ai_response)
# Fetch full investor data using the IDs # Fetch full fund data with investor relationships using the IDs
return self._fetch_investors_by_ids(investor_ids) return self._fetch_funds_by_ids(fund_ids)
def _extract_investor_ids_from_response(self, ai_response: str) -> List[int]: def _extract_fund_ids_from_response(self, ai_response: str) -> List[int]:
"""Extract investor IDs from AI response.""" """Extract fund IDs from AI response."""
import re import re
investor_ids = [] fund_ids = []
try: try:
# Try multiple patterns to extract IDs from the response # Try multiple patterns to extract IDs from the response
# Pattern 1: Simple numbers (assuming they are IDs) # Pattern 1: Simple numbers (assuming they are IDs)
numbers = re.findall(r"\b\d+\b", ai_response) numbers = re.findall(r"\b\d+\b", ai_response)
investor_ids = [int(num) for num in numbers] fund_ids = [int(num) for num in numbers]
# Pattern 2: If response contains explicit ID references # Pattern 2: If response contains explicit ID references
id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower()) id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower())
if id_matches: if id_matches:
investor_ids = [int(id_str) for id_str in id_matches] fund_ids = [int(id_str) for id_str in id_matches]
except Exception as e: except Exception as e:
print(f"Error extracting IDs from response: {e}") print(f"Error extracting IDs from response: {e}")
return [] return []
return investor_ids return fund_ids
def _fetch_investors_by_ids(self, investor_ids: List[int]) -> InvestorList: def _fetch_funds_by_ids(
"""Fetch investors with all their relationships from the database using IDs.""" self, fund_ids: List[int]
if not investor_ids: ) -> PaginatedResponse[InvestmentResponse]:
return InvestorList(investors=[]) """Fetch funds with all their relationships from the database using fund IDs.
Constructs response similar to read_investors but starting from funds."""
if not fund_ids:
return PaginatedResponse(
items=[],
total=0,
page=1,
page_size=len(fund_ids) if fund_ids else 10,
total_pages=0,
)
# Get database session # Get database session
db_session = next(get_db()) db_session = next(get_db())
try: try:
# Build query with all relationships loaded # Query funds with all necessary relationships loaded
query = ( funds = (
db_session.query(InvestorTable) db_session.query(FundTable)
.options( .options(
selectinload(InvestorTable.portfolio_companies), selectinload(FundTable.investor).selectinload(
selectinload(InvestorTable.team_members), InvestorTable.portfolio_companies
selectinload(InvestorTable.sectors), ),
selectinload(InvestorTable.funds), selectinload(FundTable.investor).selectinload(
InvestorTable.team_members
),
selectinload(FundTable.investor).selectinload(
InvestorTable.sectors
),
selectinload(FundTable.investment_stages),
selectinload(FundTable.sectors),
) )
.filter(InvestorTable.id.in_(investor_ids)) .filter(FundTable.id.in_(fund_ids))
.all()
) )
investors = query.all() # Transform to InvestmentResponse format (one row per fund)
investment_responses = []
for fund in funds:
investor = fund.investor
# Transform to InvestorData format # Get top 3 portfolio companies (id and name only)
investor_data_list = [] portfolio_companies = [
for investor in investors: CompanyMinimal(id=company.id, name=company.name)
investor_data = InvestorData( for company in investor.portfolio_companies[:3]
investor=investor, ]
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members, # Get stage focus as comma-separated string
sectors=investor.sectors, stage_focus = (
funds=investor.funds, ", ".join([stage.name for stage in fund.investment_stages])
if fund.investment_stages
else None
) )
investor_data_list.append(investor_data)
return InvestorList(investors=investor_data_list) # Get top 3 sectors from fund (id and name only)
fund_sectors = [
SectorMinimal(id=sector.id, name=sector.name)
for sector in (fund.sectors[:3] if fund.sectors else [])
]
investment_response = InvestmentResponse(
id=investor.id,
name=f"{investor.name} - {fund.fund_name}"
if fund.fund_name
else investor.name,
aum=investor.aum,
check_size_lower=fund.check_size_lower,
check_size_upper=fund.check_size_upper,
geographic_focus=fund.geographic_focus,
stage_focus=stage_focus,
portfolio_companies=portfolio_companies,
sectors=fund_sectors,
compatibility_score=1.0,
)
investment_responses.append(investment_response)
total_count = len(investment_responses)
total_pages = 1 if total_count > 0 else 0
return PaginatedResponse(
items=investment_responses,
total=total_count,
page=1,
page_size=total_count,
total_pages=total_pages,
)
finally: finally:
db_session.close() db_session.close()
View File
BIN
View File
Binary file not shown.