Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| cefe89bb67 | |||
| 58722f1102 | |||
| be6fde9ba2 | |||
| 37e1ad01c4 | |||
| faf92a3b47 | |||
| 26a1197db0 | |||
| 84e3c7b72a | |||
| a9589e54f3 | |||
| d341cacb9a | |||
| c0fbbdd917 | |||
| 1f3f08e80d | |||
| cd7172ed9f | |||
| c199f5423a | |||
| a2b3ceedbe | |||
| 3842171549 |
+1
-1
@@ -10,7 +10,7 @@
|
||||
|
||||
*__pycache__
|
||||
|
||||
/*.db
|
||||
|
||||
*.cypython
|
||||
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
@@ -9,6 +10,10 @@ from sqlalchemy.orm import Session, sessionmaker
|
||||
Base = declarative_base()
|
||||
|
||||
# Database configuration
|
||||
# Use the preprocessor's database for consistency
|
||||
# Get absolute path to the preprocessor database
|
||||
# APP_DIR = Path(__file__).parent.parent
|
||||
# PREPROCESSOR_DB = APP_DIR.parent / "preprocessor" / "version_two.db"
|
||||
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./investors.db")
|
||||
|
||||
# Create engine
|
||||
@@ -38,6 +43,7 @@ def get_session_sync() -> Session:
|
||||
"""Get a database session for synchronous operations"""
|
||||
return SessionLocal()
|
||||
|
||||
|
||||
def get_db_session():
|
||||
"""Get a database session for direct use."""
|
||||
return SessionLocal()
|
||||
|
||||
+117
-10
@@ -2,7 +2,7 @@ import enum
|
||||
|
||||
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table, Text, func
|
||||
from sqlalchemy.orm import declarative_mixin, relationship
|
||||
from sqlalchemy.types import Enum
|
||||
from sqlalchemy.types import JSON, Enum
|
||||
|
||||
from db.db import Base
|
||||
|
||||
@@ -70,6 +70,22 @@ project_company_association = Table(
|
||||
Column("company_id", Integer, ForeignKey("companies.id")),
|
||||
)
|
||||
|
||||
# Association table for fund-stage many-to-many
|
||||
fund_investment_stages_association = Table(
|
||||
"fund_investment_stages",
|
||||
Base.metadata,
|
||||
Column("fund_id", Integer, ForeignKey("funds.id")),
|
||||
Column("stage_id", Integer, ForeignKey("investment_stages.id")),
|
||||
)
|
||||
|
||||
# Association table for fund-sector many-to-many
|
||||
fund_sectors_association = Table(
|
||||
"fund_sectors",
|
||||
Base.metadata,
|
||||
Column("fund_id", Integer, ForeignKey("funds.id")),
|
||||
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
||||
)
|
||||
|
||||
|
||||
class InvestorTable(Base, TimestampMixin):
|
||||
__tablename__ = "investors"
|
||||
@@ -77,14 +93,47 @@ class InvestorTable(Base, TimestampMixin):
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
aum = Column(Integer, nullable=True) # Assets Under Management
|
||||
check_size_lower = Column(Integer, nullable=True) # Lower bound
|
||||
check_size_upper = Column(Integer, nullable=True) # Upper bound
|
||||
|
||||
# Basic investor info
|
||||
website = Column(String, nullable=True)
|
||||
headquarters = Column(String, nullable=True)
|
||||
|
||||
# AUM fields
|
||||
aum = Column(Integer, nullable=True) # Store as integer for numerical filtering
|
||||
aum_as_of_date = Column(String, nullable=True)
|
||||
aum_source_url = Column(String, nullable=True)
|
||||
|
||||
# Check size (deprecated in favor of fund-level data, but keeping for backward compatibility)
|
||||
check_size_lower = Column(Integer, nullable=True)
|
||||
check_size_upper = Column(Integer, nullable=True)
|
||||
|
||||
# Geographic focus (deprecated in favor of fund-level, but keeping for backward compatibility)
|
||||
geographic_focus = Column(String, nullable=True)
|
||||
stage_focus = Column(Enum(InvestmentStage), nullable=True)
|
||||
|
||||
# Investment thesis and portfolio
|
||||
investment_thesis = Column(JSON, nullable=True) # Array of thesis statements
|
||||
portfolio_highlights = Column(
|
||||
JSON, nullable=True
|
||||
) # Array of portfolio company names
|
||||
linked_documents = Column(JSON, nullable=True) # Array of document URLs
|
||||
|
||||
# Research metadata
|
||||
researcher_notes = Column(Text, nullable=True)
|
||||
missing_important_fields = Column(
|
||||
JSON, nullable=True
|
||||
) # Array of missing field names
|
||||
sources = Column(JSON, nullable=True) # JSON object with source URLs
|
||||
|
||||
# Portfolio info
|
||||
number_of_investments = Column(Integer, default=0, nullable=True)
|
||||
|
||||
team_members = relationship("InvestorMember", back_populates="investor")
|
||||
# Relationships
|
||||
team_members = relationship(
|
||||
"InvestorMember", back_populates="investor", cascade="all, delete-orphan"
|
||||
)
|
||||
funds = relationship(
|
||||
"FundTable", back_populates="investor", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Relationship to portfolio companies
|
||||
portfolio_companies = relationship(
|
||||
@@ -111,12 +160,51 @@ class InvestorMember(Base, TimestampMixin):
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
role = Column(String, nullable=True)
|
||||
title = Column(String, nullable=True) # Alternative to role
|
||||
email = Column(String, nullable=True)
|
||||
source_url = Column(String, nullable=True) # URL where member info was found
|
||||
|
||||
investor_id = Column(Integer, ForeignKey("investors.id"))
|
||||
investor = relationship("InvestorTable", back_populates="team_members")
|
||||
|
||||
|
||||
class FundTable(Base, TimestampMixin):
|
||||
__tablename__ = "funds"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
investor_id = Column(Integer, ForeignKey("investors.id"), nullable=False)
|
||||
|
||||
# Fund details
|
||||
fund_name = Column(String, nullable=True)
|
||||
fund_size = Column(
|
||||
Integer, nullable=True
|
||||
) # Store as integer for numerical filtering
|
||||
fund_size_source_url = Column(String, nullable=True)
|
||||
|
||||
# Check size range (parsed from estimated_investment_size by LLM)
|
||||
check_size_lower = Column(Integer, nullable=True)
|
||||
check_size_upper = Column(Integer, nullable=True)
|
||||
|
||||
source_url = Column(String, nullable=True)
|
||||
source_provider = Column(String, nullable=True) # e.g., "Perplexity"
|
||||
|
||||
# Geographic focus as simple string
|
||||
geographic_focus = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
investor = relationship("InvestorTable", back_populates="funds")
|
||||
investment_stages = relationship(
|
||||
"InvestmentStageTable",
|
||||
secondary=fund_investment_stages_association,
|
||||
back_populates="funds",
|
||||
)
|
||||
sectors = relationship(
|
||||
"SectorTable",
|
||||
secondary=fund_sectors_association,
|
||||
back_populates="funds",
|
||||
)
|
||||
|
||||
|
||||
class CompanyTable(Base, TimestampMixin):
|
||||
__tablename__ = "companies"
|
||||
|
||||
@@ -128,7 +216,9 @@ class CompanyTable(Base, TimestampMixin):
|
||||
founded_year = Column(Integer, nullable=True)
|
||||
website = Column(String, nullable=True)
|
||||
|
||||
members = relationship("CompanyMember", back_populates="company")
|
||||
members = relationship(
|
||||
"CompanyMember", back_populates="company", cascade="all, delete-orphan"
|
||||
)
|
||||
# Relationship back to investors
|
||||
investors = relationship(
|
||||
"InvestorTable",
|
||||
@@ -158,26 +248,43 @@ class CompanyMember(Base, TimestampMixin):
|
||||
company = relationship("CompanyTable", back_populates="members")
|
||||
|
||||
|
||||
class InvestmentStageTable(Base, TimestampMixin):
|
||||
__tablename__ = "investment_stages"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
|
||||
# Relationships
|
||||
funds = relationship(
|
||||
"FundTable",
|
||||
secondary=fund_investment_stages_association,
|
||||
back_populates="investment_stages",
|
||||
)
|
||||
|
||||
|
||||
class SectorTable(Base, TimestampMixin):
|
||||
__tablename__ = "sectors"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
|
||||
# Add relationship back to investors
|
||||
# Relationships
|
||||
investors = relationship(
|
||||
"InvestorTable",
|
||||
secondary=investor_sector_association,
|
||||
back_populates="sectors",
|
||||
)
|
||||
|
||||
companies = relationship(
|
||||
"CompanyTable", secondary=company_sector_association, back_populates="sectors"
|
||||
)
|
||||
|
||||
projects = relationship(
|
||||
"ProjectTable", secondary=project_sector_association, back_populates="sector"
|
||||
)
|
||||
funds = relationship(
|
||||
"FundTable",
|
||||
secondary=fund_sectors_association,
|
||||
back_populates="sectors",
|
||||
)
|
||||
|
||||
|
||||
class ProjectTable(Base, TimestampMixin):
|
||||
|
||||
+36
-7
@@ -6,7 +6,7 @@ from dotenv import load_dotenv
|
||||
from fastapi import FastAPI, File, Form, UploadFile
|
||||
from pydantic import BaseModel
|
||||
from routers import companies, investors, projects
|
||||
from schemas.router_schemas import InvestorList
|
||||
from schemas.router_schemas import InvestmentResponse, PaginatedResponse
|
||||
from services.llm_parser import InvestorProcessor
|
||||
from services.querying import QueryProcessor
|
||||
|
||||
@@ -44,6 +44,27 @@ def health():
|
||||
async def parse_csv(
|
||||
db: db_dependency, file: UploadFile = File(...), is_investor: int = Form(...)
|
||||
):
|
||||
"""
|
||||
Parse and import CSV data into the database.
|
||||
|
||||
**For investors:**
|
||||
- Expected columns: Name, Website, Final Investor Profile, Final Profile sourcing
|
||||
- Manually parses JSON profiles for efficiency
|
||||
- Uses LLM only for currency conversion to USD
|
||||
- Handles AUM, fund sizes, and check sizes as integers
|
||||
|
||||
**For companies:**
|
||||
- Expected columns: Name, Website, Investor, Final Investor Profile (company profile)
|
||||
- 100% manual JSON parsing - no LLM needed
|
||||
- Extracts company details, executives, investors, and client categories
|
||||
- Automatically links companies to investors in database
|
||||
|
||||
**Benefits:**
|
||||
- Fast processing (5-10s per record)
|
||||
- Low cost (minimal or no LLM usage)
|
||||
- Accurate data extraction
|
||||
- Automatic database persistence
|
||||
"""
|
||||
# Read uploaded CSV with pandas
|
||||
content = await file.read()
|
||||
df = pd.read_csv(io.StringIO(content.decode("utf-8")))
|
||||
@@ -52,19 +73,27 @@ async def parse_csv(
|
||||
processor = InvestorProcessor()
|
||||
|
||||
if is_investor == 1:
|
||||
results = await processor.parse_investors(df)
|
||||
# Manual parser with LLM currency conversion
|
||||
results = await processor.parse_investors(df, save_to_db=True)
|
||||
# Results are already dicts from the new parser
|
||||
return results
|
||||
else:
|
||||
results = await processor.parse_companies(df)
|
||||
|
||||
# Convert Pydantic objects to dictionaries
|
||||
return [r.model_dump() for r in results]
|
||||
# Manual parser for companies (no LLM needed)
|
||||
results = await processor.parse_companies(df, save_to_db=True)
|
||||
# Results are already dicts from the new parser
|
||||
return results
|
||||
|
||||
|
||||
@app.post("/query", response_model=InvestorList, tags=["Querying"])
|
||||
@app.post(
|
||||
"/query", response_model=PaginatedResponse[InvestmentResponse], tags=["Querying"]
|
||||
)
|
||||
async def query_investors(request: QueryRequest):
|
||||
"""
|
||||
Query investors using natural language.
|
||||
|
||||
Returns fund-level matches (one row per fund) with investor details.
|
||||
This ensures only relevant funds are included in the response.
|
||||
|
||||
Supports queries like:
|
||||
- "Show me seed stage investors"
|
||||
- "Find fintech investors in Silicon Valley"
|
||||
|
||||
Binary file not shown.
Binary file not shown.
+53
-14
@@ -1,10 +1,10 @@
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from db.db import get_db
|
||||
from db.models import CompanyTable, InvestorTable
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from schemas.router_schemas import CompanyData
|
||||
from schemas.router_schemas import CompanyData, PaginatedResponse
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
router = APIRouter(tags=["Company Routes"])
|
||||
@@ -29,20 +29,34 @@ class CompanyUpdate(BaseModel):
|
||||
website: Optional[str] = None
|
||||
|
||||
|
||||
@router.get("/companies", response_model=List[CompanyData])
|
||||
def read_companies(db: Session = Depends(get_db)):
|
||||
"""Get all companies with their investor relationships"""
|
||||
@router.get("/companies", response_model=PaginatedResponse[CompanyData])
|
||||
def read_companies(
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all companies with their investor relationships (paginated)"""
|
||||
# Calculate offset
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Get total count
|
||||
total_count = (
|
||||
db.query(CompanyTable)
|
||||
.filter(CompanyTable.name.isnot(None), CompanyTable.description.isnot(None))
|
||||
.count()
|
||||
)
|
||||
|
||||
# Get paginated results
|
||||
companies = (
|
||||
db.query(CompanyTable)
|
||||
.filter(
|
||||
CompanyTable.name.isnot(None),
|
||||
CompanyTable.description.isnot(None)
|
||||
)
|
||||
.filter(CompanyTable.name.isnot(None), CompanyTable.description.isnot(None))
|
||||
.options(
|
||||
selectinload(CompanyTable.investors),
|
||||
selectinload(CompanyTable.members),
|
||||
selectinload(CompanyTable.sectors),
|
||||
)
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -57,10 +71,19 @@ def read_companies(db: Session = Depends(get_db)):
|
||||
)
|
||||
company_data_list.append(company_data)
|
||||
|
||||
return company_data_list
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=company_data_list,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/companies/filter", response_model=List[CompanyData])
|
||||
@router.get("/companies/filter", response_model=PaginatedResponse[CompanyData])
|
||||
def filter_companies(
|
||||
industry: Optional[str] = Query(
|
||||
None, description="Filter by industry (partial match)"
|
||||
@@ -76,9 +99,11 @@ def filter_companies(
|
||||
investor_name: Optional[str] = Query(
|
||||
None, description="Filter by investor name (partial match)"
|
||||
),
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Filter companies based on various criteria"""
|
||||
"""Filter companies based on various criteria (paginated)"""
|
||||
|
||||
# Start with base query
|
||||
query = db.query(CompanyTable).options(
|
||||
@@ -112,7 +137,12 @@ def filter_companies(
|
||||
InvestorTable.name.ilike(f"%{investor_name}%")
|
||||
)
|
||||
|
||||
companies = query.all()
|
||||
# Get total count before pagination
|
||||
total_count = query.count()
|
||||
|
||||
# Calculate offset and apply pagination
|
||||
offset = (page - 1) * page_size
|
||||
companies = query.offset(offset).limit(page_size).all()
|
||||
|
||||
# Transform to CompanyData format
|
||||
company_data_list = []
|
||||
@@ -125,7 +155,16 @@ def filter_companies(
|
||||
)
|
||||
company_data_list.append(company_data)
|
||||
|
||||
return company_data_list
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=company_data_list,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/companies/{company_id}", response_model=CompanyData)
|
||||
|
||||
+350
-76
@@ -1,11 +1,17 @@
|
||||
from typing import List, Optional
|
||||
from typing import Optional
|
||||
|
||||
from db.db import get_db
|
||||
from db.models import InvestorTable, SectorTable
|
||||
from db.models import FundTable, InvestorTable, SectorTable
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from pydantic import BaseModel
|
||||
from schemas.router_schemas import InvestmentStage, InvestorData
|
||||
from services.querying import QueryProcessor
|
||||
from schemas.router_schemas import (
|
||||
CompanyMinimal,
|
||||
InvestmentResponse,
|
||||
InvestmentStage,
|
||||
InvestorData,
|
||||
PaginatedResponse,
|
||||
SectorMinimal,
|
||||
)
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
router = APIRouter(tags=["Investor Routes"])
|
||||
@@ -15,53 +21,127 @@ router = APIRouter(tags=["Investor Routes"])
|
||||
class InvestorCreate(BaseModel):
|
||||
name: str
|
||||
description: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
headquarters: Optional[str] = None
|
||||
aum: int
|
||||
check_size_lower: int
|
||||
check_size_upper: int
|
||||
geographic_focus: str
|
||||
stage_focus: InvestmentStage
|
||||
number_of_investments: int = 0
|
||||
|
||||
|
||||
class InvestorUpdate(BaseModel):
|
||||
name: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
website: Optional[str] = None
|
||||
headquarters: Optional[str] = None
|
||||
aum: Optional[int] = None
|
||||
check_size_lower: Optional[int] = None
|
||||
check_size_upper: Optional[int] = None
|
||||
geographic_focus: Optional[str] = None
|
||||
stage_focus: Optional[InvestmentStage] = None
|
||||
number_of_investments: Optional[int] = None
|
||||
|
||||
|
||||
@router.get("/investors", response_model=List[InvestorData])
|
||||
def read_investors(db: Session = Depends(get_db)):
|
||||
"""Get all investors with their related data"""
|
||||
@router.get("/investors", response_model=PaginatedResponse[InvestmentResponse])
|
||||
def read_investors(
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all investors with their funds as separate entries (paginated)
|
||||
|
||||
Each investor-fund combination is returned as a separate row.
|
||||
An investor with 3 funds will appear as 3 entries.
|
||||
"""
|
||||
# Calculate offset
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Get total count
|
||||
total_count = db.query(InvestorTable).count()
|
||||
|
||||
# Get paginated results
|
||||
investors = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(InvestorTable.funds),
|
||||
)
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Transform InvestorTable objects to InvestorData format
|
||||
investor_data_list = []
|
||||
# Transform to InvestmentResponse format (one row per investor-fund combination)
|
||||
investment_responses = []
|
||||
for investor in investors:
|
||||
investor_data = InvestorData(
|
||||
investor=investor, # This maps to InvestorSchema
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
)
|
||||
investor_data_list.append(investor_data)
|
||||
# Get top 3 portfolio companies (id and name only)
|
||||
portfolio_companies = [
|
||||
CompanyMinimal(id=company.id, name=company.name)
|
||||
for company in investor.portfolio_companies[:3]
|
||||
]
|
||||
|
||||
return investor_data_list
|
||||
# If investor has funds, create one entry per fund
|
||||
if investor.funds:
|
||||
for fund in investor.funds:
|
||||
# Get stage focus as comma-separated string
|
||||
stage_focus = (
|
||||
", ".join([stage.name for stage in fund.investment_stages])
|
||||
if fund.investment_stages
|
||||
else None
|
||||
)
|
||||
|
||||
# Get top 3 sectors from fund (id and name only)
|
||||
fund_sectors = [
|
||||
SectorMinimal(id=sector.id, name=sector.name)
|
||||
for sector in (fund.sectors[:3] if fund.sectors else [])
|
||||
]
|
||||
|
||||
investment_response = InvestmentResponse(
|
||||
id=investor.id,
|
||||
name=f"{investor.name} - {fund.fund_name}"
|
||||
if fund.fund_name
|
||||
else investor.name,
|
||||
aum=investor.aum,
|
||||
check_size_lower=fund.check_size_lower,
|
||||
check_size_upper=fund.check_size_upper,
|
||||
geographic_focus=fund.geographic_focus,
|
||||
stage_focus=stage_focus,
|
||||
portfolio_companies=portfolio_companies,
|
||||
sectors=fund_sectors,
|
||||
compatibility_score=1.0,
|
||||
)
|
||||
investment_responses.append(investment_response)
|
||||
else:
|
||||
# If no funds, create one entry with null fund fields
|
||||
investment_response = InvestmentResponse(
|
||||
id=investor.id,
|
||||
name=investor.name,
|
||||
aum=investor.aum,
|
||||
check_size_lower=None,
|
||||
check_size_upper=None,
|
||||
geographic_focus=None,
|
||||
stage_focus=None,
|
||||
portfolio_companies=portfolio_companies,
|
||||
sectors=[],
|
||||
compatibility_score=1.0,
|
||||
)
|
||||
investment_responses.append(investment_response)
|
||||
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=investment_responses,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/investors/filter", response_model=List[InvestorData])
|
||||
@router.get("/investors/filter", response_model=PaginatedResponse[InvestmentResponse])
|
||||
def filter_investors(
|
||||
stage: Optional[InvestmentStage] = Query(
|
||||
None, description="Filter by investment stage"
|
||||
@@ -74,67 +154,121 @@ def filter_investors(
|
||||
sector: Optional[str] = Query(None, description="Sector name (partial match)"),
|
||||
min_aum: Optional[int] = Query(None, description="Minimum AUM"),
|
||||
max_aum: Optional[int] = Query(None, description="Maximum AUM"),
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Filter investors based on various criteria"""
|
||||
"""Filter investors based on various criteria (paginated)
|
||||
|
||||
# Start with base query
|
||||
query = db.query(InvestorTable).options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
Returns investor-fund combinations as separate rows.
|
||||
Queries the funds table to find matching funds.
|
||||
"""
|
||||
|
||||
# Start with base query on funds table
|
||||
query = db.query(FundTable).options(
|
||||
selectinload(FundTable.investor).selectinload(
|
||||
InvestorTable.portfolio_companies
|
||||
),
|
||||
selectinload(FundTable.investor).selectinload(InvestorTable.team_members),
|
||||
selectinload(FundTable.investor).selectinload(InvestorTable.sectors),
|
||||
selectinload(FundTable.investment_stages),
|
||||
selectinload(FundTable.sectors),
|
||||
)
|
||||
|
||||
# Apply filters
|
||||
if stage:
|
||||
query = query.filter(InvestorTable.stage_focus == stage)
|
||||
|
||||
# Apply filters at fund level
|
||||
if min_check_size is not None:
|
||||
query = query.filter(InvestorTable.check_size_lower >= min_check_size)
|
||||
query = query.filter(FundTable.check_size_lower >= min_check_size)
|
||||
|
||||
if max_check_size is not None:
|
||||
query = query.filter(InvestorTable.check_size_upper <= max_check_size)
|
||||
query = query.filter(FundTable.check_size_upper <= max_check_size)
|
||||
|
||||
if geography:
|
||||
query = query.filter(InvestorTable.geographic_focus.ilike(f"%{geography}%"))
|
||||
query = query.filter(FundTable.geographic_focus.ilike(f"%{geography}%"))
|
||||
|
||||
# Apply filters at investor level (through relationship)
|
||||
if min_aum is not None:
|
||||
query = query.filter(InvestorTable.aum >= min_aum)
|
||||
query = query.join(FundTable.investor).filter(InvestorTable.aum >= min_aum)
|
||||
|
||||
if max_aum is not None:
|
||||
if min_aum is None: # Only join if not already joined
|
||||
query = query.join(FundTable.investor)
|
||||
query = query.filter(InvestorTable.aum <= max_aum)
|
||||
|
||||
# Filter by sector if provided
|
||||
# Filter by sector if provided (at fund level)
|
||||
if sector:
|
||||
query = query.join(InvestorTable.sectors).filter(
|
||||
query = query.join(FundTable.sectors).filter(
|
||||
SectorTable.name.ilike(f"%{sector}%")
|
||||
)
|
||||
|
||||
investors = query.all()
|
||||
# Get total count before pagination
|
||||
total_count = query.count()
|
||||
|
||||
# Transform to InvestorData format
|
||||
investor_data_list = []
|
||||
for investor in investors:
|
||||
investor_data = InvestorData(
|
||||
investor=investor,
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
# Calculate offset and apply pagination
|
||||
offset = (page - 1) * page_size
|
||||
funds = query.offset(offset).limit(page_size).all()
|
||||
|
||||
# Transform to InvestmentResponse format (one row per fund)
|
||||
investment_responses = []
|
||||
for fund in funds:
|
||||
investor = fund.investor
|
||||
|
||||
# Get top 3 portfolio companies (id and name only)
|
||||
portfolio_companies = [
|
||||
CompanyMinimal(id=company.id, name=company.name)
|
||||
for company in investor.portfolio_companies[:3]
|
||||
]
|
||||
|
||||
# Get stage focus as comma-separated string
|
||||
stage_focus = (
|
||||
", ".join([stage.name for stage in fund.investment_stages])
|
||||
if fund.investment_stages
|
||||
else None
|
||||
)
|
||||
investor_data_list.append(investor_data)
|
||||
|
||||
return investor_data_list
|
||||
# Get top 3 sectors from fund (id and name only)
|
||||
fund_sectors = [
|
||||
SectorMinimal(id=sector.id, name=sector.name)
|
||||
for sector in (fund.sectors[:3] if fund.sectors else [])
|
||||
]
|
||||
|
||||
investment_response = InvestmentResponse(
|
||||
id=investor.id,
|
||||
name=f"{investor.name} - {fund.fund_name}"
|
||||
if fund.fund_name
|
||||
else investor.name,
|
||||
aum=investor.aum,
|
||||
check_size_lower=fund.check_size_lower,
|
||||
check_size_upper=fund.check_size_upper,
|
||||
geographic_focus=fund.geographic_focus,
|
||||
stage_focus=stage_focus,
|
||||
portfolio_companies=portfolio_companies,
|
||||
sectors=fund_sectors,
|
||||
compatibility_score=1.0,
|
||||
)
|
||||
investment_responses.append(investment_response)
|
||||
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=investment_responses,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/investors/{investor_id}", response_model=InvestorData)
|
||||
def read_investor(investor_id: int, db: Session = Depends(get_db)):
|
||||
"""Get a specific investor by ID"""
|
||||
"""Get a specific investor by ID with all their funds"""
|
||||
investor = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(InvestorTable.funds),
|
||||
)
|
||||
.filter(InvestorTable.id == investor_id)
|
||||
.first()
|
||||
@@ -143,12 +277,13 @@ def read_investor(investor_id: int, db: Session = Depends(get_db)):
|
||||
if not investor:
|
||||
raise HTTPException(status_code=404, detail="Investor not found")
|
||||
|
||||
# Transform to InvestorData format
|
||||
# Transform to InvestorData format (includes funds array)
|
||||
return InvestorData(
|
||||
investor=investor,
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
funds=investor.funds,
|
||||
)
|
||||
|
||||
|
||||
@@ -167,6 +302,7 @@ def create_investor(investor: InvestorCreate, db: Session = Depends(get_db)):
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(InvestorTable.funds),
|
||||
)
|
||||
.filter(InvestorTable.id == db_investor.id)
|
||||
.first()
|
||||
@@ -178,6 +314,7 @@ def create_investor(investor: InvestorCreate, db: Session = Depends(get_db)):
|
||||
portfolio_companies=investor_with_relations.portfolio_companies,
|
||||
team_members=investor_with_relations.team_members,
|
||||
sectors=investor_with_relations.sectors,
|
||||
funds=investor_with_relations.funds,
|
||||
)
|
||||
|
||||
|
||||
@@ -206,6 +343,7 @@ def update_investor(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(InvestorTable.funds),
|
||||
)
|
||||
.filter(InvestorTable.id == investor_id)
|
||||
.first()
|
||||
@@ -217,6 +355,7 @@ def update_investor(
|
||||
portfolio_companies=investor_with_relations.portfolio_companies,
|
||||
team_members=investor_with_relations.team_members,
|
||||
sectors=investor_with_relations.sectors,
|
||||
funds=investor_with_relations.funds,
|
||||
)
|
||||
|
||||
|
||||
@@ -234,17 +373,32 @@ def delete_investor(investor_id: int, db: Session = Depends(get_db)):
|
||||
return {"message": "Investor deleted successfully"}
|
||||
|
||||
|
||||
@router.get("/investors/{investor_id}/similar", response_model=List[InvestorData])
|
||||
def find_similar_investors(investor_id: int, db: Session = Depends(get_db)):
|
||||
"""Find investors similar to a given investor using AI agent"""
|
||||
@router.get(
|
||||
"/investors/{investor_id}/similar",
|
||||
response_model=PaginatedResponse[InvestmentResponse],
|
||||
)
|
||||
def find_similar_investors(
|
||||
investor_id: int,
|
||||
limit: int = Query(10, description="Maximum number of similar investors to return"),
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Find investors similar to a given investor based on characteristics (paginated)
|
||||
|
||||
# First, get the target investor to build the AI query
|
||||
Returns investor-fund combinations as separate rows.
|
||||
Queries the funds table to find matching funds.
|
||||
"""
|
||||
|
||||
# Get the target investor to get their funds for comparison
|
||||
target_investor = (
|
||||
db.query(InvestorTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(InvestorTable.funds).selectinload(FundTable.investment_stages),
|
||||
selectinload(InvestorTable.funds).selectinload(FundTable.sectors),
|
||||
)
|
||||
.filter(InvestorTable.id == investor_id)
|
||||
.first()
|
||||
@@ -253,29 +407,149 @@ def find_similar_investors(investor_id: int, db: Session = Depends(get_db)):
|
||||
if not target_investor:
|
||||
raise HTTPException(status_code=404, detail="Investor not found")
|
||||
|
||||
# Build a descriptive query for the AI agent based on target investor characteristics
|
||||
target_sectors = [sector.name for sector in target_investor.sectors]
|
||||
sectors_text = ", ".join(target_sectors) if target_sectors else "any sector"
|
||||
# Get target investor's sector IDs for comparison (from their funds)
|
||||
target_sector_ids = set()
|
||||
target_stage_ids = set()
|
||||
target_check_ranges = []
|
||||
target_geographies = []
|
||||
|
||||
ai_query = f"""
|
||||
Find investors similar to investor ID {investor_id} with the following characteristics:
|
||||
- Stage focus: {target_investor.stage_focus.value if target_investor.stage_focus else "any stage"}
|
||||
- Geographic focus: {target_investor.geographic_focus or "any geography"}
|
||||
- Check size range: ${target_investor.check_size_lower or 0:,} to ${target_investor.check_size_upper or 0:,}
|
||||
- AUM (Assets Under Management): ${target_investor.aum or 0:,}
|
||||
- Sectors: {sectors_text}
|
||||
|
||||
Find investors with similar characteristics but exclude investor ID {investor_id}.
|
||||
Look for investors with:
|
||||
- Same or similar stage focus
|
||||
- Similar geographic regions
|
||||
- Overlapping check size ranges
|
||||
- Similar AUM levels (within a reasonable range)
|
||||
- Common sector interests
|
||||
"""
|
||||
for fund in target_investor.funds:
|
||||
if fund.sectors:
|
||||
target_sector_ids.update({sector.id for sector in fund.sectors})
|
||||
if fund.investment_stages:
|
||||
target_stage_ids.update({stage.id for stage in fund.investment_stages})
|
||||
if fund.check_size_lower and fund.check_size_upper:
|
||||
target_check_ranges.append((fund.check_size_lower, fund.check_size_upper))
|
||||
if fund.geographic_focus:
|
||||
target_geographies.append(fund.geographic_focus.lower())
|
||||
|
||||
# Use the AI agent to find similar investors
|
||||
query_processor = QueryProcessor()
|
||||
result = query_processor.process_query(ai_query)
|
||||
# Query all funds from other investors
|
||||
candidate_funds = (
|
||||
db.query(FundTable)
|
||||
.options(
|
||||
selectinload(FundTable.investor).selectinload(
|
||||
InvestorTable.portfolio_companies
|
||||
),
|
||||
selectinload(FundTable.investor).selectinload(InvestorTable.team_members),
|
||||
selectinload(FundTable.investor).selectinload(InvestorTable.sectors),
|
||||
selectinload(FundTable.investment_stages),
|
||||
selectinload(FundTable.sectors),
|
||||
)
|
||||
.join(FundTable.investor)
|
||||
.filter(InvestorTable.id != investor_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
return result.investors
|
||||
# Calculate similarity scores for each fund
|
||||
scored_funds = []
|
||||
for fund in candidate_funds:
|
||||
score = 0
|
||||
|
||||
# Geographic focus match (20 points for exact, 10 for partial)
|
||||
if fund.geographic_focus and target_geographies:
|
||||
fund_geo_lower = fund.geographic_focus.lower()
|
||||
for target_geo in target_geographies:
|
||||
if fund_geo_lower == target_geo:
|
||||
score += 20
|
||||
break
|
||||
elif fund_geo_lower in target_geo or target_geo in fund_geo_lower:
|
||||
score += 10
|
||||
break
|
||||
|
||||
# Check size overlap (20 points max)
|
||||
if fund.check_size_lower and fund.check_size_upper and target_check_ranges:
|
||||
max_overlap_score = 0
|
||||
for target_lower, target_upper in target_check_ranges:
|
||||
overlap_start = max(fund.check_size_lower, target_lower)
|
||||
overlap_end = min(fund.check_size_upper, target_upper)
|
||||
if overlap_end > overlap_start:
|
||||
overlap = overlap_end - overlap_start
|
||||
target_range = target_upper - target_lower
|
||||
overlap_ratio = overlap / target_range if target_range > 0 else 0
|
||||
max_overlap_score = max(max_overlap_score, int(20 * overlap_ratio))
|
||||
score += max_overlap_score
|
||||
|
||||
# AUM similarity (15 points max)
|
||||
if fund.investor.aum and target_investor.aum:
|
||||
aum_diff = abs(fund.investor.aum - target_investor.aum)
|
||||
max_aum = max(fund.investor.aum, target_investor.aum)
|
||||
similarity_ratio = 1 - (aum_diff / max_aum) if max_aum > 0 else 0
|
||||
score += int(15 * similarity_ratio)
|
||||
|
||||
# Sector overlap (30 points max)
|
||||
if fund.sectors and target_sector_ids:
|
||||
fund_sector_ids = {sector.id for sector in fund.sectors}
|
||||
common_sectors = target_sector_ids.intersection(fund_sector_ids)
|
||||
overlap_ratio = len(common_sectors) / len(target_sector_ids)
|
||||
score += int(30 * overlap_ratio)
|
||||
|
||||
# Investment stage match (15 points max)
|
||||
if fund.investment_stages and target_stage_ids:
|
||||
fund_stage_ids = {stage.id for stage in fund.investment_stages}
|
||||
common_stages = target_stage_ids.intersection(fund_stage_ids)
|
||||
overlap_ratio = len(common_stages) / len(target_stage_ids)
|
||||
score += int(15 * overlap_ratio)
|
||||
|
||||
if score > 0: # Only include funds with some similarity
|
||||
scored_funds.append((score, fund))
|
||||
|
||||
# Sort by score (descending) and take top N based on limit
|
||||
scored_funds.sort(key=lambda x: x[0], reverse=True)
|
||||
top_similar = scored_funds[:limit]
|
||||
|
||||
# Apply pagination to the top similar funds
|
||||
total_count = len(top_similar)
|
||||
offset = (page - 1) * page_size
|
||||
paginated_similar = top_similar[offset : offset + page_size]
|
||||
similar_funds = [fund for score, fund in paginated_similar]
|
||||
|
||||
# Transform to InvestmentResponse format (one row per fund)
|
||||
investment_responses = []
|
||||
for fund in similar_funds:
|
||||
investor = fund.investor
|
||||
|
||||
# Get top 3 portfolio companies (id and name only)
|
||||
portfolio_companies = [
|
||||
CompanyMinimal(id=company.id, name=company.name)
|
||||
for company in investor.portfolio_companies[:3]
|
||||
]
|
||||
|
||||
# Get stage focus as comma-separated string
|
||||
stage_focus = (
|
||||
", ".join([stage.name for stage in fund.investment_stages])
|
||||
if fund.investment_stages
|
||||
else None
|
||||
)
|
||||
|
||||
# Get top 3 sectors from fund (id and name only)
|
||||
fund_sectors = [
|
||||
SectorMinimal(id=sector.id, name=sector.name)
|
||||
for sector in (fund.sectors[:3] if fund.sectors else [])
|
||||
]
|
||||
|
||||
investment_response = InvestmentResponse(
|
||||
id=investor.id,
|
||||
name=f"{investor.name} - {fund.fund_name}"
|
||||
if fund.fund_name
|
||||
else investor.name,
|
||||
aum=investor.aum,
|
||||
check_size_lower=fund.check_size_lower,
|
||||
check_size_upper=fund.check_size_upper,
|
||||
geographic_focus=fund.geographic_focus,
|
||||
stage_focus=stage_focus,
|
||||
portfolio_companies=portfolio_companies,
|
||||
sectors=fund_sectors,
|
||||
compatibility_score=1.0,
|
||||
)
|
||||
investment_responses.append(investment_response)
|
||||
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=investment_responses,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
+47
-8
@@ -14,14 +14,26 @@ from schemas.project_schemas import (
|
||||
ProjectData,
|
||||
ProjectUpdate,
|
||||
)
|
||||
from schemas.router_schemas import PaginatedResponse
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
|
||||
router = APIRouter(tags=["Project Routes"])
|
||||
|
||||
|
||||
@router.get("/projects", response_model=List[ProjectData])
|
||||
def read_projects(db: Session = Depends(get_db)):
|
||||
"""Get all projects with their related data"""
|
||||
@router.get("/projects", response_model=PaginatedResponse[ProjectData])
|
||||
def read_projects(
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Get all projects with their related data (paginated)"""
|
||||
# Calculate offset
|
||||
offset = (page - 1) * page_size
|
||||
|
||||
# Get total count
|
||||
total_count = db.query(ProjectTable).count()
|
||||
|
||||
# Get paginated results
|
||||
projects = (
|
||||
db.query(ProjectTable)
|
||||
.options(
|
||||
@@ -29,6 +41,8 @@ def read_projects(db: Session = Depends(get_db)):
|
||||
selectinload(ProjectTable.investors),
|
||||
selectinload(ProjectTable.companies),
|
||||
)
|
||||
.offset(offset)
|
||||
.limit(page_size)
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -43,7 +57,16 @@ def read_projects(db: Session = Depends(get_db)):
|
||||
)
|
||||
project_data_list.append(project_data)
|
||||
|
||||
return project_data_list
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=project_data_list,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/projects/{project_id}", response_model=ProjectData)
|
||||
@@ -151,7 +174,7 @@ def delete_project(project_id: int, db: Session = Depends(get_db)):
|
||||
return {"message": "Project deleted successfully"}
|
||||
|
||||
|
||||
@router.get("/projects/filter", response_model=List[ProjectData])
|
||||
@router.get("/projects/filter", response_model=PaginatedResponse[ProjectData])
|
||||
def filter_projects(
|
||||
stage: Optional[InvestmentStage] = Query(
|
||||
None, description="Filter by project stage"
|
||||
@@ -166,9 +189,11 @@ def filter_projects(
|
||||
company_name: Optional[str] = Query(
|
||||
None, description="Company name (partial match)"
|
||||
),
|
||||
page: int = Query(1, ge=1, description="Page number (starts at 1)"),
|
||||
page_size: int = Query(10, ge=1, le=100, description="Items per page (max 100)"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""Filter projects based on various criteria"""
|
||||
"""Filter projects based on various criteria (paginated)"""
|
||||
|
||||
# Start with base query
|
||||
query = db.query(ProjectTable).options(
|
||||
@@ -205,7 +230,12 @@ def filter_projects(
|
||||
CompanyTable.name.ilike(f"%{company_name}%")
|
||||
)
|
||||
|
||||
projects = query.all()
|
||||
# Get total count before pagination
|
||||
total_count = query.count()
|
||||
|
||||
# Calculate offset and apply pagination
|
||||
offset = (page - 1) * page_size
|
||||
projects = query.offset(offset).limit(page_size).all()
|
||||
|
||||
# Transform to ProjectData format
|
||||
project_data_list = []
|
||||
@@ -218,7 +248,16 @@ def filter_projects(
|
||||
)
|
||||
project_data_list.append(project_data)
|
||||
|
||||
return project_data_list
|
||||
# Calculate total pages
|
||||
total_pages = (total_count + page_size - 1) // page_size
|
||||
|
||||
return PaginatedResponse(
|
||||
items=project_data_list,
|
||||
total=total_count,
|
||||
page=page,
|
||||
page_size=page_size,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
|
||||
# Association management routes
|
||||
|
||||
Binary file not shown.
Binary file not shown.
@@ -258,10 +258,6 @@ class InvestorSchema(BaseModel):
|
||||
default=None,
|
||||
description="Geographic investment focus. Do not return any special characters, Just locations separated by commas. Leave empty if not clearly identifiable.",
|
||||
)
|
||||
stage_focus: InvestmentStage = Field(
|
||||
default=InvestmentStage.SEED,
|
||||
description="Investment stage focus. Use SEED as default if uncertain.",
|
||||
)
|
||||
number_of_investments: Optional[int] = Field(
|
||||
default=None,
|
||||
ge=0,
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import Any, Generic, List, Optional, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
# Generic type for pagination
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class InvestmentStage(str, Enum):
|
||||
SEED = "SEED"
|
||||
@@ -22,6 +25,14 @@ class SectorSchema(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestmentStageSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorMemberSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
@@ -32,6 +43,25 @@ class InvestorMemberSchema(BaseModel):
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class FundSchema(BaseModel):
|
||||
id: int
|
||||
fund_name: str | None
|
||||
fund_size: int | None # Changed to int for numerical filtering
|
||||
fund_size_source_url: str | None
|
||||
check_size_lower: int | None # NEW: Lower bound of check size range
|
||||
check_size_upper: int | None # NEW: Upper bound of check size range
|
||||
source_url: str | None
|
||||
source_provider: str | None
|
||||
geographic_focus: str | None # Changed from List[str] to string
|
||||
investment_stages: List[InvestmentStageSchema] | None # Changed to relationship
|
||||
sectors: List[SectorSchema] | None # Changed to relationship
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class CompanyMemberSchema(BaseModel):
|
||||
id: int
|
||||
name: Optional[str]
|
||||
@@ -62,11 +92,20 @@ class InvestorSchema(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: Optional[str]
|
||||
website: Optional[str] = None
|
||||
headquarters: Optional[str] = None
|
||||
aum: int | None
|
||||
aum_as_of_date: str | None = None
|
||||
aum_source_url: str | None = None
|
||||
check_size_lower: int | None
|
||||
check_size_upper: int | None
|
||||
geographic_focus: str | None
|
||||
stage_focus: InvestmentStage
|
||||
investment_thesis: Any = (
|
||||
None # Flexible JSON field - can be list, dict, or list of dicts
|
||||
)
|
||||
portfolio_highlights: Any = (
|
||||
None # Flexible JSON field - can be list, dict, or list of dicts
|
||||
)
|
||||
number_of_investments: int | None
|
||||
created_at: Optional[datetime] = None
|
||||
updated_at: Optional[datetime] = None
|
||||
@@ -76,22 +115,82 @@ class InvestorSchema(BaseModel):
|
||||
|
||||
|
||||
class InvestorData(BaseModel):
|
||||
"""Comprehensive investor data schema for LLM processing"""
|
||||
"""Comprehensive investor data schema - used for individual investor requests"""
|
||||
|
||||
investor: InvestorSchema
|
||||
portfolio_companies: List[CompanySchema]
|
||||
team_members: List[InvestorMemberSchema]
|
||||
sectors: List[SectorSchema]
|
||||
funds: List[FundSchema]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestorFundData(BaseModel):
|
||||
"""Investor-Fund combined data - used for list/filter requests
|
||||
|
||||
Each row represents one investor-fund combination.
|
||||
An investor with 3 funds will appear as 3 separate entries.
|
||||
"""
|
||||
|
||||
# Investor fields
|
||||
investor_id: int
|
||||
investor_name: str
|
||||
investor_description: Optional[str]
|
||||
investor_website: Optional[str]
|
||||
investor_headquarters: Optional[str]
|
||||
aum: int | None
|
||||
aum_as_of_date: str | None
|
||||
aum_source_url: str | None
|
||||
investment_thesis: Any = None # Flexible JSON field
|
||||
portfolio_highlights: Any = None # Flexible JSON field
|
||||
number_of_investments: int | None
|
||||
|
||||
# Fund fields
|
||||
fund_id: int | None
|
||||
fund_name: str | None
|
||||
fund_size: int | None # Changed to int for numerical filtering
|
||||
fund_size_source_url: str | None
|
||||
check_size_lower: int | None # NEW: Lower bound of check size range
|
||||
check_size_upper: int | None # NEW: Upper bound of check size range
|
||||
geographic_focus: str | None # Changed from List[str] to string
|
||||
fund_investment_stages: (
|
||||
List[InvestmentStageSchema] | None
|
||||
) # Changed to relationship
|
||||
fund_sectors: List[SectorSchema] | None # Changed to relationship
|
||||
|
||||
# Related data
|
||||
portfolio_companies: List[CompanySchema]
|
||||
team_members: List[InvestorMemberSchema]
|
||||
sectors: List[SectorSchema]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class InvestorMinimal(BaseModel):
|
||||
"""Minimal investor info with just id and name"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class CompanySchemaMinimal(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
industry: str | None
|
||||
location: str | None
|
||||
founded_year: Optional[int]
|
||||
website: Optional[str]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
|
||||
company: CompanySchema
|
||||
sectors: List[SectorSchema]
|
||||
members: List[CompanyMemberSchema]
|
||||
investors: List[InvestorSchema]
|
||||
company: CompanySchemaMinimal
|
||||
investors: List[InvestorMinimal]
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
@@ -99,3 +198,65 @@ class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
|
||||
|
||||
class InvestorList(BaseModel):
|
||||
investors: List[InvestorData]
|
||||
|
||||
|
||||
class InvestorFundList(BaseModel):
|
||||
"""List of investor-fund combinations"""
|
||||
|
||||
investor_funds: List[InvestorFundData]
|
||||
|
||||
|
||||
class CompanyMinimal(BaseModel):
|
||||
"""Minimal company info with just id and name"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class SectorMinimal(BaseModel):
|
||||
"""Minimal sector info with just id and name"""
|
||||
|
||||
id: int
|
||||
name: str
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class InvestmentResponse(BaseModel):
|
||||
"""Simplified investment response schema
|
||||
|
||||
One row per investor-fund combination with streamlined data
|
||||
"""
|
||||
|
||||
id: int # Investor ID
|
||||
name: (
|
||||
str # Combination of investor name and fund name (e.g., "Investor A - Fund A")
|
||||
)
|
||||
aum: int | None # From investor
|
||||
check_size_lower: int | None # From fund
|
||||
check_size_upper: int | None # From fund
|
||||
geographic_focus: str | None # From fund
|
||||
stage_focus: str | None # Comma-separated stages from fund
|
||||
portfolio_companies: List[CompanyMinimal] # Top 3 companies from investor
|
||||
sectors: List[SectorMinimal] # Top 3 sectors from fund
|
||||
compatibility_score: float # 0 to 1 (default 1 for now)
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class PaginatedResponse(BaseModel, Generic[T]):
|
||||
"""Generic paginated response schema"""
|
||||
|
||||
items: List[T]
|
||||
total: int
|
||||
page: int
|
||||
page_size: int
|
||||
total_pages: int
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
Binary file not shown.
Binary file not shown.
+689
-93
@@ -1,5 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
@@ -7,15 +9,35 @@ from db.db import get_db_session
|
||||
from db.models import (
|
||||
CompanyMember,
|
||||
CompanyTable,
|
||||
FundTable,
|
||||
InvestmentStageTable,
|
||||
InvestorMember,
|
||||
InvestorTable,
|
||||
SectorTable,
|
||||
)
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic import BaseModel
|
||||
from schemas.py_schemas import CompanyData, InvestorData
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
||||
class CurrencyConversion(BaseModel):
|
||||
"""Schema for LLM currency conversion responses"""
|
||||
|
||||
amount_usd: int = 0
|
||||
confidence: str = "high" # high, medium, low
|
||||
notes: str = ""
|
||||
|
||||
|
||||
class CheckSizeRange(BaseModel):
|
||||
"""Schema for LLM check size range parsing from estimated investment size"""
|
||||
|
||||
lower_bound_usd: int = 0
|
||||
upper_bound_usd: int = 0
|
||||
confidence: str = "high" # high, medium, low
|
||||
notes: str = ""
|
||||
|
||||
|
||||
class InvestorProcessor:
|
||||
def __init__(self):
|
||||
self.llm = ChatOpenAI(
|
||||
@@ -25,9 +47,465 @@ class InvestorProcessor:
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
# Structured LLMs for specific parsing tasks
|
||||
self.currency_converter_llm = self.llm.with_structured_output(
|
||||
CurrencyConversion
|
||||
)
|
||||
self.check_size_parser_llm = self.llm.with_structured_output(CheckSizeRange)
|
||||
|
||||
# Keep legacy structured LLMs for backward compatibility
|
||||
self.investor_structured_llm = self.llm.with_structured_output(InvestorData)
|
||||
self.company_structured_llm = self.llm.with_structured_output(CompanyData)
|
||||
|
||||
async def convert_to_usd(self, amount_str: str) -> Optional[int]:
|
||||
"""
|
||||
Use LLM to convert currency amounts to USD integers.
|
||||
Handles formats like:
|
||||
- "EUR 850,000,000"
|
||||
- "$5M"
|
||||
- "GBP 10-20 million"
|
||||
- "Approximately EUR 100 million"
|
||||
"""
|
||||
if not amount_str or amount_str == "Not Available" or amount_str == "0":
|
||||
return None
|
||||
|
||||
try:
|
||||
prompt = f"""Convert this amount to USD as an integer (whole number, no decimals).
|
||||
If it's a range, use the midpoint. If already in USD, just extract the number.
|
||||
Remove all commas and convert millions/billions to actual numbers.
|
||||
|
||||
Amount: {amount_str}
|
||||
|
||||
Examples:
|
||||
- "EUR 850,000,000" -> 935000000 (assuming EUR to USD rate ~1.10)
|
||||
- "$5M" -> 5000000
|
||||
- "GBP 10-20 million" -> 18000000 (midpoint 15M * 1.20 rate)
|
||||
- "Approximately EUR 100 million" -> 110000000
|
||||
|
||||
Return only the USD integer amount with current exchange rates."""
|
||||
|
||||
result = await self.currency_converter_llm.ainvoke(prompt)
|
||||
return result.amount_usd if result.amount_usd > 0 else None
|
||||
except Exception as e:
|
||||
print(f"Error converting currency '{amount_str}': {e}")
|
||||
return None
|
||||
|
||||
async def parse_check_size_range(
|
||||
self, estimated_investment_str: str
|
||||
) -> tuple[Optional[int], Optional[int]]:
|
||||
"""
|
||||
Use LLM to parse check size range from estimated investment size string.
|
||||
Returns tuple of (lower_bound_usd, upper_bound_usd).
|
||||
|
||||
Handles formats like:
|
||||
- "EUR 1,000 to 2,000"
|
||||
- "$100K-$500K"
|
||||
- "Between $1M and $5M"
|
||||
- "Up to EUR 10 million"
|
||||
- "$2M typical"
|
||||
"""
|
||||
if (
|
||||
not estimated_investment_str
|
||||
or estimated_investment_str == "Not Available"
|
||||
or estimated_investment_str == "0"
|
||||
):
|
||||
return None, None
|
||||
|
||||
try:
|
||||
prompt = f"""Parse this check size/investment range into lower and upper bounds in USD as integers.
|
||||
|
||||
Input: {estimated_investment_str}
|
||||
|
||||
Instructions:
|
||||
- If it's a range (e.g., "EUR 1M to 5M"), extract both bounds
|
||||
- If it's a single amount (e.g., "$2M typical"), use it as both lower and upper
|
||||
- If it says "up to X", use 0 as lower and X as upper
|
||||
- Convert all currencies to USD using current exchange rates
|
||||
- Return integers (whole numbers, no decimals)
|
||||
|
||||
Examples:
|
||||
- "EUR 1,000 to 2,000" -> lower: 1100, upper: 2200
|
||||
- "$100K-$500K" -> lower: 100000, upper: 500000
|
||||
- "Between $1M and $5M" -> lower: 1000000, upper: 5000000
|
||||
- "Up to EUR 10 million" -> lower: 0, upper: 11000000
|
||||
- "$2M typical" -> lower: 2000000, upper: 2000000
|
||||
- "GBP 500K-2M" -> lower: 600000, upper: 2400000
|
||||
|
||||
Return the lower and upper bounds in USD."""
|
||||
|
||||
result = await self.check_size_parser_llm.ainvoke(prompt)
|
||||
lower = result.lower_bound_usd if result.lower_bound_usd > 0 else None
|
||||
upper = result.upper_bound_usd if result.upper_bound_usd > 0 else None
|
||||
return lower, upper
|
||||
except Exception as e:
|
||||
print(f"Error parsing check size range '{estimated_investment_str}': {e}")
|
||||
return None, None
|
||||
|
||||
def parse_json_profile(self, json_str: str) -> Optional[dict]:
|
||||
"""
|
||||
Manually parse the JSON profile from the CSV.
|
||||
Returns a cleaned dictionary with the investor profile data.
|
||||
"""
|
||||
if not json_str or pd.isna(json_str):
|
||||
return None
|
||||
|
||||
try:
|
||||
# Parse JSON string
|
||||
profile = json.loads(json_str)
|
||||
return profile
|
||||
except json.JSONDecodeError as e:
|
||||
print(f"Error parsing JSON: {e}")
|
||||
return None
|
||||
|
||||
async def process_investor_profile(
|
||||
self, name: str, website: str, profile_json: str
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Process investor profile from CSV data.
|
||||
Manually extracts fields and uses LLM only for currency conversion.
|
||||
"""
|
||||
profile = self.parse_json_profile(profile_json)
|
||||
if not profile:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Extract basic info
|
||||
investor_data = {
|
||||
"name": name.strip() if name else None,
|
||||
"website": website.strip() if website else None,
|
||||
"headquarters": profile.get("headquarters"),
|
||||
"description": profile.get("investorDescription"),
|
||||
"aum": None,
|
||||
"aum_as_of_date": None,
|
||||
"aum_source_url": None,
|
||||
"investment_thesis": profile.get("investmentThesisFocus", []),
|
||||
"portfolio_highlights": profile.get("portfolioHighlights", []),
|
||||
"linked_documents": profile.get("linkedDocuments", []),
|
||||
"researcher_notes": profile.get("researcherNotes"),
|
||||
"missing_important_fields": profile.get("missingImportantFields", []),
|
||||
"sources": profile.get("sources", {}),
|
||||
"team_members": [],
|
||||
"funds": [],
|
||||
}
|
||||
|
||||
# Process AUM
|
||||
aum_data = profile.get("overallAssetsUnderManagement", {})
|
||||
if aum_data and isinstance(aum_data, dict):
|
||||
aum_amount = aum_data.get("aumAmount")
|
||||
if aum_amount and aum_amount != "Not Available":
|
||||
# Convert AUM to USD integer
|
||||
aum_usd = await self.convert_to_usd(aum_amount)
|
||||
investor_data["aum"] = aum_usd
|
||||
investor_data["aum_as_of_date"] = aum_data.get("asOfDate")
|
||||
investor_data["aum_source_url"] = aum_data.get("sourceUrl")
|
||||
|
||||
# Process senior leadership
|
||||
senior_leadership = profile.get("seniorLeadership", [])
|
||||
for member in senior_leadership:
|
||||
if isinstance(member, dict) and member.get("name"):
|
||||
investor_data["team_members"].append(
|
||||
{
|
||||
"name": member.get("name"),
|
||||
"title": member.get("title"),
|
||||
"role": member.get("title"), # Use title as role
|
||||
"email": None,
|
||||
"source_url": member.get("sourceUrl"),
|
||||
}
|
||||
)
|
||||
|
||||
# Process funds
|
||||
funds = profile.get("funds", [])
|
||||
for fund in funds:
|
||||
if isinstance(fund, dict):
|
||||
fund_data = {
|
||||
"fund_name": fund.get("fundName"),
|
||||
"fund_size": None,
|
||||
"fund_size_source_url": fund.get("fundSizeSourceUrl"),
|
||||
"check_size_lower": None,
|
||||
"check_size_upper": None,
|
||||
"source_url": fund.get("sourceUrl"),
|
||||
"source_provider": fund.get("sourceProvider"),
|
||||
"geographic_focus": None, # Will be converted to string
|
||||
"investment_stage_names": fund.get("investmentStageFocus", []),
|
||||
"sector_names": fund.get("sectorFocus", []),
|
||||
}
|
||||
|
||||
# Convert geographic focus from array to comma-separated string
|
||||
geo_focus = fund.get("geographicFocus", [])
|
||||
if geo_focus and isinstance(geo_focus, list):
|
||||
fund_data["geographic_focus"] = ", ".join(geo_focus)
|
||||
|
||||
# Convert fund size to USD integer
|
||||
fund_size_str = fund.get("fundSize")
|
||||
if fund_size_str and fund_size_str != "Not Available":
|
||||
fund_size_usd = await self.convert_to_usd(fund_size_str)
|
||||
if fund_size_usd:
|
||||
fund_data["fund_size"] = fund_size_usd # Store as integer
|
||||
|
||||
# Parse check size range from estimated investment size
|
||||
est_size_str = fund.get("estimatedInvestmentSize")
|
||||
if est_size_str and est_size_str != "Not Available":
|
||||
check_lower, check_upper = await self.parse_check_size_range(
|
||||
est_size_str
|
||||
)
|
||||
if check_lower is not None:
|
||||
fund_data["check_size_lower"] = check_lower
|
||||
if check_upper is not None:
|
||||
fund_data["check_size_upper"] = check_upper
|
||||
|
||||
investor_data["funds"].append(fund_data)
|
||||
|
||||
return investor_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing investor profile for {name}: {e}")
|
||||
return None
|
||||
|
||||
async def process_company_profile(
|
||||
self, name: str, website: str, profile_json: str, investor_names: str = None
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
Process company profile from CSV data.
|
||||
Only extracts founded_year and key_executives - rest is in base database.
|
||||
"""
|
||||
profile = self.parse_json_profile(profile_json)
|
||||
if not profile:
|
||||
return None
|
||||
|
||||
try:
|
||||
# Only extract founded_year and key_executives
|
||||
company_data = {
|
||||
"name": name.strip() if name else None,
|
||||
"founded_year": None,
|
||||
"key_executives": [],
|
||||
}
|
||||
|
||||
# Process key executives/leadership
|
||||
key_executives = profile.get("keyExecutives", [])
|
||||
if not key_executives:
|
||||
# Try alternative field names
|
||||
key_executives = profile.get("seniorLeadership", [])
|
||||
|
||||
for exec_member in key_executives:
|
||||
if isinstance(exec_member, dict) and exec_member.get("name"):
|
||||
company_data["key_executives"].append(
|
||||
{
|
||||
"name": exec_member.get("name"),
|
||||
"title": exec_member.get("title"),
|
||||
"source_url": exec_member.get("sourceUrl"),
|
||||
}
|
||||
)
|
||||
|
||||
# Try to extract founding year from description
|
||||
description = profile.get("companyDescription", "")
|
||||
if description:
|
||||
# Look for patterns like "founded in 2020", "Gegründet 2020", "founded 2020"
|
||||
year_patterns = [
|
||||
r"founded in (\d{4})",
|
||||
r"founded (\d{4})",
|
||||
r"Gegründet (\d{4})",
|
||||
r"established in (\d{4})",
|
||||
r"since (\d{4})",
|
||||
r"\((\d{4})\)", # Year in parentheses
|
||||
]
|
||||
for pattern in year_patterns:
|
||||
match = re.search(pattern, description, re.IGNORECASE)
|
||||
if match:
|
||||
try:
|
||||
year = int(match.group(1))
|
||||
if 1900 <= year <= 2025: # Sanity check
|
||||
company_data["founded_year"] = year
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return company_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing company profile for {name}: {e}")
|
||||
return None
|
||||
|
||||
def _save_parsed_company_to_db(
|
||||
self, db: Session, company_data: dict
|
||||
) -> Optional[CompanyTable]:
|
||||
"""Save manually parsed company data to database - only updates founded_year and key_executives"""
|
||||
try:
|
||||
# Check if company already exists (should exist in base database)
|
||||
existing_company = (
|
||||
db.query(CompanyTable).filter_by(name=company_data["name"]).first()
|
||||
)
|
||||
|
||||
if existing_company:
|
||||
# Update only founded_year on existing company
|
||||
company = existing_company
|
||||
if company_data.get("founded_year"):
|
||||
company.founded_year = company_data["founded_year"]
|
||||
else:
|
||||
# Company should already be in base database, but if not found, skip
|
||||
print(
|
||||
f"⚠️ Company '{company_data['name']}' not found in base database - skipping"
|
||||
)
|
||||
return None
|
||||
|
||||
# Add/update company members (key executives)
|
||||
# First, remove existing members if updating
|
||||
db.query(CompanyMember).filter_by(company_id=company.id).delete()
|
||||
|
||||
for exec_data in company_data.get("key_executives", []):
|
||||
member = CompanyMember(
|
||||
name=exec_data.get("name"),
|
||||
role=exec_data.get("title"),
|
||||
linkedin=exec_data.get(
|
||||
"source_url"
|
||||
), # Store source URL in linkedin field
|
||||
company_id=company.id,
|
||||
)
|
||||
db.add(member)
|
||||
|
||||
return company
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error saving company to database: {e}")
|
||||
db.rollback()
|
||||
return None
|
||||
|
||||
def _save_parsed_investor_to_db(
|
||||
self, db: Session, investor_data: dict
|
||||
) -> Optional[InvestorTable]:
|
||||
"""Save manually parsed investor data to database"""
|
||||
try:
|
||||
# Check if investor already exists
|
||||
existing_investor = (
|
||||
db.query(InvestorTable).filter_by(name=investor_data["name"]).first()
|
||||
)
|
||||
|
||||
if existing_investor:
|
||||
# Update existing investor
|
||||
investor = existing_investor
|
||||
investor.website = investor_data.get("website") or investor.website
|
||||
investor.headquarters = (
|
||||
investor_data.get("headquarters") or investor.headquarters
|
||||
)
|
||||
investor.description = (
|
||||
investor_data.get("description") or investor.description
|
||||
)
|
||||
investor.aum = investor_data.get("aum") or investor.aum
|
||||
investor.aum_as_of_date = (
|
||||
investor_data.get("aum_as_of_date") or investor.aum_as_of_date
|
||||
)
|
||||
investor.aum_source_url = (
|
||||
investor_data.get("aum_source_url") or investor.aum_source_url
|
||||
)
|
||||
investor.investment_thesis = (
|
||||
investor_data.get("investment_thesis") or investor.investment_thesis
|
||||
)
|
||||
investor.portfolio_highlights = (
|
||||
investor_data.get("portfolio_highlights")
|
||||
or investor.portfolio_highlights
|
||||
)
|
||||
investor.linked_documents = (
|
||||
investor_data.get("linked_documents") or investor.linked_documents
|
||||
)
|
||||
investor.researcher_notes = (
|
||||
investor_data.get("researcher_notes") or investor.researcher_notes
|
||||
)
|
||||
investor.missing_important_fields = (
|
||||
investor_data.get("missing_important_fields")
|
||||
or investor.missing_important_fields
|
||||
)
|
||||
investor.sources = investor_data.get("sources") or investor.sources
|
||||
else:
|
||||
# Create new investor
|
||||
investor = InvestorTable(
|
||||
name=investor_data["name"],
|
||||
website=investor_data.get("website"),
|
||||
headquarters=investor_data.get("headquarters"),
|
||||
description=investor_data.get("description"),
|
||||
aum=investor_data.get("aum"),
|
||||
aum_as_of_date=investor_data.get("aum_as_of_date"),
|
||||
aum_source_url=investor_data.get("aum_source_url"),
|
||||
investment_thesis=investor_data.get("investment_thesis"),
|
||||
portfolio_highlights=investor_data.get("portfolio_highlights"),
|
||||
linked_documents=investor_data.get("linked_documents"),
|
||||
researcher_notes=investor_data.get("researcher_notes"),
|
||||
missing_important_fields=investor_data.get(
|
||||
"missing_important_fields"
|
||||
),
|
||||
sources=investor_data.get("sources"),
|
||||
)
|
||||
db.add(investor)
|
||||
db.flush()
|
||||
|
||||
# Add/update team members
|
||||
# First, remove existing team members if updating
|
||||
if existing_investor:
|
||||
db.query(InvestorMember).filter_by(investor_id=investor.id).delete()
|
||||
|
||||
for member_data in investor_data.get("team_members", []):
|
||||
member = InvestorMember(
|
||||
name=member_data.get("name"),
|
||||
role=member_data.get("role"),
|
||||
title=member_data.get("title"),
|
||||
email=member_data.get("email"),
|
||||
source_url=member_data.get("source_url"),
|
||||
investor_id=investor.id,
|
||||
)
|
||||
db.add(member)
|
||||
|
||||
# Add/update funds
|
||||
# First, remove existing funds if updating
|
||||
if existing_investor:
|
||||
db.query(FundTable).filter_by(investor_id=investor.id).delete()
|
||||
|
||||
for fund_data in investor_data.get("funds", []):
|
||||
fund = FundTable(
|
||||
investor_id=investor.id,
|
||||
fund_name=fund_data.get("fund_name"),
|
||||
fund_size=fund_data.get("fund_size"), # Now an integer
|
||||
fund_size_source_url=fund_data.get("fund_size_source_url"),
|
||||
check_size_lower=fund_data.get("check_size_lower"),
|
||||
check_size_upper=fund_data.get("check_size_upper"),
|
||||
source_url=fund_data.get("source_url"),
|
||||
source_provider=fund_data.get("source_provider"),
|
||||
geographic_focus=fund_data.get("geographic_focus"), # Now a string
|
||||
)
|
||||
db.add(fund)
|
||||
db.flush() # Get the fund ID
|
||||
|
||||
# Add investment stages (many-to-many)
|
||||
for stage_name in fund_data.get("investment_stage_names", []):
|
||||
stage = self._get_or_create_investment_stage(db, stage_name)
|
||||
fund.investment_stages.append(stage)
|
||||
|
||||
# Add sectors (many-to-many)
|
||||
for sector_name in fund_data.get("sector_names", []):
|
||||
sector = self._get_or_create_sector(db, sector_name)
|
||||
fund.sectors.append(sector)
|
||||
|
||||
return investor
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error saving investor to database: {e}")
|
||||
db.rollback()
|
||||
return None
|
||||
|
||||
def _get_or_create_investment_stage(
|
||||
self, db: Session, stage_name: str
|
||||
) -> InvestmentStageTable:
|
||||
"""Get existing investment stage or create new one"""
|
||||
from db.models import InvestmentStageTable
|
||||
|
||||
stage = (
|
||||
db.query(InvestmentStageTable)
|
||||
.filter(InvestmentStageTable.name == stage_name)
|
||||
.first()
|
||||
)
|
||||
if not stage:
|
||||
stage = InvestmentStageTable(name=stage_name)
|
||||
db.add(stage)
|
||||
db.flush() # Get the ID without committing
|
||||
return stage
|
||||
|
||||
def _get_or_create_sector(self, db: Session, sector_name: str) -> SectorTable:
|
||||
"""Get existing sector or create new one"""
|
||||
sector = db.query(SectorTable).filter(SectorTable.name == sector_name).first()
|
||||
@@ -49,7 +527,6 @@ class InvestorProcessor:
|
||||
check_size_lower=investor_data.investor.check_size_lower,
|
||||
check_size_upper=investor_data.investor.check_size_upper,
|
||||
geographic_focus=investor_data.investor.geographic_focus,
|
||||
stage_focus=investor_data.investor.stage_focus,
|
||||
number_of_investments=investor_data.investor.number_of_investments,
|
||||
)
|
||||
db.add(investor)
|
||||
@@ -173,141 +650,260 @@ class InvestorProcessor:
|
||||
print(f"Error processing row {row_idx + 1}: {e}")
|
||||
return None
|
||||
|
||||
async def parse_investors(self, df, save_to_db: bool = True):
|
||||
"""Parse investors from DataFrame and optionally save to database"""
|
||||
investors = []
|
||||
df = df[20:]
|
||||
async def _process_single_investor(
|
||||
self, idx: int, row: pd.Series, total_rows: int
|
||||
) -> Optional[dict]:
|
||||
"""Process a single investor row"""
|
||||
try:
|
||||
name = row.get("Name", "").strip() if pd.notna(row.get("Name")) else None
|
||||
website = (
|
||||
row.get("Website", "").strip() if pd.notna(row.get("Website")) else None
|
||||
)
|
||||
profile_json = (
|
||||
row.get("Final Investor Profile", "")
|
||||
if pd.notna(row.get("Final Investor Profile"))
|
||||
else None
|
||||
)
|
||||
|
||||
if not name or not profile_json:
|
||||
print(f"⚠️ Row {idx + 1}: Skipping - missing name or profile")
|
||||
return None
|
||||
|
||||
print(f"📊 Processing {idx + 1}/{total_rows}: {name}")
|
||||
|
||||
# Process the investor profile
|
||||
investor_data = await self.process_investor_profile(
|
||||
name, website, profile_json
|
||||
)
|
||||
|
||||
if investor_data:
|
||||
print(f" ✓ {name} parsed successfully")
|
||||
return investor_data
|
||||
else:
|
||||
print(f" ⚠️ {name} failed to process")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing row {idx + 1}: {e}")
|
||||
return None
|
||||
|
||||
async def parse_investors(
|
||||
self, df: pd.DataFrame, save_to_db: bool = True, batch_size: int = 10
|
||||
):
|
||||
"""
|
||||
Parse investors from DataFrame using manual JSON parsing and LLM for currency conversion.
|
||||
Processes multiple investors concurrently for better performance.
|
||||
Expected CSV columns: Name, Website, Final Investor Profile, Final Profile sourcing
|
||||
|
||||
Args:
|
||||
df: DataFrame with investor data
|
||||
save_to_db: Whether to save to database
|
||||
batch_size: Number of investors to process concurrently (default: 10)
|
||||
"""
|
||||
results = []
|
||||
db = None
|
||||
if save_to_db:
|
||||
db = get_db_session()
|
||||
|
||||
try:
|
||||
# Process rows in batches asynchronously
|
||||
batch_size = 20 # Adjust batch size as needed
|
||||
rows = [(idx, row) for idx, row in df.iterrows()]
|
||||
|
||||
for i in range(0, len(rows), batch_size):
|
||||
batch = rows[i : i + batch_size]
|
||||
|
||||
# Process batch asynchronously
|
||||
tasks = [
|
||||
self._process_row(row, idx, is_investor=True) for idx, row in batch
|
||||
]
|
||||
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Handle results from batch
|
||||
for (idx, row), result in zip(batch, batch_results):
|
||||
if isinstance(result, Exception):
|
||||
print(f"Error processing row {idx}: {result}")
|
||||
if db:
|
||||
db.rollback()
|
||||
continue
|
||||
|
||||
if result:
|
||||
# Convert dict to InvestorData if needed
|
||||
if isinstance(result, dict):
|
||||
investor_data = InvestorData(**result)
|
||||
else:
|
||||
investor_data = result
|
||||
|
||||
investors.append(investor_data)
|
||||
|
||||
# Save to database if requested
|
||||
if save_to_db and db:
|
||||
try:
|
||||
saved_investor = self._save_investor_to_db(
|
||||
db, investor_data
|
||||
)
|
||||
db.commit()
|
||||
print(
|
||||
f"✅ Saved investor '{saved_investor.name}' to database"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f"❌ Failed to save investor to database: {e}")
|
||||
total_rows = len(df)
|
||||
print(
|
||||
f"\n🚀 Starting to process {total_rows} investors with batch size {batch_size}..."
|
||||
)
|
||||
|
||||
# Process in batches
|
||||
for batch_start in range(0, total_rows, batch_size):
|
||||
batch_end = min(batch_start + batch_size, total_rows)
|
||||
print(
|
||||
f"Completed batch {i // batch_size + 1} of {(len(rows) + batch_size - 1) // batch_size}"
|
||||
f"\n🔄 Processing batch {batch_start + 1}-{batch_end} of {total_rows}..."
|
||||
)
|
||||
|
||||
# Create tasks for concurrent processing
|
||||
tasks = []
|
||||
for idx in range(batch_start, batch_end):
|
||||
row = df.iloc[idx]
|
||||
task = self._process_single_investor(idx, row, total_rows)
|
||||
tasks.append(task)
|
||||
|
||||
# Process batch concurrently
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Filter out None results and exceptions, then save to database
|
||||
for investor_data in batch_results:
|
||||
if investor_data and not isinstance(investor_data, Exception):
|
||||
results.append(investor_data)
|
||||
|
||||
# Save to database
|
||||
if save_to_db and db:
|
||||
try:
|
||||
saved_investor = self._save_parsed_investor_to_db(
|
||||
db, investor_data
|
||||
)
|
||||
if saved_investor:
|
||||
print(
|
||||
f" ✅ Saved {investor_data['name']} to database (ID: {saved_investor.id})"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" ❌ Failed to save {investor_data['name']} to database"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(
|
||||
f" ❌ Database error for {investor_data['name']}: {e}"
|
||||
)
|
||||
elif isinstance(investor_data, Exception):
|
||||
print(f" ❌ Exception occurred: {investor_data}")
|
||||
|
||||
# Commit batch to database
|
||||
if save_to_db and db:
|
||||
try:
|
||||
db.commit()
|
||||
print(f"💾 Committed batch {batch_start + 1}-{batch_end}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f"❌ Failed to commit batch: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error in batch processing: {e}")
|
||||
print(f"❌ Fatal error in parse_investors: {e}")
|
||||
if db:
|
||||
db.rollback()
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
return investors
|
||||
print(f"\n🎉 Completed! Processed {len(results)}/{total_rows} investors")
|
||||
return results
|
||||
|
||||
async def parse_companies(self, df, save_to_db: bool = True):
|
||||
"""Parse companies from DataFrame and optionally save to database"""
|
||||
companies = []
|
||||
df = df[20:]
|
||||
async def _process_single_company(
|
||||
self, idx: int, row: pd.Series, total_rows: int
|
||||
) -> Optional[dict]:
|
||||
"""Process a single company row"""
|
||||
try:
|
||||
name = row.get("Name", "").strip() if pd.notna(row.get("Name")) else None
|
||||
website = (
|
||||
row.get("Website", "").strip() if pd.notna(row.get("Website")) else None
|
||||
)
|
||||
investor_names = (
|
||||
row.get("Investor", "").strip()
|
||||
if pd.notna(row.get("Investor"))
|
||||
else None
|
||||
)
|
||||
profile_json = (
|
||||
row.get("Final Investor Profile", "")
|
||||
if pd.notna(row.get("Final Investor Profile"))
|
||||
else None
|
||||
)
|
||||
|
||||
if not name or not profile_json:
|
||||
print(f"⚠️ Row {idx + 1}: Skipping - missing name or profile")
|
||||
return None
|
||||
|
||||
print(f"📊 Processing {idx + 1}/{total_rows}: {name}")
|
||||
|
||||
# Process the company profile
|
||||
company_data = await self.process_company_profile(
|
||||
name, website, profile_json, investor_names
|
||||
)
|
||||
|
||||
if company_data:
|
||||
print(f" ✓ {name} parsed successfully")
|
||||
return company_data
|
||||
else:
|
||||
print(f" ⚠️ {name} failed to process")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error processing row {idx + 1}: {e}")
|
||||
return None
|
||||
|
||||
async def parse_companies(
|
||||
self, df: pd.DataFrame, save_to_db: bool = True, batch_size: int = 10
|
||||
):
|
||||
"""
|
||||
Parse companies from DataFrame using manual JSON parsing.
|
||||
Processes multiple companies concurrently for better performance.
|
||||
Expected CSV columns: Name, Website, Investor, Final Investor Profile (actually company profile)
|
||||
|
||||
Args:
|
||||
df: DataFrame with company data
|
||||
save_to_db: Whether to save to database
|
||||
batch_size: Number of companies to process concurrently (default: 10)
|
||||
"""
|
||||
results = []
|
||||
db = None
|
||||
if save_to_db:
|
||||
db = get_db_session()
|
||||
|
||||
try:
|
||||
# Process rows in batches asynchronously
|
||||
batch_size = 20 # Adjust batch size as needed
|
||||
rows = [(idx, row) for idx, row in df.iterrows()]
|
||||
total_rows = len(df)
|
||||
print(
|
||||
f"\n🚀 Starting to process {total_rows} companies with batch size {batch_size}..."
|
||||
)
|
||||
|
||||
for i in range(0, len(rows), batch_size):
|
||||
batch = rows[i : i + batch_size]
|
||||
# Process in batches
|
||||
for batch_start in range(0, total_rows, batch_size):
|
||||
batch_end = min(batch_start + batch_size, total_rows)
|
||||
print(
|
||||
f"\n🔄 Processing batch {batch_start + 1}-{batch_end} of {total_rows}..."
|
||||
)
|
||||
|
||||
# Process batch asynchronously
|
||||
tasks = [
|
||||
self._process_row(row, idx, is_investor=False) for idx, row in batch
|
||||
]
|
||||
# Create tasks for concurrent processing
|
||||
tasks = []
|
||||
for idx in range(batch_start, batch_end):
|
||||
row = df.iloc[idx]
|
||||
task = self._process_single_company(idx, row, total_rows)
|
||||
tasks.append(task)
|
||||
|
||||
# Process batch concurrently
|
||||
batch_results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Handle results from batch
|
||||
for (idx, row), result in zip(batch, batch_results):
|
||||
if isinstance(result, Exception):
|
||||
print(f"Error processing row {idx}: {result}")
|
||||
if db:
|
||||
db.rollback()
|
||||
continue
|
||||
# Filter out None results and exceptions, then save to database
|
||||
for company_data in batch_results:
|
||||
if company_data and not isinstance(company_data, Exception):
|
||||
results.append(company_data)
|
||||
|
||||
if result:
|
||||
# Convert dict to CompanyData if needed
|
||||
if isinstance(result, dict):
|
||||
company_data = CompanyData(**result)
|
||||
else:
|
||||
company_data = result
|
||||
|
||||
companies.append(company_data)
|
||||
|
||||
# Save to database if requested
|
||||
# Save to database
|
||||
if save_to_db and db:
|
||||
try:
|
||||
saved_company = self._save_company_to_db(
|
||||
saved_company = self._save_parsed_company_to_db(
|
||||
db, company_data
|
||||
)
|
||||
db.commit()
|
||||
print(
|
||||
f"✅ Saved company '{saved_company.name}' to database"
|
||||
)
|
||||
if saved_company:
|
||||
print(
|
||||
f" ✅ Saved {company_data['name']} to database (ID: {saved_company.id})"
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f" ❌ Failed to save {company_data['name']} to database"
|
||||
)
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f"❌ Failed to save company to database: {e}")
|
||||
print(
|
||||
f" ❌ Database error for {company_data['name']}: {e}"
|
||||
)
|
||||
elif isinstance(company_data, Exception):
|
||||
print(f" ❌ Exception occurred: {company_data}")
|
||||
|
||||
print(
|
||||
f"Completed batch {i // batch_size + 1} of {(len(rows) + batch_size - 1) // batch_size}"
|
||||
)
|
||||
# Commit batch to database
|
||||
if save_to_db and db:
|
||||
try:
|
||||
db.commit()
|
||||
print(f"💾 Committed batch {batch_start + 1}-{batch_end}")
|
||||
except Exception as e:
|
||||
db.rollback()
|
||||
print(f"❌ Failed to commit batch: {e}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error processing row {idx}: {e}")
|
||||
print(f"❌ Fatal error in parse_companies: {e}")
|
||||
if db:
|
||||
db.rollback()
|
||||
finally:
|
||||
if db:
|
||||
db.close()
|
||||
|
||||
return companies
|
||||
print(f"\n🎉 Completed! Processed {len(results)}/{total_rows} companies")
|
||||
return results
|
||||
|
||||
|
||||
# async def main():
|
||||
|
||||
+100
-41
@@ -2,13 +2,18 @@ import os
|
||||
from typing import List
|
||||
|
||||
from db.db import DATABASE_URL, get_db
|
||||
from db.models import InvestorTable
|
||||
from db.models import FundTable, InvestorTable
|
||||
from langchain import hub
|
||||
from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
||||
from langchain_community.utilities import SQLDatabase
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langgraph.prebuilt import create_react_agent
|
||||
from schemas.py_schemas import InvestorData, InvestorList
|
||||
from schemas.router_schemas import (
|
||||
CompanyMinimal,
|
||||
InvestmentResponse,
|
||||
PaginatedResponse,
|
||||
SectorMinimal,
|
||||
)
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
# Connect to SQLite
|
||||
@@ -21,16 +26,16 @@ class QueryProcessor:
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=os.getenv("OPENROUTER_API_KEY"),
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="openai/gpt-4o-mini",
|
||||
model="x-ai/grok-4-fast",
|
||||
temperature=0,
|
||||
)
|
||||
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
|
||||
# Update system message to specifically request only investor IDs
|
||||
# Update system message to specifically request only fund IDs
|
||||
system_message_updated = (
|
||||
prompt_template.format(dialect="SQLite", top_k=5)
|
||||
+ "\n\nIMPORTANT: You must ONLY return the investor IDs (id field) that match the user's criteria. "
|
||||
+ "\n\nIMPORTANT: You must ONLY return the fund IDs (id field from the funds table) that match the user's criteria. "
|
||||
+ "Do NOT return any other information, explanations, or data. "
|
||||
+ "Your response should be ONLY a comma-separated list of numbers representing the investor IDs. "
|
||||
+ "Your response should be ONLY a comma-separated list of numbers representing the fund IDs. "
|
||||
+ "Example format: 1, 5, 12, 23"
|
||||
)
|
||||
self.agent = create_react_agent(
|
||||
@@ -39,9 +44,9 @@ class QueryProcessor:
|
||||
prompt=system_message_updated,
|
||||
)
|
||||
|
||||
def process_query(self, question: str) -> InvestorList:
|
||||
"""Process a query using the LLM and return investor data."""
|
||||
# Let the LLM handle all database interactions and filtering to get IDs
|
||||
def process_query(self, question: str) -> PaginatedResponse[InvestmentResponse]:
|
||||
"""Process a query using the LLM and return investment response data."""
|
||||
# Let the LLM handle all database interactions and filtering to get fund IDs
|
||||
response = self.agent.invoke(
|
||||
{"messages": [("user", question)]},
|
||||
)
|
||||
@@ -51,68 +56,122 @@ class QueryProcessor:
|
||||
response["messages"][-1].content if response.get("messages") else ""
|
||||
)
|
||||
|
||||
# Extract investor IDs from the AI response
|
||||
investor_ids = self._extract_investor_ids_from_response(ai_response)
|
||||
# Extract fund IDs from the AI response
|
||||
fund_ids = self._extract_fund_ids_from_response(ai_response)
|
||||
|
||||
# Fetch full investor data using the IDs
|
||||
return self._fetch_investors_by_ids(investor_ids)
|
||||
# Fetch full fund data with investor relationships using the IDs
|
||||
return self._fetch_funds_by_ids(fund_ids)
|
||||
|
||||
def _extract_investor_ids_from_response(self, ai_response: str) -> List[int]:
|
||||
"""Extract investor IDs from AI response."""
|
||||
def _extract_fund_ids_from_response(self, ai_response: str) -> List[int]:
|
||||
"""Extract fund IDs from AI response."""
|
||||
import re
|
||||
|
||||
investor_ids = []
|
||||
fund_ids = []
|
||||
try:
|
||||
# Try multiple patterns to extract IDs from the response
|
||||
# Pattern 1: Simple numbers (assuming they are IDs)
|
||||
numbers = re.findall(r"\b\d+\b", ai_response)
|
||||
investor_ids = [int(num) for num in numbers]
|
||||
fund_ids = [int(num) for num in numbers]
|
||||
|
||||
# Pattern 2: If response contains explicit ID references
|
||||
id_matches = re.findall(r"\bid[:\s]*(\d+)", ai_response.lower())
|
||||
if id_matches:
|
||||
investor_ids = [int(id_str) for id_str in id_matches]
|
||||
fund_ids = [int(id_str) for id_str in id_matches]
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error extracting IDs from response: {e}")
|
||||
return []
|
||||
|
||||
return investor_ids
|
||||
return fund_ids
|
||||
|
||||
def _fetch_investors_by_ids(self, investor_ids: List[int]) -> InvestorList:
|
||||
"""Fetch investors with all their relationships from the database using IDs."""
|
||||
if not investor_ids:
|
||||
return InvestorList(investors=[])
|
||||
def _fetch_funds_by_ids(
|
||||
self, fund_ids: List[int]
|
||||
) -> PaginatedResponse[InvestmentResponse]:
|
||||
"""Fetch funds with all their relationships from the database using fund IDs.
|
||||
Constructs response similar to read_investors but starting from funds."""
|
||||
if not fund_ids:
|
||||
return PaginatedResponse(
|
||||
items=[],
|
||||
total=0,
|
||||
page=1,
|
||||
page_size=len(fund_ids) if fund_ids else 10,
|
||||
total_pages=0,
|
||||
)
|
||||
|
||||
# Get database session
|
||||
db_session = next(get_db())
|
||||
|
||||
try:
|
||||
# Build query with all relationships loaded
|
||||
query = (
|
||||
db_session.query(InvestorTable)
|
||||
# Query funds with all necessary relationships loaded
|
||||
funds = (
|
||||
db_session.query(FundTable)
|
||||
.options(
|
||||
selectinload(InvestorTable.portfolio_companies),
|
||||
selectinload(InvestorTable.team_members),
|
||||
selectinload(InvestorTable.sectors),
|
||||
selectinload(FundTable.investor).selectinload(
|
||||
InvestorTable.portfolio_companies
|
||||
),
|
||||
selectinload(FundTable.investor).selectinload(
|
||||
InvestorTable.team_members
|
||||
),
|
||||
selectinload(FundTable.investor).selectinload(
|
||||
InvestorTable.sectors
|
||||
),
|
||||
selectinload(FundTable.investment_stages),
|
||||
selectinload(FundTable.sectors),
|
||||
)
|
||||
.filter(InvestorTable.id.in_(investor_ids))
|
||||
.filter(FundTable.id.in_(fund_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
investors = query.all()
|
||||
# Transform to InvestmentResponse format (one row per fund)
|
||||
investment_responses = []
|
||||
for fund in funds:
|
||||
investor = fund.investor
|
||||
|
||||
# Transform to InvestorData format
|
||||
investor_data_list = []
|
||||
for investor in investors:
|
||||
investor_data = InvestorData(
|
||||
investor=investor,
|
||||
portfolio_companies=investor.portfolio_companies,
|
||||
team_members=investor.team_members,
|
||||
sectors=investor.sectors,
|
||||
# Get top 3 portfolio companies (id and name only)
|
||||
portfolio_companies = [
|
||||
CompanyMinimal(id=company.id, name=company.name)
|
||||
for company in investor.portfolio_companies[:3]
|
||||
]
|
||||
|
||||
# Get stage focus as comma-separated string
|
||||
stage_focus = (
|
||||
", ".join([stage.name for stage in fund.investment_stages])
|
||||
if fund.investment_stages
|
||||
else None
|
||||
)
|
||||
investor_data_list.append(investor_data)
|
||||
|
||||
return InvestorList(investors=investor_data_list)
|
||||
# Get top 3 sectors from fund (id and name only)
|
||||
fund_sectors = [
|
||||
SectorMinimal(id=sector.id, name=sector.name)
|
||||
for sector in (fund.sectors[:3] if fund.sectors else [])
|
||||
]
|
||||
|
||||
investment_response = InvestmentResponse(
|
||||
id=investor.id,
|
||||
name=f"{investor.name} - {fund.fund_name}"
|
||||
if fund.fund_name
|
||||
else investor.name,
|
||||
aum=investor.aum,
|
||||
check_size_lower=fund.check_size_lower,
|
||||
check_size_upper=fund.check_size_upper,
|
||||
geographic_focus=fund.geographic_focus,
|
||||
stage_focus=stage_focus,
|
||||
portfolio_companies=portfolio_companies,
|
||||
sectors=fund_sectors,
|
||||
compatibility_score=1.0,
|
||||
)
|
||||
investment_responses.append(investment_response)
|
||||
|
||||
total_count = len(investment_responses)
|
||||
total_pages = 1 if total_count > 0 else 0
|
||||
|
||||
return PaginatedResponse(
|
||||
items=investment_responses,
|
||||
total=total_count,
|
||||
page=1,
|
||||
page_size=total_count,
|
||||
total_pages=total_pages,
|
||||
)
|
||||
|
||||
finally:
|
||||
db_session.close()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Binary file not shown.
@@ -0,0 +1,315 @@
|
||||
import logging
|
||||
import re
|
||||
import unicodedata
|
||||
|
||||
import pandas as pd
|
||||
from models import CompanyTable, InvestorTable, SectorTable, engine, init_database
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
# Set up logging
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Import the schema
|
||||
init_database()
|
||||
|
||||
|
||||
# ===================== Ingesting Original Data =====================#
|
||||
def parse_investor_names(investor_names_str):
|
||||
"""Parse comma-separated investor names and return a list"""
|
||||
if pd.isna(investor_names_str) or investor_names_str == "":
|
||||
return []
|
||||
|
||||
# Split by comma and clean whitespace
|
||||
# investors = [name.strip() for name in str(investor_names_str).split(",")]
|
||||
investors = [
|
||||
clean_name(name.strip()) for name in str(investor_names_str).split(",")
|
||||
]
|
||||
return [investor for investor in investors if investor]
|
||||
|
||||
|
||||
def parse_industries(industries_str):
|
||||
"""Parse comma-separated industries and return a list"""
|
||||
if pd.isna(industries_str) or industries_str == "":
|
||||
return []
|
||||
|
||||
# Split by comma and clean whitespace
|
||||
industries = [industry.strip() for industry in str(industries_str).split(",")]
|
||||
return [industry for industry in industries if industry]
|
||||
|
||||
|
||||
def clean_special_characters(text):
|
||||
"""Clean special characters from text, converting to ASCII equivalents"""
|
||||
if not text:
|
||||
return text
|
||||
|
||||
# First remove ellipses and other problematic patterns
|
||||
text = str(text).replace("...", "").replace("..", "")
|
||||
|
||||
# Normalize unicode characters to their closest ASCII equivalents
|
||||
normalized = unicodedata.normalize("NFKD", text)
|
||||
|
||||
# Remove accents and convert to ASCII
|
||||
ascii_text = normalized.encode("ascii", "ignore").decode("ascii")
|
||||
|
||||
# Remove any remaining non-alphanumeric characters except spaces, hyphens, and periods
|
||||
cleaned = re.sub(r"[^a-zA-Z0-9\s\-\.]", "", ascii_text)
|
||||
|
||||
# Clean up multiple spaces
|
||||
cleaned = re.sub(r"\s+", " ", cleaned).strip()
|
||||
|
||||
return cleaned
|
||||
|
||||
|
||||
def clean_string(value):
|
||||
"""Clean string values, converting empty/null/nan/0 to None and removing special characters"""
|
||||
if (
|
||||
pd.isna(value)
|
||||
or value == ""
|
||||
or str(value).lower() in ["nan", "null", "none", "0", "0.0"]
|
||||
):
|
||||
return None
|
||||
|
||||
# First clean special characters
|
||||
cleaned = clean_special_characters(str(value).strip())
|
||||
|
||||
# Check if result is just "0" after cleaning
|
||||
if cleaned in ["0", "0.0", "null", "nan", "none"]:
|
||||
return None
|
||||
|
||||
return cleaned if cleaned else None
|
||||
|
||||
|
||||
def clean_name(value):
|
||||
"""Clean names (companies, investors) with special character handling"""
|
||||
if (
|
||||
pd.isna(value)
|
||||
or value == ""
|
||||
or str(value).lower() in ["nan", "null", "none", "0", "0.0"]
|
||||
):
|
||||
return None
|
||||
|
||||
# Clean special characters but be more permissive for names
|
||||
text = str(value).strip()
|
||||
# First remove ellipses and other problematic patterns
|
||||
# text = text.replace("...", "").replace("..", "")
|
||||
|
||||
# Normalize unicode characters
|
||||
normalized = unicodedata.normalize("NFKD", text)
|
||||
|
||||
# Convert to ASCII but keep more characters for business names
|
||||
ascii_text = normalized.encode("ascii", "ignore").decode("ascii")
|
||||
|
||||
# Allow alphanumeric, spaces, hyphens, periods, parentheses, and ampersands
|
||||
cleaned = re.sub(r"[^a-zA-Z0-9\s\-\.\(\)&]", "", ascii_text)
|
||||
|
||||
# Clean up multiple spaces
|
||||
cleaned = re.sub(r"\s+", " ", cleaned).strip()
|
||||
|
||||
# Remove any trailing or leading periods
|
||||
cleaned = cleaned.strip(".")
|
||||
|
||||
cleaned = cleaned.replace("..", "").replace("...", "")
|
||||
# Check if result is just "0" after cleaning
|
||||
if cleaned in ["0", "0.0", "null", "nan", "none"]:
|
||||
return None
|
||||
|
||||
return cleaned if cleaned else None
|
||||
|
||||
|
||||
def clean_integer(value):
|
||||
"""Clean integer values, converting empty/null/nan/0 to None"""
|
||||
if pd.isna(value) or str(value).lower() in ["nan", "null", "none", "", "0", "0.0"]:
|
||||
return None
|
||||
try:
|
||||
cleaned_val = int(float(value))
|
||||
return cleaned_val if cleaned_val > 0 else None
|
||||
except (ValueError, TypeError):
|
||||
return None
|
||||
|
||||
|
||||
def parse_website(website_str: str):
|
||||
try:
|
||||
_, end = website_str.split(":")
|
||||
|
||||
if end == "0":
|
||||
return None
|
||||
return "https:" + end
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def ingest_data():
|
||||
# Create database engine and session
|
||||
Session = sessionmaker(bind=engine)
|
||||
session = Session()
|
||||
|
||||
# Load CSV files
|
||||
print("Loading CSV files...")
|
||||
companies_df = pd.read_csv("companies.csv")
|
||||
investors_df = pd.read_csv("investors.csv")
|
||||
|
||||
print(f"📊 Companies CSV: {len(companies_df)} rows")
|
||||
print(f"📊 Investors CSV: {len(investors_df)} rows")
|
||||
|
||||
# Step 1: Ingest Investors
|
||||
print("\n🔄 Step 1: Ingesting Investors...")
|
||||
investors_processed = 0
|
||||
|
||||
for index, row in investors_df.iterrows():
|
||||
try:
|
||||
investor_name = clean_name(row.get("Filtered investor names", ""))
|
||||
|
||||
if investor_name:
|
||||
# Check if investor already exists
|
||||
existing_investor = (
|
||||
session.query(InvestorTable).filter_by(name=investor_name).first()
|
||||
)
|
||||
if not existing_investor:
|
||||
investor = InvestorTable(
|
||||
name=investor_name,
|
||||
description=clean_string(row.get("Business model", "")),
|
||||
headquarters=clean_string(row.get("HQ", "")),
|
||||
website=parse_website(str(row.get("Website", "")).strip()),
|
||||
number_of_investments=clean_integer(
|
||||
row.get("Number of investments")
|
||||
),
|
||||
)
|
||||
session.add(investor)
|
||||
investors_processed += 1
|
||||
|
||||
if investors_processed % 1000 == 0:
|
||||
session.commit()
|
||||
print(f" Committed {investors_processed} investors")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing investor {index}: {e}")
|
||||
continue
|
||||
|
||||
session.commit()
|
||||
print(f"✅ Investors completed: {investors_processed} processed")
|
||||
|
||||
# Step 2: Ingest Companies and Rounds
|
||||
print("\n🔄 Step 2: Ingesting Companies and Sectors...")
|
||||
companies_processed = 0
|
||||
sectors_created = set()
|
||||
|
||||
for index, row in companies_df.iterrows():
|
||||
try:
|
||||
# Process company
|
||||
company_name = clean_name(row.get("Organization Name", ""))
|
||||
if not company_name:
|
||||
continue
|
||||
|
||||
# Check if company already exists
|
||||
existing_company = (
|
||||
session.query(CompanyTable).filter_by(name=company_name).first()
|
||||
)
|
||||
if existing_company:
|
||||
company = existing_company
|
||||
else:
|
||||
# Create company
|
||||
company = CompanyTable(
|
||||
name=company_name,
|
||||
description=clean_string(row.get("Organization Description", "")),
|
||||
location=clean_string(row.get("Organization Location", "")),
|
||||
industry=clean_string(row.get("Organization Industries", "")),
|
||||
website=clean_string(row.get("Organization Website", "")),
|
||||
)
|
||||
session.add(company)
|
||||
session.flush() # Get the company ID
|
||||
companies_processed += 1
|
||||
|
||||
# Process investor relationships
|
||||
investor_names_str = row.get("Investor Names", "")
|
||||
if pd.notna(investor_names_str) and investor_names_str:
|
||||
investor_names = parse_investor_names(investor_names_str)
|
||||
|
||||
for investor_name in investor_names:
|
||||
# Find investor in database
|
||||
investor = (
|
||||
session.query(InvestorTable)
|
||||
.filter_by(name=investor_name.strip())
|
||||
.first()
|
||||
)
|
||||
|
||||
if investor:
|
||||
# Add investor-company relationship
|
||||
if company not in investor.portfolio_companies:
|
||||
investor.portfolio_companies.append(company)
|
||||
else:
|
||||
print("This company has an investor not in DB:", investor_name)
|
||||
|
||||
# Process sectors/industries
|
||||
industries_str = row.get("Organization Industries", "")
|
||||
if pd.notna(industries_str) and industries_str:
|
||||
industries = parse_industries(industries_str)
|
||||
|
||||
for industry_name in industries:
|
||||
industry_name = industry_name.strip()
|
||||
if industry_name:
|
||||
# Check if sector exists
|
||||
sector = (
|
||||
session.query(SectorTable)
|
||||
.filter_by(name=industry_name)
|
||||
.first()
|
||||
)
|
||||
if not sector:
|
||||
sector = SectorTable(name=industry_name)
|
||||
session.add(sector)
|
||||
session.flush()
|
||||
sectors_created.add(industry_name)
|
||||
|
||||
# Add company-sector relationship
|
||||
if sector not in company.sectors:
|
||||
company.sectors.append(sector)
|
||||
|
||||
# Commit every 100 companies
|
||||
if companies_processed % 100 == 0 and companies_processed > 0:
|
||||
session.commit()
|
||||
print(f" Processed {companies_processed} companies...")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing company {index}: {e}")
|
||||
session.rollback()
|
||||
continue
|
||||
|
||||
# Step 3: Link investors to sectors based on portfolio companies
|
||||
print("\n🔄 Step 3: Linking Investors to Sectors...")
|
||||
investors_linked_to_sectors = 0
|
||||
all_investors = session.query(InvestorTable).all()
|
||||
for investor in all_investors:
|
||||
sectors = set()
|
||||
for company in investor.portfolio_companies:
|
||||
for sector in company.sectors:
|
||||
sectors.add(sector)
|
||||
# Add sectors to investor if not already present
|
||||
for sector in sectors:
|
||||
if sector not in investor.sectors:
|
||||
investor.sectors.append(sector)
|
||||
if sectors:
|
||||
investors_linked_to_sectors += 1
|
||||
session.commit()
|
||||
print(f"✅ Linked {investors_linked_to_sectors} investors to sectors")
|
||||
|
||||
# Final commit
|
||||
session.commit()
|
||||
|
||||
# Final counts
|
||||
final_investors = session.query(InvestorTable).count()
|
||||
final_companies = session.query(CompanyTable).count()
|
||||
final_sectors = session.query(SectorTable).count()
|
||||
|
||||
print("\n🎉 Ingestion Complete!")
|
||||
print(f" Investors: {final_investors}")
|
||||
print(f" Companies: {final_companies}")
|
||||
print(f" Sectors: {final_sectors}")
|
||||
|
||||
session.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
ingest_data()
|
||||
# print(clean_name("A... Energi"))
|
||||
# print(clean_name("B.. Tech"))
|
||||
# print(clean_name("A... Energi"))
|
||||
@@ -0,0 +1,381 @@
|
||||
import enum
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import Depends
|
||||
from sqlalchemy import (
|
||||
Column,
|
||||
DateTime,
|
||||
ForeignKey,
|
||||
Integer,
|
||||
String,
|
||||
Table,
|
||||
Text,
|
||||
create_engine,
|
||||
func,
|
||||
)
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import Session, declarative_mixin, relationship, sessionmaker
|
||||
from sqlalchemy.types import JSON, Enum
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
# Database configuration
|
||||
# DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///./investors.db")
|
||||
|
||||
# Create engine
|
||||
engine = create_engine("sqlite:///./investors.db", echo=False)
|
||||
|
||||
# Create session factory
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
|
||||
def get_db():
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
db_dependency = Annotated[Session, Depends(get_db)]
|
||||
|
||||
|
||||
def init_database():
|
||||
"""Initialize the database by creating all tables"""
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
|
||||
def get_session_sync() -> Session:
|
||||
"""Get a database session for synchronous operations"""
|
||||
return SessionLocal()
|
||||
|
||||
|
||||
def get_db_session():
|
||||
"""Get a database session for direct use."""
|
||||
return SessionLocal()
|
||||
|
||||
|
||||
@declarative_mixin
|
||||
class TimestampMixin:
|
||||
created_at = Column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now())
|
||||
|
||||
|
||||
class InvestmentStage(enum.Enum):
|
||||
SEED = "SEED"
|
||||
SERIES_A = "SERIES_A"
|
||||
SERIES_B = "SERIES_B"
|
||||
SERIES_C = "SERIES_C"
|
||||
GROWTH = "GROWTH"
|
||||
LATE_STAGE = "LATE_STAGE"
|
||||
|
||||
|
||||
# Association table for many-to-many relationship between investors and companies
|
||||
investor_company_association = Table(
|
||||
"investor_companies",
|
||||
Base.metadata,
|
||||
Column("investor_id", Integer, ForeignKey("investors.id")),
|
||||
Column("company_id", Integer, ForeignKey("companies.id")),
|
||||
)
|
||||
|
||||
|
||||
# Association table for investor-sector many-to-many
|
||||
investor_sector_association = Table(
|
||||
"investor_sectors",
|
||||
Base.metadata,
|
||||
Column("investor_id", Integer, ForeignKey("investors.id")),
|
||||
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
||||
)
|
||||
|
||||
|
||||
company_sector_association = Table(
|
||||
"company_sector",
|
||||
Base.metadata,
|
||||
Column("company_id", Integer, ForeignKey("companies.id")),
|
||||
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
||||
)
|
||||
|
||||
project_sector_association = Table(
|
||||
"project_sector",
|
||||
Base.metadata,
|
||||
Column("project_id", Integer, ForeignKey("projects.id")),
|
||||
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
||||
)
|
||||
|
||||
project_investor_association = Table(
|
||||
"project_investors",
|
||||
Base.metadata,
|
||||
Column("project_id", Integer, ForeignKey("projects.id")),
|
||||
Column("investor_id", Integer, ForeignKey("investors.id")),
|
||||
)
|
||||
|
||||
project_company_association = Table(
|
||||
"project_companies",
|
||||
Base.metadata,
|
||||
Column("project_id", Integer, ForeignKey("projects.id")),
|
||||
Column("company_id", Integer, ForeignKey("companies.id")),
|
||||
)
|
||||
|
||||
# Association table for investor-stage many-to-many
|
||||
investor_stage_association = Table(
|
||||
"investor_stages",
|
||||
Base.metadata,
|
||||
Column("investor_id", Integer, ForeignKey("investors.id")),
|
||||
Column("stage_id", Integer, ForeignKey("investment_stages.id")),
|
||||
)
|
||||
|
||||
# Association table for fund-stage many-to-many
|
||||
fund_investment_stages_association = Table(
|
||||
"fund_investment_stages",
|
||||
Base.metadata,
|
||||
Column("fund_id", Integer, ForeignKey("funds.id")),
|
||||
Column("stage_id", Integer, ForeignKey("investment_stages.id")),
|
||||
)
|
||||
|
||||
# Association table for fund-sector many-to-many
|
||||
fund_sectors_association = Table(
|
||||
"fund_sectors",
|
||||
Base.metadata,
|
||||
Column("fund_id", Integer, ForeignKey("funds.id")),
|
||||
Column("sector_id", Integer, ForeignKey("sectors.id")),
|
||||
)
|
||||
|
||||
|
||||
class InvestorTable(Base, TimestampMixin):
|
||||
__tablename__ = "investors"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
description = Column(Text, nullable=True)
|
||||
|
||||
# Basic investor info
|
||||
website = Column(String, nullable=True)
|
||||
headquarters = Column(String, nullable=True)
|
||||
|
||||
# AUM fields
|
||||
aum = Column(Integer, nullable=True) # Store as integer for numerical filtering
|
||||
aum_as_of_date = Column(String, nullable=True)
|
||||
aum_source_url = Column(String, nullable=True)
|
||||
|
||||
# Check size (deprecated in favor of fund-level data, but keeping for backward compatibility)
|
||||
check_size_lower = Column(Integer, nullable=True)
|
||||
check_size_upper = Column(Integer, nullable=True)
|
||||
|
||||
# Geographic focus (deprecated in favor of fund-level, but keeping for backward compatibility)
|
||||
geographic_focus = Column(String, nullable=True)
|
||||
|
||||
# Investment thesis and portfolio
|
||||
investment_thesis = Column(JSON, nullable=True) # Array of thesis statements
|
||||
portfolio_highlights = Column(
|
||||
JSON, nullable=True
|
||||
) # Array of portfolio company names
|
||||
linked_documents = Column(JSON, nullable=True) # Array of document URLs
|
||||
|
||||
# Research metadata
|
||||
researcher_notes = Column(Text, nullable=True)
|
||||
missing_important_fields = Column(
|
||||
JSON, nullable=True
|
||||
) # Array of missing field names
|
||||
sources = Column(JSON, nullable=True) # JSON object with source URLs
|
||||
|
||||
# Portfolio info
|
||||
number_of_investments = Column(Integer, nullable=True)
|
||||
|
||||
# Relationships
|
||||
team_members = relationship(
|
||||
"InvestorMember", back_populates="investor", cascade="all, delete-orphan"
|
||||
)
|
||||
funds = relationship(
|
||||
"FundTable", back_populates="investor", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
# Many-to-many relationship with investment stages
|
||||
investment_stages = relationship(
|
||||
"InvestmentStageTable",
|
||||
secondary=investor_stage_association,
|
||||
back_populates="investors",
|
||||
)
|
||||
|
||||
# Relationship to portfolio companies
|
||||
portfolio_companies = relationship(
|
||||
"CompanyTable",
|
||||
secondary=investor_company_association,
|
||||
back_populates="investors",
|
||||
)
|
||||
|
||||
sectors = relationship(
|
||||
"SectorTable",
|
||||
secondary=investor_sector_association,
|
||||
back_populates="investors",
|
||||
)
|
||||
|
||||
projects = relationship(
|
||||
"ProjectTable",
|
||||
secondary=project_investor_association,
|
||||
back_populates="investors",
|
||||
)
|
||||
|
||||
|
||||
class InvestorMember(Base, TimestampMixin):
|
||||
__tablename__ = "investor_members"
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
role = Column(String, nullable=True)
|
||||
title = Column(String, nullable=True) # Alternative to role
|
||||
email = Column(String, nullable=True)
|
||||
source_url = Column(String, nullable=True) # URL where member info was found
|
||||
|
||||
investor_id = Column(Integer, ForeignKey("investors.id"))
|
||||
investor = relationship("InvestorTable", back_populates="team_members")
|
||||
|
||||
|
||||
class FundTable(Base, TimestampMixin):
|
||||
__tablename__ = "funds"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
investor_id = Column(Integer, ForeignKey("investors.id"), nullable=False)
|
||||
|
||||
# Fund details
|
||||
fund_name = Column(String, nullable=True)
|
||||
fund_size = Column(
|
||||
Integer, nullable=True
|
||||
) # Store as integer for numerical filtering
|
||||
fund_size_source_url = Column(String, nullable=True)
|
||||
|
||||
# Check size range (parsed from estimated_investment_size by LLM)
|
||||
check_size_lower = Column(Integer, nullable=True)
|
||||
check_size_upper = Column(Integer, nullable=True)
|
||||
|
||||
source_url = Column(String, nullable=True)
|
||||
source_provider = Column(String, nullable=True) # e.g., "Perplexity"
|
||||
|
||||
# Geographic focus as simple string
|
||||
geographic_focus = Column(String, nullable=True)
|
||||
|
||||
# Relationships
|
||||
investor = relationship("InvestorTable", back_populates="funds")
|
||||
investment_stages = relationship(
|
||||
"InvestmentStageTable",
|
||||
secondary=fund_investment_stages_association,
|
||||
back_populates="funds",
|
||||
)
|
||||
sectors = relationship(
|
||||
"SectorTable",
|
||||
secondary=fund_sectors_association,
|
||||
back_populates="funds",
|
||||
)
|
||||
|
||||
|
||||
class InvestmentStageTable(Base, TimestampMixin):
|
||||
__tablename__ = "investment_stages"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False, unique=True)
|
||||
|
||||
# Relationships
|
||||
investors = relationship(
|
||||
"InvestorTable",
|
||||
secondary=investor_stage_association,
|
||||
back_populates="investment_stages",
|
||||
)
|
||||
funds = relationship(
|
||||
"FundTable",
|
||||
secondary=fund_investment_stages_association,
|
||||
back_populates="investment_stages",
|
||||
)
|
||||
|
||||
|
||||
class CompanyTable(Base, TimestampMixin):
|
||||
__tablename__ = "companies"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
industry = Column(String, nullable=True)
|
||||
location = Column(String, nullable=True)
|
||||
description = Column(String, nullable=True)
|
||||
founded_year = Column(Integer, nullable=True)
|
||||
website = Column(String, nullable=True)
|
||||
|
||||
members = relationship(
|
||||
"CompanyMember", back_populates="company", cascade="all, delete-orphan"
|
||||
)
|
||||
# Relationship back to investors
|
||||
investors = relationship(
|
||||
"InvestorTable",
|
||||
secondary=investor_company_association,
|
||||
back_populates="portfolio_companies",
|
||||
)
|
||||
|
||||
sectors = relationship(
|
||||
"SectorTable", secondary=company_sector_association, back_populates="companies"
|
||||
)
|
||||
|
||||
projects = relationship(
|
||||
"ProjectTable",
|
||||
secondary=project_company_association,
|
||||
back_populates="companies",
|
||||
)
|
||||
|
||||
|
||||
class CompanyMember(Base, TimestampMixin):
|
||||
__tablename__ = "company_members"
|
||||
id = Column(Integer, primary_key=True)
|
||||
name = Column(String)
|
||||
linkedin = Column(String, nullable=True)
|
||||
role = Column(String, nullable=True)
|
||||
company_id = Column(Integer, ForeignKey("companies.id"), nullable=False)
|
||||
|
||||
company = relationship("CompanyTable", back_populates="members")
|
||||
|
||||
|
||||
class SectorTable(Base, TimestampMixin):
|
||||
__tablename__ = "sectors"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
|
||||
# Relationships
|
||||
investors = relationship(
|
||||
"InvestorTable",
|
||||
secondary=investor_sector_association,
|
||||
back_populates="sectors",
|
||||
)
|
||||
companies = relationship(
|
||||
"CompanyTable", secondary=company_sector_association, back_populates="sectors"
|
||||
)
|
||||
projects = relationship(
|
||||
"ProjectTable", secondary=project_sector_association, back_populates="sector"
|
||||
)
|
||||
funds = relationship(
|
||||
"FundTable",
|
||||
secondary=fund_sectors_association,
|
||||
back_populates="sectors",
|
||||
)
|
||||
|
||||
|
||||
class ProjectTable(Base, TimestampMixin):
|
||||
__tablename__ = "projects"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
name = Column(String, nullable=False)
|
||||
valuation = Column(Integer, nullable=True)
|
||||
|
||||
stage = Column(Enum(InvestmentStage), nullable=True)
|
||||
location = Column(String, nullable=True)
|
||||
description = Column(Text, nullable=True)
|
||||
start_date = Column(DateTime, nullable=True)
|
||||
end_date = Column(DateTime, nullable=True)
|
||||
|
||||
sector = relationship(
|
||||
"SectorTable", secondary=project_sector_association, back_populates="projects"
|
||||
)
|
||||
investors = relationship(
|
||||
"InvestorTable",
|
||||
secondary=project_investor_association,
|
||||
back_populates="projects",
|
||||
)
|
||||
companies = relationship(
|
||||
"CompanyTable", secondary=project_company_association, back_populates="projects"
|
||||
)
|
||||
Binary file not shown.
Reference in New Issue
Block a user