Refactor investor-related schemas and models; update database configuration and enhance investor processing logic

This commit is contained in:
bolade
2025-09-02 15:51:35 +01:00
parent 65b5df3a43
commit 7b58834316
15 changed files with 258 additions and 51 deletions
Binary file not shown.
+154 -39
View File
@@ -3,15 +3,18 @@ from typing import List, Optional
import chromadb
import pandas as pd
from db.models import InvestorTable
from db.models import CompanyTable, InvestorTable, InvestorTeamMember, SectorTable
from langchain_core.prompts import PromptTemplate
from langchain_openai import ChatOpenAI
from pydantic_schemas import Investor, InvestorList
from py_schemas import InvestorData
from pydantic import BaseModel
from settings import settings
# Add these imports for your databases
# from sqlalchemy.ext.asyncio import AsyncSession
# from your_vector_db import VectorDBClient
class InvestorList(BaseModel):
"""Schema for LLM structured output"""
investor_list: List[InvestorData]
class InvestorProcessor:
@@ -25,22 +28,27 @@ class InvestorProcessor:
Given the following CSV data rows:
{question}
For each row, extract and structure the following fields:
For each row, extract and structure the following fields for the investor:
- name: The investor's full name
- description: Description of the investor
- aum: Assets under management (as integer, use 0 if not available)
- check_size: Investment check size (as string)
- sector_focus: Sector focus (as string)
- stage_focus: Investment stage focus (as string)
- region: Geographic region (as string)
- investment_thesis: Investment thesis (as string)
- investor_description: Description of the investor (as string)
- check_size_lower: Lower bound of investment check size (as integer)
- check_size_upper: Upper bound of investment check size (as integer)
- geographic_focus: Geographic region focus
- stage_focus: Investment stage focus (must be one of: seed, series_a, series_b, series_c, growth, late_stage)
- number_of_investments: Number of investments made (default 0)
Also extract related data:
- portfolio_companies: List of companies they've invested in
- team_members: List of team members with name, role, email
- sectors: List of sectors they focus on
Important:
- If a field is not available in the data, use appropriate default values (empty string for text fields, 0 for numbers)
- Ensure all text fields are properly escaped and contain no control characters
- If a field is not available, use appropriate defaults
- stage_focus must be one of the valid enum values
- Return clean, valid JSON only
Return the data as a structured list of investors."""
Return the data as a structured list of comprehensive investor data."""
self.prompt = PromptTemplate(
template=self.template, input_variables=["question"]
@@ -49,7 +57,7 @@ Return the data as a structured list of investors."""
self.llm = ChatOpenAI(
api_key=settings.OPENROUTER_API_KEY,
base_url="https://openrouter.ai/api/v1",
model="openai/gpt-oss-120b:fre",
model="google/gemini-2.5-flash-lite",
temperature=0,
)
@@ -65,7 +73,9 @@ Return the data as a structured list of investors."""
},
)
async def _process_batch(self, batch: pd.DataFrame, batch_idx: int) -> List:
async def _process_batch(
self, batch: pd.DataFrame, batch_idx: int
) -> List[InvestorData]:
"""Process a single batch of data"""
# Convert batch to string representation - clean the data
batch_str = ""
@@ -102,25 +112,101 @@ Return the data as a structured list of investors."""
print(f"Error processing batch {batch_idx + 1}: {e}")
return []
async def _save_to_sql(self, investors: List[Investor]) -> None:
"""Save investors to SQL database"""
async def _save_to_sql(self, investor_data_list: List[InvestorData]) -> None:
"""Save investors and related data to SQL database"""
if not self.sql_session:
return
# Implement SQL saving logic here
for investor in investors:
db_investor = InvestorTable(
name=investor.name,
aum=investor.aum,
check_size=investor.check_size,
sector_focus=investor.sector_focus,
stage_focus=investor.stage_focus,
region=investor.region,
)
self.sql_session.add(db_investor)
self.sql_session.commit()
try:
for investor_data in investor_data_list:
# Save investor
db_investor = InvestorTable(
name=investor_data.investor.name,
description=investor_data.investor.description,
aum=investor_data.investor.aum,
check_size_lower=investor_data.investor.check_size_lower,
check_size_upper=investor_data.investor.check_size_upper,
geographic_focus=investor_data.investor.geographic_focus,
stage_focus=investor_data.investor.stage_focus,
number_of_investments=investor_data.investor.number_of_investments,
)
self.sql_session.add(db_investor)
self.sql_session.flush() # Get the ID
async def _save_to_vector_db(self, investors: List[Investor]) -> None:
# Save sectors and create associations
for sector_data in investor_data.sectors:
# Check if sector exists, create if not
existing_sector = (
self.sql_session.query(SectorTable)
.filter(SectorTable.name == sector_data.name)
.first()
)
if not existing_sector:
db_sector = SectorTable(name=sector_data.name)
self.sql_session.add(db_sector)
self.sql_session.flush()
# Add sector to investor's sectors
db_investor.sectors.append(db_sector)
else:
# Add existing sector to investor if not already there
if existing_sector not in db_investor.sectors:
db_investor.sectors.append(existing_sector)
# Save companies and create portfolio associations
for company_data in investor_data.portfolio_companies:
# Check if company exists, create if not
existing_company = (
self.sql_session.query(CompanyTable)
.filter(CompanyTable.name == company_data.name)
.first()
)
if not existing_company:
db_company = CompanyTable(
name=company_data.name,
industry=company_data.industry,
location=company_data.location,
founded_year=company_data.founded_year,
website=company_data.website,
)
self.sql_session.add(db_company)
self.sql_session.flush()
# Add to investor's portfolio
db_investor.portfolio_companies.append(db_company)
else:
# Add existing company to portfolio if not already there
if existing_company not in db_investor.portfolio_companies:
db_investor.portfolio_companies.append(existing_company)
# Save team members
for team_member_data in investor_data.team_members:
# Check if team member exists
existing_member = (
self.sql_session.query(InvestorTeamMember)
.filter(InvestorTeamMember.email == team_member_data.email)
.first()
)
if not existing_member:
db_team_member = InvestorTeamMember(
name=team_member_data.name,
role=team_member_data.role,
email=team_member_data.email,
investor_id=db_investor.id,
)
self.sql_session.add(db_team_member)
self.sql_session.commit()
print(f"Successfully saved {len(investor_data_list)} investors to database")
except Exception as e:
self.sql_session.rollback()
print(f"Error saving to SQL database: {e}")
raise
async def _save_to_vector_db(self, investor_data_list: List[InvestorData]) -> None:
"""Save investors to vector database"""
if not self.vector_db_client:
return
@@ -129,19 +215,47 @@ Return the data as a structured list of investors."""
metadatas = []
ids = []
for i, investor in enumerate(investors):
doc_text = f"{investor.investor_description}\nInvestment Thesis: {investor.investment_thesis}"
for i, investor_data in enumerate(investor_data_list):
investor = investor_data.investor
sectors = ", ".join([s.name for s in investor_data.sectors])
companies = ", ".join([c.name for c in investor_data.portfolio_companies])
doc_text = f"""
Investor: {investor.name}
Description: {investor.description or "N/A"}
AUM: ${investor.aum:,}
Check Size: ${investor.check_size_lower:,} - ${investor.check_size_upper:,}
Geographic Focus: {investor.geographic_focus}
Stage Focus: {investor.stage_focus.value}
Sectors: {sectors}
Portfolio Companies: {companies}
""".strip()
documents.append(doc_text)
metadatas.append({"name": investor.name})
ids.append(f"investor_{i}_{investor.name.replace(' ', '_')}")
metadatas.append(
{
"name": investor.name,
"stage_focus": investor.stage_focus.value,
"geographic_focus": investor.geographic_focus,
"aum": investor.aum,
}
)
ids.append(
f"investor_{i}_{investor.name.replace(' ', '_').replace('/', '_')}"
)
if documents:
# Use add method with proper parameters
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
try:
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
print(
f"Successfully saved {len(documents)} investors to vector database"
)
except Exception as e:
print(f"Error saving to vector database: {e}")
async def process_csv(
self, df: pd.DataFrame, batch_size: int = 10, max_concurrent: int = 10
) -> List:
) -> List[InvestorData]:
"""Process CSV data in parallel batches and save to databases"""
results = []
@@ -172,6 +286,7 @@ Return the data as a structured list of investors."""
# Save to databases
if results:
print(f"Successfully processed {len(results)} investors")
await self._save_to_sql(results)
await self._save_to_vector_db(results)
+2 -2
View File
@@ -6,7 +6,7 @@ from langchain_community.agent_toolkits import SQLDatabaseToolkit
from langchain_community.utilities import SQLDatabase
from langchain_openai import ChatOpenAI
from langgraph.prebuilt import create_react_agent
from pydantic_schemas import Investor, InvestorList
from py_schemas import InvestorList
from settings import settings
# Connect to SQLite
@@ -29,7 +29,7 @@ class QueryProcessor:
api_key=settings.OPENROUTER_API_KEY,
base_url="https://openrouter.ai/api/v1",
model="google/gemini-2.5-flash-lite",
temperature=0,
temperature=0.3,
)
self.toolkit = SQLDatabaseToolkit(db=db, llm=self.llm)
self.agent = create_react_agent(