Compare commits

3 Commits

16 changed files with 873 additions and 51 deletions
Binary file not shown.
Binary file not shown.
+68 -2
View File
@@ -1,10 +1,11 @@
import enum
from db.db import Base
from sqlalchemy import Column, DateTime, ForeignKey, Integer, String, Table, Text, func
from sqlalchemy.orm import declarative_mixin, relationship
from sqlalchemy.types import Enum
from db.db import Base
@declarative_mixin
class TimestampMixin:
@@ -48,6 +49,27 @@ company_sector_association = Table(
Column("sector_id", Integer, ForeignKey("sectors.id")),
)
project_sector_association = Table(
"project_sector",
Base.metadata,
Column("project_id", Integer, ForeignKey("projects.id")),
Column("sector_id", Integer, ForeignKey("sectors.id")),
)
project_investor_association = Table(
"project_investors",
Base.metadata,
Column("project_id", Integer, ForeignKey("projects.id")),
Column("investor_id", Integer, ForeignKey("investors.id")),
)
project_company_association = Table(
"project_companies",
Base.metadata,
Column("project_id", Integer, ForeignKey("projects.id")),
Column("company_id", Integer, ForeignKey("companies.id")),
)
class InvestorTable(Base, TimestampMixin):
__tablename__ = "investors"
@@ -62,19 +84,27 @@ class InvestorTable(Base, TimestampMixin):
stage_focus = Column(Enum(InvestmentStage), nullable=True)
number_of_investments = Column(Integer, default=0, nullable=True)
team_members = relationship("InvestorMember", back_populates="investor")
# Relationship to portfolio companies
portfolio_companies = relationship(
"CompanyTable",
secondary=investor_company_association,
back_populates="investors",
)
team_members = relationship("InvestorMember", back_populates="investor")
sectors = relationship(
"SectorTable",
secondary=investor_sector_association,
back_populates="investors",
)
projects = relationship(
"ProjectTable",
secondary=project_investor_association,
back_populates="investors",
)
class InvestorMember(Base, TimestampMixin):
__tablename__ = "investor_members"
@@ -110,6 +140,12 @@ class CompanyTable(Base, TimestampMixin):
"SectorTable", secondary=company_sector_association, back_populates="companies"
)
projects = relationship(
"ProjectTable",
secondary=project_company_association,
back_populates="companies",
)
class CompanyMember(Base, TimestampMixin):
__tablename__ = "company_members"
@@ -138,3 +174,33 @@ class SectorTable(Base, TimestampMixin):
companies = relationship(
"CompanyTable", secondary=company_sector_association, back_populates="sectors"
)
projects = relationship(
"ProjectTable", secondary=project_sector_association, back_populates="sector"
)
class ProjectTable(Base, TimestampMixin):
__tablename__ = "projects"
id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False)
valuation = Column(Integer, nullable=True)
stage = Column(Enum(InvestmentStage), nullable=True)
location = Column(String, nullable=True)
description = Column(Text, nullable=True)
start_date = Column(DateTime, nullable=True)
end_date = Column(DateTime, nullable=True)
sector = relationship(
"SectorTable", secondary=project_sector_association, back_populates="projects"
)
investors = relationship(
"InvestorTable",
secondary=project_investor_association,
back_populates="projects",
)
companies = relationship(
"CompanyTable", secondary=project_company_association, back_populates="projects"
)
+4 -2
View File
@@ -5,7 +5,7 @@ from db.db import Base, db_dependency, engine
from dotenv import load_dotenv
from fastapi import FastAPI, File, Form, UploadFile
from pydantic import BaseModel
from routers import companies, investors
from routers import companies, investors, projects
from schemas.router_schemas import InvestorList
from services.llm_parser import InvestorProcessor
from services.querying import QueryProcessor
@@ -78,7 +78,9 @@ async def query_investors(request: QueryRequest):
app.include_router(investors.router)
app.include_router(companies.router)
app.include_router(projects.router)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app="main:app", host="localhost", port=8000, reload=True)
uvicorn.run(app="main:app", host="0.0.0.0", port=8585, reload=True)
Binary file not shown.
Binary file not shown.
+4
View File
@@ -34,6 +34,10 @@ def read_companies(db: Session = Depends(get_db)):
"""Get all companies with their investor relationships"""
companies = (
db.query(CompanyTable)
.filter(
CompanyTable.name.isnot(None),
CompanyTable.description.isnot(None)
)
.options(
selectinload(CompanyTable.investors),
selectinload(CompanyTable.members),
+71 -1
View File
@@ -7,6 +7,7 @@ from fastapi import APIRouter, Depends, HTTPException, Query
from schemas.router_schemas import InvestmentStage, InvestorData
from pydantic import BaseModel
from sqlalchemy.orm import Session, selectinload
from services.querying import QueryProcessor
router = APIRouter(tags=["Investor Routes"])
@@ -231,4 +232,73 @@ def delete_investor(investor_id: int, db: Session = Depends(get_db)):
db.delete(db_investor)
db.commit()
return {"message": "Investor deleted successfully"}
return {"message": "Investor deleted successfully"}
@router.get("/investors/{investor_id}/similar", response_model=List[InvestorData])
def find_similar_investors(investor_id: int, db: Session = Depends(get_db)):
"""Find investors similar to a given investor"""
# First, get the target investor
target_investor = (
db.query(InvestorTable)
.options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
)
.filter(InvestorTable.id == investor_id)
.first()
)
if not target_investor:
raise HTTPException(status_code=404, detail="Investor not found")
# Build query to find similar investors
query = db.query(InvestorTable).options(
selectinload(InvestorTable.portfolio_companies),
selectinload(InvestorTable.team_members),
selectinload(InvestorTable.sectors),
).filter(InvestorTable.id != investor_id) # Exclude the target investor
# Filter by same stage focus
query = query.filter(InvestorTable.stage_focus == target_investor.stage_focus)
# Filter by similar geographic focus (partial match)
query = query.filter(InvestorTable.geographic_focus.ilike(f"%{target_investor.geographic_focus}%"))
# Filter by overlapping check size ranges
query = query.filter(
InvestorTable.check_size_upper >= target_investor.check_size_lower,
InvestorTable.check_size_lower <= target_investor.check_size_upper
)
# Filter by similar AUM (within 50% range)
aum_lower = int(target_investor.aum * 0.5)
aum_upper = int(target_investor.aum * 1.5)
query = query.filter(
InvestorTable.aum >= aum_lower,
InvestorTable.aum <= aum_upper
)
# Filter by common sectors
target_sector_names = [sector.name for sector in target_investor.sectors]
if target_sector_names:
query = query.join(InvestorTable.sectors).filter(
SectorTable.name.in_(target_sector_names)
)
investors = query.all()
# Transform to InvestorData format
investor_data_list = []
for investor in investors:
investor_data = InvestorData(
investor=investor,
portfolio_companies=investor.portfolio_companies,
team_members=investor.team_members,
sectors=investor.sectors,
)
investor_data_list.append(investor_data)
return investor_data_list
+447
View File
@@ -0,0 +1,447 @@
from typing import List, Optional
from db.db import get_db
from db.models import (
CompanyTable,
InvestorTable,
ProjectTable,
SectorTable,
)
from fastapi import APIRouter, Depends, HTTPException, Query
from schemas.project_schemas import (
InvestmentStage,
ProjectCreate,
ProjectData,
ProjectUpdate,
)
from sqlalchemy.orm import Session, selectinload
router = APIRouter(tags=["Project Routes"])
@router.get("/projects", response_model=List[ProjectData])
def read_projects(db: Session = Depends(get_db)):
"""Get all projects with their related data"""
projects = (
db.query(ProjectTable)
.options(
selectinload(ProjectTable.sector),
selectinload(ProjectTable.investors),
selectinload(ProjectTable.companies),
)
.all()
)
# Transform ProjectTable objects to ProjectData format
project_data_list = []
for project in projects:
project_data = ProjectData(
project=project,
sector=project.sector,
investors=project.investors,
companies=project.companies,
)
project_data_list.append(project_data)
return project_data_list
@router.get("/projects/{project_id}", response_model=ProjectData)
def read_project(project_id: int, db: Session = Depends(get_db)):
"""Get a specific project by ID"""
project = (
db.query(ProjectTable)
.options(
selectinload(ProjectTable.sector),
selectinload(ProjectTable.investors),
selectinload(ProjectTable.companies),
)
.filter(ProjectTable.id == project_id)
.first()
)
if not project:
raise HTTPException(status_code=404, detail="Project not found")
return ProjectData(
project=project,
sector=project.sector,
investors=project.investors,
companies=project.companies,
)
@router.post("/projects", response_model=ProjectData)
def create_project(project: ProjectCreate, db: Session = Depends(get_db)):
"""Create a new project"""
db_project = ProjectTable(**project.dict())
db.add(db_project)
db.commit()
db.refresh(db_project)
# Reload with relationships
db_project = (
db.query(ProjectTable)
.options(
selectinload(ProjectTable.sector),
selectinload(ProjectTable.investors),
selectinload(ProjectTable.companies),
)
.filter(ProjectTable.id == db_project.id)
.first()
)
return ProjectData(
project=db_project,
sector=db_project.sector,
investors=db_project.investors,
companies=db_project.companies,
)
@router.put("/projects/{project_id}", response_model=ProjectData)
def update_project(
project_id: int, project: ProjectUpdate, db: Session = Depends(get_db)
):
"""Update an existing project"""
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not db_project:
raise HTTPException(status_code=404, detail="Project not found")
# Update only provided fields
update_data = project.dict(exclude_unset=True)
for key, value in update_data.items():
setattr(db_project, key, value)
db.commit()
db.refresh(db_project)
# Reload with relationships
db_project = (
db.query(ProjectTable)
.options(
selectinload(ProjectTable.sector),
selectinload(ProjectTable.investors),
selectinload(ProjectTable.companies),
)
.filter(ProjectTable.id == project_id)
.first()
)
return ProjectData(
project=db_project,
sector=db_project.sector,
investors=db_project.investors,
companies=db_project.companies,
)
@router.delete("/projects/{project_id}")
def delete_project(project_id: int, db: Session = Depends(get_db)):
"""Delete a project"""
db_project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not db_project:
raise HTTPException(status_code=404, detail="Project not found")
db.delete(db_project)
db.commit()
return {"message": "Project deleted successfully"}
@router.get("/projects/filter", response_model=List[ProjectData])
def filter_projects(
stage: Optional[InvestmentStage] = Query(
None, description="Filter by project stage"
),
min_valuation: Optional[int] = Query(None, description="Minimum valuation"),
max_valuation: Optional[int] = Query(None, description="Maximum valuation"),
location: Optional[str] = Query(None, description="Location (partial match)"),
sector: Optional[str] = Query(None, description="Sector name (partial match)"),
investor_name: Optional[str] = Query(
None, description="Investor name (partial match)"
),
company_name: Optional[str] = Query(
None, description="Company name (partial match)"
),
db: Session = Depends(get_db),
):
"""Filter projects based on various criteria"""
# Start with base query
query = db.query(ProjectTable).options(
selectinload(ProjectTable.sector),
selectinload(ProjectTable.investors),
selectinload(ProjectTable.companies),
)
# Apply filters
if stage:
query = query.filter(ProjectTable.stage == stage)
if min_valuation is not None:
query = query.filter(ProjectTable.valuation >= min_valuation)
if max_valuation is not None:
query = query.filter(ProjectTable.valuation <= max_valuation)
if location:
query = query.filter(ProjectTable.location.ilike(f"%{location}%"))
if sector:
query = query.join(ProjectTable.sector).filter(
SectorTable.name.ilike(f"%{sector}%")
)
if investor_name:
query = query.join(ProjectTable.investors).filter(
InvestorTable.name.ilike(f"%{investor_name}%")
)
if company_name:
query = query.join(ProjectTable.companies).filter(
CompanyTable.name.ilike(f"%{company_name}%")
)
projects = query.all()
# Transform to ProjectData format
project_data_list = []
for project in projects:
project_data = ProjectData(
project=project,
sector=project.sector,
investors=project.investors,
companies=project.companies,
)
project_data_list.append(project_data)
return project_data_list
# Association management routes
@router.post("/projects/{project_id}/investors/{investor_id}")
def add_investor_to_project(
project_id: int, investor_id: int, db: Session = Depends(get_db)
):
"""Add an investor to a project"""
# Check if project exists
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Check if investor exists
investor = db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
if not investor:
raise HTTPException(status_code=404, detail="Investor not found")
# Check if association already exists
if investor in project.investors:
raise HTTPException(
status_code=400, detail="Investor already associated with project"
)
# Add association
project.investors.append(investor)
db.commit()
return {"message": "Investor added to project successfully"}
@router.delete("/projects/{project_id}/investors/{investor_id}")
def remove_investor_from_project(
project_id: int, investor_id: int, db: Session = Depends(get_db)
):
"""Remove an investor from a project"""
# Check if project exists
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Check if investor exists
investor = db.query(InvestorTable).filter(InvestorTable.id == investor_id).first()
if not investor:
raise HTTPException(status_code=404, detail="Investor not found")
# Check if association exists
if investor not in project.investors:
raise HTTPException(
status_code=400, detail="Investor not associated with project"
)
# Remove association
project.investors.remove(investor)
db.commit()
return {"message": "Investor removed from project successfully"}
@router.post("/projects/{project_id}/companies/{company_id}")
def add_company_to_project(
project_id: int, company_id: int, db: Session = Depends(get_db)
):
"""Add a company to a project"""
# Check if project exists
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Check if company exists
company = db.query(CompanyTable).filter(CompanyTable.id == company_id).first()
if not company:
raise HTTPException(status_code=404, detail="Company not found")
# Check if association already exists
if company in project.companies:
raise HTTPException(
status_code=400, detail="Company already associated with project"
)
# Add association
project.companies.append(company)
db.commit()
return {"message": "Company added to project successfully"}
@router.delete("/projects/{project_id}/companies/{company_id}")
def remove_company_from_project(
project_id: int, company_id: int, db: Session = Depends(get_db)
):
"""Remove a company from a project"""
# Check if project exists
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Check if company exists
company = db.query(CompanyTable).filter(CompanyTable.id == company_id).first()
if not company:
raise HTTPException(status_code=404, detail="Company not found")
# Check if association exists
if company not in project.companies:
raise HTTPException(
status_code=400, detail="Company not associated with project"
)
# Remove association
project.companies.remove(company)
db.commit()
return {"message": "Company removed from project successfully"}
@router.post("/projects/{project_id}/sectors/{sector_id}")
def add_sector_to_project(
project_id: int, sector_id: int, db: Session = Depends(get_db)
):
"""Add a sector to a project"""
# Check if project exists
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Check if sector exists
sector = db.query(SectorTable).filter(SectorTable.id == sector_id).first()
if not sector:
raise HTTPException(status_code=404, detail="Sector not found")
# Check if association already exists
if sector in project.sector:
raise HTTPException(
status_code=400, detail="Sector already associated with project"
)
# Add association
project.sector.append(sector)
db.commit()
return {"message": "Sector added to project successfully"}
@router.delete("/projects/{project_id}/sectors/{sector_id}")
def remove_sector_from_project(
project_id: int, sector_id: int, db: Session = Depends(get_db)
):
"""Remove a sector from a project"""
# Check if project exists
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Check if sector exists
sector = db.query(SectorTable).filter(SectorTable.id == sector_id).first()
if not sector:
raise HTTPException(status_code=404, detail="Sector not found")
# Check if association exists
if sector not in project.sector:
raise HTTPException(
status_code=400, detail="Sector not associated with project"
)
# Remove association
project.sector.remove(sector)
db.commit()
return {"message": "Sector removed from project successfully"}
# Bulk association management
@router.post("/projects/{project_id}/investors")
def add_multiple_investors_to_project(
project_id: int, investor_ids: List[int], db: Session = Depends(get_db)
):
"""Add multiple investors to a project"""
# Check if project exists
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Get all investors
investors = db.query(InvestorTable).filter(InvestorTable.id.in_(investor_ids)).all()
if len(investors) != len(investor_ids):
raise HTTPException(status_code=404, detail="One or more investors not found")
# Add associations (only if not already associated)
added_count = 0
for investor in investors:
if investor not in project.investors:
project.investors.append(investor)
added_count += 1
db.commit()
return {"message": f"Added {added_count} investors to project successfully"}
@router.post("/projects/{project_id}/companies")
def add_multiple_companies_to_project(
project_id: int, company_ids: List[int], db: Session = Depends(get_db)
):
"""Add multiple companies to a project"""
# Check if project exists
project = db.query(ProjectTable).filter(ProjectTable.id == project_id).first()
if not project:
raise HTTPException(status_code=404, detail="Project not found")
# Get all companies
companies = db.query(CompanyTable).filter(CompanyTable.id.in_(company_ids)).all()
if len(companies) != len(company_ids):
raise HTTPException(status_code=404, detail="One or more companies not found")
# Add associations (only if not already associated)
added_count = 0
for company in companies:
if company not in project.companies:
project.companies.append(company)
added_count += 1
db.commit()
return {"message": f"Added {added_count} companies to project successfully"}
Binary file not shown.
+117
View File
@@ -0,0 +1,117 @@
from datetime import datetime
from enum import Enum
from typing import List, Optional
from pydantic import BaseModel
class InvestmentStage(str, Enum):
SEED = "SEED"
SERIES_A = "SERIES_A"
SERIES_B = "SERIES_B"
SERIES_C = "SERIES_C"
GROWTH = "GROWTH"
LATE_STAGE = "LATE_STAGE"
class SectorSchema(BaseModel):
id: int
name: str
class Config:
from_attributes = True
class InvestorSchema(BaseModel):
id: int
name: str
description: Optional[str]
aum: int | None
check_size_lower: int | None
check_size_upper: int | None
geographic_focus: str | None
stage_focus: InvestmentStage
number_of_investments: int | None
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
class CompanySchema(BaseModel):
id: int
name: str
industry: str | None
location: str | None
description: Optional[str]
founded_year: Optional[int]
website: Optional[str]
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
class ProjectSchema(BaseModel):
id: int
name: str
valuation: int | None
stage: InvestmentStage | None
location: str | None
description: Optional[str]
start_date: Optional[datetime]
end_date: Optional[datetime]
created_at: Optional[datetime] = None
updated_at: Optional[datetime] = None
class Config:
from_attributes = True
class ProjectCreate(BaseModel):
name: str
valuation: Optional[int] = None
stage: Optional[InvestmentStage] = None
location: Optional[str] = None
description: Optional[str] = None
start_date: Optional[datetime] = None
end_date: Optional[datetime] = None
class ProjectUpdate(BaseModel):
name: Optional[str] = None
valuation: Optional[int] = None
stage: Optional[InvestmentStage] = None
location: Optional[str] = None
description: Optional[str] = None
start_date: Optional[datetime] = None
end_date: Optional[datetime] = None
class ProjectData(BaseModel):
"""Comprehensive project data schema"""
project: ProjectSchema
sector: List[SectorSchema]
investors: List[InvestorSchema]
companies: List[CompanySchema]
class Config:
from_attributes = True
class ProjectInvestorAssociation(BaseModel):
project_id: int
investor_id: int
class ProjectCompanyAssociation(BaseModel):
project_id: int
company_id: int
class ProjectSectorAssociation(BaseModel):
project_id: int
sector_id: int
+159 -43
View File
@@ -19,13 +19,32 @@ class SectorSchema(BaseModel):
Leave name empty if uncertain about the sector classification.
"""
id: int = Field(
ge=0, description="Sector ID, must be 0 or greater. Use 0 if uncertain."
id: Optional[int] = Field(
default=None,
ge=0,
description="Sector ID, must be 0 or greater. Use 0 if uncertain.",
)
name: str = Field(
description="Sector name. Leave empty string if not clearly identifiable from the data."
name: Optional[str] = Field(
default=None,
description="Sector name. Leave empty string if not clearly identifiable from the data.",
)
@field_validator("name", mode="before")
@classmethod
def empty_string_to_none(cls, v):
"""Convert empty strings to None"""
if v == "" or (isinstance(v, str) and v.strip() == ""):
return None
return v
@field_validator("id", mode="before")
@classmethod
def zero_to_none(cls, v):
"""Convert 0 to None for optional id field"""
if v == 0:
return None
return v
class Config:
from_attributes = True
@@ -36,22 +55,45 @@ class InvestorMemberSchema(BaseModel):
Leave fields empty if uncertain about the member details.
"""
id: int = Field(
ge=0, description="Member ID, must be 0 or greater. Use 0 if uncertain."
id: Optional[int] = Field(
default=None,
ge=0,
description="Member ID, must be 0 or greater. Use 0 if uncertain.",
)
name: str = Field(
description="Team member name. Leave empty string if not clearly identifiable."
name: Optional[str] = Field(
default=None,
description="Team member name. Leave empty string if not clearly identifiable.",
)
role: str = Field(
description="Team member role/title. Leave empty string if not clearly identifiable."
role: Optional[str] = Field(
default=None,
description="Team member role/title. Leave empty string if not clearly identifiable.",
)
email: str = Field(
description="Team member email. Leave empty string if not clearly identifiable or not provided."
email: Optional[str] = Field(
default=None,
description="Team member email. Leave empty string if not clearly identifiable or not provided.",
)
investor_id: int = Field(
ge=0, description="Investor ID, must be 0 or greater. Use 0 if uncertain."
investor_id: Optional[int] = Field(
default=None,
ge=0,
description="Investor ID, must be 0 or greater. Use 0 if uncertain.",
)
@field_validator("name", "role", "email", mode="before")
@classmethod
def empty_string_to_none(cls, v):
"""Convert empty strings to None"""
if v == "" or (isinstance(v, str) and v.strip() == ""):
return None
return v
@field_validator("id", "investor_id", mode="before")
@classmethod
def zero_to_none(cls, v):
"""Convert 0 to None for optional integer fields"""
if v == 0:
return None
return v
class Config:
from_attributes = True
@@ -62,25 +104,45 @@ class CompanyMemberSchema(BaseModel):
Leave fields empty if uncertain about the member details.
"""
id: int = Field(
ge=0, description="Member ID, must be 0 or greater. Use 0 if uncertain."
id: Optional[int] = Field(
default=None,
ge=0,
description="Member ID, must be 0 or greater. Use 0 if uncertain.",
)
name: Optional[str] = Field(
default="",
default=None,
description="Company member name. Leave empty if not clearly identifiable.",
)
linkedin: Optional[str] = Field(
default="",
default=None,
description="LinkedIn profile URL. Leave empty if not provided or uncertain.",
)
role: Optional[str] = Field(
default="",
default=None,
description="Company member role/title. Leave empty if not clearly identifiable.",
)
company_id: int = Field(
ge=0, description="Company ID, must be 0 or greater. Use 0 if uncertain."
company_id: Optional[int] = Field(
default=None,
ge=0,
description="Company ID, must be 0 or greater. Use 0 if uncertain.",
)
@field_validator("name", "linkedin", "role", mode="before")
@classmethod
def empty_string_to_none(cls, v):
"""Convert empty strings to None"""
if v == "" or (isinstance(v, str) and v.strip() == ""):
return None
return v
@field_validator("id", "company_id", mode="before")
@classmethod
def zero_to_none(cls, v):
"""Convert 0 to None for optional integer fields"""
if v == 0:
return None
return v
class Config:
from_attributes = True
@@ -91,20 +153,25 @@ class CompanySchema(BaseModel):
Leave optional fields empty if uncertain. Integer values must be 0 or greater.
"""
id: int = Field(
ge=0, description="Company ID, must be 0 or greater. Use 0 if uncertain."
id: Optional[int] = Field(
default=None,
ge=0,
description="Company ID, must be 0 or greater. Use 0 if uncertain.",
)
name: str = Field(
description="Company name. Leave empty string if not clearly identifiable."
name: Optional[str] = Field(
default=None,
description="Company name. Leave empty string if not clearly identifiable.",
)
industry: str = Field(
description="Company industry/sector. Leave empty string if not clearly identifiable."
industry: Optional[str] = Field(
default=None,
description="Company industry/sector. Leave empty string if not clearly identifiable.",
)
location: str = Field(
description="Company location/address. Leave empty string if not clearly identifiable."
location: Optional[str] = Field(
default=None,
description="Company location/address. Leave empty string if not clearly identifiable.",
)
description: Optional[str] = Field(
default="",
default=None,
description="Company description. Leave empty if not clearly available or uncertain.",
)
founded_year: Optional[int] = Field(
@@ -113,10 +180,28 @@ class CompanySchema(BaseModel):
description="Year company was founded, must be 0 or greater. Leave None if not clearly identifiable or uncertain.",
)
website: Optional[str] = Field(
default="",
default=None,
description="Company website URL. Leave empty if not provided or uncertain.",
)
@field_validator(
"name", "industry", "location", "description", "website", mode="before"
)
@classmethod
def empty_string_to_none(cls, v):
"""Convert empty strings to None"""
if v == "" or (isinstance(v, str) and v.strip() == ""):
return None
return v
@field_validator("id", "founded_year", mode="before")
@classmethod
def zero_to_none(cls, v):
"""Convert 0 to None for founded_year"""
if v == 0:
return None
return v
@field_validator("founded_year", mode="before")
@classmethod
def validate_founded_year(cls, v):
@@ -141,40 +226,71 @@ class InvestorSchema(BaseModel):
Leave optional fields empty if uncertain. All numeric values must be 0 or greater.
"""
id: int = Field(
ge=0, description="Investor ID, must be 0 or greater. Use 0 if uncertain."
id: Optional[int] = Field(
default=None,
ge=0,
description="Investor ID, must be 0 or greater. Use 0 if uncertain.",
)
name: str = Field(
description="Investor name. Do not return any special characters, Just the name as a string."
name: Optional[str] = Field(
default=None,
description="Investor name. Do not return any special characters, Just the name as a string.",
)
description: Optional[str] = Field(
default="",
default=None,
description="Investor description. Leave empty if not clearly available or uncertain.",
)
aum: int | None = Field(
aum: Optional[int] = Field(
default=None,
ge=0,
description="Assets Under Management in USD, must be 0 or greater. Use 0 if not clearly identifiable or uncertain.",
)
check_size_lower: int | None = Field(
check_size_lower: Optional[int] = Field(
default=None,
ge=0,
description="Lower bound of typical investment check size in USD, must be 0 or greater. Use 0 if not clearly identifiable.",
)
check_size_upper: int | None = Field(
check_size_upper: Optional[int] = Field(
default=None,
ge=0,
description="Upper bound of typical investment check size in USD, must be 0 or greater. Use 0 if not clearly identifiable.",
)
geographic_focus: str | None = Field(
geographic_focus: Optional[str] = Field(
default=None,
description="Geographic investment focus. Do not return any special characters, Just locations separated by commas. Leave empty if not clearly identifiable.",
)
stage_focus: InvestmentStage = Field(
description="Investment stage focus. Use SEED as default if uncertain."
default=InvestmentStage.SEED,
description="Investment stage focus. Use SEED as default if uncertain.",
)
number_of_investments: int | None = Field(
number_of_investments: Optional[int] = Field(
default=None,
ge=0,
default=0,
description="Total number of investments made, must be 0 or greater. Use 0 if not clearly identifiable.",
)
@field_validator("name", "description", "geographic_focus", mode="before")
@classmethod
def empty_string_to_none(cls, v):
"""Convert empty strings to None"""
if v == "" or (isinstance(v, str) and v.strip() == ""):
return None
return v
@field_validator(
"id",
"aum",
"check_size_lower",
"check_size_upper",
"number_of_investments",
mode="before",
)
@classmethod
def zero_to_none(cls, v):
"""Convert 0 to None for optional integer fields"""
if v == 0:
return None
return v
class Config:
from_attributes = True
Binary file not shown.
Binary file not shown.
+1 -1
View File
@@ -21,7 +21,7 @@ class InvestorProcessor:
self.llm = ChatOpenAI(
api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1",
model="openai/gpt-5-nano",
model="openai/gpt-4o-mini",
temperature=0,
)
+2 -2
View File
@@ -21,8 +21,8 @@ class QueryProcessor:
self.llm = ChatOpenAI(
api_key=os.getenv("OPENROUTER_API_KEY"),
base_url="https://openrouter.ai/api/v1",
model="openai/gpt-5-nano",
temperature=0.3,
model="openai/gpt-4o-mini",
temperature=0,
)
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
# Update system message to specifically request only investor IDs