7 Commits

25 changed files with 1092 additions and 152 deletions
+4 -2
View File
@@ -8,8 +8,10 @@
/chroma_db /chroma_db
/*__pycache__*/ *__pycache__
/*.db /*.db
/*.cypython-* *.cypython
/preprocessor
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
-1
View File
@@ -32,7 +32,6 @@ db_dependency = Annotated[Session, Depends(get_db)]
def init_database(): def init_database():
"""Initialize the database by creating all tables""" """Initialize the database by creating all tables"""
Base.metadata.create_all(bind=engine) Base.metadata.create_all(bind=engine)
print("Database initialized successfully!")
def get_session_sync() -> Session: def get_session_sync() -> Session:
+80 -14
View File
@@ -1,10 +1,11 @@
import enum import enum
from db.db import Base
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table, Text, func from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table, Text, func
from sqlalchemy.orm import declarative_mixin, relationship from sqlalchemy.orm import declarative_mixin, relationship
from sqlalchemy.types import Enum from sqlalchemy.types import Enum
from db.db import Base
@declarative_mixin @declarative_mixin
class TimestampMixin: class TimestampMixin:
@@ -48,6 +49,27 @@ company_sector_association = Table(
Column("sector_id", Integer, ForeignKey("sectors.id")), Column("sector_id", Integer, ForeignKey("sectors.id")),
) )
project_sector_association = Table(
"project_sector",
Base.metadata,
Column("project_id", Integer, ForeignKey("projects.id")),
Column("sector_id", Integer, ForeignKey("sectors.id")),
)
project_investor_association = Table(
"project_investors",
Base.metadata,
Column("project_id", Integer, ForeignKey("projects.id")),
Column("investor_id", Integer, ForeignKey("investors.id")),
)
project_company_association = Table(
"project_companies",
Base.metadata,
Column("project_id", Integer, ForeignKey("projects.id")),
Column("company_id", Integer, ForeignKey("companies.id")),
)
class InvestorTable(Base, TimestampMixin): class InvestorTable(Base, TimestampMixin):
__tablename__ = "investors" __tablename__ = "investors"
@@ -55,12 +77,14 @@ class InvestorTable(Base, TimestampMixin):
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False) name = Column(String, nullable=False)
description = Column(Text, nullable=True) description = Column(Text, nullable=True)
aum = Column(Integer, nullable=False) # Assets Under Management aum = Column(Integer, nullable=True) # Assets Under Management
check_size_lower = Column(Integer, nullable=False) # Lower bound check_size_lower = Column(Integer, nullable=True) # Lower bound
check_size_upper = Column(Integer, nullable=False) # Upper bound check_size_upper = Column(Integer, nullable=True) # Upper bound
geographic_focus = Column(String, nullable=False) geographic_focus = Column(String, nullable=True)
stage_focus = Column(Enum(InvestmentStage), nullable=False) stage_focus = Column(Enum(InvestmentStage), nullable=True)
number_of_investments = Column(Integer, default=0) number_of_investments = Column(Integer, default=0, nullable=True)
team_members = relationship("InvestorMember", back_populates="investor")
# Relationship to portfolio companies # Relationship to portfolio companies
portfolio_companies = relationship( portfolio_companies = relationship(
@@ -68,20 +92,26 @@ class InvestorTable(Base, TimestampMixin):
secondary=investor_company_association, secondary=investor_company_association,
back_populates="investors", back_populates="investors",
) )
team_members = relationship("InvestorMember", back_populates="investor")
sectors = relationship( sectors = relationship(
"SectorTable", "SectorTable",
secondary=investor_sector_association, secondary=investor_sector_association,
back_populates="investors", back_populates="investors",
) )
projects = relationship(
"ProjectTable",
secondary=project_investor_association,
back_populates="investors",
)
class InvestorMember(Base, TimestampMixin): class InvestorMember(Base, TimestampMixin):
__tablename__ = "investor_members" __tablename__ = "investor_members"
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False) name = Column(String, nullable=False)
role = Column(String, nullable=False) role = Column(String, nullable=True)
email = Column(String, nullable=False) email = Column(String, nullable=True)
investor_id = Column(Integer, ForeignKey("investors.id")) investor_id = Column(Integer, ForeignKey("investors.id"))
investor = relationship("InvestorTable", back_populates="team_members") investor = relationship("InvestorTable", back_populates="team_members")
@@ -92,8 +122,8 @@ class CompanyTable(Base, TimestampMixin):
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False) name = Column(String, nullable=False)
industry = Column(String, nullable=False) industry = Column(String, nullable=True)
location = Column(String, nullable=False) location = Column(String, nullable=True)
description = Column(String, nullable=True) description = Column(String, nullable=True)
founded_year = Column(Integer, nullable=True) founded_year = Column(Integer, nullable=True)
website = Column(String, nullable=True) website = Column(String, nullable=True)
@@ -110,13 +140,19 @@ class CompanyTable(Base, TimestampMixin):
"SectorTable", secondary=company_sector_association, back_populates="companies" "SectorTable", secondary=company_sector_association, back_populates="companies"
) )
projects = relationship(
"ProjectTable",
secondary=project_company_association,
back_populates="companies",
)
class CompanyMember(Base, TimestampMixin): class CompanyMember(Base, TimestampMixin):
__tablename__ = "company_members" __tablename__ = "company_members"
id = Column(Integer, primary_key=True) id = Column(Integer, primary_key=True)
name = Column(String) name = Column(String)
linkedin = Column(String) linkedin = Column(String, nullable=True)
role = Column(String) role = Column(String, nullable=True)
company_id = Column(Integer, ForeignKey("companies.id"), nullable=False) company_id = Column(Integer, ForeignKey("companies.id"), nullable=False)
company = relationship("CompanyTable", back_populates="members") company = relationship("CompanyTable", back_populates="members")
@@ -138,3 +174,33 @@ class SectorTable(Base, TimestampMixin):
companies = relationship( companies = relationship(
"CompanyTable", secondary=company_sector_association, back_populates="sectors" "CompanyTable", secondary=company_sector_association, back_populates="sectors"
) )
projects = relationship(
"ProjectTable", secondary=project_sector_association, back_populates="sector"
)
class ProjectTable(Base, TimestampMixin):
__tablename__ = "projects"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
valuation = Column(Integer, nullable=True)
stage = Column(Enum(InvestmentStage), nullable=True)
location = Column(String, nullable=True)
description = Column(Text, nullable=True)
start_date = Column(DateTime, nullable=True)
end_date = Column(DateTime, nullable=True)
sector = relationship(
"SectorTable", secondary=project_sector_association, back_populates="projects"
)
investors = relationship(
"InvestorTable",
secondary=project_investor_association,
back_populates="projects",
)
companies = relationship(
"CompanyTable", secondary=project_company_association, back_populates="projects"
)
+15 -4
View File
@@ -1,16 +1,23 @@
import io import io
import pandas as pd import pandas as pd
from db.db import db_dependency, init_database from db.db import Base, db_dependency, engine
from dotenv import load_dotenv from dotenv import load_dotenv
from fastapi import FastAPI, File, Form, UploadFile from fastapi import FastAPI, File, Form, UploadFile
from pydantic import BaseModel from pydantic import BaseModel
from routers import companies, investors from routers import companies, investors, projects
from schemas.router_schemas import InvestorList from schemas.router_schemas import InvestorList
from services.llm_parser import InvestorProcessor from services.llm_parser import InvestorProcessor
from services.querying import QueryProcessor from services.querying import QueryProcessor
load_dotenv() load_dotenv()
def init_database():
"""Initialize the database by creating all tables"""
Base.metadata.create_all(bind=engine)
init_database() init_database()
app = FastAPI() app = FastAPI()
@@ -34,7 +41,9 @@ def health():
@app.post("/parse-csv", tags=["CSV Upload"], response_model=list[dict]) @app.post("/parse-csv", tags=["CSV Upload"], response_model=list[dict])
async def parse_csv(db: db_dependency, file: UploadFile = File(...), is_investor: int = Form(...)): async def parse_csv(
db: db_dependency, file: UploadFile = File(...), is_investor: int = Form(...)
):
# Read uploaded CSV with pandas # Read uploaded CSV with pandas
content = await file.read() content = await file.read()
df = pd.read_csv(io.StringIO(content.decode("utf-8"))) df = pd.read_csv(io.StringIO(content.decode("utf-8")))
@@ -69,7 +78,9 @@ async def query_investors(request: QueryRequest):
app.include_router(investors.router) app.include_router(investors.router)
app.include_router(companies.router) app.include_router(companies.router)
app.include_router(projects.router)
if __name__ == "__main__": if __name__ == "__main__":
import uvicorn import uvicorn
uvicorn.run(app="main:app", host="localhost", port=8000, reload=True) uvicorn.run(app="main:app", host="0.0.0.0", port=8585, reload=True)
Binary file not shown.
Binary file not shown.
Binary file not shown.
+4
View File
@@ -34,6 +34,10 @@ def read_companies(db: Session = Depends(get_db)):
"""Get all companies with their investor relationships""" """Get all companies with their investor relationships"""
companies = ( companies = (
db.query(CompanyTable) db.query(CompanyTable)
.filter(
CompanyTable.name.isnot(None),
CompanyTable.description.isnot(None)
)
.options( .options(
selectinload(CompanyTable.investors), selectinload(CompanyTable.investors),
selectinload(CompanyTable.members), selectinload(CompanyTable.members),
+87 -40
View File
@@ -1,11 +1,11 @@
from typing import List, Optional from typing import List, Optional
from db.db import get_db from db.db import get_db
from db.models import InvestorTable, SectorTable from db.models import InvestorTable, SectorTable
from fastapi import APIRouter, Depends, HTTPException, Query from fastapi import APIRouter, Depends, HTTPException, Query
from schemas.router_schemas import InvestmentStage, InvestorData
from pydantic import BaseModel from pydantic import BaseModel
from schemas.router_schemas import InvestmentStage, InvestorData
from services.querying import QueryProcessor
from sqlalchemy.orm import Session, selectinload from sqlalchemy.orm import Session, selectinload
router = APIRouter(tags=["Investor Routes"]) router = APIRouter(tags=["Investor Routes"])
@@ -181,26 +181,16 @@ def create_investor(investor: InvestorCreate, db: Session = Depends(get_db)):
) )
@router.put("/investors/{investor_id}", response_model=InvestorData) @router.get("/investors/{investor_id}/similar", response_model=List[InvestorData])
def update_investor( def find_similar_investors(
investor_id: int, investor: InvestorUpdate, db: Session = Depends(get_db) investor_id: int,
limit: int = Query(10, description="Maximum number of similar investors to return"),
db: Session = Depends(get_db)
): ):
"""Update an existing investor""" """Find investors similar to a given investor based on characteristics"""
db_investor = (
db.query(InvestorTable).filter(InvestorTable.id == investor_id).first() # Get the target investor
) target_investor = (
if not db_investor:
raise HTTPException(status_code=404, detail="Investor not found")
update_data = investor.dict(exclude_unset=True)
for field, value in update_data.items():
setattr(db_investor, field, value)
db.commit()
db.refresh(db_investor)
# Reload with relationships
investor_with_relations = (
db.query(InvestorTable) db.query(InvestorTable)
.options( .options(
selectinload(InvestorTable.portfolio_companies), selectinload(InvestorTable.portfolio_companies),
@@ -211,24 +201,81 @@ def update_investor(
.first() .first()
) )
# Transform to InvestorData format if not target_investor:
return InvestorData(
investor=investor_with_relations,
portfolio_companies=investor_with_relations.portfolio_companies,
team_members=investor_with_relations.team_members,
sectors=investor_with_relations.sectors,
)
@router.delete("/investors/{investor_id}")
def delete_investor(investor_id: int, db: Session = Depends(get_db)):
"""Delete an investor"""
db_investor = (
db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
)
if not db_investor:
raise HTTPException(status_code=404, detail="Investor not found") raise HTTPException(status_code=404, detail="Investor not found")
db.delete(db_investor) # Get target investor's sector IDs for comparison
db.commit() target_sector_ids = {sector.id for sector in target_investor.sectors}
return {"message": "Investor deleted successfully"}
# Query all other investors with their relationships
candidates = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
)
.filter(InvestorTable.id != investor_id)
.all()
)
# Calculate similarity scores
scored_investors = []
for candidate in candidates:
score = 0
# Stage focus match (30 points)
if candidate.stage_focus == target_investor.stage_focus:
score += 30
# Geographic focus match (20 points for exact, 10 for partial)
if candidate.geographic_focus and target_investor.geographic_focus:
if candidate.geographic_focus.lower() == target_investor.geographic_focus.lower():
score += 20
elif (candidate.geographic_focus.lower() in target_investor.geographic_focus.lower() or
target_investor.geographic_focus.lower() in candidate.geographic_focus.lower()):
score += 10
# Check size overlap (20 points max)
if (candidate.check_size_lower and candidate.check_size_upper and
target_investor.check_size_lower and target_investor.check_size_upper):
# Calculate overlap percentage
overlap_start = max(candidate.check_size_lower, target_investor.check_size_lower)
overlap_end = min(candidate.check_size_upper, target_investor.check_size_upper)
if overlap_end > overlap_start:
overlap = overlap_end - overlap_start
target_range = target_investor.check_size_upper - target_investor.check_size_lower
overlap_ratio = overlap / target_range if target_range > 0 else 0
score += int(20 * overlap_ratio)
# AUM similarity (15 points max)
if candidate.aum and target_investor.aum:
aum_diff = abs(candidate.aum - target_investor.aum)
max_aum = max(candidate.aum, target_investor.aum)
similarity_ratio = 1 - (aum_diff / max_aum) if max_aum > 0 else 0
score += int(15 * similarity_ratio)
# Sector overlap (30 points max)
candidate_sector_ids = {sector.id for sector in candidate.sectors}
if target_sector_ids and candidate_sector_ids:
common_sectors = target_sector_ids.intersection(candidate_sector_ids)
overlap_ratio = len(common_sectors) / len(target_sector_ids)
score += int(30 * overlap_ratio)
if score > 0: # Only include investors with some similarity
scored_investors.append((score, candidate))
# Sort by score (descending) and take top N
scored_investors.sort(key=lambda x: x[0], reverse=True)
similar_investors = [inv for score, inv in scored_investors[:limit]]
# Transform to InvestorData format
return [
InvestorData(
investor=inv,
portfolio_companies=inv.portfolio_companies,
team_members=inv.team_members,
sectors=inv.sectors,
)
for inv in similar_investors
]
+447
View File
@@ -0,0 +1,447 @@
from typing import List, Optional
from db.db import get_db
from db.models import (
CompanyTable,
InvestorTable,
ProjectTable,
SectorTable,
)
from fastapi import APIRouter, Depends, HTTPException, Query
from schemas.project_schemas import (
InvestmentStage,
ProjectCreate,
ProjectData,
ProjectUpdate,
)
from sqlalchemy.orm import Session, selectinload
router = APIRouter(tags=["Project Routes"])
@router.get("/projects", response_model=List[ProjectData])
def read_projects(db: Session = Depends(get_db)):
"""Get all projects with their related data"""
projects = (
db.query(ProjectTable)
.options(
selectinload(ProjectTable.sector),
selectinload(ProjectTable.investors),
selectinload(ProjectTable.companies),
)
.all()
)
# Transform ProjectTable objects to ProjectData format
project_data_list = []
for project in projects:
project_data = ProjectData(
project=project,
sector=project.sector,
investors=project.investors,
companies=project.companies,
)
project_data_list.append(project_data)
return project_data_list
@router.get("/projects/{project_id}", response_model=ProjectData)
def read_project(project_id: int, db: Session = Depends(get_db)):
"""Get a specific project by ID"""
project = (
db.query(ProjectTable)
.options(
selectinload(ProjectTable.sector),
selectinload(ProjectTable.investors),
selectinload(ProjectTable.companies),
)
.filter(ProjectTable.id == project_id)
.first()
)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return ProjectData(
project=project,
sector=project.sector,
investors=project.investors,
companies=project.companies,
)
@router.post("/projects", response_model=ProjectData)
def create_project(project: ProjectCreate, db: Session = Depends(get_db)):
"""Create a new project"""
db_project = ProjectTable(**project.dict())
db.add(db_project)
db.commit()
db.refresh(db_project)
# Reload with relationships
db_project = (
db.query(ProjectTable)
.options(
selectinload(ProjectTable.sector),
selectinload(ProjectTable.investors),
selectinload(ProjectTable.companies),
)
.filter(ProjectTable.id == db_project.id)
.first()
)
return ProjectData(
project=db_project,
sector=db_project.sector,
investors=db_project.investors,
companies=db_project.companies,
)
@router.put("/projects/{project_id}", response_model=ProjectData)
def update_project(
project_id: int, project: ProjectUpdate, db: Session = Depends(get_db)
):
"""Update an existing project"""
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not db_project:
raise HTTPException(status_code=404, detail="Project not found")
# Update only provided fields
update_data = project.dict(exclude_unset=True)
for key, value in update_data.items():
setattr(db_project, key, value)
db.commit()
db.refresh(db_project)
# Reload with relationships
db_project = (
db.query(ProjectTable)
.options(
selectinload(ProjectTable.sector),
selectinload(ProjectTable.investors),
selectinload(ProjectTable.companies),
)
.filter(ProjectTable.id == project_id)
.first()
)
return ProjectData(
project=db_project,
sector=db_project.sector,
investors=db_project.investors,
companies=db_project.companies,
)
@router.delete("/projects/{project_id}")
def delete_project(project_id: int, db: Session = Depends(get_db)):
"""Delete a project"""
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not db_project:
raise HTTPException(status_code=404, detail="Project not found")
db.delete(db_project)
db.commit()
return {"message": "Project deleted successfully"}
@router.get("/projects/filter", response_model=List[ProjectData])
def filter_projects(
stage: Optional[InvestmentStage] = Query(
None, description="Filter by project stage"
),
min_valuation: Optional[int] = Query(None, description="Minimum valuation"),
max_valuation: Optional[int] = Query(None, description="Maximum valuation"),
location: Optional[str] = Query(None, description="Location (partial match)"),
sector: Optional[str] = Query(None, description="Sector name (partial match)"),
investor_name: Optional[str] = Query(
None, description="Investor name (partial match)"
),
company_name: Optional[str] = Query(
None, description="Company name (partial match)"
),
db: Session = Depends(get_db),
):
"""Filter projects based on various criteria"""
# Start with base query
query = db.query(ProjectTable).options(
selectinload(ProjectTable.sector),
selectinload(ProjectTable.investors),
selectinload(ProjectTable.companies),
)
# Apply filters
if stage:
query = query.filter(ProjectTable.stage == stage)
if min_valuation is not None:
query = query.filter(ProjectTable.valuation >= min_valuation)
if max_valuation is not None:
query = query.filter(ProjectTable.valuation <= max_valuation)
if location:
query = query.filter(ProjectTable.location.ilike(f"%{location}%"))
if sector:
query = query.join(ProjectTable.sector).filter(
SectorTable.name.ilike(f"%{sector}%")
)
if investor_name:
query = query.join(ProjectTable.investors).filter(
InvestorTable.name.ilike(f"%{investor_name}%")
)
if company_name:
query = query.join(ProjectTable.companies).filter(
CompanyTable.name.ilike(f"%{company_name}%")
)
projects = query.all()
# Transform to ProjectData format
project_data_list = []
for project in projects:
project_data = ProjectData(
project=project,
sector=project.sector,
investors=project.investors,
companies=project.companies,
)
project_data_list.append(project_data)
return project_data_list
# Association management routes
@router.post("/projects/{project_id}/investors/{investor_id}")
def add_investor_to_project(
project_id: int, investor_id: int, db: Session = Depends(get_db)
):
"""Add an investor to a project"""
# Check if project exists
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Check if investor exists
investor = db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
if not investor:
raise HTTPException(status_code=404, detail="Investor not found")
# Check if association already exists
if investor in project.investors:
raise HTTPException(
status_code=400, detail="Investor already associated with project"
)
# Add association
project.investors.append(investor)
db.commit()
return {"message": "Investor added to project successfully"}
@router.delete("/projects/{project_id}/investors/{investor_id}")
def remove_investor_from_project(
project_id: int, investor_id: int, db: Session = Depends(get_db)
):
"""Remove an investor from a project"""
# Check if project exists
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Check if investor exists
investor = db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
if not investor:
raise HTTPException(status_code=404, detail="Investor not found")
# Check if association exists
if investor not in project.investors:
raise HTTPException(
status_code=400, detail="Investor not associated with project"
)
# Remove association
project.investors.remove(investor)
db.commit()
return {"message": "Investor removed from project successfully"}
@router.post("/projects/{project_id}/companies/{company_id}")
def add_company_to_project(
project_id: int, company_id: int, db: Session = Depends(get_db)
):
"""Add a company to a project"""
# Check if project exists
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Check if company exists
company = db.query(CompanyTable).filter(CompanyTable.id == company_id).first()
if not company:
raise HTTPException(status_code=404, detail="Company not found")
# Check if association already exists
if company in project.companies:
raise HTTPException(
status_code=400, detail="Company already associated with project"
)
# Add association
project.companies.append(company)
db.commit()
return {"message": "Company added to project successfully"}
@router.delete("/projects/{project_id}/companies/{company_id}")
def remove_company_from_project(
project_id: int, company_id: int, db: Session = Depends(get_db)
):
"""Remove a company from a project"""
# Check if project exists
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Check if company exists
company = db.query(CompanyTable).filter(CompanyTable.id == company_id).first()
if not company:
raise HTTPException(status_code=404, detail="Company not found")
# Check if association exists
if company not in project.companies:
raise HTTPException(
status_code=400, detail="Company not associated with project"
)
# Remove association
project.companies.remove(company)
db.commit()
return {"message": "Company removed from project successfully"}
@router.post("/projects/{project_id}/sectors/{sector_id}")
def add_sector_to_project(
project_id: int, sector_id: int, db: Session = Depends(get_db)
):
"""Add a sector to a project"""
# Check if project exists
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Check if sector exists
sector = db.query(SectorTable).filter(SectorTable.id == sector_id).first()
if not sector:
raise HTTPException(status_code=404, detail="Sector not found")
# Check if association already exists
if sector in project.sector:
raise HTTPException(
status_code=400, detail="Sector already associated with project"
)
# Add association
project.sector.append(sector)
db.commit()
return {"message": "Sector added to project successfully"}
@router.delete("/projects/{project_id}/sectors/{sector_id}")
def remove_sector_from_project(
project_id: int, sector_id: int, db: Session = Depends(get_db)
):
"""Remove a sector from a project"""
# Check if project exists
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Check if sector exists
sector = db.query(SectorTable).filter(SectorTable.id == sector_id).first()
if not sector:
raise HTTPException(status_code=404, detail="Sector not found")
# Check if association exists
if sector not in project.sector:
raise HTTPException(
status_code=400, detail="Sector not associated with project"
)
# Remove association
project.sector.remove(sector)
db.commit()
return {"message": "Sector removed from project successfully"}
# Bulk association management
@router.post("/projects/{project_id}/investors")
def add_multiple_investors_to_project(
project_id: int, investor_ids: List[int], db: Session = Depends(get_db)
):
"""Add multiple investors to a project"""
# Check if project exists
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Get all investors
investors = db.query(InvestorTable).filter(InvestorTable.id.in_(investor_ids)).all()
if len(investors) != len(investor_ids):
raise HTTPException(status_code=404, detail="One or more investors not found")
# Add associations (only if not already associated)
added_count = 0
for investor in investors:
if investor not in project.investors:
project.investors.append(investor)
added_count += 1
db.commit()
return {"message": f"Added {added_count} investors to project successfully"}
@router.post("/projects/{project_id}/companies")
def add_multiple_companies_to_project(
project_id: int, company_ids: List[int], db: Session = Depends(get_db)
):
"""Add multiple companies to a project"""
# Check if project exists
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Get all companies
companies = db.query(CompanyTable).filter(CompanyTable.id.in_(company_ids)).all()
if len(companies) != len(company_ids):
raise HTTPException(status_code=404, detail="One or more companies not found")
# Add associations (only if not already associated)
added_count = 0
for company in companies:
if company not in project.companies:
project.companies.append(company)
added_count += 1
db.commit()
return {"message": f"Added {added_count} companies to project successfully"}
Binary file not shown.
Binary file not shown.
+117
View File
@@ -0,0 +1,117 @@
from datetime import datetime
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
class InvestmentStage(str, Enum):
SEED = "SEED"
SERIES_A = "SERIES_A"
SERIES_B = "SERIES_B"
SERIES_C = "SERIES_C"
GROWTH = "GROWTH"
LATE_STAGE = "LATE_STAGE"
class SectorSchema(BaseModel):
id: int
name: str
class Config:
from_attributes = True
class InvestorSchema(BaseModel):
id: int
name: str
description: Optional[str]
aum: int | None
check_size_lower: int | None
check_size_upper: int | None
geographic_focus: str | None
stage_focus: InvestmentStage
number_of_investments: int | None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
class CompanySchema(BaseModel):
id: int
name: str
industry: str | None
location: str | None
description: Optional[str]
founded_year: Optional[int]
website: Optional[str]
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
class ProjectSchema(BaseModel):
id: int
name: str
valuation: int | None
stage: InvestmentStage | None
location: str | None
description: Optional[str]
start_date: Optional[datetime]
end_date: Optional[datetime]
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
class ProjectCreate(BaseModel):
name: str
valuation: Optional[int] = None
stage: Optional[InvestmentStage] = None
location: Optional[str] = None
description: Optional[str] = None
start_date: Optional[datetime] = None
end_date: Optional[datetime] = None
class ProjectUpdate(BaseModel):
name: Optional[str] = None
valuation: Optional[int] = None
stage: Optional[InvestmentStage] = None
location: Optional[str] = None
description: Optional[str] = None
start_date: Optional[datetime] = None
end_date: Optional[datetime] = None
class ProjectData(BaseModel):
"""Comprehensive project data schema"""
project: ProjectSchema
sector: List[SectorSchema]
investors: List[InvestorSchema]
companies: List[CompanySchema]
class Config:
from_attributes = True
class ProjectInvestorAssociation(BaseModel):
project_id: int
investor_id: int
class ProjectCompanyAssociation(BaseModel):
project_id: int
company_id: int
class ProjectSectorAssociation(BaseModel):
project_id: int
sector_id: int
+311 -66
View File
@@ -1,7 +1,7 @@
from enum import Enum from enum import Enum
from typing import List, Optional from typing import List, Optional
from pydantic import BaseModel, field_validator from pydantic import BaseModel, Field, field_validator
class InvestmentStage(str, Enum): class InvestmentStage(str, Enum):
@@ -14,98 +14,343 @@ class InvestmentStage(str, Enum):
class SectorSchema(BaseModel): class SectorSchema(BaseModel):
id: int """
name: str Expert parser: Only extract sector information if clearly identifiable.
Leave name empty if uncertain about the sector classification.
"""
class Config: id: Optional[int] = Field(
from_attributes = True default=None,
ge=0,
description="Sector ID, must be 0 or greater. Use 0 if uncertain.",
)
name: Optional[str] = Field(
default=None,
description="Sector name. Leave empty string if not clearly identifiable from the data.",
)
@field_validator("name", mode="before")
class InvestorMemberSchema(BaseModel):
id: int
name: str
role: str
email: str
investor_id: int
class Config:
from_attributes = True
class CompanyMemberSchema(BaseModel):
id: int
name: Optional[str] = None
linkedin: Optional[str] = None
role: Optional[str] = None
company_id: int
class Config:
from_attributes = True
class CompanySchema(BaseModel):
id: int
name: str
industry: str
location: str
description: Optional[str] = None # Fixed typo from 'nullabel'
founded_year: Optional[int] = None # Changed from str to int to match model
website: Optional[str] = None
@field_validator("founded_year", mode="before")
@classmethod @classmethod
def validate_founded_year(cls, v): def empty_string_to_none(cls, v):
if v is None or v == "Not Available" or v == "": """Convert empty strings to None"""
if v == "" or (isinstance(v, str) and v.strip() == ""):
return None
return v
@field_validator("id", mode="before")
@classmethod
def zero_to_none(cls, v):
"""Convert 0 to None for optional id field"""
if v == 0:
return None return None
if isinstance(v, str):
try:
return int(v)
except ValueError:
return None
return v return v
class Config: class Config:
from_attributes = True from_attributes = True
class InvestorMemberSchema(BaseModel):
"""
Expert parser: Only extract team member information if clearly identifiable.
Leave fields empty if uncertain about the member details.
"""
id: Optional[int] = Field(
default=None,
ge=0,
description="Member ID, must be 0 or greater. Use 0 if uncertain.",
)
name: Optional[str] = Field(
default=None,
description="Team member name. Leave empty string if not clearly identifiable.",
)
role: Optional[str] = Field(
default=None,
description="Team member role/title. Leave empty string if not clearly identifiable.",
)
email: Optional[str] = Field(
default=None,
description="Team member email. Leave empty string if not clearly identifiable or not provided.",
)
investor_id: Optional[int] = Field(
default=None,
ge=0,
description="Investor ID, must be 0 or greater. Use 0 if uncertain.",
)
@field_validator("name", "role", "email", mode="before")
@classmethod
def empty_string_to_none(cls, v):
"""Convert empty strings to None"""
if v == "" or (isinstance(v, str) and v.strip() == ""):
return None
return v
@field_validator("id", "investor_id", mode="before")
@classmethod
def zero_to_none(cls, v):
"""Convert 0 to None for optional integer fields"""
if v == 0:
return None
return v
class Config:
from_attributes = True
class CompanyMemberSchema(BaseModel):
"""
Expert parser: Only extract company member information if clearly identifiable.
Leave fields empty if uncertain about the member details.
"""
id: Optional[int] = Field(
default=None,
ge=0,
description="Member ID, must be 0 or greater. Use 0 if uncertain.",
)
name: Optional[str] = Field(
default=None,
description="Company member name. Leave empty if not clearly identifiable.",
)
linkedin: Optional[str] = Field(
default=None,
description="LinkedIn profile URL. Leave empty if not provided or uncertain.",
)
role: Optional[str] = Field(
default=None,
description="Company member role/title. Leave empty if not clearly identifiable.",
)
company_id: Optional[int] = Field(
default=None,
ge=0,
description="Company ID, must be 0 or greater. Use 0 if uncertain.",
)
@field_validator("name", "linkedin", "role", mode="before")
@classmethod
def empty_string_to_none(cls, v):
"""Convert empty strings to None"""
if v == "" or (isinstance(v, str) and v.strip() == ""):
return None
return v
@field_validator("id", "company_id", mode="before")
@classmethod
def zero_to_none(cls, v):
"""Convert 0 to None for optional integer fields"""
if v == 0:
return None
return v
class Config:
from_attributes = True
class CompanySchema(BaseModel):
"""
Expert parser: Only extract company information if clearly identifiable.
Leave optional fields empty if uncertain. Integer values must be 0 or greater.
"""
id: Optional[int] = Field(
default=None,
ge=0,
description="Company ID, must be 0 or greater. Use 0 if uncertain.",
)
name: Optional[str] = Field(
default=None,
description="Company name. Leave empty string if not clearly identifiable.",
)
industry: Optional[str] = Field(
default=None,
description="Company industry/sector. Leave empty string if not clearly identifiable.",
)
location: Optional[str] = Field(
default=None,
description="Company location/address. Leave empty string if not clearly identifiable.",
)
description: Optional[str] = Field(
default=None,
description="Company description. Leave empty if not clearly available or uncertain.",
)
founded_year: Optional[int] = Field(
default=None,
ge=0,
description="Year company was founded, must be 0 or greater. Leave None if not clearly identifiable or uncertain.",
)
website: Optional[str] = Field(
default=None,
description="Company website URL. Leave empty if not provided or uncertain.",
)
@field_validator(
"name", "industry", "location", "description", "website", mode="before"
)
@classmethod
def empty_string_to_none(cls, v):
"""Convert empty strings to None"""
if v == "" or (isinstance(v, str) and v.strip() == ""):
return None
return v
@field_validator("id", "founded_year", mode="before")
@classmethod
def zero_to_none(cls, v):
"""Convert 0 to None for founded_year"""
if v == 0:
return None
return v
@field_validator("founded_year", mode="before")
@classmethod
def validate_founded_year(cls, v):
"""Expert parser: Only accept clearly identifiable founding years"""
if v is None or v == "Not Available" or v == "" or v == "Unknown":
return None
if isinstance(v, str):
try:
year = int(v)
return year if year >= 0 else None
except ValueError:
return None
return v if isinstance(v, int) and v >= 0 else None
class Config:
from_attributes = True
class InvestorSchema(BaseModel): class InvestorSchema(BaseModel):
id: int """
name: str Expert parser: Only extract investor information if clearly identifiable.
description: Optional[str] = None Leave optional fields empty if uncertain. All numeric values must be 0 or greater.
aum: int """
check_size_lower: int
check_size_upper: int id: Optional[int] = Field(
geographic_focus: str default=None,
stage_focus: InvestmentStage ge=0,
number_of_investments: int = 0 description="Investor ID, must be 0 or greater. Use 0 if uncertain.",
)
name: Optional[str] = Field(
default=None,
description="Investor name. Do not return any special characters, Just the name as a string.",
)
description: Optional[str] = Field(
default=None,
description="Investor description. Leave empty if not clearly available or uncertain.",
)
aum: Optional[int] = Field(
default=None,
ge=0,
description="Assets Under Management in USD, must be 0 or greater. Use 0 if not clearly identifiable or uncertain.",
)
check_size_lower: Optional[int] = Field(
default=None,
ge=0,
description="Lower bound of typical investment check size in USD, must be 0 or greater. Use 0 if not clearly identifiable.",
)
check_size_upper: Optional[int] = Field(
default=None,
ge=0,
description="Upper bound of typical investment check size in USD, must be 0 or greater. Use 0 if not clearly identifiable.",
)
geographic_focus: Optional[str] = Field(
default=None,
description="Geographic investment focus. Do not return any special characters, Just locations separated by commas. Leave empty if not clearly identifiable.",
)
stage_focus: InvestmentStage = Field(
default=InvestmentStage.SEED,
description="Investment stage focus. Use SEED as default if uncertain.",
)
number_of_investments: Optional[int] = Field(
default=None,
ge=0,
description="Total number of investments made, must be 0 or greater. Use 0 if not clearly identifiable.",
)
@field_validator("name", "description", "geographic_focus", mode="before")
@classmethod
def empty_string_to_none(cls, v):
"""Convert empty strings to None"""
if v == "" or (isinstance(v, str) and v.strip() == ""):
return None
return v
@field_validator(
"id",
"aum",
"check_size_lower",
"check_size_upper",
"number_of_investments",
mode="before",
)
@classmethod
def zero_to_none(cls, v):
"""Convert 0 to None for optional integer fields"""
if v == 0:
return None
return v
class Config: class Config:
from_attributes = True from_attributes = True
class InvestorData(BaseModel): class InvestorData(BaseModel):
"""Comprehensive investor data schema for LLM processing""" """
Expert parser: Comprehensive investor data schema for LLM processing.
Only populate fields with clearly identifiable information. Leave lists empty if uncertain.
"""
investor: InvestorSchema investor: InvestorSchema = Field(
portfolio_companies: List[CompanySchema] = [] description="Core investor information. Only populate with clearly identifiable data."
team_members: List[InvestorMemberSchema] = [] # Changed from TeamMember )
sectors: List[SectorSchema] = [] portfolio_companies: List[CompanySchema] = Field(
default=[],
description="List of portfolio companies. Leave empty if not clearly identifiable.",
)
team_members: List[InvestorMemberSchema] = Field(
default=[],
description="List of team members. Leave empty if not clearly identifiable.",
)
sectors: List[SectorSchema] = Field(
default=[],
description="List of investment sectors. Leave empty if not clearly identifiable.",
)
class Config: class Config:
from_attributes = True from_attributes = True
class CompanyData(BaseModel): # Renamed from CompaniesData for consistency class CompanyData(BaseModel):
company: CompanySchema """
sectors: List[SectorSchema] = [] Expert parser: Comprehensive company data schema for LLM processing.
members: List[CompanyMemberSchema] = [] # Changed to match model relationship name Only populate fields with clearly identifiable information. Leave lists empty if uncertain.
investors: List[InvestorSchema] = [] """
company: CompanySchema = Field(
description="Core company information. Only populate with clearly identifiable data."
)
sectors: List[SectorSchema] = Field(
default=[],
description="List of company sectors. Leave empty if not clearly identifiable.",
)
members: List[CompanyMemberSchema] = Field(
default=[],
description="List of company members. Leave empty if not clearly identifiable.",
)
investors: List[InvestorSchema] = Field(
default=[],
description="List of investors. Leave empty if not clearly identifiable.",
)
class Config: class Config:
from_attributes = True from_attributes = True
class InvestorList(BaseModel): class InvestorList(BaseModel):
investors: List[InvestorData] = [] """Expert parser: List of investors with clearly identifiable information only."""
investors: List[InvestorData] = Field(
default=[],
description="List of investors. Leave empty if no clearly identifiable investors.",
)
+20 -18
View File
@@ -25,50 +25,51 @@ class SectorSchema(BaseModel):
class InvestorMemberSchema(BaseModel): class InvestorMemberSchema(BaseModel):
id: int id: int
name: str name: str
role: str role: str | None
email: str email: str | None
class Config: class Config:
from_attributes = True from_attributes = True
class CompanyMemberSchema(BaseModel): class CompanyMemberSchema(BaseModel):
id: int id: int
name: Optional[str] = None name: Optional[str]
linkedin: Optional[str] = None linkedin: Optional[str]
role: Optional[str] = None role: Optional[str]
company_id: int company_id: int
class Config: class Config:
from_attributes = True from_attributes = True
class CompanySchema(BaseModel): class CompanySchema(BaseModel):
id: int id: int
name: str name: str
industry: str industry: str | None
location: str location: str | None
description: Optional[str] description: Optional[str]
founded_year: Optional[int] founded_year: Optional[int]
website: Optional[str] website: Optional[str]
created_at: Optional[datetime] created_at: Optional[datetime] = None
updated_at: Optional[datetime] updated_at: Optional[datetime] = None
class Config: class Config:
from_attributes = True from_attributes = True
class InvestorSchema(BaseModel): class InvestorSchema(BaseModel):
id: int id: int
name: str name: str
description: Optional[str] description: Optional[str]
aum: int aum: int | None
check_size_lower: int check_size_lower: int | None
check_size_upper: int check_size_upper: int | None
geographic_focus: str geographic_focus: str | None
stage_focus: InvestmentStage stage_focus: InvestmentStage
number_of_investments: int number_of_investments: int | None
created_at: Optional[datetime] created_at: Optional[datetime] = None
updated_at: Optional[datetime] updated_at: Optional[datetime] = None
class Config: class Config:
from_attributes = True from_attributes = True
@@ -95,5 +96,6 @@ class CompanyData(BaseModel): # Renamed from CompaniesData for consistency
class Config: class Config:
from_attributes = True from_attributes = True
class InvestorList(BaseModel): class InvestorList(BaseModel):
investors: List[InvestorData] investors: List[InvestorData]
Binary file not shown.
Binary file not shown.
Binary file not shown.
+5 -5
View File
@@ -21,7 +21,7 @@ class InvestorProcessor:
self.llm = ChatOpenAI( self.llm = ChatOpenAI(
api_key=os.getenv("OPENROUTER_API_KEY"), api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1", base_url="https://openrouter.ai/api/v1",
model="openai/gpt-5-nano", model="openai/gpt-4o-mini",
temperature=0, temperature=0,
) )
@@ -176,14 +176,14 @@ class InvestorProcessor:
async def parse_investors(self, df, save_to_db: bool = True): async def parse_investors(self, df, save_to_db: bool = True):
"""Parse investors from DataFrame and optionally save to database""" """Parse investors from DataFrame and optionally save to database"""
investors = [] investors = []
df = df[20:]
db = None db = None
if save_to_db: if save_to_db:
db = get_db_session() db = get_db_session()
try: try:
# Process rows in batches asynchronously # Process rows in batches asynchronously
batch_size = 15 # Adjust batch size as needed batch_size = 20 # Adjust batch size as needed
rows = [(idx, row) for idx, row in df.iterrows()] rows = [(idx, row) for idx, row in df.iterrows()]
for i in range(0, len(rows), batch_size): for i in range(0, len(rows), batch_size):
@@ -244,14 +244,14 @@ class InvestorProcessor:
async def parse_companies(self, df, save_to_db: bool = True): async def parse_companies(self, df, save_to_db: bool = True):
"""Parse companies from DataFrame and optionally save to database""" """Parse companies from DataFrame and optionally save to database"""
companies = [] companies = []
df = df[20:]
db = None db = None
if save_to_db: if save_to_db:
db = get_db_session() db = get_db_session()
try: try:
# Process rows in batches asynchronously # Process rows in batches asynchronously
batch_size = 15 # Adjust batch size as needed batch_size = 20 # Adjust batch size as needed
rows = [(idx, row) for idx, row in df.iterrows()] rows = [(idx, row) for idx, row in df.iterrows()]
for i in range(0, len(rows), batch_size): for i in range(0, len(rows), batch_size):
+2 -2
View File
@@ -21,8 +21,8 @@ class QueryProcessor:
self.llm = ChatOpenAI( self.llm = ChatOpenAI(
api_key=os.getenv("OPENROUTER_API_KEY"), api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1", base_url="https://openrouter.ai/api/v1",
model="openai/gpt-5-nano", model="openai/gpt-4o-mini",
temperature=0.3, temperature=0,
) )
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
# Update system message to specifically request only investor IDs # Update system message to specifically request only investor IDs