Compare commits
6 Commits
84e3c7b72a
...
cefe89bb67
| Author | SHA1 | Date | |
|---|---|---|---|
| cefe89bb67 | |||
| 58722f1102 | |||
| be6fde9ba2 | |||
| 37e1ad01c4 | |||
| faf92a3b47 | |||
| 26a1197db0 |
Binary file not shown.
Binary file not shown.
+1
-1
@@ -14,7 +14,7 @@ Base = declarative_base()
|
||||
# Get absolute path to the preprocessor database
|
||||
# APP_DIR = Path(__file__).parent.parent
|
||||
# PREPROCESSOR_DB = APP_DIR.parent / "preprocessor" / "version_two.db"
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./version_two.db")
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./investors.db")
|
||||
|
||||
# Create engine
|
||||
engine = create_engine(DATABASE_URL, echo=False)
|
||||
|
||||
+7
-2
@@ -6,7 +6,7 @@ from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, File, Form, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from routers import companies, investors, projects
|
||||
from schemas.router_schemas import InvestorList
|
||||
from schemas.router_schemas import InvestmentResponse, PaginatedResponse
|
||||
from services.llm_parser import InvestorProcessor
|
||||
from services.querying import QueryProcessor
|
||||
|
||||
@@ -84,11 +84,16 @@ async def parse_csv(
|
||||
return results
|
||||
|
||||
|
||||
@app.post("/query", response_model=InvestorList, tags=["Querying"])
|
||||
@app.post(
|
||||
"/query", response_model=PaginatedResponse[InvestmentResponse], tags=["Querying"]
|
||||
)
|
||||
async def query_investors(request: QueryRequest):
|
||||
"""
|
||||
Query investors using natural language.
|
||||
|
||||
Returns fund-level matches (one row per fund) with investor details.
|
||||
This ensures only relevant funds are included in the response.
|
||||
|
||||
Supports queries like:
|
||||
- "Show me seed stage investors"
|
||||
- "Find fintech investors in Silicon Valley"
|
||||
|
||||
Binary file not shown.
Binary file not shown.
+53
-14
@@ -1,10 +1,10 @@
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from db.db import get_db
|
||||
from db.models import CompanyTable, InvestorTable
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from schemas.router_schemas import CompanyData
|
||||
from schemas.router_schemas import CompanyData, PaginatedResponse
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
router = APIRouter(tags=["Company Routes"])
|
||||
@@ -29,20 +29,34 @@ class CompanyUpdate(BaseModel):
|
||||
website: Optional[str] = None
|
||||
|
||||
|
||||
@router.get("/companies", response_model=List[CompanyData])
|
||||
def read_companies(db: Session = Depends(get_db)):
|
||||
"""Get all companies with their investor relationships"""
|
||||
@router.get("/companies", response_model=PaginatedResponse[CompanyData])
|
||||
def read_companies(
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all companies with their investor relationships (paginated)"""
|
||||
# Calculate offset
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Get total count
|
||||
total_count = (
|
||||
db.query(CompanyTable)
|
||||
.filter(CompanyTable.name.isnot(None), CompanyTable.description.isnot(None))
|
||||
.count()
|
||||
)
|
||||
|
||||
# Get paginated results
|
||||
companies = (
|
||||
db.query(CompanyTable)
|
||||
.filter(
|
||||
CompanyTable.name.isnot(None),
|
||||
CompanyTable.description.isnot(None)
|
||||
)
|
||||
.filter(CompanyTable.name.isnot(None), CompanyTable.description.isnot(None))
|
||||
.options(
|
||||
selectinload(CompanyTable.investors),
|
||||
selectinload(CompanyTable.members),
|
||||
selectinload(CompanyTable.sectors),
|
||||
)
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -57,10 +71,19 @@ def read_companies(db: Session = Depends(get_db)):
|
||||
)
|
||||
company_data_list.append(company_data)
|
||||
|
||||
return company_data_list
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=company_data_list,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/companies/filter", response_model=List[CompanyData])
|
||||
@router.get("/companies/filter", response_model=PaginatedResponse[CompanyData])
|
||||
def filter_companies(
|
||||
industry: Optional[str] = Query(
|
||||
None, description="Filter by industry (partial match)"
|
||||
@@ -76,9 +99,11 @@ def filter_companies(
|
||||
investor_name: Optional[str] = Query(
|
||||
None, description="Filter by investor name (partial match)"
|
||||
),
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Filter companies based on various criteria"""
|
||||
"""Filter companies based on various criteria (paginated)"""
|
||||
|
||||
# Start with base query
|
||||
query = db.query(CompanyTable).options(
|
||||
@@ -112,7 +137,12 @@ def filter_companies(
|
||||
InvestorTable.name.ilike(f"%{investor_name}%")
|
||||
)
|
||||
|
||||
companies = query.all()
|
||||
# Get total count before pagination
|
||||
total_count = query.count()
|
||||
|
||||
# Calculate offset and apply pagination
|
||||
offset = (page - 1) * page_size
|
||||
companies = query.offset(offset).limit(page_size).all()
|
||||
|
||||
# Transform to CompanyData format
|
||||
company_data_list = []
|
||||
@@ -125,7 +155,16 @@ def filter_companies(
|
||||
)
|
||||
company_data_list.append(company_data)
|
||||
|
||||
return company_data_list
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=company_data_list,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/companies/{company_id}", response_model=CompanyData)
|
||||
|
||||
+287
-277
@@ -1,13 +1,16 @@
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from db.db import get_db
|
||||
from db.models import InvestorTable, SectorTable
|
||||
from db.models import FundTable, InvestorTable, SectorTable
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from schemas.router_schemas import (
|
||||
CompanyMinimal,
|
||||
InvestmentResponse,
|
||||
InvestmentStage,
|
||||
InvestorData,
|
||||
InvestorFundData,
|
||||
PaginatedResponse,
|
||||
SectorMinimal,
|
||||
)
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
@@ -18,32 +21,45 @@ router = APIRouter(tags=["Investor Routes"])
|
||||
class InvestorCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
headquarters: Optional[str] = None
|
||||
aum: int
|
||||
check_size_lower: int
|
||||
check_size_upper: int
|
||||
geographic_focus: str
|
||||
stage_focus: InvestmentStage
|
||||
number_of_investments: int = 0
|
||||
|
||||
|
||||
class InvestorUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
headquarters: Optional[str] = None
|
||||
aum: Optional[int] = None
|
||||
check_size_lower: Optional[int] = None
|
||||
check_size_upper: Optional[int] = None
|
||||
geographic_focus: Optional[str] = None
|
||||
stage_focus: Optional[InvestmentStage] = None
|
||||
number_of_investments: Optional[int] = None
|
||||
|
||||
|
||||
@router.get("/investors", response_model=List[InvestorFundData])
|
||||
def read_investors(db: Session = Depends(get_db)):
|
||||
"""Get all investors with their funds as separate entries
|
||||
@router.get("/investors", response_model=PaginatedResponse[InvestmentResponse])
|
||||
def read_investors(
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all investors with their funds as separate entries (paginated)
|
||||
|
||||
Each investor-fund combination is returned as a separate row.
|
||||
An investor with 3 funds will appear as 3 entries.
|
||||
"""
|
||||
# Calculate offset
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Get total count
|
||||
total_count = db.query(InvestorTable).count()
|
||||
|
||||
# Get paginated results
|
||||
investors = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
@@ -52,80 +68,80 @@ def read_investors(db: Session = Depends(get_db)):
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(InvestorTable.funds),
|
||||
)
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Transform to InvestorFundData format (one row per investor-fund combination)
|
||||
investor_fund_list = []
|
||||
# Transform to InvestmentResponse format (one row per investor-fund combination)
|
||||
investment_responses = []
|
||||
for investor in investors:
|
||||
# Get top 3 portfolio companies (id and name only)
|
||||
portfolio_companies = [
|
||||
CompanyMinimal(id=company.id, name=company.name)
|
||||
for company in investor.portfolio_companies[:3]
|
||||
]
|
||||
|
||||
# If investor has funds, create one entry per fund
|
||||
if investor.funds:
|
||||
for fund in investor.funds:
|
||||
investor_fund_data = InvestorFundData(
|
||||
# Investor fields
|
||||
investor_id=investor.id,
|
||||
investor_name=investor.name,
|
||||
investor_description=investor.description,
|
||||
investor_website=investor.website,
|
||||
investor_headquarters=investor.headquarters,
|
||||
# Get stage focus as comma-separated string
|
||||
stage_focus = (
|
||||
", ".join([stage.name for stage in fund.investment_stages])
|
||||
if fund.investment_stages
|
||||
else None
|
||||
)
|
||||
|
||||
# Get top 3 sectors from fund (id and name only)
|
||||
fund_sectors = [
|
||||
SectorMinimal(id=sector.id, name=sector.name)
|
||||
for sector in (fund.sectors[:3] if fund.sectors else [])
|
||||
]
|
||||
|
||||
investment_response = InvestmentResponse(
|
||||
id=investor.id,
|
||||
name=f"{investor.name} - {fund.fund_name}"
|
||||
if fund.fund_name
|
||||
else investor.name,
|
||||
aum=investor.aum,
|
||||
aum_as_of_date=investor.aum_as_of_date,
|
||||
aum_source_url=investor.aum_source_url,
|
||||
investment_thesis=investor.investment_thesis,
|
||||
portfolio_highlights=investor.portfolio_highlights,
|
||||
number_of_investments=investor.number_of_investments,
|
||||
# Fund fields
|
||||
fund_id=fund.id,
|
||||
fund_name=fund.fund_name,
|
||||
fund_size=fund.fund_size,
|
||||
fund_size_source_url=fund.fund_size_source_url,
|
||||
check_size_lower=fund.check_size_lower,
|
||||
check_size_upper=fund.check_size_upper,
|
||||
geographic_focus=fund.geographic_focus,
|
||||
fund_investment_stages=fund.investment_stages, # Now a relationship
|
||||
fund_sectors=fund.sectors, # Now a relationship
|
||||
# Related data (same for all funds of this investor)
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
stage_focus=stage_focus,
|
||||
portfolio_companies=portfolio_companies,
|
||||
sectors=fund_sectors,
|
||||
compatibility_score=1.0,
|
||||
)
|
||||
investor_fund_list.append(investor_fund_data)
|
||||
investment_responses.append(investment_response)
|
||||
else:
|
||||
# If no funds, create one entry with null fund fields
|
||||
investor_fund_data = InvestorFundData(
|
||||
# Investor fields
|
||||
investor_id=investor.id,
|
||||
investor_name=investor.name,
|
||||
investor_description=investor.description,
|
||||
investor_website=investor.website,
|
||||
investor_headquarters=investor.headquarters,
|
||||
investment_response = InvestmentResponse(
|
||||
id=investor.id,
|
||||
name=investor.name,
|
||||
aum=investor.aum,
|
||||
aum_as_of_date=investor.aum_as_of_date,
|
||||
aum_source_url=investor.aum_source_url,
|
||||
investment_thesis=investor.investment_thesis,
|
||||
portfolio_highlights=investor.portfolio_highlights,
|
||||
number_of_investments=investor.number_of_investments,
|
||||
# Fund fields (null)
|
||||
fund_id=None,
|
||||
fund_name=None,
|
||||
fund_size=None,
|
||||
fund_size_source_url=None,
|
||||
check_size_lower=None,
|
||||
check_size_upper=None,
|
||||
geographic_focus=None,
|
||||
fund_investment_stages=None,
|
||||
fund_sectors=None,
|
||||
# Related data
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
stage_focus=None,
|
||||
portfolio_companies=portfolio_companies,
|
||||
sectors=[],
|
||||
compatibility_score=1.0,
|
||||
)
|
||||
investor_fund_list.append(investor_fund_data)
|
||||
investment_responses.append(investment_response)
|
||||
|
||||
return investor_fund_list
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=investment_responses,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/investors/filter", response_model=List[InvestorFundData])
|
||||
@router.get("/investors/filter", response_model=PaginatedResponse[InvestmentResponse])
|
||||
def filter_investors(
|
||||
stage: Optional[InvestmentStage] = Query(
|
||||
None, description="Filter by investment stage"
|
||||
@@ -138,117 +154,109 @@ def filter_investors(
|
||||
sector: Optional[str] = Query(None, description="Sector name (partial match)"),
|
||||
min_aum: Optional[int] = Query(None, description="Minimum AUM"),
|
||||
max_aum: Optional[int] = Query(None, description="Maximum AUM"),
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Filter investors based on various criteria
|
||||
"""Filter investors based on various criteria (paginated)
|
||||
|
||||
Returns investor-fund combinations as separate rows.
|
||||
An investor with 3 funds will appear as 3 entries.
|
||||
Queries the funds table to find matching funds.
|
||||
"""
|
||||
|
||||
# Start with base query
|
||||
query = db.query(InvestorTable).options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(InvestorTable.funds),
|
||||
# Start with base query on funds table
|
||||
query = db.query(FundTable).options(
|
||||
selectinload(FundTable.investor).selectinload(
|
||||
InvestorTable.portfolio_companies
|
||||
),
|
||||
selectinload(FundTable.investor).selectinload(InvestorTable.team_members),
|
||||
selectinload(FundTable.investor).selectinload(InvestorTable.sectors),
|
||||
selectinload(FundTable.investment_stages),
|
||||
selectinload(FundTable.sectors),
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if stage:
|
||||
query = query.filter(InvestorTable.stage_focus == stage)
|
||||
|
||||
# Apply filters at fund level
|
||||
if min_check_size is not None:
|
||||
query = query.filter(InvestorTable.check_size_lower >= min_check_size)
|
||||
query = query.filter(FundTable.check_size_lower >= min_check_size)
|
||||
|
||||
if max_check_size is not None:
|
||||
query = query.filter(InvestorTable.check_size_upper <= max_check_size)
|
||||
query = query.filter(FundTable.check_size_upper <= max_check_size)
|
||||
|
||||
if geography:
|
||||
query = query.filter(InvestorTable.geographic_focus.ilike(f"%{geography}%"))
|
||||
query = query.filter(FundTable.geographic_focus.ilike(f"%{geography}%"))
|
||||
|
||||
# Apply filters at investor level (through relationship)
|
||||
if min_aum is not None:
|
||||
query = query.filter(InvestorTable.aum >= min_aum)
|
||||
query = query.join(FundTable.investor).filter(InvestorTable.aum >= min_aum)
|
||||
|
||||
if max_aum is not None:
|
||||
if min_aum is None: # Only join if not already joined
|
||||
query = query.join(FundTable.investor)
|
||||
query = query.filter(InvestorTable.aum <= max_aum)
|
||||
|
||||
# Filter by sector if provided
|
||||
# Filter by sector if provided (at fund level)
|
||||
if sector:
|
||||
query = query.join(InvestorTable.sectors).filter(
|
||||
query = query.join(FundTable.sectors).filter(
|
||||
SectorTable.name.ilike(f"%{sector}%")
|
||||
)
|
||||
|
||||
investors = query.all()
|
||||
# Get total count before pagination
|
||||
total_count = query.count()
|
||||
|
||||
# Transform to InvestorFundData format (one row per investor-fund combination)
|
||||
investor_fund_list = []
|
||||
for investor in investors:
|
||||
# If investor has funds, create one entry per fund
|
||||
if investor.funds:
|
||||
for fund in investor.funds:
|
||||
investor_fund_data = InvestorFundData(
|
||||
# Investor fields
|
||||
investor_id=investor.id,
|
||||
investor_name=investor.name,
|
||||
investor_description=investor.description,
|
||||
investor_website=investor.website,
|
||||
investor_headquarters=investor.headquarters,
|
||||
aum=investor.aum,
|
||||
aum_as_of_date=investor.aum_as_of_date,
|
||||
aum_source_url=investor.aum_source_url,
|
||||
investment_thesis=investor.investment_thesis,
|
||||
portfolio_highlights=investor.portfolio_highlights,
|
||||
number_of_investments=investor.number_of_investments,
|
||||
# Fund fields
|
||||
fund_id=fund.id,
|
||||
fund_name=fund.fund_name,
|
||||
fund_size=fund.fund_size,
|
||||
fund_size_source_url=fund.fund_size_source_url,
|
||||
check_size_lower=fund.check_size_lower,
|
||||
check_size_upper=fund.check_size_upper,
|
||||
geographic_focus=fund.geographic_focus,
|
||||
fund_investment_stages=fund.investment_stages, # Now a relationship
|
||||
fund_sectors=fund.sectors, # Now a relationship
|
||||
# Related data
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
)
|
||||
investor_fund_list.append(investor_fund_data)
|
||||
else:
|
||||
# If no funds, create one entry with null fund fields
|
||||
investor_fund_data = InvestorFundData(
|
||||
# Investor fields
|
||||
investor_id=investor.id,
|
||||
investor_name=investor.name,
|
||||
investor_description=investor.description,
|
||||
investor_website=investor.website,
|
||||
investor_headquarters=investor.headquarters,
|
||||
aum=investor.aum,
|
||||
aum_as_of_date=investor.aum_as_of_date,
|
||||
aum_source_url=investor.aum_source_url,
|
||||
investment_thesis=investor.investment_thesis,
|
||||
portfolio_highlights=investor.portfolio_highlights,
|
||||
number_of_investments=investor.number_of_investments,
|
||||
# Fund fields (null)
|
||||
fund_id=None,
|
||||
fund_name=None,
|
||||
fund_size=None,
|
||||
fund_size_source_url=None,
|
||||
check_size_lower=None,
|
||||
check_size_upper=None,
|
||||
geographic_focus=None,
|
||||
fund_investment_stages=None,
|
||||
fund_sectors=None,
|
||||
# Related data
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
)
|
||||
investor_fund_list.append(investor_fund_data)
|
||||
# Calculate offset and apply pagination
|
||||
offset = (page - 1) * page_size
|
||||
funds = query.offset(offset).limit(page_size).all()
|
||||
|
||||
return investor_fund_list
|
||||
# Transform to InvestmentResponse format (one row per fund)
|
||||
investment_responses = []
|
||||
for fund in funds:
|
||||
investor = fund.investor
|
||||
|
||||
# Get top 3 portfolio companies (id and name only)
|
||||
portfolio_companies = [
|
||||
CompanyMinimal(id=company.id, name=company.name)
|
||||
for company in investor.portfolio_companies[:3]
|
||||
]
|
||||
|
||||
# Get stage focus as comma-separated string
|
||||
stage_focus = (
|
||||
", ".join([stage.name for stage in fund.investment_stages])
|
||||
if fund.investment_stages
|
||||
else None
|
||||
)
|
||||
|
||||
# Get top 3 sectors from fund (id and name only)
|
||||
fund_sectors = [
|
||||
SectorMinimal(id=sector.id, name=sector.name)
|
||||
for sector in (fund.sectors[:3] if fund.sectors else [])
|
||||
]
|
||||
|
||||
investment_response = InvestmentResponse(
|
||||
id=investor.id,
|
||||
name=f"{investor.name} - {fund.fund_name}"
|
||||
if fund.fund_name
|
||||
else investor.name,
|
||||
aum=investor.aum,
|
||||
check_size_lower=fund.check_size_lower,
|
||||
check_size_upper=fund.check_size_upper,
|
||||
geographic_focus=fund.geographic_focus,
|
||||
stage_focus=stage_focus,
|
||||
portfolio_companies=portfolio_companies,
|
||||
sectors=fund_sectors,
|
||||
compatibility_score=1.0,
|
||||
)
|
||||
investment_responses.append(investment_response)
|
||||
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=investment_responses,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/investors/{investor_id}", response_model=InvestorData)
|
||||
@@ -365,25 +373,32 @@ def delete_investor(investor_id: int, db: Session = Depends(get_db)):
|
||||
return {"message": "Investor deleted successfully"}
|
||||
|
||||
|
||||
@router.get("/investors/{investor_id}/similar", response_model=List[InvestorFundData])
|
||||
@router.get(
|
||||
"/investors/{investor_id}/similar",
|
||||
response_model=PaginatedResponse[InvestmentResponse],
|
||||
)
|
||||
def find_similar_investors(
|
||||
investor_id: int,
|
||||
limit: int = Query(10, description="Maximum number of similar investors to return"),
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Find investors similar to a given investor based on characteristics
|
||||
"""Find investors similar to a given investor based on characteristics (paginated)
|
||||
|
||||
Returns investor-fund combinations as separate rows.
|
||||
Queries the funds table to find matching funds.
|
||||
"""
|
||||
|
||||
# Get the target investor
|
||||
# Get the target investor to get their funds for comparison
|
||||
target_investor = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(InvestorTable.funds),
|
||||
selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages),
|
||||
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
|
||||
)
|
||||
.filter(InvestorTable.id == investor_id)
|
||||
.first()
|
||||
@@ -392,154 +407,149 @@ def find_similar_investors(
|
||||
if not target_investor:
|
||||
raise HTTPException(status_code=404, detail="Investor not found")
|
||||
|
||||
# Get target investor's sector IDs for comparison
|
||||
target_sector_ids = {sector.id for sector in target_investor.sectors}
|
||||
# Get target investor's sector IDs for comparison (from their funds)
|
||||
target_sector_ids = set()
|
||||
target_stage_ids = set()
|
||||
target_check_ranges = []
|
||||
target_geographies = []
|
||||
|
||||
# Query all other investors with their relationships
|
||||
candidates = (
|
||||
db.query(InvestorTable)
|
||||
for fund in target_investor.funds:
|
||||
if fund.sectors:
|
||||
target_sector_ids.update({sector.id for sector in fund.sectors})
|
||||
if fund.investment_stages:
|
||||
target_stage_ids.update({stage.id for stage in fund.investment_stages})
|
||||
if fund.check_size_lower and fund.check_size_upper:
|
||||
target_check_ranges.append((fund.check_size_lower, fund.check_size_upper))
|
||||
if fund.geographic_focus:
|
||||
target_geographies.append(fund.geographic_focus.lower())
|
||||
|
||||
# Query all funds from other investors
|
||||
candidate_funds = (
|
||||
db.query(FundTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(InvestorTable.funds),
|
||||
selectinload(FundTable.investor).selectinload(
|
||||
InvestorTable.portfolio_companies
|
||||
),
|
||||
selectinload(FundTable.investor).selectinload(InvestorTable.team_members),
|
||||
selectinload(FundTable.investor).selectinload(InvestorTable.sectors),
|
||||
selectinload(FundTable.investment_stages),
|
||||
selectinload(FundTable.sectors),
|
||||
)
|
||||
.join(FundTable.investor)
|
||||
.filter(InvestorTable.id != investor_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Calculate similarity scores
|
||||
scored_investors = []
|
||||
for candidate in candidates:
|
||||
# Calculate similarity scores for each fund
|
||||
scored_funds = []
|
||||
for fund in candidate_funds:
|
||||
score = 0
|
||||
|
||||
# Stage focus match (30 points)
|
||||
if candidate.stage_focus == target_investor.stage_focus:
|
||||
score += 30
|
||||
|
||||
# Geographic focus match (20 points for exact, 10 for partial)
|
||||
if candidate.geographic_focus and target_investor.geographic_focus:
|
||||
if (
|
||||
candidate.geographic_focus.lower()
|
||||
== target_investor.geographic_focus.lower()
|
||||
):
|
||||
score += 20
|
||||
elif (
|
||||
candidate.geographic_focus.lower()
|
||||
in target_investor.geographic_focus.lower()
|
||||
or target_investor.geographic_focus.lower()
|
||||
in candidate.geographic_focus.lower()
|
||||
):
|
||||
score += 10
|
||||
if fund.geographic_focus and target_geographies:
|
||||
fund_geo_lower = fund.geographic_focus.lower()
|
||||
for target_geo in target_geographies:
|
||||
if fund_geo_lower == target_geo:
|
||||
score += 20
|
||||
break
|
||||
elif fund_geo_lower in target_geo or target_geo in fund_geo_lower:
|
||||
score += 10
|
||||
break
|
||||
|
||||
# Check size overlap (20 points max)
|
||||
if (
|
||||
candidate.check_size_lower
|
||||
and candidate.check_size_upper
|
||||
and target_investor.check_size_lower
|
||||
and target_investor.check_size_upper
|
||||
):
|
||||
# Calculate overlap percentage
|
||||
overlap_start = max(
|
||||
candidate.check_size_lower, target_investor.check_size_lower
|
||||
)
|
||||
overlap_end = min(
|
||||
candidate.check_size_upper, target_investor.check_size_upper
|
||||
)
|
||||
if overlap_end > overlap_start:
|
||||
overlap = overlap_end - overlap_start
|
||||
target_range = (
|
||||
target_investor.check_size_upper - target_investor.check_size_lower
|
||||
)
|
||||
overlap_ratio = overlap / target_range if target_range > 0 else 0
|
||||
score += int(20 * overlap_ratio)
|
||||
if fund.check_size_lower and fund.check_size_upper and target_check_ranges:
|
||||
max_overlap_score = 0
|
||||
for target_lower, target_upper in target_check_ranges:
|
||||
overlap_start = max(fund.check_size_lower, target_lower)
|
||||
overlap_end = min(fund.check_size_upper, target_upper)
|
||||
if overlap_end > overlap_start:
|
||||
overlap = overlap_end - overlap_start
|
||||
target_range = target_upper - target_lower
|
||||
overlap_ratio = overlap / target_range if target_range > 0 else 0
|
||||
max_overlap_score = max(max_overlap_score, int(20 * overlap_ratio))
|
||||
score += max_overlap_score
|
||||
|
||||
# AUM similarity (15 points max)
|
||||
if candidate.aum and target_investor.aum:
|
||||
aum_diff = abs(candidate.aum - target_investor.aum)
|
||||
max_aum = max(candidate.aum, target_investor.aum)
|
||||
if fund.investor.aum and target_investor.aum:
|
||||
aum_diff = abs(fund.investor.aum - target_investor.aum)
|
||||
max_aum = max(fund.investor.aum, target_investor.aum)
|
||||
similarity_ratio = 1 - (aum_diff / max_aum) if max_aum > 0 else 0
|
||||
score += int(15 * similarity_ratio)
|
||||
|
||||
# Sector overlap (30 points max)
|
||||
candidate_sector_ids = {sector.id for sector in candidate.sectors}
|
||||
if target_sector_ids and candidate_sector_ids:
|
||||
common_sectors = target_sector_ids.intersection(candidate_sector_ids)
|
||||
if fund.sectors and target_sector_ids:
|
||||
fund_sector_ids = {sector.id for sector in fund.sectors}
|
||||
common_sectors = target_sector_ids.intersection(fund_sector_ids)
|
||||
overlap_ratio = len(common_sectors) / len(target_sector_ids)
|
||||
score += int(30 * overlap_ratio)
|
||||
|
||||
if score > 0: # Only include investors with some similarity
|
||||
scored_investors.append((score, candidate))
|
||||
# Investment stage match (15 points max)
|
||||
if fund.investment_stages and target_stage_ids:
|
||||
fund_stage_ids = {stage.id for stage in fund.investment_stages}
|
||||
common_stages = target_stage_ids.intersection(fund_stage_ids)
|
||||
overlap_ratio = len(common_stages) / len(target_stage_ids)
|
||||
score += int(15 * overlap_ratio)
|
||||
|
||||
# Sort by score (descending) and take top N
|
||||
scored_investors.sort(key=lambda x: x[0], reverse=True)
|
||||
similar_investors = [inv for score, inv in scored_investors[:limit]]
|
||||
if score > 0: # Only include funds with some similarity
|
||||
scored_funds.append((score, fund))
|
||||
|
||||
# Transform to InvestorFundData format (one row per investor-fund combination)
|
||||
investor_fund_list = []
|
||||
for investor in similar_investors:
|
||||
# If investor has funds, create one entry per fund
|
||||
if investor.funds:
|
||||
for fund in investor.funds:
|
||||
investor_fund_data = InvestorFundData(
|
||||
# Investor fields
|
||||
investor_id=investor.id,
|
||||
investor_name=investor.name,
|
||||
investor_description=investor.description,
|
||||
investor_website=investor.website,
|
||||
investor_headquarters=investor.headquarters,
|
||||
aum=investor.aum,
|
||||
aum_as_of_date=investor.aum_as_of_date,
|
||||
aum_source_url=investor.aum_source_url,
|
||||
investment_thesis=investor.investment_thesis,
|
||||
portfolio_highlights=investor.portfolio_highlights,
|
||||
number_of_investments=investor.number_of_investments,
|
||||
# Fund fields
|
||||
fund_id=fund.id,
|
||||
fund_name=fund.fund_name,
|
||||
fund_size=fund.fund_size,
|
||||
fund_size_source_url=fund.fund_size_source_url,
|
||||
check_size_lower=fund.check_size_lower,
|
||||
check_size_upper=fund.check_size_upper,
|
||||
geographic_focus=fund.geographic_focus,
|
||||
fund_investment_stages=fund.investment_stages, # Now a relationship
|
||||
fund_sectors=fund.sectors, # Now a relationship
|
||||
# Related data
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
)
|
||||
investor_fund_list.append(investor_fund_data)
|
||||
else:
|
||||
# If no funds, create one entry with null fund fields
|
||||
investor_fund_data = InvestorFundData(
|
||||
# Investor fields
|
||||
investor_id=investor.id,
|
||||
investor_name=investor.name,
|
||||
investor_description=investor.description,
|
||||
investor_website=investor.website,
|
||||
investor_headquarters=investor.headquarters,
|
||||
aum=investor.aum,
|
||||
aum_as_of_date=investor.aum_as_of_date,
|
||||
aum_source_url=investor.aum_source_url,
|
||||
investment_thesis=investor.investment_thesis,
|
||||
portfolio_highlights=investor.portfolio_highlights,
|
||||
number_of_investments=investor.number_of_investments,
|
||||
# Fund fields (null)
|
||||
fund_id=None,
|
||||
fund_name=None,
|
||||
fund_size=None,
|
||||
fund_size_source_url=None,
|
||||
check_size_lower=None,
|
||||
check_size_upper=None,
|
||||
geographic_focus=None,
|
||||
fund_investment_stages=None,
|
||||
fund_sectors=None,
|
||||
# Related data
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
)
|
||||
investor_fund_list.append(investor_fund_data)
|
||||
# Sort by score (descending) and take top N based on limit
|
||||
scored_funds.sort(key=lambda x: x[0], reverse=True)
|
||||
top_similar = scored_funds[:limit]
|
||||
|
||||
return investor_fund_list
|
||||
# Apply pagination to the top similar funds
|
||||
total_count = len(top_similar)
|
||||
offset = (page - 1) * page_size
|
||||
paginated_similar = top_similar[offset : offset + page_size]
|
||||
similar_funds = [fund for score, fund in paginated_similar]
|
||||
|
||||
# Transform to InvestmentResponse format (one row per fund)
|
||||
investment_responses = []
|
||||
for fund in similar_funds:
|
||||
investor = fund.investor
|
||||
|
||||
# Get top 3 portfolio companies (id and name only)
|
||||
portfolio_companies = [
|
||||
CompanyMinimal(id=company.id, name=company.name)
|
||||
for company in investor.portfolio_companies[:3]
|
||||
]
|
||||
|
||||
# Get stage focus as comma-separated string
|
||||
stage_focus = (
|
||||
", ".join([stage.name for stage in fund.investment_stages])
|
||||
if fund.investment_stages
|
||||
else None
|
||||
)
|
||||
|
||||
# Get top 3 sectors from fund (id and name only)
|
||||
fund_sectors = [
|
||||
SectorMinimal(id=sector.id, name=sector.name)
|
||||
for sector in (fund.sectors[:3] if fund.sectors else [])
|
||||
]
|
||||
|
||||
investment_response = InvestmentResponse(
|
||||
id=investor.id,
|
||||
name=f"{investor.name} - {fund.fund_name}"
|
||||
if fund.fund_name
|
||||
else investor.name,
|
||||
aum=investor.aum,
|
||||
check_size_lower=fund.check_size_lower,
|
||||
check_size_upper=fund.check_size_upper,
|
||||
geographic_focus=fund.geographic_focus,
|
||||
stage_focus=stage_focus,
|
||||
portfolio_companies=portfolio_companies,
|
||||
sectors=fund_sectors,
|
||||
compatibility_score=1.0,
|
||||
)
|
||||
investment_responses.append(investment_response)
|
||||
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=investment_responses,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
+47
-8
@@ -14,14 +14,26 @@ from schemas.project_schemas import (
|
||||
ProjectData,
|
||||
ProjectUpdate,
|
||||
)
|
||||
from schemas.router_schemas import PaginatedResponse
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
router = APIRouter(tags=["Project Routes"])
|
||||
|
||||
|
||||
@router.get("/projects", response_model=List[ProjectData])
|
||||
def read_projects(db: Session = Depends(get_db)):
|
||||
"""Get all projects with their related data"""
|
||||
@router.get("/projects", response_model=PaginatedResponse[ProjectData])
|
||||
def read_projects(
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all projects with their related data (paginated)"""
|
||||
# Calculate offset
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Get total count
|
||||
total_count = db.query(ProjectTable).count()
|
||||
|
||||
# Get paginated results
|
||||
projects = (
|
||||
db.query(ProjectTable)
|
||||
.options(
|
||||
@@ -29,6 +41,8 @@ def read_projects(db: Session = Depends(get_db)):
|
||||
selectinload(ProjectTable.investors),
|
||||
selectinload(ProjectTable.companies),
|
||||
)
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -43,7 +57,16 @@ def read_projects(db: Session = Depends(get_db)):
|
||||
)
|
||||
project_data_list.append(project_data)
|
||||
|
||||
return project_data_list
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=project_data_list,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}", response_model=ProjectData)
|
||||
@@ -151,7 +174,7 @@ def delete_project(project_id: int, db: Session = Depends(get_db)):
|
||||
return {"message": "Project deleted successfully"}
|
||||
|
||||
|
||||
@router.get("/projects/filter", response_model=List[ProjectData])
|
||||
@router.get("/projects/filter", response_model=PaginatedResponse[ProjectData])
|
||||
def filter_projects(
|
||||
stage: Optional[InvestmentStage] = Query(
|
||||
None, description="Filter by project stage"
|
||||
@@ -166,9 +189,11 @@ def filter_projects(
|
||||
company_name: Optional[str] = Query(
|
||||
None, description="Company name (partial match)"
|
||||
),
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Filter projects based on various criteria"""
|
||||
"""Filter projects based on various criteria (paginated)"""
|
||||
|
||||
# Start with base query
|
||||
query = db.query(ProjectTable).options(
|
||||
@@ -205,7 +230,12 @@ def filter_projects(
|
||||
CompanyTable.name.ilike(f"%{company_name}%")
|
||||
)
|
||||
|
||||
projects = query.all()
|
||||
# Get total count before pagination
|
||||
total_count = query.count()
|
||||
|
||||
# Calculate offset and apply pagination
|
||||
offset = (page - 1) * page_size
|
||||
projects = query.offset(offset).limit(page_size).all()
|
||||
|
||||
# Transform to ProjectData format
|
||||
project_data_list = []
|
||||
@@ -218,7 +248,16 @@ def filter_projects(
|
||||
)
|
||||
project_data_list.append(project_data)
|
||||
|
||||
return project_data_list
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=project_data_list,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
# Association management routes
|
||||
|
||||
Binary file not shown.
@@ -1,9 +1,12 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import Any, Generic, List, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Generic type for pagination
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class InvestmentStage(str, Enum):
|
||||
SEED = "SEED"
|
||||
@@ -89,11 +92,20 @@ class InvestorSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: Optional[str]
|
||||
website: Optional[str] = None
|
||||
headquarters: Optional[str] = None
|
||||
aum: int | None
|
||||
aum_as_of_date: str | None = None
|
||||
aum_source_url: str | None = None
|
||||
check_size_lower: int | None
|
||||
check_size_upper: int | None
|
||||
geographic_focus: str | None
|
||||
stage_focus: InvestmentStage
|
||||
investment_thesis: Any = (
|
||||
None # Flexible JSON field - can be list, dict, or list of dicts
|
||||
)
|
||||
portfolio_highlights: Any = (
|
||||
None # Flexible JSON field - can be list, dict, or list of dicts
|
||||
)
|
||||
number_of_investments: int | None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
@@ -131,8 +143,8 @@ class InvestorFundData(BaseModel):
|
||||
aum: int | None
|
||||
aum_as_of_date: str | None
|
||||
aum_source_url: str | None
|
||||
investment_thesis: List[str] | None
|
||||
portfolio_highlights: List[str] | None
|
||||
investment_thesis: Any = None # Flexible JSON field
|
||||
portfolio_highlights: Any = None # Flexible JSON field
|
||||
number_of_investments: int | None
|
||||
|
||||
# Fund fields
|
||||
@@ -156,12 +168,29 @@ class InvestorFundData(BaseModel):
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class InvestorMinimal(BaseModel):
|
||||
"""Minimal investor info with just id and name"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class CompanySchemaMinimal(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
industry: str | None
|
||||
location: str | None
|
||||
founded_year: Optional[int]
|
||||
website: Optional[str]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
|
||||
company: CompanySchema
|
||||
sectors: List[SectorSchema]
|
||||
members: List[CompanyMemberSchema]
|
||||
investors: List[InvestorSchema]
|
||||
company: CompanySchemaMinimal
|
||||
investors: List[InvestorMinimal]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
@@ -175,3 +204,59 @@ class InvestorFundList(BaseModel):
|
||||
"""List of investor-fund combinations"""
|
||||
|
||||
investor_funds: List[InvestorFundData]
|
||||
|
||||
|
||||
class CompanyMinimal(BaseModel):
|
||||
"""Minimal company info with just id and name"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SectorMinimal(BaseModel):
|
||||
"""Minimal sector info with just id and name"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestmentResponse(BaseModel):
|
||||
"""Simplified investment response schema
|
||||
|
||||
One row per investor-fund combination with streamlined data
|
||||
"""
|
||||
|
||||
id: int # Investor ID
|
||||
name: (
|
||||
str # Combination of investor name and fund name (e.g., "Investor A - Fund A")
|
||||
)
|
||||
aum: int | None # From investor
|
||||
check_size_lower: int | None # From fund
|
||||
check_size_upper: int | None # From fund
|
||||
geographic_focus: str | None # From fund
|
||||
stage_focus: str | None # Comma-separated stages from fund
|
||||
portfolio_companies: List[CompanyMinimal] # Top 3 companies from investor
|
||||
sectors: List[SectorMinimal] # Top 3 sectors from fund
|
||||
compatibility_score: float # 0 to 1 (default 1 for now)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel, Generic[T]):
|
||||
"""Generic paginated response schema"""
|
||||
|
||||
items: List[T]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
Binary file not shown.
Binary file not shown.
+180
-181
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
@@ -265,37 +266,20 @@ Return the lower and upper bounds in USD."""
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Process company profile from CSV data.
|
||||
Manually extracts fields without using LLM.
|
||||
Only extracts founded_year and key_executives - rest is in base database.
|
||||
"""
|
||||
profile = self.parse_json_profile(profile_json)
|
||||
if not profile:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Extract basic info
|
||||
# Only extract founded_year and key_executives
|
||||
company_data = {
|
||||
"name": name.strip() if name else None,
|
||||
"website": website.strip() if website else None,
|
||||
"description": profile.get("companyDescription"),
|
||||
"location": profile.get("geographicFocus"),
|
||||
"industry": profile.get("sectorDescription"),
|
||||
"founded_year": None, # Not typically in the company JSON
|
||||
"founded_year": None,
|
||||
"key_executives": [],
|
||||
"client_categories": profile.get("clientCategories", []),
|
||||
"product_description": profile.get("productDescription"),
|
||||
"linked_documents": profile.get("linkedDocuments", []),
|
||||
"researcher_notes": profile.get("researcherNotes"),
|
||||
"missing_important_fields": profile.get("missingImportantFields", []),
|
||||
"sources": profile.get("sources", {}),
|
||||
"investor_names": [],
|
||||
}
|
||||
|
||||
# Parse investor names from the Investor column
|
||||
if investor_names and pd.notna(investor_names):
|
||||
# Split by comma and clean
|
||||
investors = [inv.strip() for inv in str(investor_names).split(",")]
|
||||
company_data["investor_names"] = [inv for inv in investors if inv]
|
||||
|
||||
# Process key executives/leadership
|
||||
key_executives = profile.get("keyExecutives", [])
|
||||
if not key_executives:
|
||||
@@ -313,7 +297,7 @@ Return the lower and upper bounds in USD."""
|
||||
)
|
||||
|
||||
# Try to extract founding year from description
|
||||
description = company_data.get("description", "")
|
||||
description = profile.get("companyDescription", "")
|
||||
if description:
|
||||
# Look for patterns like "founded in 2020", "Gegründet 2020", "founded 2020"
|
||||
year_patterns = [
|
||||
@@ -344,41 +328,28 @@ Return the lower and upper bounds in USD."""
|
||||
def _save_parsed_company_to_db(
|
||||
self, db: Session, company_data: dict
|
||||
) -> Optional[CompanyTable]:
|
||||
"""Save manually parsed company data to database"""
|
||||
"""Save manually parsed company data to database - only updates founded_year and key_executives"""
|
||||
try:
|
||||
# Check if company already exists
|
||||
# Check if company already exists (should exist in base database)
|
||||
existing_company = (
|
||||
db.query(CompanyTable).filter_by(name=company_data["name"]).first()
|
||||
)
|
||||
|
||||
if existing_company:
|
||||
# Update existing company
|
||||
# Update only founded_year on existing company
|
||||
company = existing_company
|
||||
company.website = company_data.get("website") or company.website
|
||||
company.location = company_data.get("location") or company.location
|
||||
company.description = (
|
||||
company_data.get("description") or company.description
|
||||
)
|
||||
company.industry = company_data.get("industry") or company.industry
|
||||
if company_data.get("founded_year"):
|
||||
company.founded_year = company_data["founded_year"]
|
||||
else:
|
||||
# Create new company
|
||||
company = CompanyTable(
|
||||
name=company_data["name"],
|
||||
website=company_data.get("website"),
|
||||
location=company_data.get("location"),
|
||||
description=company_data.get("description"),
|
||||
industry=company_data.get("industry"),
|
||||
founded_year=company_data.get("founded_year"),
|
||||
# Company should already be in base database, but if not found, skip
|
||||
print(
|
||||
f"⚠️ Company '{company_data['name']}' not found in base database - skipping"
|
||||
)
|
||||
db.add(company)
|
||||
db.flush()
|
||||
return None
|
||||
|
||||
# Add/update company members (key executives)
|
||||
# First, remove existing members if updating
|
||||
if existing_company:
|
||||
db.query(CompanyMember).filter_by(company_id=company.id).delete()
|
||||
db.query(CompanyMember).filter_by(company_id=company.id).delete()
|
||||
|
||||
for exec_data in company_data.get("key_executives", []):
|
||||
member = CompanyMember(
|
||||
@@ -391,19 +362,6 @@ Return the lower and upper bounds in USD."""
|
||||
)
|
||||
db.add(member)
|
||||
|
||||
# Link to investors if provided
|
||||
for investor_name in company_data.get("investor_names", []):
|
||||
# Find investor in database
|
||||
investor = (
|
||||
db.query(InvestorTable)
|
||||
.filter_by(name=investor_name.strip())
|
||||
.first()
|
||||
)
|
||||
if investor:
|
||||
# Add company to investor's portfolio if not already there
|
||||
if company not in investor.portfolio_companies:
|
||||
investor.portfolio_companies.append(company)
|
||||
|
||||
return company
|
||||
|
||||
except Exception as e:
|
||||
@@ -692,10 +650,55 @@ Return the lower and upper bounds in USD."""
|
||||
print(f"Error processing row {row_idx + 1}: {e}")
|
||||
return None
|
||||
|
||||
async def parse_investors(self, df: pd.DataFrame, save_to_db: bool = True):
|
||||
async def _process_single_investor(
|
||||
self, idx: int, row: pd.Series, total_rows: int
|
||||
) -> Optional[dict]:
|
||||
"""Process a single investor row"""
|
||||
try:
|
||||
name = row.get("Name", "").strip() if pd.notna(row.get("Name")) else None
|
||||
website = (
|
||||
row.get("Website", "").strip() if pd.notna(row.get("Website")) else None
|
||||
)
|
||||
profile_json = (
|
||||
row.get("Final Investor Profile", "")
|
||||
if pd.notna(row.get("Final Investor Profile"))
|
||||
else None
|
||||
)
|
||||
|
||||
if not name or not profile_json:
|
||||
print(f"⚠️ Row {idx + 1}: Skipping - missing name or profile")
|
||||
return None
|
||||
|
||||
print(f"📊 Processing {idx + 1}/{total_rows}: {name}")
|
||||
|
||||
# Process the investor profile
|
||||
investor_data = await self.process_investor_profile(
|
||||
name, website, profile_json
|
||||
)
|
||||
|
||||
if investor_data:
|
||||
print(f" ✓ {name} parsed successfully")
|
||||
return investor_data
|
||||
else:
|
||||
print(f" ⚠️ {name} failed to process")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing row {idx + 1}: {e}")
|
||||
return None
|
||||
|
||||
async def parse_investors(
|
||||
self, df: pd.DataFrame, save_to_db: bool = True, batch_size: int = 10
|
||||
):
|
||||
"""
|
||||
Parse investors from DataFrame using manual JSON parsing and LLM for currency conversion.
|
||||
Processes multiple investors concurrently for better performance.
|
||||
Expected CSV columns: Name, Website, Final Investor Profile, Final Profile sourcing
|
||||
|
||||
Args:
|
||||
df: DataFrame with investor data
|
||||
save_to_db: Whether to save to database
|
||||
batch_size: Number of investors to process concurrently (default: 10)
|
||||
"""
|
||||
results = []
|
||||
db = None
|
||||
@@ -704,50 +707,31 @@ Return the lower and upper bounds in USD."""
|
||||
|
||||
try:
|
||||
total_rows = len(df)
|
||||
print(f"\n🚀 Starting to process {total_rows} investors...")
|
||||
print(
|
||||
f"\n🚀 Starting to process {total_rows} investors with batch size {batch_size}..."
|
||||
)
|
||||
|
||||
for idx, row in df.iterrows():
|
||||
try:
|
||||
name = (
|
||||
row.get("Name", "").strip()
|
||||
if pd.notna(row.get("Name"))
|
||||
else None
|
||||
)
|
||||
website = (
|
||||
row.get("Website", "").strip()
|
||||
if pd.notna(row.get("Website"))
|
||||
else None
|
||||
)
|
||||
profile_json = (
|
||||
row.get("Final Investor Profile", "")
|
||||
if pd.notna(row.get("Final Investor Profile"))
|
||||
else None
|
||||
)
|
||||
# Process in batches
|
||||
for batch_start in range(0, total_rows, batch_size):
|
||||
batch_end = min(batch_start + batch_size, total_rows)
|
||||
print(
|
||||
f"\n🔄 Processing batch {batch_start + 1}-{batch_end} of {total_rows}..."
|
||||
)
|
||||
|
||||
if not name or not profile_json:
|
||||
print(f"⚠️ Row {idx + 1}: Skipping - missing name or profile")
|
||||
continue
|
||||
# Create tasks for concurrent processing
|
||||
tasks = []
|
||||
for idx in range(batch_start, batch_end):
|
||||
row = df.iloc[idx]
|
||||
task = self._process_single_investor(idx, row, total_rows)
|
||||
tasks.append(task)
|
||||
|
||||
print(f"\n📊 Processing {idx + 1}/{total_rows}: {name}")
|
||||
# Process batch concurrently
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Process the investor profile
|
||||
investor_data = await self.process_investor_profile(
|
||||
name, website, profile_json
|
||||
)
|
||||
|
||||
if investor_data:
|
||||
# Filter out None results and exceptions, then save to database
|
||||
for investor_data in batch_results:
|
||||
if investor_data and not isinstance(investor_data, Exception):
|
||||
results.append(investor_data)
|
||||
print(" ✓ Parsed successfully")
|
||||
print(f" - HQ: {investor_data.get('headquarters')}")
|
||||
print(
|
||||
f" - AUM: ${investor_data.get('aum'):,}"
|
||||
if investor_data.get("aum")
|
||||
else " - AUM: Not Available"
|
||||
)
|
||||
print(f" - Funds: {len(investor_data.get('funds', []))}")
|
||||
print(
|
||||
f" - Team: {len(investor_data.get('team_members', []))}"
|
||||
)
|
||||
|
||||
# Save to database
|
||||
if save_to_db and db:
|
||||
@@ -756,33 +740,29 @@ Return the lower and upper bounds in USD."""
|
||||
db, investor_data
|
||||
)
|
||||
if saved_investor:
|
||||
db.commit()
|
||||
print(
|
||||
f" ✅ Saved to database (ID: {saved_investor.id})"
|
||||
f" ✅ Saved {investor_data['name']} to database (ID: {saved_investor.id})"
|
||||
)
|
||||
else:
|
||||
print(" ❌ Failed to save to database")
|
||||
print(
|
||||
f" ❌ Failed to save {investor_data['name']} to database"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f" ❌ Database error: {e}")
|
||||
else:
|
||||
print(" ⚠️ Failed to process profile")
|
||||
print(
|
||||
f" ❌ Database error for {investor_data['name']}: {e}"
|
||||
)
|
||||
elif isinstance(investor_data, Exception):
|
||||
print(f" ❌ Exception occurred: {investor_data}")
|
||||
|
||||
# Commit every 10 investors to avoid memory issues
|
||||
if save_to_db and db and (idx + 1) % 10 == 0:
|
||||
# Commit batch to database
|
||||
if save_to_db and db:
|
||||
try:
|
||||
db.commit()
|
||||
print(f"\n💾 Committed batch at row {idx + 1}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing row {idx + 1}: {e}")
|
||||
if db:
|
||||
print(f"💾 Committed batch {batch_start + 1}-{batch_end}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
continue
|
||||
|
||||
# Final commit
|
||||
if save_to_db and db:
|
||||
db.commit()
|
||||
print("\n✅ Final commit completed")
|
||||
print(f"❌ Failed to commit batch: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Fatal error in parse_investors: {e}")
|
||||
@@ -795,10 +775,60 @@ Return the lower and upper bounds in USD."""
|
||||
print(f"\n🎉 Completed! Processed {len(results)}/{total_rows} investors")
|
||||
return results
|
||||
|
||||
async def parse_companies(self, df: pd.DataFrame, save_to_db: bool = True):
|
||||
async def _process_single_company(
|
||||
self, idx: int, row: pd.Series, total_rows: int
|
||||
) -> Optional[dict]:
|
||||
"""Process a single company row"""
|
||||
try:
|
||||
name = row.get("Name", "").strip() if pd.notna(row.get("Name")) else None
|
||||
website = (
|
||||
row.get("Website", "").strip() if pd.notna(row.get("Website")) else None
|
||||
)
|
||||
investor_names = (
|
||||
row.get("Investor", "").strip()
|
||||
if pd.notna(row.get("Investor"))
|
||||
else None
|
||||
)
|
||||
profile_json = (
|
||||
row.get("Final Investor Profile", "")
|
||||
if pd.notna(row.get("Final Investor Profile"))
|
||||
else None
|
||||
)
|
||||
|
||||
if not name or not profile_json:
|
||||
print(f"⚠️ Row {idx + 1}: Skipping - missing name or profile")
|
||||
return None
|
||||
|
||||
print(f"📊 Processing {idx + 1}/{total_rows}: {name}")
|
||||
|
||||
# Process the company profile
|
||||
company_data = await self.process_company_profile(
|
||||
name, website, profile_json, investor_names
|
||||
)
|
||||
|
||||
if company_data:
|
||||
print(f" ✓ {name} parsed successfully")
|
||||
return company_data
|
||||
else:
|
||||
print(f" ⚠️ {name} failed to process")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing row {idx + 1}: {e}")
|
||||
return None
|
||||
|
||||
async def parse_companies(
|
||||
self, df: pd.DataFrame, save_to_db: bool = True, batch_size: int = 10
|
||||
):
|
||||
"""
|
||||
Parse companies from DataFrame using manual JSON parsing.
|
||||
Processes multiple companies concurrently for better performance.
|
||||
Expected CSV columns: Name, Website, Investor, Final Investor Profile (actually company profile)
|
||||
|
||||
Args:
|
||||
df: DataFrame with company data
|
||||
save_to_db: Whether to save to database
|
||||
batch_size: Number of companies to process concurrently (default: 10)
|
||||
"""
|
||||
results = []
|
||||
db = None
|
||||
@@ -807,58 +837,31 @@ Return the lower and upper bounds in USD."""
|
||||
|
||||
try:
|
||||
total_rows = len(df)
|
||||
print(f"\n🚀 Starting to process {total_rows} companies...")
|
||||
print(
|
||||
f"\n🚀 Starting to process {total_rows} companies with batch size {batch_size}..."
|
||||
)
|
||||
|
||||
for idx, row in df.iterrows():
|
||||
try:
|
||||
name = (
|
||||
row.get("Name", "").strip()
|
||||
if pd.notna(row.get("Name"))
|
||||
else None
|
||||
)
|
||||
website = (
|
||||
row.get("Website", "").strip()
|
||||
if pd.notna(row.get("Website"))
|
||||
else None
|
||||
)
|
||||
investor_names = (
|
||||
row.get("Investor", "").strip()
|
||||
if pd.notna(row.get("Investor"))
|
||||
else None
|
||||
)
|
||||
profile_json = (
|
||||
row.get("Final Investor Profile", "")
|
||||
if pd.notna(row.get("Final Investor Profile"))
|
||||
else None
|
||||
)
|
||||
# Process in batches
|
||||
for batch_start in range(0, total_rows, batch_size):
|
||||
batch_end = min(batch_start + batch_size, total_rows)
|
||||
print(
|
||||
f"\n🔄 Processing batch {batch_start + 1}-{batch_end} of {total_rows}..."
|
||||
)
|
||||
|
||||
if not name or not profile_json:
|
||||
print(f"⚠️ Row {idx + 1}: Skipping - missing name or profile")
|
||||
continue
|
||||
# Create tasks for concurrent processing
|
||||
tasks = []
|
||||
for idx in range(batch_start, batch_end):
|
||||
row = df.iloc[idx]
|
||||
task = self._process_single_company(idx, row, total_rows)
|
||||
tasks.append(task)
|
||||
|
||||
print(f"\n📊 Processing {idx + 1}/{total_rows}: {name}")
|
||||
# Process batch concurrently
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Process the company profile
|
||||
company_data = await self.process_company_profile(
|
||||
name, website, profile_json, investor_names
|
||||
)
|
||||
|
||||
if company_data:
|
||||
# Filter out None results and exceptions, then save to database
|
||||
for company_data in batch_results:
|
||||
if company_data and not isinstance(company_data, Exception):
|
||||
results.append(company_data)
|
||||
print(" ✓ Parsed successfully")
|
||||
print(f" - Location: {company_data.get('location')}")
|
||||
print(f" - Industry: {company_data.get('industry')}")
|
||||
print(
|
||||
f" - Founded: {company_data.get('founded_year')}"
|
||||
if company_data.get("founded_year")
|
||||
else " - Founded: Unknown"
|
||||
)
|
||||
print(
|
||||
f" - Executives: {len(company_data.get('key_executives', []))}"
|
||||
)
|
||||
print(
|
||||
f" - Investors: {len(company_data.get('investor_names', []))}"
|
||||
)
|
||||
|
||||
# Save to database
|
||||
if save_to_db and db:
|
||||
@@ -867,33 +870,29 @@ Return the lower and upper bounds in USD."""
|
||||
db, company_data
|
||||
)
|
||||
if saved_company:
|
||||
db.commit()
|
||||
print(
|
||||
f" ✅ Saved to database (ID: {saved_company.id})"
|
||||
f" ✅ Saved {company_data['name']} to database (ID: {saved_company.id})"
|
||||
)
|
||||
else:
|
||||
print(" ❌ Failed to save to database")
|
||||
print(
|
||||
f" ❌ Failed to save {company_data['name']} to database"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f" ❌ Database error: {e}")
|
||||
else:
|
||||
print(" ⚠️ Failed to process profile")
|
||||
print(
|
||||
f" ❌ Database error for {company_data['name']}: {e}"
|
||||
)
|
||||
elif isinstance(company_data, Exception):
|
||||
print(f" ❌ Exception occurred: {company_data}")
|
||||
|
||||
# Commit every 10 companies to avoid memory issues
|
||||
if save_to_db and db and (idx + 1) % 10 == 0:
|
||||
# Commit batch to database
|
||||
if save_to_db and db:
|
||||
try:
|
||||
db.commit()
|
||||
print(f"\n💾 Committed batch at row {idx + 1}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing row {idx + 1}: {e}")
|
||||
if db:
|
||||
print(f"💾 Committed batch {batch_start + 1}-{batch_end}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
continue
|
||||
|
||||
# Final commit
|
||||
if save_to_db and db:
|
||||
db.commit()
|
||||
print("\n✅ Final commit completed")
|
||||
print(f"❌ Failed to commit batch: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Fatal error in parse_companies: {e}")
|
||||
|
||||
+100
-43
@@ -2,13 +2,18 @@ import os
|
||||
from typing import List
|
||||
|
||||
from db.db import DATABASE_URL, get_db
|
||||
from db.models import InvestorTable
|
||||
from db.models import FundTable, InvestorTable
|
||||
from langchain import hub
|
||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from schemas.py_schemas import InvestorData, InvestorList
|
||||
from schemas.router_schemas import (
|
||||
CompanyMinimal,
|
||||
InvestmentResponse,
|
||||
PaginatedResponse,
|
||||
SectorMinimal,
|
||||
)
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# Connect to SQLite
|
||||
@@ -21,16 +26,16 @@ class QueryProcessor:
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="openai/gpt-4o-mini",
|
||||
model="x-ai/grok-4-fast",
|
||||
temperature=0,
|
||||
)
|
||||
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
||||
# Update system message to specifically request only investor IDs
|
||||
# Update system message to specifically request only fund IDs
|
||||
system_message_updated = (
|
||||
prompt_template.format(dialect="SQLite", top_k=5)
|
||||
+ "\n\nIMPORTANT: You must ONLY return the investor IDs (id field) that match the user's criteria. "
|
||||
+ "\n\nIMPORTANT: You must ONLY return the fund IDs (id field from the funds table) that match the user's criteria. "
|
||||
+ "Do NOT return any other information, explanations, or data. "
|
||||
+ "Your response should be ONLY a comma-separated list of numbers representing the investor IDs. "
|
||||
+ "Your response should be ONLY a comma-separated list of numbers representing the fund IDs. "
|
||||
+ "Example format: 1, 5, 12, 23"
|
||||
)
|
||||
self.agent = create_react_agent(
|
||||
@@ -39,9 +44,9 @@ class QueryProcessor:
|
||||
prompt=system_message_updated,
|
||||
)
|
||||
|
||||
def process_query(self, question: str) -> InvestorList:
|
||||
"""Process a query using the LLM and return investor data."""
|
||||
# Let the LLM handle all database interactions and filtering to get IDs
|
||||
def process_query(self, question: str) -> PaginatedResponse[InvestmentResponse]:
|
||||
"""Process a query using the LLM and return investment response data."""
|
||||
# Let the LLM handle all database interactions and filtering to get fund IDs
|
||||
response = self.agent.invoke(
|
||||
{"messages": [("user", question)]},
|
||||
)
|
||||
@@ -51,70 +56,122 @@ class QueryProcessor:
|
||||
response["messages"][-1].content if response.get("messages") else ""
|
||||
)
|
||||
|
||||
# Extract investor IDs from the AI response
|
||||
investor_ids = self._extract_investor_ids_from_response(ai_response)
|
||||
# Extract fund IDs from the AI response
|
||||
fund_ids = self._extract_fund_ids_from_response(ai_response)
|
||||
|
||||
# Fetch full investor data using the IDs
|
||||
return self._fetch_investors_by_ids(investor_ids)
|
||||
# Fetch full fund data with investor relationships using the IDs
|
||||
return self._fetch_funds_by_ids(fund_ids)
|
||||
|
||||
def _extract_investor_ids_from_response(self, ai_response: str) -> List[int]:
|
||||
"""Extract investor IDs from AI response."""
|
||||
def _extract_fund_ids_from_response(self, ai_response: str) -> List[int]:
|
||||
"""Extract fund IDs from AI response."""
|
||||
import re
|
||||
|
||||
investor_ids = []
|
||||
fund_ids = []
|
||||
try:
|
||||
# Try multiple patterns to extract IDs from the response
|
||||
# Pattern 1: Simple numbers (assuming they are IDs)
|
||||
numbers = re.findall(r"\b\d+\b", ai_response)
|
||||
investor_ids = [int(num) for num in numbers]
|
||||
fund_ids = [int(num) for num in numbers]
|
||||
|
||||
# Pattern 2: If response contains explicit ID references
|
||||
id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower())
|
||||
if id_matches:
|
||||
investor_ids = [int(id_str) for id_str in id_matches]
|
||||
fund_ids = [int(id_str) for id_str in id_matches]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error extracting IDs from response: {e}")
|
||||
return []
|
||||
|
||||
return investor_ids
|
||||
return fund_ids
|
||||
|
||||
def _fetch_investors_by_ids(self, investor_ids: List[int]) -> InvestorList:
|
||||
"""Fetch investors with all their relationships from the database using IDs."""
|
||||
if not investor_ids:
|
||||
return InvestorList(investors=[])
|
||||
def _fetch_funds_by_ids(
|
||||
self, fund_ids: List[int]
|
||||
) -> PaginatedResponse[InvestmentResponse]:
|
||||
"""Fetch funds with all their relationships from the database using fund IDs.
|
||||
Constructs response similar to read_investors but starting from funds."""
|
||||
if not fund_ids:
|
||||
return PaginatedResponse(
|
||||
items=[],
|
||||
total=0,
|
||||
page=1,
|
||||
page_size=len(fund_ids) if fund_ids else 10,
|
||||
total_pages=0,
|
||||
)
|
||||
|
||||
# Get database session
|
||||
db_session = next(get_db())
|
||||
|
||||
try:
|
||||
# Build query with all relationships loaded
|
||||
query = (
|
||||
db_session.query(InvestorTable)
|
||||
# Query funds with all necessary relationships loaded
|
||||
funds = (
|
||||
db_session.query(FundTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(InvestorTable.funds),
|
||||
selectinload(FundTable.investor).selectinload(
|
||||
InvestorTable.portfolio_companies
|
||||
),
|
||||
selectinload(FundTable.investor).selectinload(
|
||||
InvestorTable.team_members
|
||||
),
|
||||
selectinload(FundTable.investor).selectinload(
|
||||
InvestorTable.sectors
|
||||
),
|
||||
selectinload(FundTable.investment_stages),
|
||||
selectinload(FundTable.sectors),
|
||||
)
|
||||
.filter(InvestorTable.id.in_(investor_ids))
|
||||
.filter(FundTable.id.in_(fund_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
investors = query.all()
|
||||
# Transform to InvestmentResponse format (one row per fund)
|
||||
investment_responses = []
|
||||
for fund in funds:
|
||||
investor = fund.investor
|
||||
|
||||
# Transform to InvestorData format
|
||||
investor_data_list = []
|
||||
for investor in investors:
|
||||
investor_data = InvestorData(
|
||||
investor=investor,
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
funds=investor.funds,
|
||||
# Get top 3 portfolio companies (id and name only)
|
||||
portfolio_companies = [
|
||||
CompanyMinimal(id=company.id, name=company.name)
|
||||
for company in investor.portfolio_companies[:3]
|
||||
]
|
||||
|
||||
# Get stage focus as comma-separated string
|
||||
stage_focus = (
|
||||
", ".join([stage.name for stage in fund.investment_stages])
|
||||
if fund.investment_stages
|
||||
else None
|
||||
)
|
||||
investor_data_list.append(investor_data)
|
||||
|
||||
return InvestorList(investors=investor_data_list)
|
||||
# Get top 3 sectors from fund (id and name only)
|
||||
fund_sectors = [
|
||||
SectorMinimal(id=sector.id, name=sector.name)
|
||||
for sector in (fund.sectors[:3] if fund.sectors else [])
|
||||
]
|
||||
|
||||
investment_response = InvestmentResponse(
|
||||
id=investor.id,
|
||||
name=f"{investor.name} - {fund.fund_name}"
|
||||
if fund.fund_name
|
||||
else investor.name,
|
||||
aum=investor.aum,
|
||||
check_size_lower=fund.check_size_lower,
|
||||
check_size_upper=fund.check_size_upper,
|
||||
geographic_focus=fund.geographic_focus,
|
||||
stage_focus=stage_focus,
|
||||
portfolio_companies=portfolio_companies,
|
||||
sectors=fund_sectors,
|
||||
compatibility_score=1.0,
|
||||
)
|
||||
investment_responses.append(investment_response)
|
||||
|
||||
total_count = len(investment_responses)
|
||||
total_pages = 1 if total_count > 0 else 0
|
||||
|
||||
return PaginatedResponse(
|
||||
items=investment_responses,
|
||||
total=total_count,
|
||||
page=1,
|
||||
page_size=total_count,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
Binary file not shown.
Reference in New Issue
Block a user