Refactor investor-related schemas and models; update database configuration and enhance investor processing logic

This commit is contained in:
bolade
2025-09-02 15:51:35 +01:00
parent 65b5df3a43
commit 7b58834316
15 changed files with 258 additions and 51 deletions
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
+2 -2
View File
@@ -1,6 +1,6 @@
from fastapi.routing import apirouter from fastapi import APIRouter
router = apirouter() router = APIRouter()
@router.get("/investors") @router.get("/investors")
def read_investors(): def read_investors():
Binary file not shown.
Binary file not shown.
+1 -1
View File
@@ -9,7 +9,7 @@ from sqlalchemy.orm import Session, sessionmaker
Base = declarative_base() Base = declarative_base()
# Database configuration # Database configuration
DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///investors.db") DATABASE_URL = os.getenv("DATABASE_URL", "sqlite:///investors_2.db")
# Create engine # Create engine
engine = create_engine(DATABASE_URL, echo=False) engine = create_engine(DATABASE_URL, echo=False)
+18 -4
View File
@@ -33,6 +33,8 @@ investor_sector_association = Table(
Column("investor_id", Integer, ForeignKey("investors.id")), Column("investor_id", Integer, ForeignKey("investors.id")),
Column("sector_id", Integer, ForeignKey("sectors.id")), Column("sector_id", Integer, ForeignKey("sectors.id")),
) )
class InvestorTable(Base): class InvestorTable(Base):
__tablename__ = "investors" __tablename__ = "investors"
@@ -42,7 +44,7 @@ class InvestorTable(Base):
aum = Column(Integer, nullable=False) # Assets Under Management aum = Column(Integer, nullable=False) # Assets Under Management
check_size_lower = Column(Integer, nullable=False) # Lower bound check_size_lower = Column(Integer, nullable=False) # Lower bound
check_size_upper = Column(Integer, nullable=False) # Upper bound check_size_upper = Column(Integer, nullable=False) # Upper bound
geography = Column(String, nullable=False) geographic_focus = Column(String, nullable=False)
stage_focus = Column(Enum(InvestmentStage), nullable=False) stage_focus = Column(Enum(InvestmentStage), nullable=False)
number_of_investments = Column(Integer, default=0) number_of_investments = Column(Integer, default=0)
created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC)) created_at = Column(DateTime, default=datetime.datetime.now(datetime.UTC))
@@ -58,6 +60,12 @@ class InvestorTable(Base):
secondary=investor_company_association, secondary=investor_company_association,
back_populates="investors", back_populates="investors",
) )
team_members = relationship("InvestorTeamMember", back_populates="investor")
sectors = relationship(
"SectorTable",
secondary=investor_sector_association,
back_populates="investors",
)
class CompanyTable(Base): class CompanyTable(Base):
@@ -88,16 +96,22 @@ class SectorTable(Base):
__tablename__ = "sectors" __tablename__ = "sectors"
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
name = Column(String, unique=True, nullable=False) name = Column(String, nullable=False)
# Add relationship back to investors
investors = relationship(
"InvestorTable",
secondary=investor_sector_association,
back_populates="sectors",
)
class InvestorTeamMember(Base): class InvestorTeamMember(Base):
__tablename__ = "investor_team"
id = Column(Integer, primary_key=True, index=True) id = Column(Integer, primary_key=True, index=True)
name = Column(String, nullable=False) name = Column(String, nullable=False)
role = Column(String, nullable=False) role = Column(String, nullable=False)
email = Column(String, unique=True, nullable=False) email = Column(String, nullable=False)
investor_id = Column(Integer, ForeignKey("investors.id")) investor_id = Column(Integer, ForeignKey("investors.id"))
investor = relationship("InvestorTable", back_populates="team_members") investor = relationship("InvestorTable", back_populates="team_members")
+4 -3
View File
@@ -1,14 +1,15 @@
import io import io
from api import investors
import pandas as pd import pandas as pd
from db.db import db_dependency from api import investors
from db.db import db_dependency, init_database
from fastapi import FastAPI, File, UploadFile from fastapi import FastAPI, File, UploadFile
from services.openrouter import InvestorProcessor from services.openrouter import InvestorProcessor
from services.querying import QueryProcessor from services.querying import QueryProcessor
app = FastAPI() app = FastAPI()
app.include_router(investors.router) app.include_router(investors.router)
# init_database() init_database()
@app.get("/") @app.get("/")
+77
View File
@@ -0,0 +1,77 @@
from pydantic import BaseModel
from datetime import datetime
from typing import List, Optional
from enum import Enum
class InvestmentStage(str, Enum):
SEED = "seed"
SERIES_A = "series_a"
SERIES_B = "series_b"
SERIES_C = "series_c"
GROWTH = "growth"
LATE_STAGE = "late_stage"
class SectorSchema(BaseModel):
id: int
name: str
class Config:
from_attributes = True
class CompanySchema(BaseModel):
id: int
name: str
industry: str
location: str
founded_year: Optional[int]
website: Optional[str]
created_at: Optional[datetime]
updated_at: Optional[datetime]
class Config:
from_attributes = True
class InvestorTeamMemberSchema(BaseModel):
id: int
name: str
role: str
email: str
class Config:
from_attributes = True
class InvestorSchema(BaseModel):
id: int
name: str
description: Optional[str]
aum: int
check_size_lower: int
check_size_upper: int
geographic_focus: str
stage_focus: InvestmentStage
number_of_investments: int
created_at: Optional[datetime]
updated_at: Optional[datetime]
class Config:
from_attributes = True
class InvestorData(BaseModel):
"""Comprehensive investor data schema for LLM processing"""
investor: InvestorSchema
portfolio_companies: List[CompanySchema] = []
team_members: List[InvestorTeamMemberSchema] = []
sectors: List[SectorSchema] = []
class Config:
from_attributes = True
class InvestorList(BaseModel):
investors: List[InvestorData]
Binary file not shown.
+150 -35
View File
@@ -3,15 +3,18 @@ from typing import List, Optional
import chromadb import chromadb
import pandas as pd import pandas as pd
from db.models import InvestorTable from db.models import CompanyTable, InvestorTable, InvestorTeamMember, SectorTable
from langchain_core.prompts import PromptTemplate from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from pydantic_schemas import Investor, InvestorList from py_schemas import InvestorData
from pydantic import BaseModel
from settings import settings from settings import settings
# Add these imports for your databases
# from sqlalchemy.ext.asyncio import AsyncSession class InvestorList(BaseModel):
# from your_vector_db import VectorDBClient """Schema for LLM structured output"""
investor_list: List[InvestorData]
class InvestorProcessor: class InvestorProcessor:
@@ -25,22 +28,27 @@ class InvestorProcessor:
Given the following CSV data rows: Given the following CSV data rows:
{question} {question}
For each row, extract and structure the following fields: For each row, extract and structure the following fields for the investor:
- name: The investor's full name - name: The investor's full name
- description: Description of the investor
- aum: Assets under management (as integer, use 0 if not available) - aum: Assets under management (as integer, use 0 if not available)
- check_size: Investment check size (as string) - check_size_lower: Lower bound of investment check size (as integer)
- sector_focus: Sector focus (as string) - check_size_upper: Upper bound of investment check size (as integer)
- stage_focus: Investment stage focus (as string) - geographic_focus: Geographic region focus
- region: Geographic region (as string) - stage_focus: Investment stage focus (must be one of: seed, series_a, series_b, series_c, growth, late_stage)
- investment_thesis: Investment thesis (as string) - number_of_investments: Number of investments made (default 0)
- investor_description: Description of the investor (as string)
Also extract related data:
- portfolio_companies: List of companies they've invested in
- team_members: List of team members with name, role, email
- sectors: List of sectors they focus on
Important: Important:
- If a field is not available in the data, use appropriate default values (empty string for text fields, 0 for numbers) - If a field is not available, use appropriate defaults
- Ensure all text fields are properly escaped and contain no control characters - stage_focus must be one of the valid enum values
- Return clean, valid JSON only - Return clean, valid JSON only
Return the data as a structured list of investors.""" Return the data as a structured list of comprehensive investor data."""
self.prompt = PromptTemplate( self.prompt = PromptTemplate(
template=self.template, input_variables=["question"] template=self.template, input_variables=["question"]
@@ -49,7 +57,7 @@ Return the data as a structured list of investors."""
self.llm = ChatOpenAI( self.llm = ChatOpenAI(
api_key=settings.OPENROUTER_API_KEY, api_key=settings.OPENROUTER_API_KEY,
base_url="https://openrouter.ai/api/v1", base_url="https://openrouter.ai/api/v1",
model="openai/gpt-oss-120b:fre", model="google/gemini-2.5-flash-lite",
temperature=0, temperature=0,
) )
@@ -65,7 +73,9 @@ Return the data as a structured list of investors."""
}, },
) )
async def _process_batch(self, batch: pd.DataFrame, batch_idx: int) -> List: async def _process_batch(
self, batch: pd.DataFrame, batch_idx: int
) -> List[InvestorData]:
"""Process a single batch of data""" """Process a single batch of data"""
# Convert batch to string representation - clean the data # Convert batch to string representation - clean the data
batch_str = "" batch_str = ""
@@ -102,25 +112,101 @@ Return the data as a structured list of investors."""
print(f"Error processing batch {batch_idx + 1}: {e}") print(f"Error processing batch {batch_idx + 1}: {e}")
return [] return []
async def _save_to_sql(self, investors: List[Investor]) -> None: async def _save_to_sql(self, investor_data_list: List[InvestorData]) -> None:
"""Save investors to SQL database""" """Save investors and related data to SQL database"""
if not self.sql_session: if not self.sql_session:
return return
# Implement SQL saving logic here try:
for investor in investors: for investor_data in investor_data_list:
# Save investor
db_investor = InvestorTable( db_investor = InvestorTable(
name=investor.name, name=investor_data.investor.name,
aum=investor.aum, description=investor_data.investor.description,
check_size=investor.check_size, aum=investor_data.investor.aum,
sector_focus=investor.sector_focus, check_size_lower=investor_data.investor.check_size_lower,
stage_focus=investor.stage_focus, check_size_upper=investor_data.investor.check_size_upper,
region=investor.region, geographic_focus=investor_data.investor.geographic_focus,
stage_focus=investor_data.investor.stage_focus,
number_of_investments=investor_data.investor.number_of_investments,
) )
self.sql_session.add(db_investor) self.sql_session.add(db_investor)
self.sql_session.commit() self.sql_session.flush() # Get the ID
async def _save_to_vector_db(self, investors: List[Investor]) -> None: # Save sectors and create associations
for sector_data in investor_data.sectors:
# Check if sector exists, create if not
existing_sector = (
self.sql_session.query(SectorTable)
.filter(SectorTable.name == sector_data.name)
.first()
)
if not existing_sector:
db_sector = SectorTable(name=sector_data.name)
self.sql_session.add(db_sector)
self.sql_session.flush()
# Add sector to investor's sectors
db_investor.sectors.append(db_sector)
else:
# Add existing sector to investor if not already there
if existing_sector not in db_investor.sectors:
db_investor.sectors.append(existing_sector)
# Save companies and create portfolio associations
for company_data in investor_data.portfolio_companies:
# Check if company exists, create if not
existing_company = (
self.sql_session.query(CompanyTable)
.filter(CompanyTable.name == company_data.name)
.first()
)
if not existing_company:
db_company = CompanyTable(
name=company_data.name,
industry=company_data.industry,
location=company_data.location,
founded_year=company_data.founded_year,
website=company_data.website,
)
self.sql_session.add(db_company)
self.sql_session.flush()
# Add to investor's portfolio
db_investor.portfolio_companies.append(db_company)
else:
# Add existing company to portfolio if not already there
if existing_company not in db_investor.portfolio_companies:
db_investor.portfolio_companies.append(existing_company)
# Save team members
for team_member_data in investor_data.team_members:
# Check if team member exists
existing_member = (
self.sql_session.query(InvestorTeamMember)
.filter(InvestorTeamMember.email == team_member_data.email)
.first()
)
if not existing_member:
db_team_member = InvestorTeamMember(
name=team_member_data.name,
role=team_member_data.role,
email=team_member_data.email,
investor_id=db_investor.id,
)
self.sql_session.add(db_team_member)
self.sql_session.commit()
print(f"Successfully saved {len(investor_data_list)} investors to database")
except Exception as e:
self.sql_session.rollback()
print(f"Error saving to SQL database: {e}")
raise
async def _save_to_vector_db(self, investor_data_list: List[InvestorData]) -> None:
"""Save investors to vector database""" """Save investors to vector database"""
if not self.vector_db_client: if not self.vector_db_client:
return return
@@ -129,19 +215,47 @@ Return the data as a structured list of investors."""
metadatas = [] metadatas = []
ids = [] ids = []
for i, investor in enumerate(investors): for i, investor_data in enumerate(investor_data_list):
doc_text = f"{investor.investor_description}\nInvestment Thesis: {investor.investment_thesis}" investor = investor_data.investor
sectors = ", ".join([s.name for s in investor_data.sectors])
companies = ", ".join([c.name for c in investor_data.portfolio_companies])
doc_text = f"""
Investor: {investor.name}
Description: {investor.description or "N/A"}
AUM: ${investor.aum:,}
Check Size: ${investor.check_size_lower:,} - ${investor.check_size_upper:,}
Geographic Focus: {investor.geographic_focus}
Stage Focus: {investor.stage_focus.value}
Sectors: {sectors}
Portfolio Companies: {companies}
""".strip()
documents.append(doc_text) documents.append(doc_text)
metadatas.append({"name": investor.name}) metadatas.append(
ids.append(f"investor_{i}_{investor.name.replace(' ', '_')}") {
"name": investor.name,
"stage_focus": investor.stage_focus.value,
"geographic_focus": investor.geographic_focus,
"aum": investor.aum,
}
)
ids.append(
f"investor_{i}_{investor.name.replace(' ', '_').replace('/', '_')}"
)
if documents: if documents:
# Use add method with proper parameters try:
self.collection.add(documents=documents, metadatas=metadatas, ids=ids) self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
print(
f"Successfully saved {len(documents)} investors to vector database"
)
except Exception as e:
print(f"Error saving to vector database: {e}")
async def process_csv( async def process_csv(
self, df: pd.DataFrame, batch_size: int = 10, max_concurrent: int = 10 self, df: pd.DataFrame, batch_size: int = 10, max_concurrent: int = 10
) -> List: ) -> List[InvestorData]:
"""Process CSV data in parallel batches and save to databases""" """Process CSV data in parallel batches and save to databases"""
results = [] results = []
@@ -172,6 +286,7 @@ Return the data as a structured list of investors."""
# Save to databases # Save to databases
if results: if results:
print(f"Successfully processed {len(results)} investors")
await self._save_to_sql(results) await self._save_to_sql(results)
await self._save_to_vector_db(results) await self._save_to_vector_db(results)
+2 -2
View File
@@ -6,7 +6,7 @@ from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent from langgraph.prebuilt import create_react_agent
from pydantic_schemas import Investor, InvestorList from py_schemas import InvestorList
from settings import settings from settings import settings
# Connect to SQLite # Connect to SQLite
@@ -29,7 +29,7 @@ class QueryProcessor:
api_key=settings.OPENROUTER_API_KEY, api_key=settings.OPENROUTER_API_KEY,
base_url="https://openrouter.ai/api/v1", base_url="https://openrouter.ai/api/v1",
model="google/gemini-2.5-flash-lite", model="google/gemini-2.5-flash-lite",
temperature=0, temperature=0.3,
) )
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm) self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
self.agent = create_react_agent( self.agent = create_react_agent(