Refactor investor-related schemas and models; update database configuration and enhance investor processing logic
This commit is contained in:
+154
-39
@@ -3,15 +3,18 @@ from typing import List, Optional
|
||||
|
||||
import chromadb
|
||||
import pandas as pd
|
||||
from db.models import InvestorTable
|
||||
from db.models import CompanyTable, InvestorTable, InvestorTeamMember, SectorTable
|
||||
from langchain_core.prompts import PromptTemplate
|
||||
from langchain_openai import ChatOpenAI
|
||||
from pydantic_schemas import Investor, InvestorList
|
||||
from py_schemas import InvestorData
|
||||
from pydantic import BaseModel
|
||||
from settings import settings
|
||||
|
||||
# Add these imports for your databases
|
||||
# from sqlalchemy.ext.asyncio import AsyncSession
|
||||
# from your_vector_db import VectorDBClient
|
||||
|
||||
class InvestorList(BaseModel):
|
||||
"""Schema for LLM structured output"""
|
||||
|
||||
investor_list: List[InvestorData]
|
||||
|
||||
|
||||
class InvestorProcessor:
|
||||
@@ -25,22 +28,27 @@ class InvestorProcessor:
|
||||
Given the following CSV data rows:
|
||||
{question}
|
||||
|
||||
For each row, extract and structure the following fields:
|
||||
For each row, extract and structure the following fields for the investor:
|
||||
- name: The investor's full name
|
||||
- description: Description of the investor
|
||||
- aum: Assets under management (as integer, use 0 if not available)
|
||||
- check_size: Investment check size (as string)
|
||||
- sector_focus: Sector focus (as string)
|
||||
- stage_focus: Investment stage focus (as string)
|
||||
- region: Geographic region (as string)
|
||||
- investment_thesis: Investment thesis (as string)
|
||||
- investor_description: Description of the investor (as string)
|
||||
- check_size_lower: Lower bound of investment check size (as integer)
|
||||
- check_size_upper: Upper bound of investment check size (as integer)
|
||||
- geographic_focus: Geographic region focus
|
||||
- stage_focus: Investment stage focus (must be one of: seed, series_a, series_b, series_c, growth, late_stage)
|
||||
- number_of_investments: Number of investments made (default 0)
|
||||
|
||||
Also extract related data:
|
||||
- portfolio_companies: List of companies they've invested in
|
||||
- team_members: List of team members with name, role, email
|
||||
- sectors: List of sectors they focus on
|
||||
|
||||
Important:
|
||||
- If a field is not available in the data, use appropriate default values (empty string for text fields, 0 for numbers)
|
||||
- Ensure all text fields are properly escaped and contain no control characters
|
||||
- If a field is not available, use appropriate defaults
|
||||
- stage_focus must be one of the valid enum values
|
||||
- Return clean, valid JSON only
|
||||
|
||||
Return the data as a structured list of investors."""
|
||||
Return the data as a structured list of comprehensive investor data."""
|
||||
|
||||
self.prompt = PromptTemplate(
|
||||
template=self.template, input_variables=["question"]
|
||||
@@ -49,7 +57,7 @@ Return the data as a structured list of investors."""
|
||||
self.llm = ChatOpenAI(
|
||||
api_key=settings.OPENROUTER_API_KEY,
|
||||
base_url="https://openrouter.ai/api/v1",
|
||||
model="openai/gpt-oss-120b:fre",
|
||||
model="google/gemini-2.5-flash-lite",
|
||||
temperature=0,
|
||||
)
|
||||
|
||||
@@ -65,7 +73,9 @@ Return the data as a structured list of investors."""
|
||||
},
|
||||
)
|
||||
|
||||
async def _process_batch(self, batch: pd.DataFrame, batch_idx: int) -> List:
|
||||
async def _process_batch(
|
||||
self, batch: pd.DataFrame, batch_idx: int
|
||||
) -> List[InvestorData]:
|
||||
"""Process a single batch of data"""
|
||||
# Convert batch to string representation - clean the data
|
||||
batch_str = ""
|
||||
@@ -102,25 +112,101 @@ Return the data as a structured list of investors."""
|
||||
print(f"Error processing batch {batch_idx + 1}: {e}")
|
||||
return []
|
||||
|
||||
async def _save_to_sql(self, investors: List[Investor]) -> None:
|
||||
"""Save investors to SQL database"""
|
||||
async def _save_to_sql(self, investor_data_list: List[InvestorData]) -> None:
|
||||
"""Save investors and related data to SQL database"""
|
||||
if not self.sql_session:
|
||||
return
|
||||
|
||||
# Implement SQL saving logic here
|
||||
for investor in investors:
|
||||
db_investor = InvestorTable(
|
||||
name=investor.name,
|
||||
aum=investor.aum,
|
||||
check_size=investor.check_size,
|
||||
sector_focus=investor.sector_focus,
|
||||
stage_focus=investor.stage_focus,
|
||||
region=investor.region,
|
||||
)
|
||||
self.sql_session.add(db_investor)
|
||||
self.sql_session.commit()
|
||||
try:
|
||||
for investor_data in investor_data_list:
|
||||
# Save investor
|
||||
db_investor = InvestorTable(
|
||||
name=investor_data.investor.name,
|
||||
description=investor_data.investor.description,
|
||||
aum=investor_data.investor.aum,
|
||||
check_size_lower=investor_data.investor.check_size_lower,
|
||||
check_size_upper=investor_data.investor.check_size_upper,
|
||||
geographic_focus=investor_data.investor.geographic_focus,
|
||||
stage_focus=investor_data.investor.stage_focus,
|
||||
number_of_investments=investor_data.investor.number_of_investments,
|
||||
)
|
||||
self.sql_session.add(db_investor)
|
||||
self.sql_session.flush() # Get the ID
|
||||
|
||||
async def _save_to_vector_db(self, investors: List[Investor]) -> None:
|
||||
# Save sectors and create associations
|
||||
for sector_data in investor_data.sectors:
|
||||
# Check if sector exists, create if not
|
||||
existing_sector = (
|
||||
self.sql_session.query(SectorTable)
|
||||
.filter(SectorTable.name == sector_data.name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_sector:
|
||||
db_sector = SectorTable(name=sector_data.name)
|
||||
self.sql_session.add(db_sector)
|
||||
self.sql_session.flush()
|
||||
# Add sector to investor's sectors
|
||||
db_investor.sectors.append(db_sector)
|
||||
else:
|
||||
# Add existing sector to investor if not already there
|
||||
if existing_sector not in db_investor.sectors:
|
||||
db_investor.sectors.append(existing_sector)
|
||||
|
||||
# Save companies and create portfolio associations
|
||||
for company_data in investor_data.portfolio_companies:
|
||||
# Check if company exists, create if not
|
||||
existing_company = (
|
||||
self.sql_session.query(CompanyTable)
|
||||
.filter(CompanyTable.name == company_data.name)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_company:
|
||||
db_company = CompanyTable(
|
||||
name=company_data.name,
|
||||
industry=company_data.industry,
|
||||
location=company_data.location,
|
||||
founded_year=company_data.founded_year,
|
||||
website=company_data.website,
|
||||
)
|
||||
self.sql_session.add(db_company)
|
||||
self.sql_session.flush()
|
||||
|
||||
# Add to investor's portfolio
|
||||
db_investor.portfolio_companies.append(db_company)
|
||||
else:
|
||||
# Add existing company to portfolio if not already there
|
||||
if existing_company not in db_investor.portfolio_companies:
|
||||
db_investor.portfolio_companies.append(existing_company)
|
||||
|
||||
# Save team members
|
||||
for team_member_data in investor_data.team_members:
|
||||
# Check if team member exists
|
||||
existing_member = (
|
||||
self.sql_session.query(InvestorTeamMember)
|
||||
.filter(InvestorTeamMember.email == team_member_data.email)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_member:
|
||||
db_team_member = InvestorTeamMember(
|
||||
name=team_member_data.name,
|
||||
role=team_member_data.role,
|
||||
email=team_member_data.email,
|
||||
investor_id=db_investor.id,
|
||||
)
|
||||
self.sql_session.add(db_team_member)
|
||||
|
||||
self.sql_session.commit()
|
||||
print(f"Successfully saved {len(investor_data_list)} investors to database")
|
||||
|
||||
except Exception as e:
|
||||
self.sql_session.rollback()
|
||||
print(f"Error saving to SQL database: {e}")
|
||||
raise
|
||||
|
||||
async def _save_to_vector_db(self, investor_data_list: List[InvestorData]) -> None:
|
||||
"""Save investors to vector database"""
|
||||
if not self.vector_db_client:
|
||||
return
|
||||
@@ -129,19 +215,47 @@ Return the data as a structured list of investors."""
|
||||
metadatas = []
|
||||
ids = []
|
||||
|
||||
for i, investor in enumerate(investors):
|
||||
doc_text = f"{investor.investor_description}\nInvestment Thesis: {investor.investment_thesis}"
|
||||
for i, investor_data in enumerate(investor_data_list):
|
||||
investor = investor_data.investor
|
||||
sectors = ", ".join([s.name for s in investor_data.sectors])
|
||||
companies = ", ".join([c.name for c in investor_data.portfolio_companies])
|
||||
|
||||
doc_text = f"""
|
||||
Investor: {investor.name}
|
||||
Description: {investor.description or "N/A"}
|
||||
AUM: ${investor.aum:,}
|
||||
Check Size: ${investor.check_size_lower:,} - ${investor.check_size_upper:,}
|
||||
Geographic Focus: {investor.geographic_focus}
|
||||
Stage Focus: {investor.stage_focus.value}
|
||||
Sectors: {sectors}
|
||||
Portfolio Companies: {companies}
|
||||
""".strip()
|
||||
|
||||
documents.append(doc_text)
|
||||
metadatas.append({"name": investor.name})
|
||||
ids.append(f"investor_{i}_{investor.name.replace(' ', '_')}")
|
||||
metadatas.append(
|
||||
{
|
||||
"name": investor.name,
|
||||
"stage_focus": investor.stage_focus.value,
|
||||
"geographic_focus": investor.geographic_focus,
|
||||
"aum": investor.aum,
|
||||
}
|
||||
)
|
||||
ids.append(
|
||||
f"investor_{i}_{investor.name.replace(' ', '_').replace('/', '_')}"
|
||||
)
|
||||
|
||||
if documents:
|
||||
# Use add method with proper parameters
|
||||
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
try:
|
||||
self.collection.add(documents=documents, metadatas=metadatas, ids=ids)
|
||||
print(
|
||||
f"Successfully saved {len(documents)} investors to vector database"
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error saving to vector database: {e}")
|
||||
|
||||
async def process_csv(
|
||||
self, df: pd.DataFrame, batch_size: int = 10, max_concurrent: int = 10
|
||||
) -> List:
|
||||
) -> List[InvestorData]:
|
||||
"""Process CSV data in parallel batches and save to databases"""
|
||||
results = []
|
||||
|
||||
@@ -172,6 +286,7 @@ Return the data as a structured list of investors."""
|
||||
|
||||
# Save to databases
|
||||
if results:
|
||||
print(f"Successfully processed {len(results)} investors")
|
||||
await self._save_to_sql(results)
|
||||
await self._save_to_vector_db(results)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user